mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 15:57:04 +00:00
This change adds tracking of download chunks during the pull process so that subsequent pulls can skip downloading already completed chunks. This works across restarts of ollama. Currently, download state will be lost if a prune is triggered during a pull (e.g. restart or remove). This issue should be addressed in a follow-up PR.
1157 lines
30 KiB
Go
1157 lines
30 KiB
Go
// Package ollama provides a client for interacting with an Ollama registry
|
|
// which pushes and pulls model manifests and layers as defined by the
|
|
// [ollama.com/manifest].
|
|
package ollama
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"crypto"
|
|
"crypto/ed25519"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"encoding/base64"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"iter"
|
|
"log/slog"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"runtime"
|
|
"runtime/debug"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
"github.com/ollama/ollama/server/internal/internal/names"
|
|
|
|
_ "embed"
|
|
)
|
|
|
|
// Errors
|
|
var (
|
|
// ErrModelNotFound is returned when a manifest is not found in the
|
|
// cache or registry.
|
|
ErrModelNotFound = errors.New("model not found")
|
|
|
|
// ErrManifestInvalid is returned when a manifest found in a local or
|
|
// remote cache is invalid.
|
|
ErrManifestInvalid = errors.New("invalid manifest")
|
|
|
|
// ErrMissingModel is returned when the model part of a name is missing
|
|
// or invalid.
|
|
ErrNameInvalid = errors.New("invalid or missing name")
|
|
|
|
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
|
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
|
ErrCached = errors.New("cached")
|
|
|
|
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
|
// incomplete due to one or more layer download failures. Users that
|
|
// want specific errors should use [WithTrace].
|
|
ErrIncomplete = errors.New("incomplete")
|
|
)
|
|
|
|
// Defaults
|
|
const (
|
|
// DefaultChunkingThreshold is the threshold at which a layer should be
|
|
// split up into chunks when downloading.
|
|
DefaultChunkingThreshold = 64 << 20
|
|
)
|
|
|
|
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
|
dir := os.Getenv("OLLAMA_MODELS")
|
|
if dir == "" {
|
|
home, _ := os.UserHomeDir()
|
|
home = cmp.Or(home, ".")
|
|
dir = filepath.Join(home, ".ollama", "models")
|
|
}
|
|
return blob.Open(dir)
|
|
})
|
|
|
|
// DefaultCache returns the default cache used by the registry. It is
|
|
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
|
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
|
// it uses the current working directory.
|
|
func DefaultCache() (*blob.DiskCache, error) {
|
|
return defaultCache()
|
|
}
|
|
|
|
// Error is the standard error returned by Ollama APIs. It can represent a
|
|
// single or multiple error response.
|
|
//
|
|
// Single error responses have the following format:
|
|
//
|
|
// {"code": "optional_code","error":"error message"}
|
|
//
|
|
// Multiple error responses have the following format:
|
|
//
|
|
// {"errors": [{"code": "optional_code","message":"error message"}]}
|
|
//
|
|
// Note, that the error field is used in single error responses, while the
|
|
// message field is used in multiple error responses.
|
|
//
|
|
// In both cases, the code field is optional and may be empty.
|
|
type Error struct {
|
|
Status int `json:"-"` // TODO(bmizerany): remove this
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
}
|
|
|
|
func (e *Error) Error() string {
|
|
var b strings.Builder
|
|
b.WriteString("registry responded with status ")
|
|
b.WriteString(strconv.Itoa(e.Status))
|
|
if e.Code != "" {
|
|
b.WriteString(": code ")
|
|
b.WriteString(e.Code)
|
|
}
|
|
if e.Message != "" {
|
|
b.WriteString(": ")
|
|
b.WriteString(e.Message)
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func (e *Error) LogValue() slog.Value {
|
|
return slog.GroupValue(
|
|
slog.Int("status", e.Status),
|
|
slog.String("code", e.Code),
|
|
slog.String("message", e.Message),
|
|
)
|
|
}
|
|
|
|
// UnmarshalJSON implements json.Unmarshaler.
|
|
func (e *Error) UnmarshalJSON(b []byte) error {
|
|
type E Error
|
|
var v struct {
|
|
// Single error
|
|
Code string
|
|
Error string
|
|
|
|
// Multiple errors
|
|
Errors []E
|
|
}
|
|
if err := json.Unmarshal(b, &v); err != nil {
|
|
return err
|
|
}
|
|
if v.Error != "" {
|
|
// Single error case
|
|
e.Code = v.Code
|
|
e.Message = v.Error
|
|
return nil
|
|
}
|
|
if len(v.Errors) == 0 {
|
|
return fmt.Errorf("no messages in error response: %s", string(b))
|
|
}
|
|
*e = Error(v.Errors[0]) // our registry only returns one error.
|
|
return nil
|
|
}
|
|
|
|
const DefaultMask = "registry.ollama.ai/library/_:latest"
|
|
|
|
var defaultMask = func() names.Name {
|
|
n := names.Parse(DefaultMask)
|
|
if !n.IsFullyQualified() {
|
|
panic("default mask is not fully qualified")
|
|
}
|
|
return n
|
|
}()
|
|
|
|
// CompleteName returns a fully qualified name by merging the given name with
|
|
// the default mask. If the name is already fully qualified, it is returned
|
|
// unchanged.
|
|
func CompleteName(name string) string {
|
|
return names.Merge(names.Parse(name), defaultMask).String()
|
|
}
|
|
|
|
// Registry is a client for performing push and pull operations against an
|
|
// Ollama registry.
|
|
type Registry struct {
|
|
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
|
// used.
|
|
Cache *blob.DiskCache
|
|
|
|
// UserAgent is the User-Agent header to send with requests to the
|
|
// registry. If empty, the User-Agent is determined by HTTPClient.
|
|
UserAgent string
|
|
|
|
// Key is the key used to authenticate with the registry.
|
|
//
|
|
// Currently, only Ed25519 keys are supported.
|
|
Key crypto.PrivateKey
|
|
|
|
// HTTPClient is the HTTP client used to make requests to the registry.
|
|
//
|
|
// If nil, [http.DefaultClient] is used.
|
|
//
|
|
// As a quick note: If a Registry function that makes a call to a URL
|
|
// with the "https+insecure" scheme, the client will be cloned and the
|
|
// transport will be set to skip TLS verification, unless the client's
|
|
// Transport done not have a Clone method with the same signature as
|
|
// [http.Transport.Clone], which case, the call will fail.
|
|
HTTPClient *http.Client
|
|
|
|
// MaxStreams is the maximum number of concurrent streams to use when
|
|
// pushing or pulling models. If zero, the number of streams is
|
|
// determined by [runtime.GOMAXPROCS].
|
|
//
|
|
// A negative value means no limit.
|
|
MaxStreams int
|
|
|
|
// ChunkingThreshold is the maximum size of a layer to download in a single
|
|
// request. If zero, [DefaultChunkingThreshold] is used.
|
|
ChunkingThreshold int64
|
|
|
|
// Mask, if set, is the name used to convert non-fully qualified names
|
|
// to fully qualified names. If empty, [DefaultMask] is used.
|
|
Mask string
|
|
}
|
|
|
|
func (r *Registry) cache() (*blob.DiskCache, error) {
|
|
if r.Cache != nil {
|
|
return r.Cache, nil
|
|
}
|
|
return defaultCache()
|
|
}
|
|
|
|
func (r *Registry) parseName(name string) (names.Name, error) {
|
|
mask := defaultMask
|
|
if r.Mask != "" {
|
|
mask = names.Parse(r.Mask)
|
|
}
|
|
n := names.Merge(names.Parse(name), mask)
|
|
if !n.IsFullyQualified() {
|
|
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// DefaultRegistry returns a new Registry configured from the environment. The
|
|
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
|
|
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
|
|
// system's temporary directory.
|
|
//
|
|
// It returns an error if any configuration in the environment is invalid.
|
|
func DefaultRegistry() (*Registry, error) {
|
|
home, err := os.UserHomeDir()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
|
|
if err != nil && errors.Is(err, fs.ErrNotExist) {
|
|
return nil, err
|
|
}
|
|
|
|
var rc Registry
|
|
rc.UserAgent = UserAgent()
|
|
rc.Key, err = ssh.ParseRawPrivateKey(keyPEM)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
maxStreams := os.Getenv("OLLAMA_REGISTRY_MAXSTREAMS")
|
|
if maxStreams != "" {
|
|
var err error
|
|
rc.MaxStreams, err = strconv.Atoi(maxStreams)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid OLLAMA_REGISTRY_MAXSTREAMS: %w", err)
|
|
}
|
|
}
|
|
return &rc, nil
|
|
}
|
|
|
|
func UserAgent() string {
|
|
buildinfo, _ := debug.ReadBuildInfo()
|
|
|
|
version := buildinfo.Main.Version
|
|
if version == "(devel)" {
|
|
// When using `go run .` the version is "(devel)". This is seen
|
|
// as an invalid version by ollama.com and so it defaults to
|
|
// "needs upgrade" for some requests, such as pulls. These
|
|
// checks can be skipped by using the special version "v0.0.0",
|
|
// so we set it to that here.
|
|
version = "v0.0.0"
|
|
}
|
|
|
|
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
|
version,
|
|
runtime.GOARCH,
|
|
runtime.GOOS,
|
|
runtime.Version(),
|
|
)
|
|
}
|
|
|
|
func (r *Registry) maxStreams() int {
|
|
return cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
|
}
|
|
|
|
func (r *Registry) maxChunkingThreshold() int64 {
|
|
return cmp.Or(r.ChunkingThreshold, DefaultChunkingThreshold)
|
|
}
|
|
|
|
type PushParams struct {
|
|
// From is an optional destination name for the model. If empty, the
|
|
// destination name is the same as the source name.
|
|
From string
|
|
}
|
|
|
|
// Push pushes the model with the name in the cache to the remote registry.
|
|
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
|
if p == nil {
|
|
p = &PushParams{}
|
|
}
|
|
|
|
c, err := r.cache()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Before much else happens, check layers at not null, the blobs exist,
|
|
// and the sizes match. This prevents long uploads followed by
|
|
// disappointment.
|
|
for _, l := range m.Layers {
|
|
if l == nil {
|
|
return fmt.Errorf("%w: null layer", ErrManifestInvalid)
|
|
}
|
|
info, err := c.Get(l.Digest)
|
|
if err != nil {
|
|
return fmt.Errorf("error getting %s: %w", l.Digest.Short(), err)
|
|
}
|
|
if info.Size != l.Size {
|
|
return fmt.Errorf("size mismatch for %s: %d != %d", l.Digest.Short(), info.Size, l.Size)
|
|
}
|
|
}
|
|
|
|
t := traceFromContext(ctx)
|
|
|
|
scheme, n, _, err := r.parseNameExtended(name)
|
|
if err != nil {
|
|
// This should never happen since ResolveLocal should have
|
|
// already validated the name.
|
|
panic(err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var g errgroup.Group
|
|
g.SetLimit(r.maxStreams())
|
|
for _, l := range m.Layers {
|
|
var progress atomic.Int64
|
|
g.Go(func() (err error) {
|
|
defer func() { t.update(l, progress.Load(), err) }()
|
|
|
|
t.update(l, 0, nil)
|
|
|
|
startURL := fmt.Sprintf("%s://%s/v2/%s/%s/blobs/uploads/?digest=%s",
|
|
scheme,
|
|
n.Host(),
|
|
n.Namespace(),
|
|
n.Model(),
|
|
l.Digest,
|
|
)
|
|
res, err := r.send(ctx, "POST", startURL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
res.Body.Close()
|
|
|
|
f, err := os.Open(c.GetFile(l.Digest))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
uploadURL := res.Header.Get("Location")
|
|
if uploadURL == "" {
|
|
t.update(l, l.Size, ErrCached)
|
|
return nil
|
|
}
|
|
|
|
req, err := r.newRequest(ctx, "PUT", uploadURL, f)
|
|
if err != nil {
|
|
return fmt.Errorf("invalid upload URL returned from registry: %q: %w", uploadURL, err)
|
|
}
|
|
req.ContentLength = l.Size
|
|
|
|
res, err = sendRequest(r.client(), req)
|
|
if err == nil {
|
|
res.Body.Close()
|
|
}
|
|
return err
|
|
})
|
|
}
|
|
|
|
if err := g.Wait(); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Commit
|
|
path := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s",
|
|
scheme,
|
|
n.Host(),
|
|
n.Namespace(),
|
|
n.Model(),
|
|
n.Tag(),
|
|
)
|
|
res, err := r.send(ctx, "PUT", path, bytes.NewReader(m.Data))
|
|
if err == nil {
|
|
res.Body.Close()
|
|
}
|
|
// TODO(bmizerany): add a "commit" trace event
|
|
return err
|
|
}
|
|
|
|
// trackingReader is an io.Reader that tracks the number of bytes read and
|
|
// calls the update function with the layer, the number of bytes read.
|
|
//
|
|
// It always calls update with a nil error.
|
|
type trackingReader struct {
|
|
l *Layer
|
|
r io.Reader
|
|
update func(l *Layer, n int64, err error)
|
|
}
|
|
|
|
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
|
n, err = r.r.Read(p)
|
|
r.update(r.l, int64(n), nil)
|
|
return
|
|
}
|
|
|
|
// Pull pulls the model with the given name from the remote registry into the
|
|
// cache.
|
|
//
|
|
// For layers larger then [Registry.MaxChunkSize], the layer is downloaded in
|
|
// chunks of the specified size, and then reassembled and verified. This is
|
|
// typically slower than splitting the model up across layers, and is mostly
|
|
// utilized for layers of type equal to "application/vnd.ollama.image".
|
|
func (r *Registry) Pull(ctx context.Context, name string) error {
|
|
m, err := r.Resolve(ctx, name)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO(bmizerany): decide if this should be considered valid. Maybe
|
|
// server-side we special case '{}' to have some special meaning? Maybe
|
|
// "archiving" a tag (which is how we reason about it in the registry
|
|
// already, just with a different twist).
|
|
if len(m.Layers) == 0 {
|
|
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
|
}
|
|
|
|
c, err := r.cache()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// TODO(bmizerany): work to remove the need to do this
|
|
layers := m.Layers
|
|
if m.Config != nil && m.Config.Digest.IsValid() {
|
|
layers = append(layers, m.Config)
|
|
}
|
|
|
|
// Send initial layer trace events to allow clients to have an
|
|
// understanding of work to be done before work starts.
|
|
var expected int64
|
|
t := traceFromContext(ctx)
|
|
for _, l := range layers {
|
|
t.update(l, 0, nil)
|
|
expected += l.Size
|
|
}
|
|
|
|
var received atomic.Int64
|
|
var g errgroup.Group
|
|
g.SetLimit(r.maxStreams())
|
|
for _, l := range layers {
|
|
info, err := c.Get(l.Digest)
|
|
if err == nil && info.Size == l.Size {
|
|
received.Add(l.Size)
|
|
t.update(l, l.Size, ErrCached)
|
|
continue
|
|
}
|
|
|
|
var wg sync.WaitGroup
|
|
chunked, err := c.Chunked(l.Digest, l.Size)
|
|
if err != nil {
|
|
t.update(l, 0, err)
|
|
continue
|
|
}
|
|
|
|
for cs, err := range r.chunksums(ctx, name, l) {
|
|
if err != nil {
|
|
// Chunksum stream interrupted. Note in trace
|
|
// log and let in-flight downloads complete.
|
|
// This will naturally trigger ErrIncomplete
|
|
// since received < expected bytes.
|
|
t.update(l, 0, err)
|
|
break
|
|
}
|
|
|
|
cacheKey := fmt.Sprintf(
|
|
"v1 pull chunksum %s %s %d-%d",
|
|
l.Digest,
|
|
cs.Digest,
|
|
cs.Chunk.Start,
|
|
cs.Chunk.End,
|
|
)
|
|
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
|
|
_, err := c.Get(cacheKeyDigest)
|
|
if err == nil {
|
|
received.Add(cs.Chunk.Size())
|
|
t.update(l, cs.Chunk.Size(), ErrCached)
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
g.Go(func() (err error) {
|
|
defer func() {
|
|
if err == nil {
|
|
// Ignore cache key write errors for now. We've already
|
|
// reported to trace that the chunk is complete.
|
|
//
|
|
// Ideally, we should only report completion to trace
|
|
// after successful cache commit. This current approach
|
|
// works but could trigger unnecessary redownloads if
|
|
// the checkpoint key is missing on next pull.
|
|
//
|
|
// Not incorrect, just suboptimal - fix this in a
|
|
// future update.
|
|
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
|
|
|
|
received.Add(cs.Chunk.Size())
|
|
} else {
|
|
t.update(l, 0, err)
|
|
}
|
|
wg.Done()
|
|
}()
|
|
|
|
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
|
|
res, err := sendRequest(r.client(), req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
|
return chunked.Put(cs.Chunk, cs.Digest, body)
|
|
})
|
|
}
|
|
|
|
// Close writer immediately after downloads finish, not at Pull
|
|
// exit. Using defer would keep file descriptors open until all
|
|
// layers complete, potentially exhausting system limits with
|
|
// many layers.
|
|
//
|
|
// The WaitGroup tracks when all chunks finish downloading,
|
|
// allowing precise writer closure in a background goroutine.
|
|
// Each layer briefly uses one extra goroutine while at most
|
|
// maxStreams()-1 chunks download in parallel.
|
|
//
|
|
// This caps file descriptors at maxStreams() instead of
|
|
// growing with layer count.
|
|
g.Go(func() error {
|
|
wg.Wait()
|
|
chunked.Close()
|
|
return nil
|
|
})
|
|
}
|
|
if err := g.Wait(); err != nil {
|
|
return err
|
|
}
|
|
if received.Load() != expected {
|
|
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
|
|
}
|
|
|
|
md := blob.DigestFromBytes(m.Data)
|
|
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
|
return err
|
|
}
|
|
return c.Link(m.Name, md)
|
|
}
|
|
|
|
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
|
// before attempting to unlink the model.
|
|
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
|
n, err := r.parseName(name)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
c, err := r.cache()
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return c.Unlink(n.String())
|
|
}
|
|
|
|
// Manifest represents a [ollama.com/manifest].
|
|
type Manifest struct {
|
|
Name string `json:"-"` // the canonical name of the model
|
|
Data []byte `json:"-"` // the raw data of the manifest
|
|
Layers []*Layer `json:"layers"`
|
|
|
|
// For legacy reasons, we still have to download the config layer.
|
|
Config *Layer `json:"config"`
|
|
}
|
|
|
|
// Layer returns the layer with the given
|
|
// digest, or nil if not found.
|
|
func (m *Manifest) Layer(d blob.Digest) *Layer {
|
|
for _, l := range m.Layers {
|
|
if l.Digest == d {
|
|
return l
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Manifest) All() iter.Seq[*Layer] {
|
|
return func(yield func(*Layer) bool) {
|
|
if !yield(m.Config) {
|
|
return
|
|
}
|
|
for _, l := range m.Layers {
|
|
if !yield(l) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manifest) Size() int64 {
|
|
var size int64
|
|
if m.Config != nil {
|
|
size += m.Config.Size
|
|
}
|
|
for _, l := range m.Layers {
|
|
size += l.Size
|
|
}
|
|
return size
|
|
}
|
|
|
|
// MarshalJSON implements json.Marshaler.
|
|
//
|
|
// NOTE: It adds an empty config object to the manifest, which is required by
|
|
// the registry, but not used by the client. In the future, the config object
|
|
// will not be required by the registry and this will should be removed.
|
|
func (m Manifest) MarshalJSON() ([]byte, error) {
|
|
type M Manifest
|
|
v := struct {
|
|
M
|
|
|
|
// This is ignored, mostly, by the registry But, if not
|
|
// present, it will cause an error to be returned during the
|
|
// last phase of the commit which expects it, but does nothing
|
|
// with it. This will be fixed in a future release of
|
|
// ollama.com.
|
|
Config Layer `json:"config"`
|
|
}{
|
|
M: M(m),
|
|
}
|
|
return json.Marshal(v)
|
|
}
|
|
|
|
// unmarshalManifest unmarshals the data into a manifest, and sets the name
|
|
// field to the string representation of the name.
|
|
//
|
|
// It panics if the name is not fully qualified. Callers should ensure the name
|
|
// is fully qualified before calling this function.
|
|
func unmarshalManifest(n names.Name, data []byte) (*Manifest, error) {
|
|
if !n.IsFullyQualified() {
|
|
panic(fmt.Sprintf("unmarshalManifest: name is not fully qualified: %s", n.String()))
|
|
}
|
|
var m Manifest
|
|
if err := json.Unmarshal(data, &m); err != nil {
|
|
return nil, err
|
|
}
|
|
m.Name = n.String()
|
|
m.Data = data
|
|
return &m, nil
|
|
}
|
|
|
|
// Layer is a layer in a model.
|
|
type Layer struct {
|
|
Digest blob.Digest `json:"digest"`
|
|
MediaType string `json:"mediaType"`
|
|
Size int64 `json:"size"`
|
|
}
|
|
|
|
// ResolveLocal resolves a name to a Manifest in the local cache.
|
|
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
|
_, n, d, err := r.parseNameExtended(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c, err := r.cache()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !d.IsValid() {
|
|
// No digest, so resolve the manifest by name.
|
|
d, err = c.Resolve(n.String())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
data, err := os.ReadFile(c.GetFile(d))
|
|
if err != nil {
|
|
if errors.Is(err, fs.ErrNotExist) {
|
|
return nil, fmt.Errorf("%w: %s", ErrModelNotFound, name)
|
|
}
|
|
return nil, err
|
|
}
|
|
m, err := unmarshalManifest(n, data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
// Resolve resolves a name to a Manifest in the remote registry.
|
|
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
|
scheme, n, d, err := r.parseNameExtended(name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
manifestURL := fmt.Sprintf("%s://%s/v2/%s/%s/manifests/%s", scheme, n.Host(), n.Namespace(), n.Model(), n.Tag())
|
|
if d.IsValid() {
|
|
manifestURL = fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s", scheme, n.Host(), n.Namespace(), n.Model(), d)
|
|
}
|
|
|
|
res, err := r.send(ctx, "GET", manifestURL, nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
data, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// TODO(bmizerany): return digest here
|
|
m, err := unmarshalManifest(n, data)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: %w", name, errors.Join(ErrManifestInvalid, err))
|
|
}
|
|
return m, nil
|
|
}
|
|
|
|
type chunksum struct {
|
|
URL string
|
|
Chunk blob.Chunk
|
|
Digest blob.Digest
|
|
}
|
|
|
|
// chunksums returns a sequence of chunksums for the given layer. If the layer is under the
|
|
// chunking threshold, a single chunksum is returned that covers the entire layer. If the layer
|
|
// is over the chunking threshold, the chunksums are read from the chunksums endpoint.
|
|
func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Seq2[chunksum, error] {
|
|
return func(yield func(chunksum, error) bool) {
|
|
scheme, n, _, err := r.parseNameExtended(name)
|
|
if err != nil {
|
|
yield(chunksum{}, err)
|
|
return
|
|
}
|
|
|
|
if l.Size < r.maxChunkingThreshold() {
|
|
// any layer under the threshold should be downloaded
|
|
// in one go.
|
|
cs := chunksum{
|
|
URL: fmt.Sprintf("%s://%s/v2/%s/%s/blobs/%s",
|
|
scheme,
|
|
n.Host(),
|
|
n.Namespace(),
|
|
n.Model(),
|
|
l.Digest,
|
|
),
|
|
Chunk: blob.Chunk{Start: 0, End: l.Size - 1},
|
|
Digest: l.Digest,
|
|
}
|
|
yield(cs, nil)
|
|
return
|
|
}
|
|
|
|
// The response is a sequence of chunksums.
|
|
//
|
|
// Chunksums are chunks of a larger blob that can be
|
|
// downloaded and verified independently.
|
|
//
|
|
// The chunksums endpoint is a GET request that returns a
|
|
// sequence of chunksums in the following format:
|
|
//
|
|
// > GET /v2/<namespace>/<model>/chunksums/<digest>
|
|
//
|
|
// < HTTP/1.1 200 OK
|
|
// < Content-Location: <blobURL>
|
|
// <
|
|
// < <digest> <start>-<end>
|
|
// < ...
|
|
//
|
|
// The <blobURL> is the URL to download the chunks from and
|
|
// each <digest> is the digest of the chunk, and <start>-<end>
|
|
// is the range the chunk in the blob.
|
|
//
|
|
// Ranges may be used directly in Range headers like
|
|
// "bytes=<start>-<end>".
|
|
//
|
|
// The chunksums returned are guaranteed to be contiguous and
|
|
// include all bytes of the layer. If the stream is cut short,
|
|
// clients should retry.
|
|
|
|
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
|
|
scheme,
|
|
n.Host(),
|
|
n.Namespace(),
|
|
n.Model(),
|
|
l.Digest,
|
|
)
|
|
|
|
req, err := r.newRequest(ctx, "GET", chunksumsURL, nil)
|
|
if err != nil {
|
|
yield(chunksum{}, err)
|
|
return
|
|
}
|
|
res, err := sendRequest(r.client(), req)
|
|
if err != nil {
|
|
yield(chunksum{}, err)
|
|
return
|
|
}
|
|
defer res.Body.Close()
|
|
if res.StatusCode != 200 {
|
|
err := fmt.Errorf("chunksums: unexpected status code %d", res.StatusCode)
|
|
yield(chunksum{}, err)
|
|
return
|
|
}
|
|
blobURL := res.Header.Get("Content-Location")
|
|
|
|
s := bufio.NewScanner(res.Body)
|
|
s.Split(bufio.ScanWords)
|
|
for {
|
|
if !s.Scan() {
|
|
if s.Err() != nil {
|
|
yield(chunksum{}, s.Err())
|
|
}
|
|
return
|
|
}
|
|
d, err := blob.ParseDigest(s.Bytes())
|
|
if err != nil {
|
|
yield(chunksum{}, fmt.Errorf("invalid digest: %q", s.Bytes()))
|
|
return
|
|
}
|
|
|
|
if !s.Scan() {
|
|
err := s.Err()
|
|
if err == nil {
|
|
err = fmt.Errorf("missing chunk range for digest %s", d)
|
|
}
|
|
yield(chunksum{}, err)
|
|
return
|
|
}
|
|
chunk, err := parseChunk(s.Bytes())
|
|
if err != nil {
|
|
yield(chunksum{}, fmt.Errorf("invalid chunk range for digest %s: %q", d, s.Bytes()))
|
|
return
|
|
}
|
|
|
|
cs := chunksum{
|
|
URL: blobURL,
|
|
Chunk: chunk,
|
|
Digest: d,
|
|
}
|
|
if !yield(cs, nil) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *Registry) client() *http.Client {
|
|
if r.HTTPClient != nil {
|
|
return r.HTTPClient
|
|
}
|
|
return http.DefaultClient
|
|
}
|
|
|
|
// newRequest constructs a new request, ready to use, with the given method,
|
|
// url, and body, pre-signed with client [Key] and [UserAgent].
|
|
func (r *Registry) newRequest(ctx context.Context, method, url string, body io.Reader) (*http.Request, error) {
|
|
req, err := http.NewRequestWithContext(ctx, method, url, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if r.UserAgent != "" {
|
|
req.Header.Set("User-Agent", r.UserAgent)
|
|
}
|
|
if r.Key != nil {
|
|
token, err := makeAuthToken(r.Key)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
// sendRequest makes a request with the given client and request, and returns the
|
|
// response if the status code is 200. If the status code is not 200, an Error
|
|
// is parsed from the response body and returned. If any other error occurs, it
|
|
// is returned.
|
|
func sendRequest(c *http.Client, r *http.Request) (_ *http.Response, err error) {
|
|
defer func() {
|
|
if err != nil {
|
|
err = fmt.Errorf("request error %s: %w", r.URL, err)
|
|
}
|
|
}()
|
|
|
|
if r.URL.Scheme == "https+insecure" {
|
|
// TODO(bmizerany): clone client.Transport, set
|
|
// InsecureSkipVerify, etc.
|
|
|
|
type cloner interface {
|
|
Clone() *http.Transport
|
|
}
|
|
|
|
// Attempt to configure the transport to skip TLS verification
|
|
// if we can clone it, otherwise fall through and let the http
|
|
// client complain and the scheme being invalid.
|
|
x, ok := cmp.Or(c.Transport, http.DefaultTransport).(cloner)
|
|
if ok {
|
|
tr := x.Clone()
|
|
tr.TLSClientConfig = cmp.Or(tr.TLSClientConfig, &tls.Config{})
|
|
tr.TLSClientConfig.InsecureSkipVerify = true
|
|
|
|
cc := *c // shallow copy
|
|
cc.Transport = tr
|
|
c = &cc
|
|
|
|
r = r.Clone(r.Context())
|
|
r.URL.Scheme = "https"
|
|
|
|
// fall through
|
|
}
|
|
}
|
|
|
|
res, err := c.Do(r)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if res.StatusCode/100 != 2 {
|
|
out, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var re Error
|
|
if err := json.Unmarshal(out, &re); err != nil {
|
|
// Use the raw body if we can't parse it as an error object.
|
|
re.Message = string(out)
|
|
}
|
|
|
|
// coerce MANIFEST_UNKNOWN to ErrManifestNotFound
|
|
if strings.EqualFold(re.Code, "MANIFEST_UNKNOWN") {
|
|
return nil, ErrModelNotFound
|
|
}
|
|
|
|
re.Status = res.StatusCode
|
|
return nil, &re
|
|
}
|
|
return res, nil
|
|
}
|
|
|
|
// send is a convenience method for making a request with newRequest and
|
|
// passing it to send with r.client().
|
|
func (r *Registry) send(ctx context.Context, method, path string, body io.Reader) (*http.Response, error) {
|
|
req, err := r.newRequest(ctx, method, path, body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return sendRequest(r.client(), req)
|
|
}
|
|
|
|
// makeAuthToken creates an Ollama auth token for the given private key.
|
|
//
|
|
// NOTE: This format is OLD, overly complex, and should be replaced. We're
|
|
// inheriting it from the original Ollama client and ollama.com
|
|
// implementations, so we need to support it for now.
|
|
func makeAuthToken(key crypto.PrivateKey) (string, error) {
|
|
privKey, _ := key.(*ed25519.PrivateKey)
|
|
if privKey == nil {
|
|
return "", fmt.Errorf("unsupported private key type: %T", key)
|
|
}
|
|
|
|
url := fmt.Sprintf("https://ollama.com?ts=%d", time.Now().Unix())
|
|
// Part 1: the checkData (e.g. the URL with a timestamp)
|
|
|
|
// Part 2: the public key
|
|
pubKeyShort, err := func() ([]byte, error) {
|
|
sshPubKey, err := ssh.NewPublicKey(privKey.Public())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
pubKeyParts := bytes.Fields(ssh.MarshalAuthorizedKey(sshPubKey))
|
|
if len(pubKeyParts) < 2 {
|
|
return nil, fmt.Errorf("malformed public key: %q", pubKeyParts)
|
|
}
|
|
pubKeyShort := pubKeyParts[1]
|
|
return pubKeyShort, nil
|
|
}()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Part 3: the signature
|
|
sig := ed25519.Sign(*privKey, []byte(checkData(url)))
|
|
|
|
// Assemble the token: <checkData>:<pubKey>:<signature>
|
|
var b strings.Builder
|
|
io.WriteString(&b, base64.StdEncoding.EncodeToString([]byte(url)))
|
|
b.WriteByte(':')
|
|
b.Write(pubKeyShort)
|
|
b.WriteByte(':')
|
|
io.WriteString(&b, base64.StdEncoding.EncodeToString(sig))
|
|
|
|
return b.String(), nil
|
|
}
|
|
|
|
// The original spec for Ollama tokens was to use the SHA256 of the zero
|
|
// string as part of the signature. I'm not sure why that was, but we still
|
|
// need it to verify the signature.
|
|
var zeroSum = func() string {
|
|
sha256sum := sha256.Sum256(nil)
|
|
x := base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))
|
|
return x
|
|
}()
|
|
|
|
// checkData takes a URL and creates the original string format of the
|
|
// data signature that is used by the ollama client to sign requests
|
|
func checkData(url string) string {
|
|
return fmt.Sprintf("GET,%s,%s", url, zeroSum)
|
|
}
|
|
|
|
type publicError struct {
|
|
wrapped error
|
|
message string
|
|
}
|
|
|
|
func withPublicMessagef(err error, message string, args ...any) error {
|
|
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
|
|
}
|
|
|
|
func (e publicError) Error() string { return e.message }
|
|
func (e publicError) Unwrap() error { return e.wrapped }
|
|
|
|
var supportedSchemes = []string{
|
|
"http",
|
|
"https",
|
|
"https+insecure",
|
|
}
|
|
|
|
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
|
|
|
// parseNameExtended parses and validates an extended name, returning the scheme, name,
|
|
// and digest.
|
|
//
|
|
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
|
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
|
|
//
|
|
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
|
|
// message is returned.
|
|
//
|
|
// If the name is not, once merged with the mask, fully qualified,
|
|
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
|
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
|
scheme, name, digest := splitExtended(s)
|
|
scheme = cmp.Or(scheme, "https")
|
|
if !slices.Contains(supportedSchemes, scheme) {
|
|
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
|
return "", names.Name{}, blob.Digest{}, err
|
|
}
|
|
|
|
var d blob.Digest
|
|
if digest != "" {
|
|
var err error
|
|
d, err = blob.ParseDigest(digest)
|
|
if err != nil {
|
|
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
|
|
return "", names.Name{}, blob.Digest{}, err
|
|
}
|
|
if name == "" {
|
|
// We have can resolve a manifest from a digest only,
|
|
// so skip name validation and return the scheme and
|
|
// digest.
|
|
return scheme, names.Name{}, d, nil
|
|
}
|
|
}
|
|
|
|
n, err := r.parseName(name)
|
|
if err != nil {
|
|
return "", names.Name{}, blob.Digest{}, err
|
|
}
|
|
return scheme, n, d, nil
|
|
}
|
|
|
|
// splitExtended splits an extended name string into its scheme, name, and digest
|
|
// parts.
|
|
//
|
|
// Examples:
|
|
//
|
|
// http://ollama.com/bmizerany/smol:latest@digest
|
|
// https://ollama.com/bmizerany/smol:latest
|
|
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
|
// model@digest
|
|
// @digest
|
|
func splitExtended(s string) (scheme, name, digest string) {
|
|
i := strings.Index(s, "://")
|
|
if i >= 0 {
|
|
scheme = s[:i]
|
|
s = s[i+3:]
|
|
}
|
|
i = strings.LastIndex(s, "@")
|
|
if i >= 0 {
|
|
digest = s[i+1:]
|
|
s = s[:i]
|
|
}
|
|
return scheme, s, digest
|
|
}
|
|
|
|
// parseChunk parses a string in the form "start-end" and returns the Chunk.
|
|
func parseChunk[S ~string | ~[]byte](s S) (blob.Chunk, error) {
|
|
startPart, endPart, found := strings.Cut(string(s), "-")
|
|
if !found {
|
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: missing '-'", s)
|
|
}
|
|
start, err := strconv.ParseInt(startPart, 10, 64)
|
|
if err != nil {
|
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid start to %q: %v", s, err)
|
|
}
|
|
end, err := strconv.ParseInt(endPart, 10, 64)
|
|
if err != nil {
|
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid end to %q: %v", s, err)
|
|
}
|
|
if start > end {
|
|
return blob.Chunk{}, fmt.Errorf("chunks: invalid range %q: start > end", s)
|
|
}
|
|
return blob.Chunk{Start: start, End: end}, nil
|
|
}
|