no iterator

This commit is contained in:
Michael Yang
2024-04-25 08:53:08 -07:00
parent 7ffe45734d
commit 4d0d0fa383
3 changed files with 34 additions and 85 deletions

View File

@@ -30,7 +30,6 @@ import (
"github.com/ollama/ollama/server/envconfig"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/types/ordered"
"github.com/ollama/ollama/version"
)
@@ -344,7 +343,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
switch c.Name {
case "model", "adapter":
var baseLayers *ordered.Map[*Layer, *llm.GGML]
var baseLayers []*layerWithGGML
if name := model.ParseName(c.Args); name.IsValid() {
baseLayers, err = parseFromModel(ctx, name, fn)
if err != nil {
@@ -377,70 +376,51 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, m
return fmt.Errorf("invalid model reference: %s", c.Args)
}
var err2 error
var tempfiles []*os.File
// TODO(mxyng): replace with rangefunc
baseLayers.Items()(func(layer *Layer, ggml *llm.GGML) bool {
if quantization != "" && ggml != nil && ggml.Name() == "gguf" {
for _, baseLayer := range baseLayers {
if quantization != "" && baseLayer.GGML != nil && baseLayer.GGML.Name() == "gguf" {
ftype, err := llm.ParseFileType(quantization)
if err != nil {
err2 = err
return false
return err
}
filetype := ggml.KV().FileType()
filetype := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, filetype) {
err2 = errors.New("quantization is only supported for F16 and F32 models")
return false
return errors.New("quantization is only supported for F16 and F32 models")
}
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", filetype, quantization)})
blob, err := GetBlobsPath(layer.Digest)
blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil {
err2 = err
return false
return err
}
temp, err := os.CreateTemp(filepath.Dir(blob), quantization)
if err != nil {
err2 = err
return false
return err
}
tempfiles = append(tempfiles, temp)
defer temp.Close()
defer os.Remove(temp.Name())
if err := llm.Quantize(blob, temp.Name(), ftype); err != nil {
err2 = err
return false
return err
}
layer, err = NewLayer(temp, layer.MediaType)
baseLayer.Layer, err = NewLayer(temp, baseLayer.Layer.MediaType)
if err != nil {
err2 = err
return false
return err
}
}
if ggml != nil {
config.ModelFormat = cmp.Or(config.ModelFormat, ggml.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, ggml.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(ggml.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, ggml.KV().FileType())
config.ModelFamilies = append(config.ModelFamilies, ggml.KV().Architecture())
if baseLayer.GGML != nil {
config.ModelFormat = cmp.Or(config.ModelFormat, baseLayer.GGML.Name())
config.ModelFamily = cmp.Or(config.ModelFamily, baseLayer.GGML.KV().Architecture())
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType())
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
}
layers = append(layers, layer)
return true
})
for _, tempfile := range tempfiles {
defer tempfile.Close()
defer os.Remove(tempfile.Name())
}
if err2 != nil {
return err2
layers = append(layers, baseLayer.Layer)
}
case "license", "template", "system":
blob := strings.NewReader(c.Args)