fix system prompt (#5662)

* fix system prompt

* execute template when hitting previous roles

* fix tests

---------

Co-authored-by: jmorganca <jmorganca@gmail.com>
This commit is contained in:
Michael Yang
2024-07-12 21:04:44 -07:00
committed by GitHub
parent 23ebbaa46e
commit 22c5451fc2
3 changed files with 51 additions and 30 deletions

View File

@@ -149,27 +149,19 @@ type Values struct {
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
"Messages": messages,
})
}
system = ""
var b bytes.Buffer
var prompt, response string
for i, m := range collated {
switch m.Role {
case "system":
system = m.Content
case "user":
prompt = m.Content
case "assistant":
response = m.Content
}
if i != len(collated)-1 && prompt != "" && response != "" {
for _, m := range messages {
execute := func () error {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
@@ -181,6 +173,26 @@ func (t *Template) Execute(w io.Writer, v Values) error {
system = ""
prompt = ""
response = ""
return nil
}
switch m.Role {
case "system":
if prompt != "" || response != "" {
if err := execute(); err != nil {
return err
}
}
system = m.Content
case "user":
if response != "" {
if err := execute(); err != nil {
return err
}
}
prompt = m.Content
case "assistant":
response = m.Content
}
}
@@ -199,7 +211,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": "",
"System": system,
"Prompt": prompt,
}); err != nil {
return err