mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 00:07:07 +00:00
server/internal/client/ollama: persist through chunk download errors (#9923)
This commit is contained in:
@@ -59,6 +59,11 @@ var (
|
||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||
ErrCached = errors.New("cached")
|
||||
|
||||
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
|
||||
// incomplete due to one or more layer download failures. Users that
|
||||
// want specific errors should use [WithTrace].
|
||||
ErrIncomplete = errors.New("incomplete")
|
||||
)
|
||||
|
||||
// Defaults
|
||||
@@ -271,8 +276,19 @@ func DefaultRegistry() (*Registry, error) {
|
||||
|
||||
func UserAgent() string {
|
||||
buildinfo, _ := debug.ReadBuildInfo()
|
||||
|
||||
version := buildinfo.Main.Version
|
||||
if version == "(devel)" {
|
||||
// When using `go run .` the version is "(devel)". This is seen
|
||||
// as an invalid version by ollama.com and so it defaults to
|
||||
// "needs upgrade" for some requests, such as pulls. These
|
||||
// checks can be skipped by using the special version "v0.0.0",
|
||||
// so we set it to that here.
|
||||
version = "v0.0.0"
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
|
||||
buildinfo.Main.Version,
|
||||
version,
|
||||
runtime.GOARCH,
|
||||
runtime.GOOS,
|
||||
runtime.Version(),
|
||||
@@ -418,13 +434,14 @@ func canRetry(err error) bool {
|
||||
//
|
||||
// It always calls update with a nil error.
|
||||
type trackingReader struct {
|
||||
r io.Reader
|
||||
n *atomic.Int64
|
||||
l *Layer
|
||||
r io.Reader
|
||||
update func(l *Layer, n int64, err error)
|
||||
}
|
||||
|
||||
func (r *trackingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.n.Add(int64(n))
|
||||
r.update(r.l, int64(n), nil)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -462,16 +479,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
|
||||
// Send initial layer trace events to allow clients to have an
|
||||
// understanding of work to be done before work starts.
|
||||
var expected int64
|
||||
t := traceFromContext(ctx)
|
||||
for _, l := range layers {
|
||||
t.update(l, 0, nil)
|
||||
expected += l.Size
|
||||
}
|
||||
|
||||
var total atomic.Int64
|
||||
var g errgroup.Group
|
||||
g.SetLimit(r.maxStreams())
|
||||
for _, l := range layers {
|
||||
info, err := c.Get(l.Digest)
|
||||
if err == nil && info.Size == l.Size {
|
||||
total.Add(l.Size)
|
||||
t.update(l, l.Size, ErrCached)
|
||||
continue
|
||||
}
|
||||
@@ -484,21 +505,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
// TODO(bmizerany): fix this unbounded use of defer
|
||||
defer chunked.Close()
|
||||
|
||||
var progress atomic.Int64
|
||||
for cs, err := range r.chunksums(ctx, name, l) {
|
||||
if err != nil {
|
||||
// Bad chunksums response, update tracing
|
||||
// clients and then bail.
|
||||
t.update(l, progress.Load(), err)
|
||||
return err
|
||||
// Chunksum stream was interrupted, so tell
|
||||
// trace about it, and let in-flight chunk
|
||||
// downloads finish. Once they finish, return
|
||||
// ErrIncomplete, which is triggered by the
|
||||
// fact that the total bytes received is less
|
||||
// than the expected bytes.
|
||||
t.update(l, 0, err)
|
||||
break
|
||||
}
|
||||
|
||||
g.Go(func() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
if err == nil || errors.Is(err, ErrCached) {
|
||||
total.Add(cs.Chunk.Size())
|
||||
} else {
|
||||
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
|
||||
}
|
||||
t.update(l, progress.Load(), err)
|
||||
}()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
|
||||
@@ -522,7 +547,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
// download rate since it knows better than a
|
||||
// client that is measuring rate based on
|
||||
// wall-clock time-since-last-update.
|
||||
body := &trackingReader{r: res.Body, n: &progress}
|
||||
body := &trackingReader{l: l, r: res.Body, update: t.update}
|
||||
|
||||
return chunked.Put(cs.Chunk, cs.Digest, body)
|
||||
})
|
||||
@@ -531,6 +556,9 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
if total.Load() != expected {
|
||||
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
|
||||
}
|
||||
|
||||
md := blob.DigestFromBytes(m.Data)
|
||||
if err := blob.PutBytes(c, md, m.Data); err != nil {
|
||||
@@ -757,15 +785,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
||||
}
|
||||
blobURL := res.Header.Get("Content-Location")
|
||||
|
||||
var size int64
|
||||
s := bufio.NewScanner(res.Body)
|
||||
s.Split(bufio.ScanWords)
|
||||
for {
|
||||
if !s.Scan() {
|
||||
if s.Err() != nil {
|
||||
yield(chunksum{}, s.Err())
|
||||
} else if size != l.Size {
|
||||
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -789,12 +814,6 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
|
||||
return
|
||||
}
|
||||
|
||||
size += chunk.Size()
|
||||
if size > l.Size {
|
||||
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
|
||||
return
|
||||
}
|
||||
|
||||
cs := chunksum{
|
||||
URL: blobURL,
|
||||
Chunk: chunk,
|
||||
|
||||
Reference in New Issue
Block a user