Fix issues with templating prompt in chat mode (#2460)

This commit is contained in:
Jeffrey Morgan
2024-02-12 15:06:57 -08:00
committed by GitHub
parent 939c60473f
commit 48a273f80b
7 changed files with 538 additions and 1013 deletions

View File

@@ -19,7 +19,6 @@ import (
"strconv"
"strings"
"text/template"
"text/template/parse"
"golang.org/x/exp/slices"
@@ -58,162 +57,6 @@ type Message struct {
Content string `json:"content"`
}
type PromptVars struct {
System string
Prompt string
Response string
First bool
Images []llm.ImageData
}
// extractParts extracts the parts of the template before and after the {{.Response}} node.
func extractParts(tmplStr string) (pre string, post string, err error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", "", err
}
var foundResponse bool
for _, node := range tmpl.Tree.Root.Nodes {
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
foundResponse = true
}
if !foundResponse {
pre += node.String()
} else {
post += node.String()
}
}
return pre, post, nil
}
func Prompt(promptTemplate string, p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
if err != nil {
return "", err
}
vars := map[string]any{
"System": p.System,
"Prompt": p.Prompt,
"Response": p.Response,
"First": p.First,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
if !strings.Contains(prompt.String(), p.Response) {
// if the response is not in the prompt template, append it to the end
prompt.WriteString(p.Response)
}
return prompt.String(), nil
}
// PreResponsePrompt returns the prompt before the response tag
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
pre, _, err := extractParts(m.Template)
if err != nil {
return "", err
}
return Prompt(pre, p)
}
// PostResponseTemplate returns the template after the response tag
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
_, post, err := extractParts(m.Template)
if err != nil {
return "", err
}
if post == "" {
// if there is no post-response template, return the provided response
return p.Response, nil
}
return Prompt(post, p)
}
type ChatHistory struct {
Prompts []PromptVars
LastSystem string
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages
lastSystem := m.System
currentVars := PromptVars{
First: true,
System: m.System,
}
prompts := []PromptVars{}
var images []llm.ImageData
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
// if this is the first message it overrides the system prompt in the modelfile
if !currentVars.First && currentVars.System != "" {
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}
currentVars.System = msg.Content
lastSystem = msg.Content
case "user":
if currentVars.Prompt != "" {
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
}
currentVars.Prompt = msg.Content
if len(m.ProjectorPaths) > 0 {
for i := range msg.Images {
id := len(images) + i
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
currentVars.Images = append(currentVars.Images, llm.ImageData{
ID: id,
Data: msg.Images[i],
})
}
images = append(images, currentVars.Images...)
}
case "assistant":
currentVars.Response = msg.Content
prompts = append(prompts, currentVars)
currentVars = PromptVars{}
default:
return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
prompts = append(prompts, currentVars)
}
return &ChatHistory{
Prompts: prompts,
LastSystem: lastSystem,
}, nil
}
type ManifestV2 struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`