allow specifying zero values in modelfile

This commit is contained in:
Bruce MacDonald
2023-08-02 17:07:53 -04:00
committed by GitHub
5 changed files with 101 additions and 25 deletions

View File

@@ -32,8 +32,8 @@ type Model struct {
ModelPath string
Template string
System string
Digest string
Options api.Options
Digest string
Options map[string]interface{}
}
func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
@@ -135,7 +135,7 @@ func GetModel(name string) (*Model, error) {
}
model := &Model{
Name: mp.GetFullTagname(),
Name: mp.GetFullTagname(),
Digest: manifest.Config.Digest,
}
@@ -176,12 +176,10 @@ func GetModel(name string) (*Model, error) {
}
defer params.Close()
var opts api.Options
if err = json.NewDecoder(params).Decode(&opts); err != nil {
// parse model options parameters into a map so that we can see which fields have been specified explicitly
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err
}
model.Options = opts
}
}
@@ -442,11 +440,13 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil
}
// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
opts := api.DefaultOptions()
typeOpts := reflect.TypeOf(opts)
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
@@ -455,7 +455,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
}
}
valueOpts := reflect.ValueOf(&opts).Elem()
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
@@ -468,25 +468,26 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
return nil, fmt.Errorf("invalid float value %s", vals)
}
field.SetFloat(floatVal)
out[key] = floatVal
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 0)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
field.SetInt(intVal)
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
field.SetBool(boolVal)
out[key] = boolVal
case reflect.String:
field.SetString(vals[0])
out[key] = vals[0]
case reflect.Slice:
field.Set(reflect.ValueOf(vals))
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
@@ -494,7 +495,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
}
}
bts, err := json.Marshal(opts)
bts, err := json.Marshal(out)
if err != nil {
return nil, err
}

View File

@@ -15,7 +15,6 @@ import (
"sync"
"time"
"dario.cat/mergo"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
@@ -61,12 +60,13 @@ func GenerateHandler(c *gin.Context) {
}
opts := api.DefaultOptions()
if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
if err := opts.FromMap(req.Options); err != nil {
log.Printf("could not merge model options: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}