server/internal: replace model delete API with new registry handler. (#9347)

This commit introduces a new API implementation for handling
interactions with the registry and the local model cache. The new API is
located in server/internal/registry. The package name is "registry" and
should be considered temporary; it is hidden and not bleeding outside of
the server package. As the commits roll in, we'll start consuming more
of the API and then let reverse osmosis take effect, at which point it
will surface closer to the root level packages as much as needed.
This commit is contained in:
Blake Mizerany
2025-02-27 12:04:53 -08:00
committed by GitHub
parent be2ac1ed93
commit 2412adf42b
16 changed files with 705 additions and 90 deletions

View File

@@ -19,6 +19,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
"path/filepath"
@@ -86,9 +87,23 @@ func DefaultCache() (*blob.DiskCache, error) {
return blob.Open(dir)
}
// Error is the standard error returned by Ollama APIs.
// Error is the standard error returned by Ollama APIs. It can represent a
// single or multiple error response.
//
// Single error responses have the following format:
//
// {"code": "optional_code","error":"error message"}
//
// Multiple error responses have the following format:
//
// {"errors": [{"code": "optional_code","message":"error message"}]}
//
// Note, that the error field is used in single error responses, while the
// message field is used in multiple error responses.
//
// In both cases, the code field is optional and may be empty.
type Error struct {
Status int `json:"-"`
Status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
@@ -97,13 +112,34 @@ func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
}
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.Status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *Error) UnmarshalJSON(b []byte) error {
type E Error
var v struct{ Errors []E }
var v struct {
// Single error
Code string
Error string
// Multiple errors
Errors []E
}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
if v.Error != "" {
// Single error case
e.Code = v.Code
e.Message = v.Error
return nil
}
if len(v.Errors) == 0 {
return fmt.Errorf("no messages in error response: %s", string(b))
}
@@ -111,9 +147,8 @@ func (e *Error) UnmarshalJSON(b []byte) error {
return nil
}
// TODO(bmizerany): make configurable on [Registry]
var defaultName = func() names.Name {
n := names.Parse("ollama.com/library/_:latest")
n := names.Parse("registry.ollama.ai/library/_:latest")
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
}
@@ -160,21 +195,26 @@ type Registry struct {
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// NameMask, if set, is the name used to convert non-fully qualified
// names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used.
NameMask string
}
// RegistryFromEnv returns a new Registry configured from the environment. The
// DefaultRegistry returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
// system's temporary directory.
//
// It returns an error if any configuration in the environment is invalid.
func RegistryFromEnv() (*Registry, error) {
func DefaultRegistry() (*Registry, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
if err != nil {
if err != nil && errors.Is(err, fs.ErrNotExist) {
return nil, err
}
@@ -208,9 +248,19 @@ type PushParams struct {
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
maskName := defaultName
if mask != "" {
maskName = names.Parse(mask)
if !maskName.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("invalid name mask: %s", mask)
}
}
scheme, n, ds := names.ParseExtended(s)
n = names.Merge(n, defaultName)
if !n.IsValid() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
n = names.Merge(n, maskName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
@@ -223,7 +273,7 @@ func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error)
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
}
scheme = cmp.Or(scheme, "https")
@@ -255,7 +305,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
p = &PushParams{}
}
m, err := ResolveLocal(c, cmp.Or(p.From, name))
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
if err != nil {
return err
}
@@ -278,7 +328,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name)
scheme, n, _, err := parseName(name, r.NameMask)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@@ -372,7 +422,7 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name)
scheme, n, _, err := parseName(name, r.NameMask)
if err != nil {
return err
}
@@ -520,6 +570,16 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
_, n, _, err := parseName(name, r.NameMask)
if err != nil {
return false, err
}
return c.Unlink(n.String())
}
// Manifest represents a [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
@@ -590,8 +650,8 @@ type Layer struct {
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored.
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name)
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name, r.NameMask)
if err != nil {
return nil, err
}
@@ -617,7 +677,7 @@ func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name)
scheme, n, d, err := parseName(name, r.NameMask)
if err != nil {
return nil, err
}

View File

@@ -21,7 +21,7 @@ import (
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/testutil"
"github.com/ollama/ollama/server/internal/testutil"
)
func TestManifestMarshalJSON(t *testing.T) {
@@ -37,20 +37,6 @@ func TestManifestMarshalJSON(t *testing.T) {
}
}
func link(c *blob.DiskCache, name string, manifest string) {
_, n, _, err := parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
@@ -98,29 +84,44 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
}
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
link := func(name string, manifest string) {
_, n, _, err := parseName(name, rc.NameMask)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(c, name, string(data))
link(name, string(data))
}
link(c, "empty", "")
link("empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link(c, "invalid", "!!!!!")
link("invalid", "!!!!!")
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
return rc, c
}
@@ -385,7 +386,7 @@ func TestRegistryPullNotCached(t *testing.T) {
})
// Confirm that the layer does not exist locally
_, err := ResolveLocal(c, "model")
_, err := rc.ResolveLocal(c, "model")
checkNotExist(t, err)
_, err = c.Get(d)
@@ -396,7 +397,7 @@ func TestRegistryPullNotCached(t *testing.T) {
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := ResolveLocal(c, "model")
mg, err := rc.ResolveLocal(c, "model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -654,3 +655,72 @@ func TestCanRetry(t *testing.T) {
}
}
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
data string
want *Error
wantErr bool
}{
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors single",
data: `{"errors":[{"code":"blob_unknown"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "errors multiple",
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "error empty",
data: `{"error":""}`,
wantErr: true,
},
{
name: "error very empty",
data: `{}`,
wantErr: true,
},
{
name: "error message",
data: `{"error":"message", "code":"code"}`,
want: &Error{Code: "code", Message: "message"},
},
{
name: "invalid value",
data: `{"error": 1}`,
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var got Error
err := json.Unmarshal([]byte(tt.data), &got)
if err != nil {
if tt.wantErr {
return
}
t.Errorf("Unmarshal() error = %v", err)
// fallthrough and check got
}
if tt.want == nil {
tt.want = &Error{}
}
if !reflect.DeepEqual(got, *tt.want) {
t.Errorf("got = %v; want %v", got, *tt.want)
}
})
}
}