server: collect nested tool call objects when parsing (#5824)

This commit is contained in:
Jeffrey Morgan
2024-07-22 12:38:03 -04:00
committed by GitHub
parent 80ee9b5e47
commit b3e5491e41
5 changed files with 120 additions and 13 deletions

View File

@@ -344,6 +344,10 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
}
}
if name == "" || arguments == "" {
return nil, false
}
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
@@ -361,23 +365,40 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
return nil, false
} else {
offset += int(decoder.InputOffset())
objs = append(objs, obj)
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
objs = append(objs, collect(obj)...)
}
}
var toolCalls []api.ToolCall
for _, kv := range objs {
var call api.ToolCall
for k, v := range kv {
switch k {
case name:
call.Function.Name = v.(string)
case arguments:
call.Function.Arguments = v.(map[string]any)
}
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
toolCalls = append(toolCalls, call)
}
return toolCalls, len(toolCalls) > 0