post-response templating (#1427)

This commit is contained in:
Bruce MacDonald
2023-12-22 17:07:05 -05:00
committed by GitHub
parent b80081022f
commit db356c8519
3 changed files with 334 additions and 16 deletions

View File

@@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"text/template"
"text/template/parse"
"golang.org/x/exp/slices"
@@ -57,17 +58,35 @@ type PromptVars struct {
First bool
}
func (m *Model) Prompt(p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(m.Template)
// extractParts extracts the parts of the template before and after the {{.Response}} node.
func extractParts(tmplStr string) (pre string, post string, err error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", err
return "", "", err
}
if p.System == "" {
// use the default system message for this model if one is not specified
p.System = m.System
var foundResponse bool
for _, node := range tmpl.Tree.Root.Nodes {
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
foundResponse = true
}
if !foundResponse {
pre += node.String()
} else {
post += node.String()
}
}
return pre, post, nil
}
func Prompt(promptTemplate string, p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
if err != nil {
return "", err
}
vars := map[string]any{
@@ -82,20 +101,59 @@ func (m *Model) Prompt(p PromptVars) (string, error) {
return "", err
}
prompt.WriteString(sb.String())
prompt.WriteString(p.Response)
if !strings.Contains(prompt.String(), p.Response) {
// if the response is not in the prompt template, append it to the end
prompt.WriteString(p.Response)
}
return prompt.String(), nil
}
// PreResponsePrompt returns the prompt before the response tag
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
pre, _, err := extractParts(m.Template)
if err != nil {
return "", err
}
return Prompt(pre, p)
}
// PostResponseTemplate returns the template after the response tag
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
_, post, err := extractParts(m.Template)
if err != nil {
return "", err
}
if post == "" {
// if there is no post-response template, return the provided response
return p.Response, nil
}
return Prompt(post, p)
}
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages
var prompt strings.Builder
var currentImages []api.ImageData
currentVars := PromptVars{
First: true,
First: true,
System: m.System,
}
writePrompt := func() error {
p, err := m.Prompt(currentVars)
p, err := Prompt(m.Template, currentVars)
if err != nil {
return err
}
@@ -133,9 +191,11 @@ func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error)
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", nil, err
p, err := m.PreResponsePrompt(currentVars)
if err != nil {
return "", nil, fmt.Errorf("pre-response template: %w", err)
}
prompt.WriteString(p)
}
return prompt.String(), currentImages, nil