mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-15 18:27:08 +00:00
server/internal/registry: take over pulls from server package (#9485)
This commit replaces the old pull implementation in the server package with the new, faster, more robust pull implementation in the registry package. The new endpoint, and now the remove endpoint too, are behind the feature gate "client2" enabled only by setting the OLLAMA_EXPERIMENT environment variable include "client2". Currently, the progress indication is wired to perform the same as the previous implementation to avoid making changes to the CLI, and because the status reports happen at the start of the download, and the end of the write to disk, the progress indication is not as smooth as it could be. This is a known issue and will be addressed in a future change. This implementation may be ~0.5-1.0% slower in rare cases, depending on network and disk speed, but is generally MUCH faster and more robust than the its predecessor in all other cases.
This commit is contained in:
@@ -7,10 +7,14 @@ import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
@@ -109,6 +113,8 @@ func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/delete":
|
||||
return false, s.handleDelete(rec, r)
|
||||
case "/api/pull":
|
||||
return false, s.handlePull(rec, r)
|
||||
default:
|
||||
if s.Fallback != nil {
|
||||
s.Fallback.ServeHTTP(rec, r)
|
||||
@@ -214,6 +220,97 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
return s.Prune()
|
||||
}
|
||||
|
||||
type progressUpdateJSON struct {
|
||||
Status string `json:"status"`
|
||||
Digest blob.Digest `json:"digest,omitempty,omitzero"`
|
||||
Total int64 `json:"total,omitempty,omitzero"`
|
||||
Completed int64 `json:"completed,omitempty,omitzero"`
|
||||
}
|
||||
|
||||
func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
maybeFlush := func() {
|
||||
fl, _ := w.(http.Flusher)
|
||||
if fl != nil {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
defer maybeFlush()
|
||||
|
||||
var mu sync.Mutex
|
||||
enc := json.NewEncoder(w)
|
||||
enc.Encode(progressUpdateJSON{Status: "pulling manifest"})
|
||||
|
||||
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
|
||||
Update: func(l *ollama.Layer, n int64, err error) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// TODO(bmizerany): coalesce these updates; writing per
|
||||
// update is expensive
|
||||
enc.Encode(progressUpdateJSON{
|
||||
Digest: l.Digest,
|
||||
Status: "pulling",
|
||||
Total: l.Size,
|
||||
Completed: n,
|
||||
})
|
||||
},
|
||||
})
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
// TODO(bmizerany): continue to support non-streaming responses
|
||||
done <- s.Client.Pull(ctx, p.model())
|
||||
}()
|
||||
|
||||
func() {
|
||||
t := time.NewTicker(100 * time.Millisecond)
|
||||
defer t.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
mu.Lock()
|
||||
maybeFlush()
|
||||
mu.Unlock()
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
var status string
|
||||
if errors.Is(err, ollama.ErrModelNotFound) {
|
||||
status = fmt.Sprintf("error: model %q not found", p.model())
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
} else {
|
||||
status = fmt.Sprintf("error: %v", err)
|
||||
enc.Encode(progressUpdateJSON{Status: status})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// These final updates are not strictly necessary, because they have
|
||||
// already happened at this point. Our pull handler code used to do
|
||||
// these steps after, not during, the pull, and they were slow, so we
|
||||
// wanted to provide feedback to users what was happening. For now, we
|
||||
// keep them to not jar users who are used to seeing them. We can phase
|
||||
// them out with a new and nicer UX later. One without progress bars
|
||||
// and digests that no one cares about.
|
||||
enc.Encode(progressUpdateJSON{Status: "verifying layers"})
|
||||
enc.Encode(progressUpdateJSON{Status: "writing manifest"})
|
||||
enc.Encode(progressUpdateJSON{Status: "success"})
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
|
||||
@@ -1,17 +1,27 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
"golang.org/x/tools/txtar"
|
||||
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
type panicTransport struct{}
|
||||
@@ -30,7 +40,7 @@ type bytesResetter interface {
|
||||
Reset()
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *Local {
|
||||
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
||||
@@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
client := panicOnRoundTrip
|
||||
if upstreamRegistry != nil {
|
||||
s := httptest.NewTLSServer(upstreamRegistry)
|
||||
t.Cleanup(s.Close)
|
||||
tr := s.Client().Transport.(*http.Transport).Clone()
|
||||
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
|
||||
var d net.Dialer
|
||||
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
|
||||
}
|
||||
client = &http.Client{Transport: tr}
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
HTTPClient: client,
|
||||
Mask: "example.com/library/_:latest",
|
||||
}
|
||||
|
||||
l := &Local{
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
@@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
||||
func TestServerDelete(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
s := newTestServer(t)
|
||||
s := newTestServer(t, nil)
|
||||
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
@@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
//go:embed testdata/registry.txt
|
||||
var registryTXT []byte
|
||||
|
||||
var registryFS = sync.OnceValue(func() fs.FS {
|
||||
// Txtar gets hung up on \r\n line endings, so we need to convert them
|
||||
// to \n when parsing the txtar on Windows.
|
||||
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
|
||||
a := txtar.Parse(data)
|
||||
fmt.Printf("%q\n", a.Comment)
|
||||
fsys, err := txtar.FS(a)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fsys
|
||||
})
|
||||
|
||||
func TestServerPull(t *testing.T) {
|
||||
modelsHandler := http.FileServerFS(registryFS())
|
||||
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/v2/library/BOOM/manifests/latest":
|
||||
w.WriteHeader(999)
|
||||
io.WriteString(w, `{"error": "boom"}`)
|
||||
case "/v2/library/unknown/manifests/latest":
|
||||
w.WriteHeader(404)
|
||||
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
|
||||
default:
|
||||
t.Logf("serving file: %s", r.URL.Path)
|
||||
modelsHandler.ServeHTTP(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
|
||||
t.Helper()
|
||||
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
gotlines := got.Body.String()
|
||||
t.Logf("got:\n%s", gotlines)
|
||||
for want := range strings.Lines(wantlines) {
|
||||
want = strings.TrimSpace(want)
|
||||
want, unwanted := strings.CutPrefix(want, "!")
|
||||
want = strings.TrimSpace(want)
|
||||
if !unwanted && !strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! missing %q in body", want)
|
||||
}
|
||||
if unwanted && strings.Contains(gotlines, want) {
|
||||
t.Fatalf("! unexpected %q in body", want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
|
||||
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
|
||||
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
|
||||
{"status":"verifying layers"}
|
||||
{"status":"writing manifest"}
|
||||
{"status":"success"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: model \"unknown\" not found"}
|
||||
`)
|
||||
|
||||
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
|
||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `!`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
|
||||
checkResponse(got, `
|
||||
{"status":"pulling manifest"}
|
||||
{"status":"error: invalid or missing name: \"\""}
|
||||
|
||||
!verifying
|
||||
!writing
|
||||
!success
|
||||
`)
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t)
|
||||
s := newTestServer(t, nil)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
}
|
||||
|
||||
22
server/internal/registry/testdata/registry.txt
vendored
Normal file
22
server/internal/registry/testdata/registry.txt
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
-- v2/library/smol/manifests/latest --
|
||||
{
|
||||
"schemaVersion": 2,
|
||||
"mediaType": "application/vnd.docker.distribution.manifest.v2+json",
|
||||
"config": {
|
||||
"mediaType": "application/vnd.docker.container.image.v1+json",
|
||||
"digest": "sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356",
|
||||
"size": 3
|
||||
},
|
||||
"layers": [
|
||||
{
|
||||
"mediaType": "application/vnd.ollama.image.model",
|
||||
"digest": "sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312",
|
||||
"size": 5
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
-- v2/library/smol/blobs/sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312 --
|
||||
GGUF
|
||||
-- v2/library/smol/blobs/sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356 --
|
||||
{}
|
||||
Reference in New Issue
Block a user