diff --git a/tools/tools_utils.go b/tools/tools_utils.go index 48531b78..b6f80729 100644 --- a/tools/tools_utils.go +++ b/tools/tools_utils.go @@ -166,31 +166,26 @@ func extractToolArgs(tmpl *gotmpl.Template) (name, arguments string, err error) return "", "", err } - var obj any - err = json.Unmarshal(b.Bytes(), &obj) - if err != nil { + // Extract JSON object between curly braces + // JSON arrays are also valid as they will not be repeated in the template + output := b.String() + start := strings.Index(output, "{") + end := strings.LastIndex(output, "}") + if start == -1 || end == -1 || start > end { + return "", "", errors.New("no valid JSON object found in template output") + } + jsonStr := output[start : end+1] + + var obj map[string]any + if err := json.Unmarshal([]byte(jsonStr), &obj); err != nil { return "", "", err } - var objs []map[string]any - switch v := obj.(type) { - case map[string]any: - objs = []map[string]any{v} - case []map[string]any: - objs = v - case []any: - objs = collect(v) - } - if len(objs) == 0 { - return "", "", errors.New("no template objects found") - } - - // find the keys that correspond to the name and arguments fields - for k, v := range objs[0] { - switch v.(type) { - case string: + // Find name and arguments fields + for k, v := range obj { + if str, ok := v.(string); ok && str == "@@name@@" { name = k - case map[string]any: + } else if _, ok := v.(map[string]any); ok { arguments = k } } diff --git a/tools/tools_utils_test.go b/tools/tools_utils_test.go index 769183b7..e346117a 100644 --- a/tools/tools_utils_test.go +++ b/tools/tools_utils_test.go @@ -271,74 +271,99 @@ func TestExtractToolArgs(t *testing.T) { cases := []struct { name string template string - want string - ok bool + wantName string + wantArgs string + wantErr bool }{ { - name: "basic tool call with text after", - template: `{{if .ToolCalls}}tool response{{end}}`, - want: "tool response", - ok: true, + name: "basic tool call", + template: `{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "parameters": {{ .Function.Arguments }}}{{ end }}`, + wantName: "name", + wantArgs: "parameters", + wantErr: false, }, { - name: "tool call with mixed content after", - template: `{{if .ToolCalls}}{{.Something}}{{end}}`, - want: "", - ok: true, + name: "tool call with whitespace", + template: `{{range .ToolCalls}} + {"name": "{{.Function.Name}}", "parameters": {{.Function.Arguments}}} +{{end}}`, + wantName: "name", + wantArgs: "parameters", + wantErr: false, }, { - name: "tool call with no text after", - template: `{{if .ToolCalls}}{{.Something}}{{end}}`, - want: "", - ok: true, - }, - { - name: "nested tool call", - template: `{{if .Something}}{{if .ToolCalls}}[TOOL_CALL]{{end}}{{end}}`, - want: "[TOOL_CALL]", - ok: true, + name: "tool call with extra content", + template: `Before {{range .ToolCalls}} +{"name": "{{.Function.Name}}", "arguments": {{.Function.Arguments}}}{{end}} After`, + wantName: "name", + wantArgs: "arguments", + wantErr: false, }, { name: "no tool calls", template: `{{if .Something}}no tools here{{end}}`, - want: "", - ok: false, + wantName: "", + wantArgs: "", + wantErr: true, }, { name: "empty template", template: ``, - want: "", - ok: false, + wantName: "", + wantArgs: "", + wantErr: true, }, { - name: "multiple tool calls sections", - template: `{{if .ToolCalls}}first{{end}}{{if .ToolCalls}}second{{end}}`, - want: "first", - ok: true, + name: "prefix within tool call", + template: `{{- if .ToolCalls }} +{{ range .ToolCalls }} + +{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}} +{{ end }}{{- end }}`, + wantName: "name", + wantArgs: "arguments", + wantErr: false, }, { - name: "range over tool calls", - template: `{{if .ToolCalls}}{{range .ToolCalls}}tool{{end}}{{end}}`, - want: "", - ok: true, + name: "JSON array", + template: `{{ range .ToolCalls }} +[{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}]{{ end }}`, + wantName: "name", + wantArgs: "arguments", + wantErr: false, }, { - name: "tool calls with pipe delimiters", - template: `{{if .ToolCalls}}<|tool|>{{end}}`, - want: "<|tool|>", - ok: true, + name: "invalid JSON", + template: `{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}, invalid}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, }, { - name: "tool calls with nested template", - template: `{{if .ToolCalls}}{{template "tool" .}}{{end}}`, - want: "", - ok: true, + name: "missing name field", + template: `{{ range .ToolCalls }} +{"parameters": {{ .Function.Arguments }}}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, }, { - name: "tool calls with whitespace variations", - template: `{{if .ToolCalls}} tool {{end}}`, - want: " tool ", - ok: true, + name: "missing arguments field", + template: `{{ range .ToolCalls }} +{"name": "{{ .Function.Name }}"}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, + }, + { + name: "malformed JSON", + template: `{{ range .ToolCalls }} +{"name": {{ .Function.Name }}, "arguments": {{ .Function.Arguments }}{{ end }}`, + wantName: "", + wantArgs: "", + wantErr: true, }, } @@ -349,12 +374,20 @@ func TestExtractToolArgs(t *testing.T) { t.Fatalf("failed to parse template: %v", err) } - got, ok := extractToolCallsFormat(tmpl) - if got != tt.want { - t.Errorf("TextAfterToolCalls() got = %q, want %q", got, tt.want) + gotName, gotArgs, err := extractToolArgs(tmpl) + if (err != nil) != tt.wantErr { + t.Errorf("extractToolArgs() error = %v, wantErr %v", err, tt.wantErr) + return } - if ok != tt.ok { - t.Errorf("TextAfterToolCalls() ok = %v, want %v", ok, tt.ok) + if err != nil { + return + } + + if gotName != tt.wantName { + t.Errorf("extractToolArgs() gotName = %q, want %q", gotName, tt.wantName) + } + if gotArgs != tt.wantArgs { + t.Errorf("extractToolArgs() gotArgs = %q, want %q", gotArgs, tt.wantArgs) } }) }