diff --git a/server/routes.go b/server/routes.go index 8eda5c73..cb46cef1 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1526,12 +1526,7 @@ func (s *Server) ChatHandler(c *gin.Context) { var toolParser *tools.Parser if len(req.Tools) > 0 { - toolParser, err = tools.NewParser(m.Template.Template) - if err != nil { - slog.Error("failed to create tool parser", "error", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) - return - } + toolParser = tools.NewParser(m.Template.Template, req.Tools) } ch := make(chan any) @@ -1584,6 +1579,7 @@ func (s *Server) ChatHandler(c *gin.Context) { // don't return } else { if r.Done { + res.Message.Content = toolParser.Content() ch <- res } return diff --git a/tools/template.go b/tools/template.go new file mode 100644 index 00000000..e22f0675 --- /dev/null +++ b/tools/template.go @@ -0,0 +1,156 @@ +package tools + +import ( + "bytes" + "log/slog" + "slices" + "strings" + "text/template" + "text/template/parse" +) + +// parseTag finds the tool calling tag from a Go template +// often [TOOL_CALL] or similar by finding the +// first text node after .ToolCalls and returning the content +// if no tag is found, return "{" to indicate that json objects +// should be attempted to be parsed as tool calls +func parseTag(tmpl *template.Template) string { + if tmpl == nil || tmpl.Tree == nil { + slog.Debug("template or tree is nil") + return "{" + } + + tc := findToolCallNode(tmpl.Tree.Root.Nodes) + if tc == nil { + return "{" + } + + tn := findTextNode(tc.List.Nodes) + if tn == nil { + return "{" + } + + tag := string(tn.Text) + tag = strings.ReplaceAll(tag, "\r\n", "\n") + + // avoid parsing { onwards as this may be a tool call + // however keep '{' as a prefix if there is no tag + // so that all json objects will be attempted to + // be parsed as tool calls + tag, _, _ = strings.Cut(tag, "{") + tag = strings.TrimSpace(tag) + if tag == "" { + tag = "{" + } + + return tag +} + +// findToolCallNode searches for and returns an IfNode with .ToolCalls +func findToolCallNode(nodes []parse.Node) *parse.IfNode { + isToolCallsNode := func(n *parse.IfNode) bool { + for _, cmd := range n.Pipe.Cmds { + for _, arg := range cmd.Args { + if field, ok := arg.(*parse.FieldNode); ok { + if slices.Contains(field.Ident, "ToolCalls") { + return true + } + } + } + } + return false + } + + for _, node := range nodes { + switch n := node.(type) { + case *parse.IfNode: + if isToolCallsNode(n) { + return n + } + // Recursively search in nested IfNodes + if result := findToolCallNode(n.List.Nodes); result != nil { + return result + } + if n.ElseList != nil { + if result := findToolCallNode(n.ElseList.Nodes); result != nil { + return result + } + } + case *parse.ListNode: + if result := findToolCallNode(n.Nodes); result != nil { + return result + } + case *parse.RangeNode: + if result := findToolCallNode(n.List.Nodes); result != nil { + return result + } + if n.ElseList != nil { + if result := findToolCallNode(n.ElseList.Nodes); result != nil { + return result + } + } + case *parse.WithNode: + if result := findToolCallNode(n.List.Nodes); result != nil { + return result + } + if n.ElseList != nil { + if result := findToolCallNode(n.ElseList.Nodes); result != nil { + return result + } + } + } + } + return nil +} + +// findTextNode does a depth-first search for the first text content in nodes, +// stopping at template constructs to avoid parsing text after the tool calls +func findTextNode(nodes []parse.Node) *parse.TextNode { + for _, node := range nodes { + switch n := node.(type) { + case *parse.TextNode: + // skip whitespace-only text nodes + if len(bytes.TrimSpace(n.Text)) == 0 { + continue + } + return n + case *parse.IfNode: + if text := findTextNode(n.List.Nodes); text != nil { + return text + } + if n.ElseList != nil { + if text := findTextNode(n.ElseList.Nodes); text != nil { + return text + } + } + return nil + case *parse.ListNode: + if text := findTextNode(n.Nodes); text != nil { + return text + } + case *parse.RangeNode: + if text := findTextNode(n.List.Nodes); text != nil { + return text + } + if n.ElseList != nil { + if text := findTextNode(n.ElseList.Nodes); text != nil { + return text + } + } + return nil + case *parse.WithNode: + if text := findTextNode(n.List.Nodes); text != nil { + return text + } + if n.ElseList != nil { + if text := findTextNode(n.ElseList.Nodes); text != nil { + return text + } + } + return nil + case *parse.ActionNode: + return nil + } + } + return nil +} diff --git a/tools/template_test.go b/tools/template_test.go new file mode 100644 index 00000000..970c0d59 --- /dev/null +++ b/tools/template_test.go @@ -0,0 +1,139 @@ +package tools + +import ( + "testing" + "text/template" +) + +func TestParseTag(t *testing.T) { + cases := []struct { + name string + template string + want string + }{ + { + name: "empty", + template: "", + want: "{", + }, + { + name: "no tag", + template: "{{if .ToolCalls}}{{end}}", + want: "{", + }, + { + name: "no tag with range", + template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}{{end}}", + want: "{", + }, + { + name: "tool call with json format", + template: "{{if .ToolCalls}}```json\n{{end}}", + want: "```json", + }, + { + name: "square brackets", + template: "{{if .ToolCalls}}[{{range .ToolCalls}}{{ . }}{{end}}]{{end}}", + want: "[", + }, + { + name: "square brackets with whitespace", + template: "{{if .ToolCalls}}\n [ {{range .ToolCalls}}{{ . }}{{end}}]{{end}}", + want: "[", + }, + { + name: "tailing ]", + template: "{{if .ToolCalls}}{{range .ToolCalls}}{{ . }}{{end}}]{{end}}", + want: "{", + }, + { + name: "whitespace only", + template: "{{if .ToolCalls}} {{range .ToolCalls}}{{ . }}{{end}}{{end}}", + want: "{", + }, + { + name: "whitespace only in range", + template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{{ . }}\n{{end}}{{end}}", + want: "{", + }, + { + name: "json objects", + template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`, + want: "{", + }, + { + name: "json objects with whitespace", + template: "{{if .ToolCalls}}{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "{", + }, + { + name: "json objects with CRLF", + template: "{{if .ToolCalls}}{{range .ToolCalls}}\r\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "{", + }, + { + name: "json objects with whitespace before and after range", + template: "{{if .ToolCalls}}\n{{range .ToolCalls}}\n{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}\r\n{{end}}\r\n{{end}}", + want: "{", + }, + { + name: "before and after range", + template: "{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>functionget_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>\n{{end}}<|tool▁calls▁end|>{{end}}", + want: "<|tool▁calls▁begin|>", + }, + { + name: "after range", + template: "{{if .ToolCalls}}{{range .ToolCalls}}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "", + }, + { + name: "after range with leading whitespace before range", + template: "{{if .ToolCalls}}\n{{range .ToolCalls}}{\"name\": \"{{ .Function.Name }}\", \"arguments\": {{ .Function.Arguments }}}{{end}}{{end}}", + want: "", + }, + { + name: "tool call in range with {", + template: `{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{end}}{{end}}`, + want: "", + }, + { + name: "tool call with multiple text nodes", + template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", + want: "First text", + }, + { + name: "action tag", + template: "{{if .ToolCalls}}Action: ```json{{end}}", + want: "Action: ```json", + }, + { + name: "incomplete functools bracket", + template: "{{if .ToolCalls}}functools[{{end}}", + want: "functools[", + }, + { + name: "uppercase tool call with incomplete bracket", + template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", + want: "[TOOL_CALL] [", + }, + { + name: "uppercase tool call with adjacent bracket", + template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", + want: "[TOOL_CALL][", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + tmpl, err := template.New("test").Parse(tc.template) + if err != nil && tc.template != "" { + t.Fatalf("failed to parse template: %v", err) + } + + got := parseTag(tmpl) + if got != tc.want { + t.Errorf("got text %q, want %q", got, tc.want) + } + }) + } +} diff --git a/tools/testdata/command-r-plus.gotmpl b/tools/testdata/command-r-plus.gotmpl deleted file mode 100644 index f30124e3..00000000 --- a/tools/testdata/command-r-plus.gotmpl +++ /dev/null @@ -1,67 +0,0 @@ -{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> -{{- if .Tools }}# Safety Preamble -The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. - -# System Preamble -## Basic Rules -You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. - -{{ if .System }}# User Preamble -{{ .System }} -{{- end }} - -## Available Tools -Here is a list of tools that you have available to you: -{{- range .Tools }} - -```python -def {{ .Function.Name }}( -{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]: - """{{ .Function.Description }} - -{{- if .Function.Parameters.Properties }} - - Args: -{{- range $name, $property := .Function.Parameters.Properties }} - {{ $name }} ({{ $property.Type }}): {{ $property.Description }} -{{- end }} -{{- end }} - """ - pass -``` -{{- end }} -{{- else if .System }}{{ .System }} -{{- end }}<|END_OF_TURN_TOKEN|> -{{- end }} -{{- range .Messages }} -{{- if eq .Role "system" }} -{{- continue }} -{{- end }}<|START_OF_TURN_TOKEN|> -{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }} -{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|> -{{- if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -Action: ```json -[ -{{- range .ToolCalls }} - { - "tool_name": "{{ .Function.Name }}", - "parameters": {{ .Function.Arguments }} - } -{{- end }} -]``` -{{ continue }} -{{ end }} -{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|> -{{ .Content }} -{{- end }}<|END_OF_TURN_TOKEN|> -{{- end }} -{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: -```json -[ - { - "tool_name": title of the tool in the specification, - "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters - } -]``` -{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/tools/testdata/command-r-plus.out b/tools/testdata/command-r-plus.out deleted file mode 100644 index 8193d40c..00000000 --- a/tools/testdata/command-r-plus.out +++ /dev/null @@ -1,39 +0,0 @@ -<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble -The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral. - -# System Preamble -## Basic Rules -You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions. - -# User Preamble -You are a knowledgeable assistant. You can answer questions and perform tasks. - -## Available Tools -Here is a list of tools that you have available to you: - -```python -def get_current_weather(format: string, location: string, ) -> List[Dict]: - """Get the current weather - - Args: - format (string): The temperature unit to use. Infer this from the user's location. - location (string): The city and state, e.g. San Francisco, CA - """ - pass -```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> -Action: ```json -[ - { - "tool_name": "get_current_weather", - "parameters": {"format":"celsius","location":"Paris, France"} - } -]``` -<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|> -22<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example: -```json -[ - { - "tool_name": title of the tool in the specification, - "parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters - } -]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file diff --git a/tools/testdata/firefunction.gotmpl b/tools/testdata/firefunction.gotmpl deleted file mode 100644 index 312be205..00000000 --- a/tools/testdata/firefunction.gotmpl +++ /dev/null @@ -1,31 +0,0 @@ -{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|> -{{- if .System }} -{{ .System }} -{{- end }} -In addition to plain text responses, you can chose to call one or more of the provided functions. - -Use the following rule to decide when to call a function: - * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so - * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls - -If you decide to call functions: - * prefix function calls with functools marker (no closing marker required) - * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] - * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 - * make sure you pick the right functions that match the user intent - -Available functions as JSON spec: -{{- if .Tools }} -{{ .Tools }} -{{- end }}<|eot_id|> -{{- end }} -{{- range .Messages }}<|start_header_id|> -{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }} -{{- end }}<|end_header_id|> -{{- if .Content }}{{ .Content }} -{{- else if .ToolCalls }} functools[ -{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }} -{{- end }}] -{{- end }}<|eot_id|> -{{- end }}<|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/tools/testdata/firefunction.out b/tools/testdata/firefunction.out deleted file mode 100644 index 144f5e42..00000000 --- a/tools/testdata/firefunction.out +++ /dev/null @@ -1,17 +0,0 @@ -<|start_header_id|>system<|end_header_id|> -You are a knowledgeable assistant. You can answer questions and perform tasks. -In addition to plain text responses, you can chose to call one or more of the provided functions. - -Use the following rule to decide when to call a function: - * if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so - * if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls - -If you decide to call functions: - * prefix function calls with functools marker (no closing marker required) - * all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...] - * follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples - * respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0 - * make sure you pick the right functions that match the user intent - -Available functions as JSON spec: -[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgeable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> \ No newline at end of file diff --git a/tools/testdata/llama3-groq-tool-use.gotmpl b/tools/testdata/llama3-groq-tool-use.gotmpl deleted file mode 100644 index 45e9b462..00000000 --- a/tools/testdata/llama3-groq-tool-use.gotmpl +++ /dev/null @@ -1,43 +0,0 @@ -{{- if .Messages }} -{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|> - -{{ .System }} -{{- if .Tools }} You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within XML tags as follows: - -{"name": ,"arguments": } - - -Here are the available tools: - -{{- range .Tools }} {{ .Function }} -{{- end }} -{{- end }} -{{- end }}<|eot_id|> -{{- range .Messages }} -{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|> - -{{ if eq .Role "user" }}{{ .Content }} -{{- else if eq .Role "assistant" }} -{{- if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{- end }} - -{{- end }} -{{- else if eq .Role "tool" }} -{{ .Content }} - -{{- end }}<|eot_id|> -{{- end }} -{{- end }}<|start_header_id|>assistant<|end_header_id|> - -{{ else }} -{{ if .System }}<|start_header_id|>system<|end_header_id|> - -{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|> - -{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|> - -{{ end }}{{ .Response }} -{{- if .Response }}<|eot_id|> -{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3-groq-tool-use.out b/tools/testdata/llama3-groq-tool-use.out deleted file mode 100644 index 912ad11c..00000000 --- a/tools/testdata/llama3-groq-tool-use.out +++ /dev/null @@ -1,24 +0,0 @@ -<|start_header_id|>system<|end_header_id|> - -You are a knowledgeable assistant. You can answer questions and perform tasks. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within XML tags as follows: - -{"name": ,"arguments": } - - -Here are the available tools: - {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} <|eot_id|><|start_header_id|>user<|end_header_id|> - -What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - - -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -<|eot_id|><|start_header_id|>tool<|end_header_id|> - - -22 -<|eot_id|><|start_header_id|>assistant<|end_header_id|> - -The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> - -What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - diff --git a/tools/testdata/llama3.2.gotmpl b/tools/testdata/llama3.2.gotmpl deleted file mode 100644 index b132423e..00000000 --- a/tools/testdata/llama3.2.gotmpl +++ /dev/null @@ -1,44 +0,0 @@ -<|start_header_id|>system<|end_header_id|> - -Cutting Knowledge Date: December 2023 - -{{ if .System }}{{ .System }} -{{- end }} -{{- if .Tools }}When you receive a tool call response, use the output to format an answer to the orginal user question. - -You are a helpful assistant with tool calling capabilities. -{{- end }}<|eot_id|> -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 }} -{{- if eq .Role "user" }}<|start_header_id|>user<|end_header_id|> -{{- if and $.Tools $last }} - -Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. - -Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. - -{{ range $.Tools }} -{{- . }} -{{ end }} -{{ .Content }}<|eot_id|> -{{- else }} - -{{ .Content }}<|eot_id|> -{{- end }}{{ if $last }}<|start_header_id|>assistant<|end_header_id|> - -{{ end }} -{{- else if eq .Role "assistant" }}<|start_header_id|>assistant<|end_header_id|> -{{- if .ToolCalls }} -{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }} -{{- else }} - -{{ .Content }} -{{- end }}{{ if not $last }}<|eot_id|>{{ end }} -{{- else if eq .Role "tool" }}<|start_header_id|>ipython<|end_header_id|> - -{{ .Content }}<|eot_id|>{{ if $last }}<|start_header_id|>assistant<|end_header_id|> - -{{ end }} -{{- end }} -{{- end }} \ No newline at end of file diff --git a/tools/testdata/llama3.2.out b/tools/testdata/llama3.2.out deleted file mode 100644 index a27c6eaf..00000000 --- a/tools/testdata/llama3.2.out +++ /dev/null @@ -1,24 +0,0 @@ -<|start_header_id|>system<|end_header_id|> - -Cutting Knowledge Date: December 2023 - -You are a knowledgeable assistant. You can answer questions and perform tasks.When you receive a tool call response, use the output to format an answer to the orginal user question. - -You are a helpful assistant with tool calling capabilities.<|eot_id|><|start_header_id|>user<|end_header_id|> - -What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - -{"name": "get_current_weather", "parameters": {"format":"celsius","location":"Paris, France"}}<|eot_id|><|start_header_id|>ipython<|end_header_id|> - -22<|eot_id|><|start_header_id|>assistant<|end_header_id|> - -The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|> - -Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. - -Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. - -{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - -What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|> - diff --git a/tools/testdata/messages.json b/tools/testdata/messages.json deleted file mode 100644 index 42de4711..00000000 --- a/tools/testdata/messages.json +++ /dev/null @@ -1,39 +0,0 @@ -[ - { - "role": "system", - "content": "You are a knowledgeable assistant. You can answer questions and perform tasks." - }, - { - "role": "user", - "content": "What's the weather like today in Paris?" - }, - { - "role": "assistant", - "tool_calls": [ - { - "id": "89a1e453-0bce-4de3-a456-c54bed09c520", - "type": "function", - "function": { - "name": "get_current_weather", - "arguments": { - "location": "Paris, France", - "format": "celsius" - } - } - } - ] - }, - { - "role": "tool", - "tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520", - "content": "22" - }, - { - "role": "assistant", - "content": "The current temperature in Paris, France is 22 degrees Celsius." - }, - { - "role": "user", - "content": "What's the weather like today in San Francisco and Toronto?" - } -] diff --git a/tools/testdata/mistral.gotmpl b/tools/testdata/mistral.gotmpl deleted file mode 100644 index b08d6c2c..00000000 --- a/tools/testdata/mistral.gotmpl +++ /dev/null @@ -1,15 +0,0 @@ -{{- range $index, $_ := .Messages }} -{{- if eq .Role "user" }} -{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS] -{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }} - -{{ end }}{{ .Content }}[/INST] -{{- else if eq .Role "assistant" }} -{{- if .Content }} {{ .Content }} -{{- else if .ToolCalls }}[TOOL_CALLS] [ -{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{- end }}] -{{- end }} -{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS] -{{- end }} -{{- end }} \ No newline at end of file diff --git a/tools/testdata/mistral.out b/tools/testdata/mistral.out deleted file mode 100644 index 6956e392..00000000 --- a/tools/testdata/mistral.out +++ /dev/null @@ -1,3 +0,0 @@ -[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}][TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgeable assistant. You can answer questions and perform tasks. - -What's the weather like today in San Francisco and Toronto?[/INST] \ No newline at end of file diff --git a/tools/testdata/nemotron.gotmpl b/tools/testdata/nemotron.gotmpl deleted file mode 100644 index 1b6b89ec..00000000 --- a/tools/testdata/nemotron.gotmpl +++ /dev/null @@ -1,33 +0,0 @@ -{{- if (or .Tools .System) }}System -{{ if .System }}{{ .System }} - - -{{ end }} -{{- if .Tools }} -{{- range .Tools }} {{ . }} {{ end }} - - -{{ end }} -{{- end }} -{{- range $i, $m := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}User -{{ .Content }} -{{- if $last }} -Assistant -{{- end }} -{{ else if eq .Role "tool" }}Tool -{{ .Content }} -{{- if $last }} -Assistant -{{- end }} -{{ else if eq .Role "assistant" }}Assistant -{{- if .ToolCalls }} -{{ range .ToolCalls }} {"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} {{ end }} -{{ else }} -{{ .Content }} -{{- if not $last }} -{{ end }} -{{- end }} -{{- end }} -{{- end }} \ No newline at end of file diff --git a/tools/testdata/nemotron.out b/tools/testdata/nemotron.out deleted file mode 100644 index 486889ca..00000000 --- a/tools/testdata/nemotron.out +++ /dev/null @@ -1,18 +0,0 @@ -System -You are a knowledgeable assistant. You can answer questions and perform tasks. - - - {"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - - -User -What's the weather like today in Paris? -Assistant - {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -Tool -22 -Assistant -The current temperature in Paris, France is 22 degrees Celsius. -User -What's the weather like today in San Francisco and Toronto? -Assistant diff --git a/tools/testdata/qwen2.5.gotmpl b/tools/testdata/qwen2.5.gotmpl deleted file mode 100644 index cbd7302c..00000000 --- a/tools/testdata/qwen2.5.gotmpl +++ /dev/null @@ -1,51 +0,0 @@ -{{- if .Suffix }}<|fim_prefix|>{{ .Prompt }}<|fim_suffix|>{{ .Suffix }}<|fim_middle|> -{{- else if .Messages }} -{{- if or .System .Tools }}<|im_start|>system -{{- if .System }} -{{ .System }} -{{- end }} -{{- if .Tools }} - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{{- range .Tools }} -{"type": "function", "function": {{ .Function }}} -{{- end }} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -{{- end }}<|im_end|> -{{ end }} -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}<|im_start|>user -{{ .Content }}<|im_end|> -{{ else if eq .Role "assistant" }}<|im_start|>assistant -{{ if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }} -{{- end }}{{ if not $last }}<|im_end|> -{{ end }} -{{- else if eq .Role "tool" }}<|im_start|>user - -{{ .Content }} -<|im_end|> -{{ end }} -{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant -{{ end }} -{{- end }} -{{- else }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }}{{ if .Prompt }}<|im_start|>user -{{ .Prompt }}<|im_end|> -{{ end }}<|im_start|>assistant -{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen2.5.out b/tools/testdata/qwen2.5.out deleted file mode 100644 index 76bfbfa9..00000000 --- a/tools/testdata/qwen2.5.out +++ /dev/null @@ -1,31 +0,0 @@ -<|im_start|>system -You are a knowledgeable assistant. You can answer questions and perform tasks. - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } -<|im_end|> -<|im_start|>user -What's the weather like today in Paris?<|im_end|> -<|im_start|>assistant - -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -<|im_end|> -<|im_start|>user - -22 -<|im_end|> -<|im_start|>assistant -The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> -<|im_start|>user -What's the weather like today in San Francisco and Toronto?<|im_end|> -<|im_start|>assistant diff --git a/tools/testdata/qwen3.gotmpl b/tools/testdata/qwen3.gotmpl deleted file mode 100644 index 26f6656f..00000000 --- a/tools/testdata/qwen3.gotmpl +++ /dev/null @@ -1,50 +0,0 @@ -{{- if .Messages }} -{{- if or .System .Tools }}<|im_start|>system -{{- if .System }} -{{ .System }} -{{- end }} -{{- if .Tools }} - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{{- range .Tools }} -{"type": "function", "function": {{ .Function }}} -{{- end }} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } - -{{- end }}<|im_end|> -{{ end }} -{{- range $i, $_ := .Messages }} -{{- $last := eq (len (slice $.Messages $i)) 1 -}} -{{- if eq .Role "user" }}<|im_start|>user -{{ .Content }}<|im_end|> -{{ else if eq .Role "assistant" }}<|im_start|>assistant -{{ if .Content }}{{ .Content }} -{{- else if .ToolCalls }} -{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }} -{{- end }}{{ if not $last }}<|im_end|> -{{ end }} -{{- else if eq .Role "tool" }}<|im_start|>user - -{{ .Content }} -<|im_end|> -{{ end }} -{{- if and (ne .Role "assistant") $last }}<|im_start|>assistant -{{ end }} -{{- end }} -{{- else }} -{{- if .System }}<|im_start|>system -{{ .System }}<|im_end|> -{{ end }}{{ if .Prompt }}<|im_start|>user -{{ .Prompt }}<|im_end|> -{{ end }}<|im_start|>assistant -{{ end }}{{ .Response }}{{ if .Response }}<|im_end|>{{ end }} \ No newline at end of file diff --git a/tools/testdata/qwen3.out b/tools/testdata/qwen3.out deleted file mode 100644 index 76bfbfa9..00000000 --- a/tools/testdata/qwen3.out +++ /dev/null @@ -1,31 +0,0 @@ -<|im_start|>system -You are a knowledgeable assistant. You can answer questions and perform tasks. - -# Tools - -You may call one or more functions to assist with the user query. - -You are provided with function signatures within XML tags: - -{"type": "function", "function": {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}} - - -For each function call, return a json object with function name and arguments within XML tags: - -{"name": , "arguments": } -<|im_end|> -<|im_start|>user -What's the weather like today in Paris?<|im_end|> -<|im_start|>assistant - -{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}} -<|im_end|> -<|im_start|>user - -22 -<|im_end|> -<|im_start|>assistant -The current temperature in Paris, France is 22 degrees Celsius.<|im_end|> -<|im_start|>user -What's the weather like today in San Francisco and Toronto?<|im_end|> -<|im_start|>assistant diff --git a/tools/testdata/tools.json b/tools/testdata/tools.json deleted file mode 100644 index edde4ae0..00000000 --- a/tools/testdata/tools.json +++ /dev/null @@ -1,30 +0,0 @@ -[ - { - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather", - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA" - }, - "format": { - "type": "string", - "enum": [ - "celsius", - "fahrenheit" - ], - "description": "The temperature unit to use. Infer this from the user's location." - } - }, - "required": [ - "location", - "format" - ] - } - } - } -] diff --git a/tools/testdata/xlam.gotmpl b/tools/testdata/xlam.gotmpl deleted file mode 100644 index 51513d69..00000000 --- a/tools/testdata/xlam.gotmpl +++ /dev/null @@ -1,45 +0,0 @@ -{{- if .System }}{{ .System }} -{{ end }} -{{- range $i, $_ := .Messages }} -{{- if eq .Role "user" }}### Instruction: -{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }} -[BEGIN OF TASK INSTRUCTION] -You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the functions can be used, point it out and refuse to answer. -If the given question lacks the parameters required by the function, also point it out. -[END OF TASK INSTRUCTION] - -[BEGIN OF AVAILABLE TOOLS] -{{ $.Tools }} -[END OF AVAILABLE TOOLS] - -[BEGIN OF FORMAT INSTRUCTION] -The output MUST strictly adhere to the following JSON format, and NO other text MUST be included. -The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'. -``` -{ - "tool_calls": [ - {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, - ... (more tool calls as required) - ] -} -``` -[END OF FORMAT INSTRUCTION] - -[BEGIN OF QUERY] -{{ .Content }} -[END OF QUERY] - - -{{ else }} -{{ .Content }} -{{ end }} -{{- else if .ToolCalls }}### Response: -{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]} -<|EOT|> -{{ else if eq .Role "assistant" }}### Response: -{{ .Content }} -<|EOT|> -{{ end }} -{{- end }}### Response: \ No newline at end of file diff --git a/tools/testdata/xlam.out b/tools/testdata/xlam.out deleted file mode 100644 index 5d806532..00000000 --- a/tools/testdata/xlam.out +++ /dev/null @@ -1,40 +0,0 @@ -You are a knowledgeable assistant. You can answer questions and perform tasks. -### Instruction: -What's the weather like today in Paris? -### Response: -{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]} -<|EOT|> -### Response: -The current temperature in Paris, France is 22 degrees Celsius. -<|EOT|> -### Instruction: -[BEGIN OF TASK INSTRUCTION] -You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the functions can be used, point it out and refuse to answer. -If the given question lacks the parameters required by the function, also point it out. -[END OF TASK INSTRUCTION] - -[BEGIN OF AVAILABLE TOOLS] -[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the user's location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}] -[END OF AVAILABLE TOOLS] - -[BEGIN OF FORMAT INSTRUCTION] -The output MUST strictly adhere to the following JSON format, and NO other text MUST be included. -The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'. -``` -{ - "tool_calls": [ - {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, - ... (more tool calls as required) - ] -} -``` -[END OF FORMAT INSTRUCTION] - -[BEGIN OF QUERY] -What's the weather like today in San Francisco and Toronto? -[END OF QUERY] - - -### Response: \ No newline at end of file diff --git a/tools/tools.go b/tools/tools.go index 914a5eaf..efeaeee0 100644 --- a/tools/tools.go +++ b/tools/tools.go @@ -1,253 +1,287 @@ package tools import ( + "bytes" "encoding/json" - "errors" - "log/slog" "strings" - gotmpl "text/template" + "text/template" "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" ) -var ( - errInvalidToolCall = errors.New("invalid tool call format") - errAccumulateMore = errors.New("need to accumulate more content") +type toolsState int + +const ( + toolsState_LookingForTag toolsState = iota + toolsState_ToolCalling + toolsState_Done ) type Parser struct { - greedyParseJSON bool - prefix string - prefixFound bool - tmpl gotmpl.Template - sb strings.Builder - index int - name string - arguments string + tag string + names []string + properties []string + + state toolsState + buffer []byte + n int } -// parseJSONToolCalls attempts to parse a JSON string into a slice of ToolCalls. -// -// Parameters: -// - s: The string to parse -// - name: The field name from template that identifies the tool call name -// - arguments: The field name from template that identifies the tool call arguments -// -// Returns: -// - []api.ToolCall: The parsed tool calls if successful -// - error: ErrAccumulateMore if braces unbalanced, ErrInvalidToolCall if invalid, or nil if successful -func parseJSONToolCalls(s string, name, arguments string, prefix string) ([]api.ToolCall, error) { - // Check for balanced braces before attempting to parse - braceCount := 0 - squareCount := 0 - startIndex := -1 - var rawToolCalls []string - s = strings.TrimSpace(s) - - // Only track these if we don't have a prefix as it will be cut off from the prefix. Also track in the parseLeadingJSON case. - trackSquareBrackets := prefix == "" || !strings.HasSuffix(prefix, "[") || strings.HasPrefix(s, "[") - for i, c := range s { - switch c { - case '{': - braceCount++ - if startIndex == -1 { - startIndex = i - } - case '}': - braceCount-- - if braceCount == 0 { - rawToolCalls = append(rawToolCalls, s[startIndex:i+1]) - startIndex = -1 - } - case '[': - if trackSquareBrackets { - squareCount++ - } - case ']': - if trackSquareBrackets { - squareCount-- - } - } - - // Negative means we have an extra closing brace/bracket - if braceCount < 0 || squareCount < 0 { - return nil, errInvalidToolCall - } - } - - // If braces/brackets aren't balanced, need more input - if braceCount > 0 || squareCount > 0 { - return nil, errAccumulateMore - } - - t := strings.TrimSpace(s) - if len(t) == 0 { - return nil, errAccumulateMore - } - // If the input is a single square bracket, it's not a valid tool call - if t[0] == '[' && len(t) == 1 { - return nil, errAccumulateMore - } - - // Attempt full unmarshal of the JSON - var toolCalls []api.ToolCall - for _, rawToolCall := range rawToolCalls { - var resp map[string]any - if err := json.Unmarshal([]byte(rawToolCall), &resp); err != nil { - continue - } - - // Collect nested objects that could contain tool calls - objs := collect(resp) - if len(objs) == 0 { - continue - } - - // Extract tool calls from objects - for _, kv := range objs { - n, nok := kv[name].(string) - a, aok := kv[arguments].(map[string]any) - if nok && aok { - toolCalls = append(toolCalls, api.ToolCall{ - Function: api.ToolCallFunction{ - Name: n, - Arguments: a, - }, - }) - } else { - slog.Debug("No valid tool call found in object.", "object", kv) - } - } - } - - // Valid JSON, no tool calls found - if len(toolCalls) == 0 { - slog.Debug("No valid tool calls found in any raw tool calls.", "rawToolCalls", rawToolCalls) - return nil, errInvalidToolCall - } - - return toolCalls, nil +// NewParser creates a new tool call parser from a model's chat +// template and a list of provided tools. +func NewParser(tmpl *template.Template, tools []api.Tool) *Parser { + return NewParserWithTag(tools, parseTag(tmpl)) } -// checkPrefix processes a string to find and handle a prefix pattern. -// -// Returns: -// - The processed string with prefix removed if found -// - error: ErrAccumulateMore if prefix is incomplete, or nil if successful -func (p *Parser) checkPrefix(s string) (string, error) { - if s == "" || p.prefix == "" { - return s, nil +func NewParserWithTag(tools []api.Tool, tag string) *Parser { + var p Parser + for _, t := range tools { + p.names = append(p.names, t.Function.Name) + for r := range t.Function.Parameters.Properties { + p.properties = append(p.properties, r) + } } - - // Check for prefix at start of string - if cut, hasPrefix := strings.CutPrefix(s, p.prefix); hasPrefix { - // Found prefix at start - accumulate for potential tool - p.prefixFound = true - return cut, nil - } - - // Check if prefix overlaps end of string - if idx := suffixOverlap(s, p.prefix); idx != -1 { - // Return everything except overlapping portion - p.sb.Reset() - p.sb.WriteString(s[idx:]) - return s[:idx], errAccumulateMore - } - - // Check if prefix appears in middle of string - if idx := strings.Index(s, p.prefix); idx != -1 { - // Save remainder starting at prefix for next pass - p.sb.Reset() - p.sb.WriteString(strings.TrimSpace(s[idx:])) - // Return everything before prefix - return s[:idx], errAccumulateMore - } - - // No partial prefix found - return s, nil + p.tag = tag + return &p } -// Add processes a string input to parse tool calls and content. -// It handles prefix detection and JSON parsing to extract tool calls. -// -// Returns: -// - tools: Any parsed tool calls -// - content: Non-tool call content -func (p *Parser) Add(s string) (tools []api.ToolCall, content string) { - p.sb.WriteString(s) - s = p.sb.String() - - // Check for prefix pattern in input - s, err := p.checkPrefix(s) - if err != nil { - // Need more input to complete prefix +// Add processes a string input to parse tool calls and content that +// should be sent back to the user. +func (p *Parser) Add(s string) (calls []api.ToolCall, content string) { + if p.state == toolsState_Done { return nil, s } - // Exit if prefix exists in template, greedy parsing is off, and prefix not found - if !p.greedyParseJSON && !p.prefixFound { - p.sb.Reset() - return nil, s + p.buffer = append(p.buffer, s...) + + if p.state == toolsState_LookingForTag { + i, found := p.findTag() + if i == -1 { + content = string(p.buffer) + p.buffer = []byte{} + } else { + content = string(p.buffer[:i]) + p.buffer = p.buffer[i:] + } + + // for models where { or [ are used as tool calling + // tags, we only support parsing tools if the first non- + // whitespace character is { or [ + if p.tag == "{" || p.tag == "[" { + if strings.TrimSpace(content) != "" { + p.state = toolsState_Done + return nil, content + string(p.buffer) + } + } + + if !found { + return nil, content + } + + p.state = toolsState_ToolCalling } - toolCalls, err := parseJSONToolCalls(s, p.name, p.arguments, p.prefix) - if err != nil { - if errors.Is(err, errAccumulateMore) { - return nil, "" + for { + call := p.parseToolCall() + if call == nil { + break } - p.sb.Reset() - // Only do greedy JSON parsing if there is no prefix from template - if p.prefix != "" { - p.greedyParseJSON = false - } - if p.index != 0 && p.prefix == "" { - return nil, "" - } - if p.prefixFound { - // Drop tokens since prefix was found - return nil, "" - } - return nil, s + + calls = append(calls, *call) } - for _, tc := range toolCalls { - tc.Function.Index = p.index - p.index++ + if p.done() { + p.state = toolsState_Done + content = string(p.buffer) + p.buffer = []byte{} } - p.sb.Reset() - return toolCalls, "" + return calls, content } -// NewParser creates a new tool call parser from a template. It extracts the tool call format, -// prefix, and field names from the template to use for parsing tool calls from model output. -// -// Returns an error if the template does not contain valid tool call formatting. -func NewParser(templateToProcess *gotmpl.Template) (*Parser, error) { - parsed, err := template.Parse(templateToProcess.Root.String()) - if err != nil { - return nil, err +// findTag searches the buffer to find and handle a tool calling tag +// returning true if the tag was found and false otherwise, and +// a string content signaling any content that should be sent back to the user +func (p *Parser) findTag() (int, bool) { + // First check for complete substring anywhere in s + if i := bytes.Index(p.buffer, []byte(p.tag)); i > -1 { + return i, true } - tt, err := toolTemplate(parsed) - if err != nil { - return nil, err + // Then check for partial suffix overlap + max := min(len(p.buffer), len(p.tag)) + for i := max; i > 0; i-- { + if bytes.HasSuffix(p.buffer, []byte(p.tag[:i])) { + return len(p.buffer) - i, false + } } - - tp := toolPrefix(templateToProcess) - - name, arguments, err := extractToolArgs(tt) - if err != nil { - return nil, err - } - - return &Parser{ - tmpl: *tt, - sb: strings.Builder{}, - prefix: tp, - greedyParseJSON: true, - name: name, - arguments: arguments, - }, nil + return -1, false +} + +// parseToolCall finds the next complete tool call in the buffer +// incrementing n and advancing the buffer. +func (p *Parser) parseToolCall() *api.ToolCall { + var name string + var args map[string]any + var end int = len(p.buffer) + + // find tool name + var i int + for _, n := range p.names { + if i = bytes.Index(p.buffer, []byte(n)); i != -1 { + if i+len(n) < end { + name = n + end = i + len(n) + } + } + } + + if name == "" { + return nil + } + + if args, i = p.findArguments(); args == nil { + return nil + } + + if i > end { + end = i + } + + tc := &api.ToolCall{ + Function: api.ToolCallFunction{ + Name: name, + Arguments: args, + Index: p.n, + }, + } + + p.n++ + p.buffer = p.buffer[end:] + return tc +} + +// findArguments returns the first object that appears to be +// arguments and the position where the arguments end, returning nil and 0 if +// an invalid JSON object or non-arguments object is found first +func (p *Parser) findArguments() (map[string]any, int) { + if len(p.buffer) == 0 { + return nil, 0 + } + + var braces int + var start int = -1 + var end int + var object []byte + + // find any outer json object + for i, c := range p.buffer { + if c == '{' { + braces++ + if start == -1 { + start = i + } + } + + if c == '}' { + braces-- + if braces == 0 && start != -1 { + end = i + 1 + object = p.buffer[start:end] + break + } + } + } + + if braces > 0 { + return nil, 0 + } + + var data map[string]any + + // not valid json + if err := json.Unmarshal(object, &data); err != nil { + return nil, 0 + } + + var find func(obj any) map[string]any + find = func(obj any) map[string]any { + switch v := obj.(type) { + case map[string]any: + // check if the object keys are valid tool properties + // TODO (jmorganca): check only sets of properties that + // go together instead of the entire set + for _, prop := range p.properties { + if _, exists := v[prop]; exists { + return v + } + } + + for _, value := range v { + if result := find(value); result != nil { + return result + } + } + case []any: + for _, item := range v { + if result := find(item); result != nil { + return result + } + } + } + + return nil + } + + result := find(data) + if result != nil { + return result, end + } + + return nil, 0 +} + +// done checks if the parser is done parsing by looking +// for closing tag. currently only } and ] are supported +// for closing tags as {} or [] pairs may not always +// represent tool calls and we need to send the content back +func (p *Parser) done() bool { + var open, close rune + switch p.tag { + case "{": + open, close = '{', '}' + case "[": + open, close = '[', ']' + default: + return false + } + + var count int + for _, c := range p.buffer { + if c == byte(open) { + count++ + } else if c == byte(close) { + count-- + if count == 0 { + return true + } + } + } + + return false +} + +// Content returns any remaining content that +// should be sent to the user. This should be the empty string +// string unless the tag is { or [ and a tool call was not found +func (p *Parser) Content() string { + if p.n > 0 { + return "" + } + + if p.tag == "{" || p.tag == "[" { + return string(p.buffer) + } + + return "" } diff --git a/tools/tools_test.go b/tools/tools_test.go index 5fee8f57..67864168 100644 --- a/tools/tools_test.go +++ b/tools/tools_test.go @@ -1,673 +1,805 @@ package tools import ( - "bytes" - "encoding/json" - "fmt" - "os" - "path/filepath" - "strings" "testing" + "text/template" "github.com/google/go-cmp/cmp" - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" ) -func readFile(t *testing.T, base, name string) *bytes.Buffer { - t.Helper() - - bts, err := os.ReadFile(filepath.Join(base, name)) +func TestParser(t *testing.T) { + qwen, err := template.New("qwen").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`) if err != nil { - t.Fatal(err) + t.Fatalf("Failed to parse template: %v", err) } - return bytes.NewBuffer(bts) -} + deepseek, err := template.New("deepseek").Parse("{{if .ToolCalls}}<|tool▁calls▁begin|>{{range .ToolCalls}}<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|>{{end}}<|tool▁calls▁end|><|end▁of▁sentence|>{{end}}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + json, err := template.New("json").Parse(`{{if .ToolCalls}}{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}{{end}}`) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + mistral, err := template.New("mistral").Parse(`{{if .ToolCalls}}[TOOL_CALLS] [{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}][/TOOL_CALLS]{{end}}`) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + list, err := template.New("list").Parse(`{{if .ToolCalls}}[{{range .ToolCalls}}{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}}]{{end}}`) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + tools := []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_temperature", + Description: "Retrieve the temperature for a given location", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "format": { + Type: api.PropertyType{"string"}, + Description: "The format to return the temperature in", + Enum: []any{"fahrenheit", "celsius"}, + }, + "city": { + Type: api.PropertyType{"string"}, + Description: "The city to get the temperature for", + }, + }, + }, + }, + }, + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_conditions", + Description: "Retrieve the current weather conditions for a given location", + Parameters: struct { + Type string `json:"type"` + Defs any `json:"$defs,omitempty"` + Items any `json:"items,omitempty"` + Required []string `json:"required"` + Properties map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + } `json:"properties"` + }{ + Type: "object", + Properties: map[string]struct { + Type api.PropertyType `json:"type"` + Items any `json:"items,omitempty"` + Description string `json:"description"` + Enum []any `json:"enum,omitempty"` + }{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The location to get the weather conditions for", + }, + }, + }, + }, + }, + } -func TestParseJSONToolCalls(t *testing.T) { tests := []struct { - name string - input string - nameField string - argsField string - wantToolCalls []api.ToolCall - wantErr error - prefix string + name string + inputs []string + tmpl *template.Template + content string + calls []api.ToolCall }{ { - name: "valid single tool call", - input: `{"name": "test_tool", "arguments": {"arg1": "value1"}}`, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "no tool calls - just text", + inputs: []string{"Hello, how can I help you today?"}, + content: "Hello, how can I help you today?", + tmpl: qwen, + calls: nil, + }, + { + name: "empty input", + inputs: []string{""}, + content: "", + tmpl: qwen, + calls: nil, + }, + { + name: "tool call", + inputs: []string{`{"name": "get_conditions", "arguments": {"location": "San Francisco"}}`}, + content: "", + tmpl: qwen, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "test_tool", - Arguments: map[string]any{ - "arg1": "value1", + Index: 0, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "San Francisco", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "incomplete JSON", - input: `{"name": "test_tool", "arguments": {"arg1": `, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errAccumulateMore, - prefix: "", - }, - { - name: "invalid JSON", - input: `not json at all`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errInvalidToolCall, - prefix: "", - }, - { - name: "missing required fields", - input: `{"other": "field"}`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errInvalidToolCall, - prefix: "", - }, - { - name: "multiple tool calls in array", - input: `[ - {"name": "tool1", "arguments": {"arg1": 1}}, - {"name": "tool2", "arguments": {"arg2": "value"}} - ]`, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "text before tool call", + inputs: []string{`Let me check the weather. {"name": "get_temperature", "arguments": {"city": "New York"}}`}, + content: "Let me check the weather. ", + tmpl: qwen, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": float64(1), - }, - }, - }, - { - Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "New York", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "multiple tool calls without array", - input: ` - {"name": "tool1", "arguments": {"arg1": 1}}, - {"name": "tool2", "arguments": {"arg2": "value"}} - `, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "two tool calls in a list", + inputs: []string{`[TOOL_CALLS] [{"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}, {"name": "get_conditions", "arguments": {"location": "Tokyo"}}][/TOOL_CALLS]`}, + content: "", + tmpl: mistral, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": float64(1), + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + "format": "fahrenheit", }, }, }, { Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "multiple tool calls with text after", - input: ` - {"name": "tool1", "arguments": {"arg1": 1}} text - {"name": "tool2", "arguments": {"arg2": "value"}} text - `, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "two tool calls", + inputs: []string{`Okay, let's call both tools! {"name": "get_temperature", "arguments": {"city": "London", "format": "fahrenheit"}}{"name": "get_conditions", "arguments": {"location": "Tokyo"}}`}, + content: "Okay, let's call both tools! ", + tmpl: qwen, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": float64(1), + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + "format": "fahrenheit", }, }, }, { Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", }, { - name: "second tool call in array", - input: ` - , {"name": "tool2", "arguments": {"arg2": "value"}} - `, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "deepseek", + inputs: []string{"Wait, I need to call a tool<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n```json\n{\"city\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"}, + content: "Wait, I need to call a tool", + tmpl: deepseek, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool2", - Arguments: map[string]any{ - "arg2": "value", + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", - }, - // a bad JSON would not return any tool calls or content as it would always accumulate more - { - name: "unbalanced square brackets", - input: `[{"name": "tool1", "arguments": {"arg1": [1, 2}]`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errAccumulateMore, - prefix: "", }, { - name: "incomplete square brackets", - input: `[{"name": "tool1", "arguments": {"arg1": [1, 2, 3`, - nameField: "name", - argsField: "arguments", - wantToolCalls: nil, - wantErr: errAccumulateMore, - prefix: "", - }, - { - name: "nested arrays in arguments", - input: `{"name": "tool1", "arguments": {"arg1": [1, 2, ["nested", "array"]]}}`, - nameField: "name", - argsField: "arguments", - wantToolCalls: []api.ToolCall{ + name: "deepseek incremental", + inputs: []string{ + "Wait", + ", I need", + " to call", + " a tool<|too", + "l▁calls▁begin", + "|>", + "<|tool▁call▁begin|>function<|tool▁sep|>get_temperature\n", + "```json\n", + "{\"city\": \"Tokyo\"}\n", + "```", + "<|tool▁c", "all▁end|>", + "<|tool▁calls▁end|>", + "<|end▁of▁sentence|>", + }, + content: "Wait, I need to call a tool", + tmpl: deepseek, + calls: []api.ToolCall{ { Function: api.ToolCallFunction{ - Name: "tool1", - Arguments: map[string]any{ - "arg1": []any{float64(1), float64(2), []any{"nested", "array"}}, + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "Tokyo", }, }, }, }, - wantErr: nil, - prefix: "", + }, + { + name: "json", + inputs: []string{ + "{", + "\"name\": \"get_temperature\",", + "\"arguments\": {", + "\"city\": \"Tokyo\"", + "}", + "}", + }, + content: "", + tmpl: json, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "Tokyo", + }, + }, + }, + }, + }, + { + name: "json maybe a tool call", + inputs: []string{ + "{", + "\"name\": \"get_temperature\",", + "\"arguments\": {", + }, + content: "", + tmpl: json, + calls: nil, + }, + { + name: "json not a tool call", + inputs: []string{ + "{", + "\"name\": \"search\", ", + "\"arguments\": {", + "\"query\": \"What is the capital of Canada?\"", + "}", + "}", + }, + content: "{\"name\": \"search\", \"arguments\": {\"query\": \"What is the capital of Canada?\"}}", + tmpl: json, + calls: nil, + }, + { + name: "json object followed by tool call", + inputs: []string{ + "{\"name\": \"jeff\"}", + "{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + }, + content: "{\"name\": \"jeff\"}{\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + tmpl: json, + }, + { + name: "json object followed by tool call split", + inputs: []string{ + "{\"name\": \"jeff\"} {", + "\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + }, + content: "{\"name\": \"jeff\"} {\"name\": \"get_conditions\", \"arguments\": {\"location\": \"San Francisco\"}}", + tmpl: json, + }, + { + name: "json code", + inputs: []string{ + "for { fmt.Println(\"hello\") }", + }, + content: "for { fmt.Println(\"hello\") }", + tmpl: json, + }, + { + name: "list multiple", + inputs: []string{ + "[", + "{", + "\"name\": \"get_temperature\", ", + "\"arguments\": {", + "\"city\": \"London\"", + "}", + "},", + "{", + "\"name\": \"get_conditions\", ", + "\"arguments\": {", + "\"location\": \"Tokyo\"", + "}", + "}]", + }, + content: "", + tmpl: list, + calls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Index: 0, + Name: "get_temperature", + Arguments: api.ToolCallFunctionArguments{ + "city": "London", + }, + }, + }, + { + Function: api.ToolCallFunction{ + Index: 1, + Name: "get_conditions", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo", + }, + }, + }, + }, + }, + { + name: "list partial", + inputs: []string{ + "[", + "{", + "\"name\": \"search\", ", + "\"arguments\": {", + "\"query\": \"What is the capital of Canada?\"", + "}", + "}", + }, + content: "", + tmpl: list, + calls: nil, + }, + { + name: "list not a tool call", + inputs: []string{ + "[special", + " del", + "ivery]", + }, + content: "[special delivery]", + tmpl: list, + calls: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotCalls, err := parseJSONToolCalls(tt.input, tt.nameField, tt.argsField, tt.prefix) + parser := NewParser(tt.tmpl, tools) - if err != tt.wantErr { - t.Errorf("parseJSONToolCalls() error = %v, want %v", err, tt.wantErr) + var calls []api.ToolCall + var content string + for _, input := range tt.inputs { + tcs, c := parser.Add(input) + calls = append(calls, tcs...) + content += c } - if len(gotCalls) != 0 && tt.wantErr != nil { - t.Errorf("parseJSONToolCalls() valid = %v, want %v", len(gotCalls) == 0, tt.wantErr == nil) + if content != tt.content { + t.Errorf("Expected content %q, got %q", tt.content, content) } - if diff := cmp.Diff(gotCalls, tt.wantToolCalls); diff != "" { - t.Errorf("parseJSONToolCalls() tool calls mismatch (-got +want):\n%s", diff) + if len(calls) != len(tt.calls) { + t.Fatalf("Expected %d tool calls, got %d", len(tt.calls), len(calls)) + } + + for i, want := range tt.calls { + if diff := cmp.Diff(calls[i], want); diff != "" { + t.Errorf("Tool call %d mismatch (-got +want):\n%s", i, diff) + } } }) } } -func TestParseToolCalls(t *testing.T) { - p := filepath.Join("testdata") - t1 := api.ToolCall{ - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "fahrenheit", - "location": "San Francisco, CA", - }, - }, - } - t2 := api.ToolCall{ - Function: api.ToolCallFunction{ - Name: "get_current_weather", - Arguments: api.ToolCallFunctionArguments{ - "format": "celsius", - "location": "Toronto, Canada", - }, - }, - } - - cases := []struct { - name string - model string - output string - expectedToolCall []api.ToolCall - expectedTokens string +func TestDone(t *testing.T) { + tests := []struct { + name string + tag string + buffer []byte + want bool }{ { - name: "mistral malformed json with tool calls prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_curren}]`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", + name: "empty", + tag: "", + buffer: []byte{}, + want: false, }, { - name: "mistral multiple tool calls without prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} ]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", + name: "empty", + tag: "", + buffer: []byte{}, + want: false, }, { - name: "mistral tool calls with text between no prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: `model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + name: "json open", + tag: "{", + buffer: []byte("{\"name\": \"get_weather\""), + want: false, }, { - name: "mistral valid json with tool calls prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", + name: "json closed", + tag: "{", + buffer: []byte("{\"name\": \"get_weather\"}"), + want: true, }, { - name: "mistral multiple tool calls with text between and prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] - model outputs more tokens here and then [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2, t1, t2}, - expectedTokens: "", + name: "json empty", + tag: "{", + buffer: []byte("{}"), + want: true, }, { - name: "mistral incomplete json with tool calls prefix", - model: "mistral", - output: `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: "", + name: "list open", + tag: "[", + buffer: []byte("[{\"name\": \"get_weather\""), + want: false, }, { - name: "mistral invalid tool call with explanatory text no prefix", - model: "mistral", - output: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: - - [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function: [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, + name: "list closed", + tag: "[", + buffer: []byte("[{\"name\": \"get_weather\"}]"), + want: true, }, { - name: "mistral tool calls without prefix", - model: "mistral", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "command r plus tool calls with json block format", - model: "command-r-plus", - output: "Action: ```json" + ` - [ - { - "tool_name": "get_current_weather", - "parameters": { - "format": "fahrenheit", - "location": "San Francisco, CA" - } - }, - { - "tool_name": "get_current_weather", - "parameters": { - "format": "celsius", - "location": "Toronto, Canada" - } - } - ] - ` + "```", - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "firefunction tool calls with functools prefix", - model: "firefunction", - output: ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "llama3 groq single tool call with xml tags", - model: "llama3-groq-tool-use", - output: ` - {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} - `, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", - }, - { - name: "xlam tool calls with wrapper object", - model: "xlam", - output: `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 single tool call with prefix", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", - }, - { - name: "qwen2.5 multiple tool calls with and without prefix", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, - expectedToolCall: []api.ToolCall{t1, t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 plain text response no tool calls", - model: "qwen2.5", - output: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - expectedToolCall: []api.ToolCall{}, - expectedTokens: "The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", - }, - { - name: "qwen2.5 tool calls with trailing text", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "some tokens after call", - }, - { - name: "qwen2.5 tool calls with initial text", - model: "qwen2.5", - output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - }, - { - name: "qwen2.5 tool calls with prefix and trailing text", - model: "qwen2.5", - output: ` [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after call`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 tool calls with prefix and initial text", - model: "qwen2.5", - output: `some tokens before call [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}, {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] `, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "some tokens before call", - }, - { - name: "qwen2.5 tool calls without and with prefix", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "qwen2.5 tool calls without and with prefix and text between", - model: "qwen2.5", - output: `{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} some tokens between {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}} some tokens after call`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "some tokens between", - }, - { - name: "qwen2.5 tool calls without prefix and invalid tool call with other tokens", - model: "qwen2.5", - output: `hi [{"options": "foo"}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `hi [{"options": "foo"}]`, - }, - { - name: "qwen2.5 tool calls with prefix and invalid tool call", - model: "qwen2.5", - output: ` [{"options": "foo"}] `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, - }, - { - name: "qwen3 tool call with think prefix and tool prefix (sent as a single token)", - model: "qwen3", - output: `Okay, let me think what tool we should use...{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "Okay, let me think what tool we should use...", - }, - { - name: "qwen3 tool call with think prefix, tool prefix, and whitespace (sent as separate tokens)", - model: "qwen3", - output: `Okay, let me think what tool we should use... { "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "Okay, let me think what tool we should use...", - }, - { - name: "qwen3 empty think prefix without tool prefix and invalid tool call", - model: "qwen3", - output: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ` {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - }, - { - name: "qwen3 empty think prefix with tool prefix and valid tool call", - model: "qwen3", - output: `{ "name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: ``, - }, - { - name: "qwen3 invalid tool call with fake tool prefix (single rune suffix match)", - model: "qwen3", - output: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `< fakeout {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} `, - }, - { - name: "qwen3 invalid tool call with partial tool prefix (multiple rune suffix match)", - model: "qwen3", - output: ``, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, - }, - { - name: "qwen3 invalid tool call with malformed tool prefix", - model: "qwen3", - output: ``, - expectedToolCall: []api.ToolCall{}, - expectedTokens: ``, - }, - { - name: "model with prefix in template, no prefix in output", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model with prefix in template, prefix in output", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output", - model: "llama3.2", - output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output, single tool call", - model: "llama3.2", - output: `{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}}`, - expectedToolCall: []api.ToolCall{t1}, - expectedTokens: "", - }, - { - name: "model without prefix in template, prefix in output, multiple tool calls in list", - model: "llama3.2", - output: ` [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: ``, - }, - { - name: "model without prefix in template, prefix in output, individual tool calls", - model: "llama3.2", - output: ` {"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: ``, - }, - { - name: "model with prefix in template, no prefix in output, tokens before", - model: "qwen2.5", - output: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `some tokens before [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, - }, - { - name: "model with prefix in template, prefix in output, tokens after", - model: "qwen2.5", - output: `[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output, tokens after", - model: "llama3.2", - output: `[{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "", - }, - { - name: "model without prefix in template, no prefix in output, tokens before", - model: "llama3.2", - output: `some tokens before [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}]`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: `some tokens before`, - }, - { - name: "model without prefix in template, prefix in output, tokens after", - model: "llama3.2", - output: ` - [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: ``, - }, - { - name: "model without without prefix, match all jsons", - model: "llama3.2", - output: `model outputs some text [{"name": "get_current_weather", "parameters": {"format":"fahrenheit","location":"San Francisco, CA"}} {"name": "get_current_weather", "parameters": {"format":"celsius","location":"Toronto, Canada"}}] some tokens after`, - expectedToolCall: []api.ToolCall{t1, t2}, - expectedTokens: "model outputs some text", - }, - { - name: "model flushes tokens if tool call doesn't match", - model: "llama3.2", - output: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `{ "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}`, - }, - { - name: "model flushes tokens if tool call doesn't match array", - model: "llama3.2", - output: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, - expectedToolCall: []api.ToolCall{}, - expectedTokens: `[ { "user": {"id": 12345, "name": "Alice", "preferences": {"theme": "dark", "notifications": true}, "stats": {"points": 987, "level": 42}}}]`, + name: "list empty", + tag: "[", + buffer: []byte("[]"), + want: true, }, } - var tools []api.Tool - if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil { - t.Fatal(err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Parser{ + tag: tt.tag, + buffer: tt.buffer, + } + got := parser.done() + if got != tt.want { + t.Errorf("done() = %t, want %t", got, tt.want) + } + }) + } +} + +func TestContent(t *testing.T) { + tests := []struct { + name string + tag string + content []byte + want string + n int + }{ + { + name: "empty", + content: []byte{}, + tag: "{", + want: "", + n: 0, + }, + { + name: "tag", + tag: "", + content: []byte("{\"name\": \"get_temperature\""), + want: "", + n: 0, + }, + { + name: "json object", + tag: "{", + content: []byte("{\"name\": \"get_temperature\"}"), + want: "{\"name\": \"get_temperature\"}", + n: 0, + }, + { + name: "json object after called", + tag: "{", + content: []byte("{\"hello\": \"world\"}"), + want: "{\"hello\": \"world\"}", + n: 0, + }, + { + name: "json object after called", + tag: "{", + content: []byte("{\"hello\": \"world\"}"), + want: "", + n: 1, + }, + { + name: "list", + tag: "[", + content: []byte("[{\"name\": \"get_temperature\"}]"), + want: "[{\"name\": \"get_temperature\"}]", + n: 0, + }, + { + name: "code", + tag: "{", + content: []byte("{ fmt.Println(\"hello\")"), + want: "{ fmt.Println(\"hello\")", + n: 0, + }, } - var messages []api.Message - if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil { - t.Fatal(err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := &Parser{ + tag: tt.tag, + buffer: tt.content, + n: tt.n, + } + got := parser.Content() + if got != tt.want { + t.Errorf("Content() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestFindTag(t *testing.T) { + cases := []struct { + name string + buffer []byte + tag string + i int + found bool + }{ + { + name: "no overlap", + buffer: []byte("hello world"), + tag: "", + i: -1, + found: false, + }, + { + name: "full overlap", + buffer: []byte(""), + tag: "", + i: 0, + found: true, + }, + { + name: "whitespace", + buffer: []byte(" \n {\"name\": \"bob\"}"), + tag: "", + i: 4, + found: true, + }, + { + name: "over", + buffer: []byte("{\"name\""), + tag: "", + i: 0, + found: true, + }, + { + name: "partial overlap", + buffer: []byte("text "), + tag: "", + i: 5, + found: true, + }, + { + name: "overlap with extra", + buffer: []byte(""), + tag: "", + i: 0, + found: true, + }, + { + name: "delimiter longer than string", + buffer: []byte(""), + tag: "", + i: -1, + found: false, + }, + { + name: "empty string", + buffer: []byte{}, + tag: "", + i: -1, + found: false, + }, + { + name: "single char overlap", + buffer: []byte("test<"), + tag: "", + i: 4, + found: false, + }, + { + name: "partial tool call", + buffer: []byte("hello ", + i: 6, + found: false, + }, + { + name: "square bracket", + buffer: []byte("calling tools: ["), + tag: "[", + i: 15, + found: true, + }, + { + name: "bracket", + buffer: []byte("{\"name\": \"bob\""), + tag: "{", + i: 0, + found: true, + }, + { + name: "bracket with whitespace", + buffer: []byte("\n\n{\n\"name\": \"bob\""), + tag: "{", + i: 2, + found: true, + }, } for _, tt := range cases { t.Run(tt.name, func(t *testing.T) { - tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String()) - if err != nil { - t.Fatal(err) + parser := &Parser{ + tag: tt.tag, + buffer: tt.buffer, + n: 0, + } + i, found := parser.findTag() + if i != tt.i { + t.Errorf("findTag(%q, %q) = %d; want %d", tt.buffer, tt.tag, i, tt.i) + } + if found != tt.found { + t.Errorf("findTag(%q, %q) = %t; want %t", tt.buffer, tt.tag, found, tt.found) + } + }) + } +} + +func TestFindArguments(t *testing.T) { + tests := []struct { + name string + buffer []byte + want map[string]any + }{ + { + name: "empty string", + buffer: []byte{}, + want: nil, + }, + { + name: "whitespace only", + buffer: []byte(" \n\t "), + want: nil, + }, + { + name: "unbalanced braces - missing closing", + buffer: []byte(`{"format": "fahrenheit", "location": "San Francisco"`), + want: nil, + }, + { + name: "unbalanced braces - extra closing", + buffer: []byte(`{"format": "fahrenheit"}}`), + want: map[string]any{ + "format": "fahrenheit", + }, + }, + { + name: "invalid JSON", + buffer: []byte(`{format: fahrenheit, location: "San Francisco"}`), + want: nil, + }, + { + name: "valid json", + buffer: []byte(`{"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "valid arguments with special tokens", + buffer: []byte(`[tool]get_temperature[args]{"format": "fahrenheit", "location": "San Francisco, CA"}[end]`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "valid arguments in array", + buffer: []byte(`[{"arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "nested deep", + buffer: []byte(`{"function": {"name": "get_temperature", "arguments": {"format": "fahrenheit", "location": "San Francisco, CA"}}}`), + want: map[string]any{ + "format": "fahrenheit", + "location": "San Francisco, CA", + }, + }, + { + name: "one arg", + buffer: []byte(`get_weather({"location": "San Francisco, CA"})`), + want: map[string]any{ + "location": "San Francisco, CA", + }, + }, + { + name: "two args", + buffer: []byte(`[{"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}, {"name": "get_weather", "arguments": {"location": "San Francisco, CA", "format": "fahrenheit"}}]`), + want: map[string]any{ + "location": "San Francisco, CA", + "format": "fahrenheit", + }, + }, + { + name: "deepseek", + buffer: []byte("<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{\"location\": \"Tokyo\"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>"), + want: map[string]any{ + "location": "Tokyo", + }, + }, + } + + for _, tt := range tests { + parser := &Parser{ + buffer: tt.buffer, + properties: []string{"format", "location"}, + } + + t.Run(tt.name, func(t *testing.T) { + got, _ := parser.findArguments() + + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff) } - - t.Run("template", func(t *testing.T) { - actual := &bytes.Buffer{} // Create new buffer for each test - if err := tmpl.Execute(actual, template.Values{Tools: tools, Messages: messages}); err != nil { - t.Fatal(err) - } - - if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" { - t.Errorf("mismatch (-got +want):\n%s", diff) - } - }) - - t.Run("parse", func(t *testing.T) { - tp, err := NewParser(tmpl.Template) - if err != nil { - t.Fatal(err) - } - got := []api.ToolCall{} - var gotTokens strings.Builder - - tokens := strings.Fields(tt.output) - for _, tok := range tokens { - s := " " + tok - - toolCalls, content := tp.Add(s) - if len(content) > 0 { - gotTokens.WriteString(content) - } else if len(toolCalls) > 0 { - got = append(got, toolCalls...) - } - } - - // Compare tool calls if we expect any - if diff := cmp.Diff(got, tt.expectedToolCall); diff != "" { - t.Errorf("tool calls mismatch (-got +want):\n%s", diff) - } - - // Compare tokens if we expect any - stripped := strings.TrimSpace(gotTokens.String()) - if diff := cmp.Diff(stripped, tt.expectedTokens); diff != "" { - t.Log("actualTokens", stripped, "expectedTokens", tt.expectedTokens) - t.Errorf("tokens mismatch (-got +want):\n%s", diff) - } - }) }) } } diff --git a/tools/tools_utils.go b/tools/tools_utils.go deleted file mode 100644 index b6f80729..00000000 --- a/tools/tools_utils.go +++ /dev/null @@ -1,222 +0,0 @@ -package tools - -import ( - "bytes" - "encoding/json" - "errors" - "log/slog" - "slices" - "strings" - gotmpl "text/template" - "text/template/parse" - - "github.com/ollama/ollama/api" - "github.com/ollama/ollama/template" -) - -// extractToolCallsFormat traverses a template AST to find text that follows a ".ToolCalls" condition. -// It walks the template nodes looking for if-statements containing ".ToolCalls" and extracts any -// immediate text nodes that follow. This is used to identify tool call prefixes and formatting. -// -// Returns: -// - string: The extracted text following the first ".ToolCalls" condition found -// - bool: Whether a ".ToolCalls" condition was found in the template -func extractToolCallsFormat(tmpl *gotmpl.Template) (string, bool) { - if tmpl == nil || tmpl.Tree == nil { - slog.Debug("template or tree is nil") - return "", false - } - - var result string - var found bool - - var walk func(nodes []parse.Node) - walk = func(nodes []parse.Node) { - for _, node := range nodes { - if found { - return - } - - switch n := node.(type) { - case *parse.IfNode: - if isToolCallsNode(n) { - // Collect immediate TextNode(s) at start of IfNode's list - var sb strings.Builder - for _, innerNode := range n.List.Nodes { - if tn, ok := innerNode.(*parse.TextNode); ok { - sb.Write(tn.Text) - } else { - // Stop at first non-text node - break - } - } - result = sb.String() - found = true - return - } - // Recurse into child nodes - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.ListNode: - walk(n.Nodes) - case *parse.RangeNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - case *parse.WithNode: - walk(n.List.Nodes) - if n.ElseList != nil { - walk(n.ElseList.Nodes) - } - default: - // Continue to next node - continue - } - } - } - - walk(tmpl.Tree.Root.Nodes) - return result, found -} - -// isToolCallsNode detects if a node's condition includes ".ToolCalls" -func isToolCallsNode(n *parse.IfNode) bool { - for _, cmd := range n.Pipe.Cmds { - for _, arg := range cmd.Args { - if field, ok := arg.(*parse.FieldNode); ok { - if slices.Contains(field.Ident, "ToolCalls") { - return true - } - } - } - } - return false -} - -func toolPrefix(tmpl *gotmpl.Template) string { - tokenText, ok := extractToolCallsFormat(tmpl) - if !ok { - return "" - } - tokenText = strings.TrimSpace(tokenText) - tokenText = strings.ReplaceAll(tokenText, "\r", "") - tokenText = strings.ReplaceAll(tokenText, "\n", " ") - - return tokenText -} - -// toolTemplate creates a subtree from the node that ranges over .ToolCalls -// -// Returns: -// - *gotmpl.Template: The subtree containing the .ToolCalls range -// - error: Error if parsing failed -func toolTemplate(t *template.Template) (*gotmpl.Template, error) { - tmpl := t.Subtree(func(n parse.Node) bool { - if t, ok := n.(*parse.RangeNode); ok { - return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls") - } - - return false - }) - - if tmpl == nil { - return nil, errors.New("failed to find tool template") - } - - return tmpl, nil -} - -// suffixOverlap returns the index in s where the longest suffix overlap with prefix begins -// -// Returns: -// - int: The starting index in s where the suffix overlap begins -func suffixOverlap(s, prefix string) int { - max := min(len(prefix), len(s)) - for i := max; i > 0; i-- { - if strings.HasSuffix(s, prefix[:i]) { - return len(s) - i - } - } - return -1 -} - -// extractToolArgs executes a template with a known tool call format to extract the name and arguments -// -// Returns: -// - string: The name of the tool call -// - string: The arguments of the tool call -// - error: Error if parsing failed -func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) { - var b bytes.Buffer - if err := tmpl.Execute(&b, map[string][]api.ToolCall{ - "ToolCalls": { - { - Function: api.ToolCallFunction{ - Name: "@@name@@", - Arguments: api.ToolCallFunctionArguments{ - "@@argument@@": 1, - }, - }, - }, - }, - }); err != nil { - return "", "", err - } - - // Extract JSON object between curly braces - // JSON arrays are also valid as they will not be repeated in the template - output := b.String() - start := strings.Index(output, "{") - end := strings.LastIndex(output, "}") - if start == -1 || end == -1 || start > end { - return "", "", errors.New("no valid JSON object found in template output") - } - jsonStr := output[start : end+1] - - var obj map[string]any - if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { - return "", "", err - } - - // Find name and arguments fields - for k, v := range obj { - if str, ok := v.(string); ok && str == "@@name@@" { - name = k - } else if _, ok := v.(map[string]any); ok { - arguments = k - } - } - - if name == "" || arguments == "" { - slog.Debug("missing required fields in tool call template", "name", name, "arguments", arguments) - return "", "", errors.New("missing required fields in tool call template") - } - - return name, arguments, nil -} - -// collect recursively traverses an object to collect all nested maps -// -// Returns: -// - []map[string]any: A slice of all nested maps found in the object -func collect(obj any) []map[string]any { - var all []map[string]any - switch o := obj.(type) { - case map[string]any: - all = append(all, o) - for _, v := range o { - all = append(all, collect(v)...) - } - case []any: - for _, v := range o { - all = append(all, collect(v)...) - } - default: - return nil - } - - return all -} diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go deleted file mode 100644 index e346117a..00000000 --- a/tools/tools_utils_test.go +++ /dev/null @@ -1,497 +0,0 @@ -package tools - -import ( - "testing" - gotmpl "text/template" - - "github.com/ollama/ollama/template" -) - -func TestExtractToolCallsFormat(t *testing.T) { - cases := []struct { - name string - template string - want string - found bool - }{ - { - name: "nil template", - template: "", - want: "", - found: false, - }, - { - name: "basic tool call with text", - template: "{{if .ToolCalls}}Hello world{{end}}", - want: "Hello world", - found: true, - }, - { - name: "tool call with json format", - template: "{{if .ToolCalls}}```json\n{{end}}", - want: "```json\n", - found: true, - }, - { - name: "tool call in range", - template: "{{range .ToolCalls}}tool: {{.}}{{end}}", - want: "", - found: false, - }, - { - name: "tool call with multiple text nodes", - template: "{{if .ToolCalls}}First text{{if .Something}}inner{{end}}Second text{{end}}", - want: "First text", - found: true, - }, - { - name: "nested if without tool calls", - template: "{{if .Something}}{{if .OtherThing}}text{{end}}{{end}}", - want: "", - found: false, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tc.template) - if err != nil && tc.template != "" { - t.Fatalf("failed to parse template: %v", err) - } - - got, found := extractToolCallsFormat(tmpl) - if got != tc.want { - t.Errorf("got text %q, want %q", got, tc.want) - } - if found != tc.found { - t.Errorf("got found %v, want %v", found, tc.found) - } - }) - } -} - -func TestToolPrefix(t *testing.T) { - cases := []struct { - name string - template string - want string - }{ - { - name: "basic tool call with action prefix", - template: "{{if .ToolCalls}}Action: ```json{{end}}", - want: "Action: ```json", - }, - { - name: "incomplete functools bracket", - template: "{{if .ToolCalls}}functools[{{end}}", - want: "functools[", - }, - { - name: "tool call with angle brackets", - template: "{{if .ToolCalls}}Hello, world! {{end}}", - want: "Hello, world! ", - }, - { - name: "multiple tool call formats", - template: "{{if .ToolCalls}}[tool_call] {{end}}", - want: "[tool_call] ", - }, - { - name: "single angle bracket tool call", - template: "{{if .ToolCalls}}{{end}}", - want: "", - }, - { - name: "incomplete angle bracket after tool call", - template: "{{if .ToolCalls}}[tool_call] <{{end}}", - want: "[tool_call] <", - }, - { - name: "angle bracket prefix with tool call", - template: "{{if .ToolCalls}}> {{end}}", - want: "> ", - }, - { - name: "uppercase tool call with incomplete bracket", - template: "{{if .ToolCalls}}[TOOL_CALL] [{{end}}", - want: "[TOOL_CALL] [", - }, - { - name: "uppercase tool call with adjacent bracket", - template: "{{if .ToolCalls}}[TOOL_CALL][{{end}}", - want: "[TOOL_CALL][", - }, - { - name: "tool call with pipe delimiters", - template: "{{if .ToolCalls}}<|tool_call|>{{end}}", - want: "<|tool_call|>", - }, - { - name: "tool with no prefix", - template: "{{if .ToolCalls}}{{end}}", - want: "", - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - got := toolPrefix(tmpl) - if got != tt.want { - t.Errorf("ToolToken(%q) = %q; want %q", tt.template, got, tt.want) - } - }) - } -} - -func TestToolTemplate(t *testing.T) { - cases := []struct { - name string - template string - want bool - }{ - { - name: "basic tool call range", - template: "{{range .ToolCalls}}test{{end}}", - want: true, - }, - { - name: "no tool calls", - template: "{{range .Other}}test{{end}}", - want: false, - }, - { - name: "nested tool calls", - template: "{{range .Outer}}{{range .ToolCalls}}test{{end}}{{end}}", - want: true, - }, - { - name: "empty template", - template: "", - want: false, - }, - { - name: "tool calls in if statement", - template: "{{if .ToolCalls}}test{{end}}", - want: false, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - - parsed, err := template.Parse(tmpl.Root.String()) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - - _, err = toolTemplate(parsed) - if err != nil && tt.want { - t.Errorf("toolTemplate() = %v; want %v", err, tt.want) - } - }) - } -} - -func TestSuffixOverlap(t *testing.T) { - cases := []struct { - name string - s string - d string - want int - }{ - { - name: "no overlap", - s: "hello world", - d: "", - want: -1, - }, - { - name: "full overlap", - s: "", - d: "", - want: 0, - }, - { - name: "partial overlap", - s: "text ", - d: "", - want: 5, - }, - { - name: "delimiter longer than string", - s: "", - d: "", - want: -1, - }, - { - name: "empty string", - s: "", - d: "", - want: -1, - }, - { - name: "empty delimiter", - s: "", - d: "", - want: -1, - }, - { - name: "single char overlap", - s: "test<", - d: "", - want: 4, - }, - { - name: "partial tool call", - s: "hello ", - want: 6, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - got := suffixOverlap(tt.s, tt.d) - if got != tt.want { - t.Errorf("suffixOverlap(%q, %q) = %d; want %d", tt.s, tt.d, got, tt.want) - } - }) - } -} - -func TestExtractToolArgs(t *testing.T) { - cases := []struct { - name string - template string - wantName string - wantArgs string - wantErr bool - }{ - { - name: "basic tool call", - template: `{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`, - wantName: "name", - wantArgs: "parameters", - wantErr: false, - }, - { - name: "tool call with whitespace", - template: `{{range .ToolCalls}} - {"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}} -{{end}}`, - wantName: "name", - wantArgs: "parameters", - wantErr: false, - }, - { - name: "tool call with extra content", - template: `Before {{range .ToolCalls}} -{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`, - wantName: "name", - wantArgs: "arguments", - wantErr: false, - }, - { - name: "no tool calls", - template: `{{if .Something}}no tools here{{end}}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "empty template", - template: ``, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "prefix within tool call", - template: `{{- if .ToolCalls }} -{{ range .ToolCalls }} - -{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} -{{ end }}{{- end }}`, - wantName: "name", - wantArgs: "arguments", - wantErr: false, - }, - { - name: "JSON array", - template: `{{ range .ToolCalls }} -[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`, - wantName: "name", - wantArgs: "arguments", - wantErr: false, - }, - { - name: "invalid JSON", - template: `{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "missing name field", - template: `{{ range .ToolCalls }} -{"parameters": {{ .Function.Arguments }}}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "missing arguments field", - template: `{{ range .ToolCalls }} -{"name": "{{ .Function.Name }}"}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - { - name: "malformed JSON", - template: `{{ range .ToolCalls }} -{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`, - wantName: "", - wantArgs: "", - wantErr: true, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - tmpl, err := gotmpl.New("test").Parse(tt.template) - if err != nil { - t.Fatalf("failed to parse template: %v", err) - } - - gotName, gotArgs, err := extractToolArgs(tmpl) - if (err != nil) != tt.wantErr { - t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr) - return - } - if err != nil { - return - } - - if gotName != tt.wantName { - t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName) - } - if gotArgs != tt.wantArgs { - t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs) - } - }) - } -} - -func TestCollect(t *testing.T) { - cases := []struct { - name string - obj any - want []map[string]any - }{ - { - name: "simple map", - obj: map[string]any{ - "key": "value", - }, - want: []map[string]any{ - {"key": "value"}, - }, - }, - { - name: "nested map", - obj: map[string]any{ - "outer": map[string]any{ - "inner": "value", - }, - }, - want: []map[string]any{ - {"outer": map[string]any{"inner": "value"}}, - {"inner": "value"}, - }, - }, - { - name: "array of maps", - obj: []any{ - map[string]any{"key1": "val1"}, - map[string]any{"key2": "val2"}, - }, - want: []map[string]any{ - {"key1": "val1"}, - {"key2": "val2"}, - }, - }, - { - name: "deeply nested", - obj: map[string]any{ - "l1": map[string]any{ - "l2": map[string]any{ - "l3": "value", - }, - }, - }, - want: []map[string]any{ - {"l1": map[string]any{"l2": map[string]any{"l3": "value"}}}, - {"l2": map[string]any{"l3": "value"}}, - {"l3": "value"}, - }, - }, - { - name: "non-map value", - obj: "string", - want: nil, - }, - } - - for _, tt := range cases { - t.Run(tt.name, func(t *testing.T) { - got := collect(tt.obj) - if len(got) != len(tt.want) { - t.Errorf("collect() got %d maps, want %d", len(got), len(tt.want)) - return - } - - // Compare each map in the result - for i := range tt.want { - if !mapsEqual(got[i], tt.want[i]) { - t.Errorf("collect() map[%d] = %v, want %v", i, got[i], tt.want[i]) - } - } - }) - } -} - -// mapsEqual compares two maps for deep equality -func mapsEqual(m1, m2 map[string]any) bool { - if len(m1) != len(m2) { - return false - } - for k, v1 := range m1 { - v2, ok := m2[k] - if !ok { - return false - } - switch val1 := v1.(type) { - case map[string]any: - val2, ok := v2.(map[string]any) - if !ok || !mapsEqual(val1, val2) { - return false - } - default: - if v1 != v2 { - return false - } - } - } - return true -}