Convert Safetensors to an Ollama model (#2824)

This commit is contained in:
Patrick Devine
2024-03-06 21:01:51 -08:00
committed by GitHub
parent 0ded7fdc4b
commit 2c017ca441
9 changed files with 3083 additions and 153 deletions

View File

@@ -1,6 +1,7 @@
package server
import (
"archive/zip"
"bytes"
"context"
"crypto/sha256"
@@ -23,6 +24,7 @@ import (
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/convert"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
@@ -316,7 +318,24 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
c.Args = blobPath
}
bin, err := os.Open(realpath(modelFileDir, c.Args))
pathName := realpath(modelFileDir, c.Args)
ggufName, err := convertSafetensors(name, pathName)
if err != nil {
switch {
case errors.Is(err, zip.ErrFormat):
// it's not a safetensor archive
default:
return err
}
}
if ggufName != "" {
pathName = ggufName
defer os.RemoveAll(ggufName)
}
bin, err := os.Open(pathName)
if err != nil {
// not a file on disk so must be a model reference
modelpath := ParseModelPath(c.Args)
@@ -592,6 +611,73 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
return nil
}
func convertSafetensors(name, fn string) (string, error) {
r, err := zip.OpenReader(fn)
if err != nil {
return "", err
}
defer r.Close()
tempDir, err := os.MkdirTemp("", "ollama-convert")
if err != nil {
return "", err
}
defer os.RemoveAll(tempDir)
for _, f := range r.File {
fpath := filepath.Join(tempDir, f.Name)
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
if err != nil {
return "", err
}
rc, err := f.Open()
if err != nil {
return "", err
}
_, err = io.Copy(outFile, rc)
if err != nil {
return "", err
}
outFile.Close()
rc.Close()
}
params, err := convert.GetParams(tempDir)
if err != nil {
return "", err
}
SupportedArchs := []string{
"MistralForCausalLM",
}
for _, arch := range params.Architectures {
if !slices.Contains(SupportedArchs, arch) {
return "", fmt.Errorf("this safetensors model is not yet supported")
}
}
t, err := convert.GetSafeTensors(tempDir)
if err != nil {
return "", err
}
vocab, err := convert.LoadTokens(tempDir)
if err != nil {
return "", err
}
fn, err = convert.WriteGGUF(name, t, params, vocab)
if err != nil {
return "", err
}
return fn, nil
}
func CopyModel(src, dest string) error {
srcModelPath := ParseModelPath(src)
srcPath, err := srcModelPath.GetManifestPath()