Allow setting parameters in the REPL (#1294)

This commit is contained in:
Patrick Devine
2023-11-29 09:56:42 -08:00
committed by GitHub
parent 63097607b2
commit cde31cb220
3 changed files with 154 additions and 87 deletions

View File

@@ -14,7 +14,6 @@ import (
"net/url"
"os"
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
@@ -426,7 +425,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
formattedParams, err := formatParams(params)
formattedParams, err := api.FormatParams(params)
if err != nil {
return err
}
@@ -581,64 +580,6 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
return newLayer, nil
}
// formatParams converts specified parameter options to their correct types
func formatParams(params map[string][]string) (map[string]interface{}, error) {
opts := api.Options{}
valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
typeOpts := reflect.TypeOf(opts) // types of the fields in the options struct
// build map of json struct tags to their types
jsonOpts := make(map[string]reflect.StructField)
for _, field := range reflect.VisibleFields(typeOpts) {
jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
if jsonTag != "" {
jsonOpts[jsonTag] = field
}
}
out := make(map[string]interface{})
// iterate params and set values based on json struct tags
for key, vals := range params {
if opt, ok := jsonOpts[key]; ok {
field := valueOpts.FieldByName(opt.Name)
if field.IsValid() && field.CanSet() {
switch field.Kind() {
case reflect.Float32:
floatVal, err := strconv.ParseFloat(vals[0], 32)
if err != nil {
return nil, fmt.Errorf("invalid float value %s", vals)
}
out[key] = float32(floatVal)
case reflect.Int:
intVal, err := strconv.ParseInt(vals[0], 10, 64)
if err != nil {
return nil, fmt.Errorf("invalid int value %s", vals)
}
out[key] = intVal
case reflect.Bool:
boolVal, err := strconv.ParseBool(vals[0])
if err != nil {
return nil, fmt.Errorf("invalid bool value %s", vals)
}
out[key] = boolVal
case reflect.String:
out[key] = vals[0]
case reflect.Slice:
// TODO: only string slices are supported right now
out[key] = vals
default:
return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
}
}
}
}
return out, nil
}
func getLayerDigests(layers []*LayerReader) ([]string, error) {
var digests []string
for _, l := range layers {