server/internal/registry: take over pulls from server package (#9485)

This commit replaces the old pull implementation in the server package
with the new, faster, more robust pull implementation in the registry
package.

The new endpoint, and now the remove endpoint too, are behind the
feature gate "client2" enabled only by setting the OLLAMA_EXPERIMENT
environment variable include "client2".

Currently, the progress indication is wired to perform the same as the
previous implementation to avoid making changes to the CLI, and because
the status reports happen at the start of the download, and the end of
the write to disk, the progress indication is not as smooth as it could
be. This is a known issue and will be addressed in a future change.

This implementation may be ~0.5-1.0% slower in rare cases, depending on
network and disk speed, but is generally MUCH faster and more robust
than the its predecessor in all other cases.
This commit is contained in:
Blake Mizerany
2025-03-05 14:48:18 -08:00
committed by GitHub
parent cae5d4d4ea
commit e2252d0fc6
11 changed files with 370 additions and 52 deletions

View File

@@ -1,17 +1,27 @@
package registry
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/fs"
"net"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strings"
"sync"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil"
"golang.org/x/tools/txtar"
_ "embed"
)
type panicTransport struct{}
@@ -30,7 +40,7 @@ type bytesResetter interface {
Reset()
}
func newTestServer(t *testing.T) *Local {
func newTestServer(t *testing.T, upstreamRegistry http.HandlerFunc) *Local {
t.Helper()
dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models"))
@@ -41,10 +51,25 @@ func newTestServer(t *testing.T) *Local {
if err != nil {
t.Fatal(err)
}
client := panicOnRoundTrip
if upstreamRegistry != nil {
s := httptest.NewTLSServer(upstreamRegistry)
t.Cleanup(s.Close)
tr := s.Client().Transport.(*http.Transport).Clone()
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "tcp", s.Listener.Addr().String())
}
client = &http.Client{Transport: tr}
}
rc := &ollama.Registry{
Cache: c,
HTTPClient: panicOnRoundTrip,
HTTPClient: client,
Mask: "example.com/library/_:latest",
}
l := &Local{
Client: rc,
Logger: testutil.Slogger(t),
@@ -85,7 +110,7 @@ func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
func TestServerDelete(t *testing.T) {
check := testutil.Checker(t)
s := newTestServer(t)
s := newTestServer(t, nil)
_, err := s.Client.ResolveLocal("smol")
check(err)
@@ -127,8 +152,105 @@ func TestServerDelete(t *testing.T) {
}
}
//go:embed testdata/registry.txt
var registryTXT []byte
var registryFS = sync.OnceValue(func() fs.FS {
// Txtar gets hung up on \r\n line endings, so we need to convert them
// to \n when parsing the txtar on Windows.
data := bytes.ReplaceAll(registryTXT, []byte("\r\n"), []byte("\n"))
a := txtar.Parse(data)
fmt.Printf("%q\n", a.Comment)
fsys, err := txtar.FS(a)
if err != nil {
panic(err)
}
return fsys
})
func TestServerPull(t *testing.T) {
modelsHandler := http.FileServerFS(registryFS())
s := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/v2/library/BOOM/manifests/latest":
w.WriteHeader(999)
io.WriteString(w, `{"error": "boom"}`)
case "/v2/library/unknown/manifests/latest":
w.WriteHeader(404)
io.WriteString(w, `{"errors": [{"code": "MANIFEST_UNKNOWN", "message": "manifest unknown"}]}`)
default:
t.Logf("serving file: %s", r.URL.Path)
modelsHandler.ServeHTTP(w, r)
}
})
checkResponse := func(got *httptest.ResponseRecorder, wantlines string) {
t.Helper()
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
gotlines := got.Body.String()
t.Logf("got:\n%s", gotlines)
for want := range strings.Lines(wantlines) {
want = strings.TrimSpace(want)
want, unwanted := strings.CutPrefix(want, "!")
want = strings.TrimSpace(want)
if !unwanted && !strings.Contains(gotlines, want) {
t.Fatalf("! missing %q in body", want)
}
if unwanted && strings.Contains(gotlines, want) {
t.Fatalf("! unexpected %q in body", want)
}
}
}
got := s.send(t, "POST", "/api/pull", `{"model": "BOOM"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: request error https://example.com/v2/library/BOOM/manifests/latest: registry responded with status 999: boom"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "smol"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3}
{"status":"pulling","digest":"sha256:68e0ec597aee59d35f8dc44942d7b17d471ade10d3aca07a5bb7177713950312","total":5,"completed":5}
{"status":"pulling","digest":"sha256:ca3d163bab055381827226140568f3bef7eaac187cebd76878e0b63e9e442356","total":3,"completed":3}
{"status":"verifying layers"}
{"status":"writing manifest"}
{"status":"success"}
`)
got = s.send(t, "POST", "/api/pull", `{"model": "unknown"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: model \"unknown\" not found"}
`)
got = s.send(t, "DELETE", "/api/pull", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "POST", "/api/pull", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "POST", "/api/pull", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "POST", "/api/pull", `{"model": "://"}`)
checkResponse(got, `
{"status":"pulling manifest"}
{"status":"error: invalid or missing name: \"\""}
!verifying
!writing
!success
`)
}
func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t)
s := newTestServer(t, nil)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
}