mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 15:57:04 +00:00
This commit adds retry/backoff to the registry client for pull requests. Also, revert progress indication to match original client's until we can "get it right." Also, make WithTrace wrap existing traces instead of clobbering them. This allows clients to compose traces.
954 lines
26 KiB
Go
954 lines
26 KiB
Go
package ollama
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"io/fs"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"reflect"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/ollama/ollama/server/internal/cache/blob"
|
|
"github.com/ollama/ollama/server/internal/testutil"
|
|
)
|
|
|
|
func ExampleRegistry_cancelOnFirstError() {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
ctx = WithTrace(ctx, &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
if err != nil {
|
|
// Discontinue pulling layers if there is an
|
|
// error instead of continuing to pull more
|
|
// data.
|
|
cancel()
|
|
}
|
|
},
|
|
})
|
|
|
|
var r Registry
|
|
if err := r.Pull(ctx, "model"); err != nil {
|
|
// panic for demo purposes
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
func TestManifestMarshalJSON(t *testing.T) {
|
|
// All manifests should contain an "empty" config object.
|
|
var m Manifest
|
|
data, err := json.Marshal(m)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !bytes.Contains(data, []byte(`"config":{"digest":"sha256:`)) {
|
|
t.Error("expected manifest to contain empty config")
|
|
t.Fatalf("got:\n%s", string(data))
|
|
}
|
|
}
|
|
|
|
var errRoundTrip = errors.New("forced roundtrip error")
|
|
|
|
type recordRoundTripper http.HandlerFunc
|
|
|
|
func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
w := httptest.NewRecorder()
|
|
rr(w, req)
|
|
if w.Code == 499 {
|
|
return nil, errRoundTrip
|
|
}
|
|
resp := w.Result()
|
|
// For some reason, Response.Request is not set by httptest.NewRecorder, so we
|
|
// set it manually.
|
|
resp.Request = req
|
|
return w.Result(), nil
|
|
}
|
|
|
|
// newClient constructs a cache with predefined manifests for testing. The manifests are:
|
|
//
|
|
// empty: no data
|
|
// zero: no layers
|
|
// single: one layer with the contents "exists"
|
|
// multiple: two layers with the contents "exists" and "here"
|
|
// notfound: a layer that does not exist in the cache
|
|
// null: one null layer (e.g. [null])
|
|
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
|
|
// invalid: a layer with invalid JSON data
|
|
//
|
|
// Tests that want to ensure the client does not communicate with the upstream
|
|
// registry should pass a nil handler, which will cause a panic if
|
|
// communication is attempted.
|
|
//
|
|
// To simulate a network error, pass a handler that returns a 499 status code.
|
|
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
|
t.Helper()
|
|
|
|
c, err := blob.Open(t.TempDir())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
mklayer := func(data string) *Layer {
|
|
return &Layer{
|
|
Digest: importBytes(t, c, data),
|
|
Size: int64(len(data)),
|
|
}
|
|
}
|
|
|
|
r := &Registry{
|
|
Cache: c,
|
|
HTTPClient: &http.Client{
|
|
Transport: recordRoundTripper(upstreamRegistry),
|
|
},
|
|
}
|
|
|
|
link := func(name string, manifest string) {
|
|
n, err := r.parseName(name)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
if err := c.Link(n.String(), d); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
commit := func(name string, layers ...*Layer) {
|
|
t.Helper()
|
|
data, err := json.Marshal(&Manifest{Layers: layers})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
link(name, string(data))
|
|
}
|
|
|
|
link("empty", "")
|
|
commit("zero")
|
|
commit("single", mklayer("exists"))
|
|
commit("multiple", mklayer("exists"), mklayer("present"))
|
|
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
|
|
commit("null", nil)
|
|
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
|
link("invalid", "!!!!!")
|
|
|
|
return r, c
|
|
}
|
|
|
|
func okHandler(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
func checkErrCode(t *testing.T, err error, status int, code string) {
|
|
t.Helper()
|
|
var e *Error
|
|
if !errors.As(err, &e) || e.status != status || e.Code != code {
|
|
t.Errorf("err = %v; want %v %v", err, status, code)
|
|
}
|
|
}
|
|
|
|
func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
|
d, err := c.Import(strings.NewReader(data), int64(len(data)))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return d
|
|
}
|
|
|
|
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
|
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
|
return WithTrace(ctx, t), t
|
|
}
|
|
|
|
func TestPushZero(t *testing.T) {
|
|
rc, _ := newClient(t, okHandler)
|
|
err := rc.Push(t.Context(), "empty", nil)
|
|
if !errors.Is(err, ErrManifestInvalid) {
|
|
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
|
}
|
|
}
|
|
|
|
func TestPushSingle(t *testing.T) {
|
|
rc, _ := newClient(t, okHandler)
|
|
err := rc.Push(t.Context(), "single", nil)
|
|
testutil.Check(t, err)
|
|
}
|
|
|
|
func TestPushMultiple(t *testing.T) {
|
|
rc, _ := newClient(t, okHandler)
|
|
err := rc.Push(t.Context(), "multiple", nil)
|
|
testutil.Check(t, err)
|
|
}
|
|
|
|
func TestPushNotFound(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
t.Errorf("unexpected request: %v", r)
|
|
})
|
|
err := rc.Push(t.Context(), "notfound", nil)
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
|
}
|
|
}
|
|
|
|
func TestPushNullLayer(t *testing.T) {
|
|
rc, _ := newClient(t, nil)
|
|
err := rc.Push(t.Context(), "null", nil)
|
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
|
t.Errorf("err = %v; want invalid manifest", err)
|
|
}
|
|
}
|
|
|
|
func TestPushSizeMismatch(t *testing.T) {
|
|
rc, _ := newClient(t, nil)
|
|
ctx, _ := withTraceUnexpected(t.Context())
|
|
got := rc.Push(ctx, "sizemismatch", nil)
|
|
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
|
t.Errorf("err = %v; want size mismatch", got)
|
|
}
|
|
}
|
|
|
|
func TestPushInvalid(t *testing.T) {
|
|
rc, _ := newClient(t, nil)
|
|
err := rc.Push(t.Context(), "invalid", nil)
|
|
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
|
t.Errorf("err = %v; want invalid manifest", err)
|
|
}
|
|
}
|
|
|
|
func TestPushExistsAtRemote(t *testing.T) {
|
|
var pushed bool
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/uploads/") {
|
|
if !pushed {
|
|
// First push. Return an uploadURL.
|
|
pushed = true
|
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
|
return
|
|
}
|
|
w.WriteHeader(http.StatusAccepted)
|
|
return
|
|
}
|
|
|
|
io.Copy(io.Discard, r.Body)
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
rc.MaxStreams = 1 // prevent concurrent uploads
|
|
|
|
var errs []error
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(_ *Layer, n int64, err error) {
|
|
// uploading one at a time so no need to lock
|
|
errs = append(errs, err)
|
|
},
|
|
})
|
|
|
|
check := testutil.Checker(t)
|
|
|
|
err := rc.Push(ctx, "single", nil)
|
|
check(err)
|
|
|
|
if !errors.Is(errors.Join(errs...), nil) {
|
|
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
|
}
|
|
|
|
err = rc.Push(ctx, "single", nil)
|
|
check(err)
|
|
}
|
|
|
|
func TestPushRemoteError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
w.WriteHeader(500)
|
|
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
|
return
|
|
}
|
|
})
|
|
got := rc.Push(t.Context(), "single", nil)
|
|
checkErrCode(t, got, 500, "blob_error")
|
|
}
|
|
|
|
func TestPushLocationError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Location", ":///x")
|
|
w.WriteHeader(http.StatusAccepted)
|
|
})
|
|
got := rc.Push(t.Context(), "single", nil)
|
|
wantContains := "invalid upload URL"
|
|
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
|
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
|
}
|
|
}
|
|
|
|
func TestPushUploadRoundtripError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Host == "blob.store" {
|
|
w.WriteHeader(499) // force RoundTrip error on upload
|
|
return
|
|
}
|
|
w.Header().Set("Location", "http://blob.store/blobs/123")
|
|
})
|
|
got := rc.Push(t.Context(), "single", nil)
|
|
if !errors.Is(got, errRoundTrip) {
|
|
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
|
}
|
|
}
|
|
|
|
func TestPushUploadFileOpenError(t *testing.T) {
|
|
rc, c := newClient(t, okHandler)
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, _ int64, err error) {
|
|
// Remove the file just before it is opened for upload,
|
|
// but after the initial Stat that happens before the
|
|
// upload starts
|
|
os.Remove(c.GetFile(l.Digest))
|
|
},
|
|
})
|
|
got := rc.Push(ctx, "single", nil)
|
|
if !errors.Is(got, fs.ErrNotExist) {
|
|
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
|
}
|
|
}
|
|
|
|
func TestPushCommitRoundtripError(t *testing.T) {
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if strings.Contains(r.URL.Path, "/blobs/") {
|
|
panic("unexpected")
|
|
}
|
|
w.WriteHeader(499) // force RoundTrip error
|
|
})
|
|
err := rc.Push(t.Context(), "zero", nil)
|
|
if !errors.Is(err, errRoundTrip) {
|
|
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
|
}
|
|
}
|
|
|
|
func TestRegistryPullInvalidName(t *testing.T) {
|
|
rc, _ := newRegistryClient(t, nil)
|
|
err := rc.Pull(t.Context(), "://")
|
|
if !errors.Is(err, ErrNameInvalid) {
|
|
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
|
}
|
|
}
|
|
|
|
func TestRegistryPullInvalidManifest(t *testing.T) {
|
|
cases := []string{
|
|
"",
|
|
"null",
|
|
"!!!",
|
|
`{"layers":[]}`,
|
|
}
|
|
|
|
for _, resp := range cases {
|
|
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
io.WriteString(w, resp)
|
|
})
|
|
err := rc.Pull(t.Context(), "http://example.com/a/b")
|
|
if !errors.Is(err, ErrManifestInvalid) {
|
|
t.Errorf("err = %v; want invalid manifest", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRegistryResolveByDigest(t *testing.T) {
|
|
check := testutil.Checker(t)
|
|
|
|
exists := blob.DigestFromBytes("exists")
|
|
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
if r.URL.Path != "/v2/alice/palace/blobs/"+exists.String() {
|
|
w.WriteHeader(499) // should not hit manifest endpoint
|
|
}
|
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
|
})
|
|
|
|
_, err := rc.Resolve(t.Context(), "alice/palace@"+exists.String())
|
|
check(err)
|
|
}
|
|
|
|
func TestInsecureSkipVerify(t *testing.T) {
|
|
exists := blob.DigestFromBytes("exists")
|
|
|
|
s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":5}]}`, exists)
|
|
}))
|
|
defer s.Close()
|
|
|
|
const name = "library/insecure"
|
|
|
|
var rc Registry
|
|
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
|
_, err := rc.Resolve(t.Context(), url)
|
|
if err == nil || !strings.Contains(err.Error(), "failed to verify") {
|
|
t.Errorf("err = %v; want cert verification failure", err)
|
|
}
|
|
|
|
url = fmt.Sprintf("https+insecure://%s/%s", s.Listener.Addr(), name)
|
|
_, err = rc.Resolve(t.Context(), url)
|
|
testutil.Check(t, err)
|
|
}
|
|
|
|
func TestErrorUnmarshal(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
data string
|
|
want *Error
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "errors empty",
|
|
data: `{"errors":[]}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "errors empty",
|
|
data: `{"errors":[]}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "errors single",
|
|
data: `{"errors":[{"code":"blob_unknown"}]}`,
|
|
want: &Error{Code: "blob_unknown", Message: ""},
|
|
},
|
|
{
|
|
name: "errors multiple",
|
|
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
|
|
want: &Error{Code: "blob_unknown", Message: ""},
|
|
},
|
|
{
|
|
name: "error empty",
|
|
data: `{"error":""}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "error very empty",
|
|
data: `{}`,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "error message",
|
|
data: `{"error":"message", "code":"code"}`,
|
|
want: &Error{Code: "code", Message: "message"},
|
|
},
|
|
{
|
|
name: "invalid value",
|
|
data: `{"error": 1}`,
|
|
wantErr: true,
|
|
},
|
|
}
|
|
for _, tt := range cases {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var got Error
|
|
err := json.Unmarshal([]byte(tt.data), &got)
|
|
if err != nil {
|
|
if tt.wantErr {
|
|
return
|
|
}
|
|
t.Errorf("Unmarshal() error = %v", err)
|
|
// fallthrough and check got
|
|
}
|
|
if tt.want == nil {
|
|
tt.want = &Error{}
|
|
}
|
|
if !reflect.DeepEqual(got, *tt.want) {
|
|
t.Errorf("got = %v; want %v", got, *tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestParseNameErrors tests that parseName returns errors messages with enough
|
|
// detail for users to debug naming issues they may encounter. Previous to this
|
|
// test, the error messages were not very helpful and each problem was reported
|
|
// as the same message.
|
|
//
|
|
// It is only for testing error messages, not that all invalids and valids are
|
|
// covered. Those are in other tests for names.Name and blob.Digest.
|
|
func TestParseNameExtendedErrors(t *testing.T) {
|
|
cases := []struct {
|
|
name string
|
|
err error
|
|
want string
|
|
}{}
|
|
|
|
var r Registry
|
|
for _, tt := range cases {
|
|
_, _, _, err := r.parseNameExtended(tt.name)
|
|
if !errors.Is(err, tt.err) {
|
|
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
|
}
|
|
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
|
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestParseNameExtended(t *testing.T) {
|
|
cases := []struct {
|
|
in string
|
|
scheme string
|
|
name string
|
|
digest string
|
|
err string
|
|
}{
|
|
{in: "http://m", scheme: "http", name: "m"},
|
|
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
|
{in: "http+insecure://m", err: "unsupported scheme"},
|
|
|
|
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
|
|
|
{in: "", err: "invalid or missing name"},
|
|
{in: "m", scheme: "https", name: "m"},
|
|
{in: "://", err: "invalid or missing name"},
|
|
{in: "@sha256:deadbeef", err: "invalid digest"},
|
|
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
|
}
|
|
for _, tt := range cases {
|
|
t.Run(tt.in, func(t *testing.T) {
|
|
var r Registry
|
|
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
|
if err != nil {
|
|
if tt.err == "" {
|
|
t.Errorf("err = %v; want nil", err)
|
|
} else if !strings.Contains(err.Error(), tt.err) {
|
|
t.Errorf("err = %v; want %q", err, tt.err)
|
|
}
|
|
} else if tt.err != "" {
|
|
t.Errorf("err = nil; want %q", tt.err)
|
|
}
|
|
if err == nil && !n.IsFullyQualified() {
|
|
t.Errorf("name = %q; want fully qualified", n)
|
|
}
|
|
|
|
if scheme != tt.scheme {
|
|
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
|
}
|
|
|
|
// smoke-test name is superset of tt.name
|
|
if !strings.Contains(n.String(), tt.name) {
|
|
t.Errorf("name = %q; want %q", n, tt.name)
|
|
}
|
|
|
|
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
|
if digest.String() != tt.digest {
|
|
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestUnlink(t *testing.T) {
|
|
t.Run("found by name", func(t *testing.T) {
|
|
check := testutil.Checker(t)
|
|
|
|
rc, _ := newRegistryClient(t, nil)
|
|
// make a blob and link it
|
|
d := blob.DigestFromBytes("{}")
|
|
err := blob.PutBytes(rc.Cache, d, "{}")
|
|
check(err)
|
|
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
|
|
check(err)
|
|
|
|
// confirm linked
|
|
_, err = rc.ResolveLocal("single")
|
|
check(err)
|
|
|
|
// unlink
|
|
_, err = rc.Unlink("single")
|
|
check(err)
|
|
|
|
// confirm unlinked
|
|
_, err = rc.ResolveLocal("single")
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
|
}
|
|
})
|
|
t.Run("not found by name", func(t *testing.T) {
|
|
rc, _ := newRegistryClient(t, nil)
|
|
ok, err := rc.Unlink("manifestNotFound")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if ok {
|
|
t.Error("expected not found")
|
|
}
|
|
})
|
|
}
|
|
|
|
// Many tests from here out, in this file are based on a single blob, "abc",
|
|
// with the checksum of its sha256 hash. The checksum is:
|
|
//
|
|
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
|
|
//
|
|
// Using the literal value instead of a constant with fmt.Xprintf calls proved
|
|
// to be the most readable and maintainable approach. The sum is consistently
|
|
// used in the tests and unique so searches do not yield false positives.
|
|
|
|
func checkRequest(t *testing.T, req *http.Request, method, path string) {
|
|
t.Helper()
|
|
if got := req.URL.Path; got != path {
|
|
t.Errorf("URL = %q, want %q", got, path)
|
|
}
|
|
if req.Method != method {
|
|
t.Errorf("Method = %q, want %q", req.Method, method)
|
|
}
|
|
}
|
|
|
|
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
|
|
s := httptest.NewServer(h)
|
|
t.Cleanup(s.Close)
|
|
cache, err := blob.Open(t.TempDir())
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
},
|
|
})
|
|
|
|
rc := &Registry{
|
|
Cache: cache,
|
|
HTTPClient: &http.Client{Transport: &http.Transport{
|
|
Dial: func(network, addr string) (net.Conn, error) {
|
|
return net.Dial(network, s.Listener.Addr().String())
|
|
},
|
|
}},
|
|
}
|
|
return rc, ctx
|
|
}
|
|
|
|
func TestPullChunked(t *testing.T) {
|
|
var steps atomic.Int64
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch steps.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
case 3, 4:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
switch rng := r.Header.Get("Range"); rng {
|
|
case "bytes=0-1":
|
|
io.WriteString(w, "ab")
|
|
case "bytes=2-2":
|
|
t.Logf("writing c")
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected range %q", rng)
|
|
}
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
testutil.Check(t, err)
|
|
|
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
|
testutil.Check(t, err)
|
|
|
|
if g := steps.Load(); g != 4 {
|
|
t.Fatalf("got %d steps, want 4", g)
|
|
}
|
|
}
|
|
|
|
func TestPullCached(t *testing.T) {
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
})
|
|
|
|
check := testutil.Checker(t)
|
|
|
|
// Premeptively cache the blob
|
|
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
check(err)
|
|
err = blob.PutBytes(c.Cache, d, []byte("abc"))
|
|
check(err)
|
|
|
|
// Pull only the manifest, which should be enough to resolve the cached blob
|
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
|
check(err)
|
|
}
|
|
|
|
func TestPullManifestError(t *testing.T) {
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
w.WriteHeader(http.StatusNotFound)
|
|
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
|
|
})
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
if err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
var got *Error
|
|
if !errors.Is(err, ErrModelNotFound) {
|
|
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
|
|
}
|
|
}
|
|
|
|
func TestPullLayerError(t *testing.T) {
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `!`)
|
|
})
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
if err == nil {
|
|
t.Fatalf("expected error")
|
|
}
|
|
var want *json.SyntaxError
|
|
if !errors.As(err, &want) {
|
|
t.Fatalf("err = %T, want %T", err, want)
|
|
}
|
|
}
|
|
|
|
func TestPullLayerChecksumError(t *testing.T) {
|
|
var step atomic.Int64
|
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
case 3:
|
|
w.WriteHeader(http.StatusNotFound)
|
|
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
|
|
case 4:
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.MaxStreams = 1
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
var written atomic.Int64
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
written.Add(n)
|
|
},
|
|
})
|
|
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
var got *Error
|
|
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
|
|
t.Fatalf("err = %v, want %v", err, got)
|
|
}
|
|
|
|
if g := written.Load(); g != 1 {
|
|
t.Fatalf("wrote %d bytes, want 1", g)
|
|
}
|
|
}
|
|
|
|
func TestPullChunksumStreamError(t *testing.T) {
|
|
var step atomic.Int64
|
|
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
|
|
// Write one valid chunksum and one invalid chunksum
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
|
|
fmt.Fprint(w, "sha256:!") // invalid
|
|
case 3:
|
|
io.WriteString(w, "ab")
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
got := c.Pull(ctx, "http://o.com/library/abc")
|
|
if !errors.Is(got, ErrIncomplete) {
|
|
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
|
|
}
|
|
}
|
|
|
|
type flushAfterWriter struct {
|
|
w io.Writer
|
|
}
|
|
|
|
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
|
|
n, err = f.w.Write(p)
|
|
f.w.(http.Flusher).Flush() // panic if not a flusher
|
|
return
|
|
}
|
|
|
|
func TestPullChunksumStreaming(t *testing.T) {
|
|
csr, csw := io.Pipe()
|
|
defer csw.Close()
|
|
|
|
var step atomic.Int64
|
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
|
|
_, err := io.Copy(fw, csr)
|
|
if err != nil {
|
|
t.Errorf("copy: %v", err)
|
|
}
|
|
case 3:
|
|
io.WriteString(w, "ab")
|
|
case 4:
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
update := make(chan int64, 1)
|
|
ctx := WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
if n > 0 {
|
|
update <- n
|
|
}
|
|
},
|
|
})
|
|
|
|
errc := make(chan error, 1)
|
|
go func() {
|
|
errc <- c.Pull(ctx, "http://o.com/library/abc")
|
|
}()
|
|
|
|
// Send first chunksum and ensure it kicks off work immediately
|
|
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
if g := <-update; g != 2 {
|
|
t.Fatalf("got %d, want 2", g)
|
|
}
|
|
|
|
// now send the second chunksum and ensure it kicks off work immediately
|
|
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
if g := <-update; g != 3 {
|
|
t.Fatalf("got %d, want 3", g)
|
|
}
|
|
csw.Close()
|
|
testutil.Check(t, <-errc)
|
|
}
|
|
|
|
func TestPullChunksumsCached(t *testing.T) {
|
|
var step atomic.Int64
|
|
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
|
|
switch step.Add(1) {
|
|
case 1:
|
|
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
|
|
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
|
|
case 2:
|
|
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
|
|
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
|
|
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
|
|
case 3, 4:
|
|
switch rng := r.Header.Get("Range"); rng {
|
|
case "bytes=0-1":
|
|
io.WriteString(w, "ab")
|
|
case "bytes=2-2":
|
|
io.WriteString(w, "c")
|
|
default:
|
|
t.Errorf("unexpected range %q", rng)
|
|
}
|
|
default:
|
|
t.Errorf("unexpected steps %d: %v", step.Load(), r)
|
|
http.Error(w, "unexpected steps", http.StatusInternalServerError)
|
|
}
|
|
})
|
|
|
|
c.MaxStreams = 1 // force serial processing of chunksums
|
|
c.ChunkingThreshold = 1 // force chunking
|
|
|
|
ctx, cancel := context.WithCancel(t.Context())
|
|
defer cancel()
|
|
|
|
// Cancel the pull after the first chunksum is processed, but before
|
|
// the second chunksum is processed (which is waiting because
|
|
// MaxStreams=1). This should cause the second chunksum to error out
|
|
// leaving the blob incomplete.
|
|
ctx = WithTrace(ctx, &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
if n > 0 {
|
|
cancel()
|
|
}
|
|
},
|
|
})
|
|
err := c.Pull(ctx, "http://o.com/library/abc")
|
|
if !errors.Is(err, context.Canceled) {
|
|
t.Fatalf("err = %v, want %v", err, context.Canceled)
|
|
}
|
|
|
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
|
if !errors.Is(err, fs.ErrNotExist) {
|
|
t.Fatalf("err = %v, want nil", err)
|
|
}
|
|
|
|
// Reset state and pull again to ensure the blob chunks that should
|
|
// have been cached are, and the remaining chunk was downloaded, making
|
|
// the blob complete.
|
|
step.Store(0)
|
|
var written atomic.Int64
|
|
var cached atomic.Int64
|
|
ctx = WithTrace(t.Context(), &Trace{
|
|
Update: func(l *Layer, n int64, err error) {
|
|
t.Log("trace:", l.Digest.Short(), n, err)
|
|
if errors.Is(err, ErrCached) {
|
|
cached.Add(n)
|
|
}
|
|
written.Add(n)
|
|
},
|
|
})
|
|
|
|
check := testutil.Checker(t)
|
|
|
|
err = c.Pull(ctx, "http://o.com/library/abc")
|
|
check(err)
|
|
|
|
_, err = c.Cache.Resolve("o.com/library/abc:latest")
|
|
check(err)
|
|
|
|
if g := written.Load(); g != 5 {
|
|
t.Fatalf("wrote %d bytes, want 3", g)
|
|
}
|
|
if g := cached.Load(); g != 2 { // "ab" should have been cached
|
|
t.Fatalf("cached %d bytes, want 5", g)
|
|
}
|
|
}
|