Update the /api/create endpoint to use JSON (#7935)

Replaces `POST /api/create` to use JSON instead of a Modelfile.

This is a breaking change.
This commit is contained in:
Patrick Devine
2024-12-31 18:02:30 -08:00
committed by GitHub
parent 459d822b51
commit 86a622cbdc
17 changed files with 1523 additions and 1094 deletions

View File

@@ -1,13 +1,10 @@
package cmd
import (
"archive/zip"
"bufio"
"bytes"
"context"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"encoding/pem"
"errors"
@@ -46,10 +43,7 @@ import (
"github.com/ollama/ollama/version"
)
var (
errModelNotFound = errors.New("no Modelfile or safetensors files found")
errModelfileNotFound = errors.New("specified Modelfile wasn't found")
)
var errModelfileNotFound = errors.New("specified Modelfile wasn't found")
func getModelfileName(cmd *cobra.Command) (string, error) {
fn, _ := cmd.Flags().GetString("file")
@@ -102,68 +96,52 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return err
}
home, err := os.UserHomeDir()
status := "gathering model components"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
req, err := modelfile.CreateRequest()
if err != nil {
return err
}
spinner.Stop()
status := "transferring model data"
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer p.Stop()
req.Name = args[0]
quantize, _ := cmd.Flags().GetString("quantize")
if quantize != "" {
req.Quantize = quantize
}
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
case "model", "adapter":
path := modelfile.Commands[i].Args
if path == "~" {
path = home
} else if strings.HasPrefix(path, "~/") {
path = filepath.Join(home, path[2:])
}
if !filepath.IsAbs(path) {
path = filepath.Join(filepath.Dir(filename), path)
}
fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && modelfile.Commands[i].Name == "model" {
continue
} else if err != nil {
if len(req.Files) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Files {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
if fi.IsDir() {
// this is likely a safetensors or pytorch directory
// TODO make this work w/ adapters
tempfile, err := tempZipFiles(path)
if err != nil {
return err
}
defer os.RemoveAll(tempfile)
path = tempfile
}
digest, err := createBlob(cmd, client, path, spinner)
if err != nil {
return err
}
modelfile.Commands[i].Args = "@" + digest
fileMap[filepath.Base(f)] = digest
}
req.Files = fileMap
}
if len(req.Adapters) > 0 {
fileMap := map[string]string{}
for f, digest := range req.Adapters {
if _, err := createBlob(cmd, client, f, digest, p); err != nil {
return err
}
fileMap[filepath.Base(f)] = digest
}
req.Adapters = fileMap
}
bars := make(map[string]*progress.Bar)
fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" {
spinner.Stop()
bar, ok := bars[resp.Digest]
if !ok {
bar = progress.NewBar(fmt.Sprintf("pulling %s...", resp.Digest[7:19]), resp.Total, resp.Completed)
@@ -183,145 +161,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}
quantize, _ := cmd.Flags().GetString("quantize")
request := api.CreateRequest{Name: args[0], Modelfile: modelfile.String(), Quantize: quantize}
if err := client.Create(cmd.Context(), &request, fn); err != nil {
if err := client.Create(cmd.Context(), req, fn); err != nil {
return err
}
return nil
}
func tempZipFiles(path string) (string, error) {
tempfile, err := os.CreateTemp("", "ollama-tf")
func createBlob(cmd *cobra.Command, client *api.Client, path string, digest string, p *progress.Progress) (string, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return "", err
}
defer tempfile.Close()
detectContentType := func(path string) (string, error) {
f, err := os.Open(path)
if err != nil {
return "", err
}
defer f.Close()
var b bytes.Buffer
b.Grow(512)
if _, err := io.CopyN(&b, f, 512); err != nil && !errors.Is(err, io.EOF) {
return "", err
}
contentType, _, _ := strings.Cut(http.DetectContentType(b.Bytes()), ";")
return contentType, nil
}
glob := func(pattern, contentType string) ([]string, error) {
matches, err := filepath.Glob(pattern)
if err != nil {
return nil, err
}
for _, safetensor := range matches {
if ct, err := detectContentType(safetensor); err != nil {
return nil, err
} else if ct != contentType {
return nil, fmt.Errorf("invalid content type: expected %s for %s", ct, safetensor)
}
}
return matches, nil
}
var files []string
if st, _ := glob(filepath.Join(path, "model*.safetensors"), "application/octet-stream"); len(st) > 0 {
// safetensors files might be unresolved git lfs references; skip if they are
// covers model-x-of-y.safetensors, model.fp32-x-of-y.safetensors, model.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapters.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapters.safetensors
files = append(files, st...)
} else if st, _ := glob(filepath.Join(path, "adapter_model.safetensors"), "application/octet-stream"); len(st) > 0 {
// covers adapter_model.safetensors
files = append(files, st...)
} else if pt, _ := glob(filepath.Join(path, "pytorch_model*.bin"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers pytorch_model-x-of-y.bin, pytorch_model.fp32-x-of-y.bin, pytorch_model.bin
files = append(files, pt...)
} else if pt, _ := glob(filepath.Join(path, "consolidated*.pth"), "application/zip"); len(pt) > 0 {
// pytorch files might also be unresolved git lfs references; skip if they are
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errModelNotFound
}
// add configuration files, json files are detected as text/plain
js, err := glob(filepath.Join(path, "*.json"), "text/plain")
if err != nil {
return "", err
}
files = append(files, js...)
// bert models require a nested config.json
// TODO(mxyng): merge this with the glob above
js, err = glob(filepath.Join(path, "**/*.json"), "text/plain")
if err != nil {
return "", err
}
files = append(files, js...)
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
zipfile := zip.NewWriter(tempfile)
defer zipfile.Close()
for _, file := range files {
f, err := os.Open(file)
if err != nil {
return "", err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return "", err
}
zfi, err := zip.FileInfoHeader(fi)
if err != nil {
return "", err
}
zfi.Name, err = filepath.Rel(path, file)
if err != nil {
return "", err
}
zf, err := zipfile.CreateHeader(zfi)
if err != nil {
return "", err
}
if _, err := io.Copy(zf, f); err != nil {
return "", err
}
}
return tempfile.Name(), nil
}
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path)
bin, err := os.Open(realPath)
if err != nil {
return "", err
}
@@ -334,18 +187,11 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *pr
}
fileSize := fileInfo.Size()
hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
if _, err := bin.Seek(0, io.SeekStart); err != nil {
return "", err
}
var pw progressWriter
status := "transferring model data 0%"
spinner.SetMessage(status)
status := fmt.Sprintf("copying file %s 0%%", digest)
spinner := progress.NewSpinner(status)
p.Add(status, spinner)
defer spinner.Stop()
done := make(chan struct{})
defer close(done)
@@ -356,15 +202,14 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *pr
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n.Load()/fileSize)))
spinner.SetMessage(fmt.Sprintf("copying file %s %d%%", digest, int(100*pw.n.Load()/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
spinner.SetMessage(fmt.Sprintf("copying file %s 100%%", digest))
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err
}

View File

@@ -13,11 +13,9 @@ import (
"strings"
"github.com/spf13/cobra"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes"
)
@@ -213,10 +211,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err
}
req := &api.CreateRequest{
Name: args[1],
Modelfile: buildModelfile(opts),
}
req := NewCreateRequest(args[1], opts)
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
@@ -459,39 +454,25 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
}
func buildModelfile(opts runOptions) string {
var f parser.File
f.Commands = append(f.Commands, parser.Command{Name: "model", Args: cmp.Or(opts.ParentModel, opts.Model)})
func NewCreateRequest(name string, opts runOptions) *api.CreateRequest {
req := &api.CreateRequest{
Name: name,
From: cmp.Or(opts.ParentModel, opts.Model),
}
if opts.System != "" {
f.Commands = append(f.Commands, parser.Command{Name: "system", Args: opts.System})
req.System = opts.System
}
keys := maps.Keys(opts.Options)
slices.Sort(keys)
for _, k := range keys {
v := opts.Options[k]
var cmds []parser.Command
switch t := v.(type) {
case []string:
for _, s := range t {
cmds = append(cmds, parser.Command{Name: k, Args: s})
}
default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", t)})
}
f.Commands = append(f.Commands, cmds...)
if len(opts.Options) > 0 {
req.Parameters = opts.Options
}
for _, msg := range opts.Messages {
if strings.Contains(msg.Content, "\"") {
msg.Content = `"""` + msg.Content + `"""`
}
f.Commands = append(f.Commands, parser.Command{Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content)})
if len(opts.Messages) > 0 {
req.Messages = opts.Messages
}
return f.String()
return req
}
func normalizeFilePath(fp string) string {

View File

@@ -3,10 +3,7 @@ package cmd
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
)
func TestExtractFilenames(t *testing.T) {
@@ -53,56 +50,3 @@ d:\path with\spaces\seven.JPEG inbetween7 c:\users\jdoe\eight.png inbetween8
assert.Contains(t, res[9], "ten.PNG")
assert.Contains(t, res[9], "E:")
}
func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
},
Options: map[string]any{
"temperature": 0.9,
"seed": 42,
"penalize_newline": false,
"stop": []string{"hi", "there"},
},
}
t.Run("model", func(t *testing.T) {
expect := `FROM hork
SYSTEM You are part horse and part shark, but all hork. Do horklike things
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop hi
PARAMETER stop there
PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE assistant Yes it is true, I am half horse, half shark.
`
actual := buildModelfile(opts)
if diff := cmp.Diff(expect, actual); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
t.Run("parent model", func(t *testing.T) {
opts.ParentModel = "horseshark"
expect := `FROM horseshark
SYSTEM You are part horse and part shark, but all hork. Do horklike things
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop hi
PARAMETER stop there
PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE assistant Yes it is true, I am half horse, half shark.
`
actual := buildModelfile(opts)
if diff := cmp.Diff(expect, actual); diff != "" {
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
}