do no automatically aggregate system messages

This commit is contained in:
Michael Yang
2024-07-11 13:10:13 -07:00
parent 791650ddef
commit e64f9ebb44
2 changed files with 27 additions and 23 deletions

View File

@@ -102,8 +102,21 @@ var response = parse.ActionNode{
},
}
var funcs = template.FuncMap{
"aggregate": func(v []*api.Message, role string) string {
var aggregated []string
for _, m := range v {
if m.Role == role {
aggregated = append(aggregated, m.Content)
}
}
return strings.Join(aggregated, "\n\n")
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero")
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s)
if err != nil {
@@ -149,23 +162,21 @@ type Values struct {
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, collated := collate(v.Messages)
collated := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": collated,
})
}
var b bytes.Buffer
var prompt, response string
var system, prompt, response string
for i, m := range collated {
switch m.Role {
case "system":
system = m.Content
case "user":
prompt = m.Content
if i != 0 {
system = ""
}
case "assistant":
response = m.Content
}
@@ -179,6 +190,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}
system = ""
prompt = ""
response = ""
}
@@ -209,25 +221,14 @@ func (t *Template) Execute(w io.Writer, v Values) error {
return err
}
type messages []*api.Message
// collate messages based on role. consecutive messages of the same role are merged
// into a single message. collate also pulls out and merges messages with Role == "system"
// which are templated separately. As a side effect, it mangles message content adding image
// tags ([img-%d]) as needed
func collate(msgs []api.Message) (system string, collated messages) {
func collate(msgs []api.Message) (collated []*api.Message) {
var n int
for i := range msgs {
msg := msgs[i]
if msg.Role == "system" {
if system != "" {
system += "\n\n"
}
system += msg.Content
continue
}
for range msg.Images {
imageTag := fmt.Sprintf("[img-%d]", n)
if !strings.Contains(msg.Content, "[img]") {