This commit is contained in:
Michael Yang
2024-06-20 13:45:47 -07:00
parent 9e35d9bbee
commit d02bbebb11
7 changed files with 263 additions and 53 deletions

View File

@@ -13,6 +13,7 @@ import (
"sync"
"text/template"
"text/template/parse"
"time"
"github.com/agnivade/levenshtein"
"github.com/ollama/ollama/api"
@@ -102,8 +103,18 @@ var response = parse.ActionNode{
},
}
var funcs = template.FuncMap{
"json": func(v any) string {
b, _ := json.Marshal(v)
return string(b)
},
"now": func() string {
return time.Now().Format("2006-01-02 15:04:05")
},
}
func Parse(s string) (*Template, error) {
tmpl := template.New("").Option("missingkey=zero")
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tmpl, err := tmpl.Parse(s)
if err != nil {
@@ -127,7 +138,7 @@ func (t *Template) Vars() []string {
var vars []string
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, parseNode(n)...)
vars = append(vars, Identifiers(n)...)
}
}
@@ -143,17 +154,65 @@ func (t *Template) Vars() []string {
type Values struct {
Messages []api.Message
Tools []api.Tool
// forceLegacy is a flag used to test compatibility with legacy templates
forceLegacy bool
}
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
return n
}
switch t := n.(type) {
case *parse.ListNode:
for _, c := range t.Nodes {
if n := walk(c); n != nil {
return n
}
}
case *parse.BranchNode:
for _, n := range []*parse.ListNode{t.List, t.ElseList} {
if n != nil {
if n := walk(n); n != nil {
return n
}
}
}
case *parse.IfNode:
return walk(&t.BranchNode)
case *parse.WithNode:
return walk(&t.BranchNode)
case *parse.RangeNode:
return walk(&t.BranchNode)
}
return nil
}
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}).Funcs(funcs)
}
return nil
}
func (t *Template) Execute(w io.Writer, v Values) error {
system, messages := collate(v.Messages)
if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
})
}
@@ -161,7 +220,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var b bytes.Buffer
var prompt, response string
for _, m := range messages {
execute := func () error {
execute := func() error {
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
@@ -198,12 +257,8 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var cut bool
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
switch t := n.(type) {
case *parse.ActionNode:
case *parse.FieldNode:
if slices.Contains(t.Ident, "Response") {
cut = true
}
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true
}
return cut
@@ -255,50 +310,46 @@ func collate(msgs []api.Message) (string, []*api.Message) {
return strings.Join(system, "\n\n"), collated
}
func parseNode(n parse.Node) []string {
// Identifiers walks the node tree returning any identifiers it finds along the way
func Identifiers(n parse.Node) []string {
switch n := n.(type) {
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, Identifiers(n)...)
}
return names
case *parse.TemplateNode:
return Identifiers(n.Pipe)
case *parse.ActionNode:
return parseNode(n.Pipe)
return Identifiers(n.Pipe)
case *parse.BranchNode:
names := Identifiers(n.Pipe)
for _, n := range []*parse.ListNode{n.List, n.ElseList} {
if n != nil {
names = append(names, Identifiers(n)...)
}
}
return names
case *parse.IfNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
return Identifiers(&n.BranchNode)
case *parse.RangeNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
return Identifiers(&n.BranchNode)
case *parse.WithNode:
names := parseNode(n.Pipe)
names = append(names, parseNode(n.List)...)
if n.ElseList != nil {
names = append(names, parseNode(n.ElseList)...)
}
return names
return Identifiers(&n.BranchNode)
case *parse.PipeNode:
var names []string
for _, c := range n.Cmds {
for _, a := range c.Args {
names = append(names, parseNode(a)...)
names = append(names, Identifiers(a)...)
}
}
return names
case *parse.ListNode:
var names []string
for _, n := range n.Nodes {
names = append(names, parseNode(n)...)
}
return names
case *parse.FieldNode:
return n.Ident
case *parse.TemplateNode:
return parseNode(n.Pipe)
case *parse.VariableNode:
return n.Ident
}
return nil