Convert Safetensors to an Ollama model (#2824)

This commit is contained in:
Patrick Devine
2024-03-06 21:01:51 -08:00
committed by GitHub
parent 0ded7fdc4b
commit 2c017ca441
9 changed files with 3083 additions and 153 deletions

View File

@@ -1,6 +1,7 @@
package cmd
import (
"archive/zip"
"bytes"
"context"
"crypto/ed25519"
@@ -87,22 +88,82 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = filepath.Join(filepath.Dir(filename), path)
}
bin, err := os.Open(path)
fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
continue
} else if err != nil {
return err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return err
// TODO make this work w/ adapters
if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return err
}
defer os.RemoveAll(tf.Name())
zf := zip.NewWriter(tf)
files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}
if len(files) == 0 {
return fmt.Errorf("no safetensors files were found in '%s'", path)
}
// add the safetensor config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))
for _, fn := range files {
f, err := os.Open(fn)
if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") {
continue
} else if err != nil {
return err
}
fi, err := f.Stat()
if err != nil {
return err
}
h, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}
h.Name = filepath.Base(fn)
h.Method = zip.Store
w, err := zf.CreateHeader(h)
if err != nil {
return err
}
_, err = io.Copy(w, f)
if err != nil {
return err
}
}
if err := zf.Close(); err != nil {
return err
}
if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
}
bin.Seek(0, io.SeekStart)
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
digest, err := createBlob(cmd, client, path)
if err != nil {
return err
}
@@ -141,6 +202,26 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {
return "", err
}
defer bin.Close()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
bin.Seek(0, io.SeekStart)
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err
}
return digest, nil
}
func RunHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {