JSON mode: add `"format" as an api parameter (#1051)

* add `"format": "json"` as an API parameter
---------
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
This commit is contained in:
Jeffrey Morgan
2023-11-09 16:44:02 -08:00
committed by GitHub
parent 5b39503bcd
commit 5cba29b9d6
5 changed files with 97 additions and 9 deletions

View File

@@ -27,6 +27,34 @@ import (
"github.com/jmorganca/ollama/format"
)
const jsonGrammar = `
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
`
//go:embed llama.cpp/*/build/*/bin/*
var llamaCppEmbed embed.FS
@@ -497,7 +525,7 @@ type prediction struct {
const maxBufferSize = 512 * format.KiloByte
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, fn func(api.GenerateResponse)) error {
func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string, format string, fn func(api.GenerateResponse)) error {
prevConvo, err := llm.Decode(ctx, prevContext)
if err != nil {
return err
@@ -532,6 +560,10 @@ func (llm *llama) Predict(ctx context.Context, prevContext []int, prompt string,
"stop": llm.Stop,
}
if format == "json" {
request["grammar"] = jsonGrammar
}
// Handling JSON marshaling with special characters unescaped.
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)

View File

@@ -14,7 +14,7 @@ import (
)
type LLM interface {
Predict(context.Context, []int, string, func(api.GenerateResponse)) error
Predict(context.Context, []int, string, string, func(api.GenerateResponse)) error
Embedding(context.Context, string) ([]float64, error)
Encode(context.Context, string) ([]int, error)
Decode(context.Context, []int) (string, error)