rerefactor

This commit is contained in:
Michael Yang
2024-02-14 11:29:49 -08:00
committed by jmorganca
parent 823a520266
commit e43648afe5
9 changed files with 224 additions and 251 deletions

95
server/auth.go Normal file
View File

@@ -0,0 +1,95 @@
package server
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
)
type registryChallenge struct {
Realm string
Service string
Scope string
}
func (r registryChallenge) URL() (*url.URL, error) {
redirectURL, err := url.Parse(r.Realm)
if err != nil {
return nil, err
}
values := redirectURL.Query()
values.Add("service", r.Service)
for _, s := range strings.Split(r.Scope, " ") {
values.Add("scope", s)
}
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
nonce, err := auth.NewNonce(rand.Reader, 16)
if err != nil {
return nil, err
}
values.Add("nonce", nonce)
redirectURL.RawQuery = values.Encode()
return redirectURL, nil
}
func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
redirectURL, err := challenge.URL()
if err != nil {
return "", err
}
sha256sum := sha256.Sum256(nil)
data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
headers := make(http.Header)
signature, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
headers.Add("Authorization", signature)
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil {
return "", err
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
if err != nil {
return "", fmt.Errorf("%d: %v", response.StatusCode, err)
}
if response.StatusCode >= http.StatusBadRequest {
if len(body) > 0 {
return "", fmt.Errorf("%d: %s", response.StatusCode, body)
} else {
return "", fmt.Errorf("%d", response.StatusCode)
}
}
var token api.TokenResponse
if err := json.Unmarshal(body, &token); err != nil {
return "", err
}
return token.Token, nil
}

View File

@@ -22,7 +22,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format"
)
@@ -86,7 +85,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil {
return err
@@ -138,11 +137,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *a
return nil
}
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
b.err = b.run(ctx, requestURL, opts)
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
@@ -211,7 +210,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.
return nil
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
headers := make(http.Header)
@@ -335,7 +334,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
type downloadOpts struct {
mp ModelPath
digest string
regOpts *auth.RegistryOptions
regOpts *registryOptions
fn func(api.ProgressResponse)
}

View File

@@ -16,17 +16,25 @@ import (
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"text/template"
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
type registryOptions struct {
Insecure bool
Username string
Password string
Token string
}
type Model struct {
Name string `json:"name"`
Config ConfigV2
@@ -312,7 +320,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
switch {
case errors.Is(err, os.ErrNotExist):
fn(api.ProgressResponse{Status: "pulling model"})
if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
return err
}
@@ -832,7 +840,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
return buf.String(), nil
}
func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -882,7 +890,7 @@ func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
return nil
}
func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
var manifest *ManifestV2
@@ -988,7 +996,7 @@ func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
return nil
}
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
headers := make(http.Header)
@@ -1020,9 +1028,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
var errUnauthorized = fmt.Errorf("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
for i := 0; i < 2; i++ {
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info(fmt.Sprintf("request failed: %v", err))
@@ -1034,9 +1042,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
switch {
case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry
authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(authenticate)
token, err := auth.GetAuthToken(ctx, authRedir)
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
if err != nil {
return nil, err
}
@@ -1063,6 +1070,58 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
return nil, errUnauthorized
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
requestURL.Scheme = "http"
}
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
if err != nil {
return nil, err
}
if headers != nil {
req.Header = headers
}
if regOpts != nil {
if regOpts.Token != "" {
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
} else if regOpts.Username != "" && regOpts.Password != "" {
req.SetBasicAuth(regOpts.Username, regOpts.Password)
}
}
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
if s := req.Header.Get("Content-Length"); s != "" {
contentLength, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, err
}
req.ContentLength = contentLength
}
proxyURL, err := http.ProxyFromEnvironment(req)
if err != nil {
return nil, err
}
client := http.Client{
Transport: &http.Transport{
Proxy: http.ProxyURL(proxyURL),
},
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func getValue(header, key string) string {
startIdx := strings.Index(header, key+"=")
if startIdx == -1 {
@@ -1086,10 +1145,10 @@ func getValue(header, key string) string {
return header[startIdx:endIdx]
}
func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
func parseRegistryChallenge(authStr string) registryChallenge {
authStr = strings.TrimPrefix(authStr, "Bearer ")
return auth.AuthRedirect{
return registryChallenge{
Realm: getValue(authStr, "realm"),
Service: getValue(authStr, "service"),
Scope: getValue(authStr, "scope"),

View File

@@ -25,7 +25,6 @@ import (
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/gpu"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/openai"
@@ -480,7 +479,7 @@ func PullModelHandler(c *gin.Context) {
ch <- r
}
regOpts := &auth.RegistryOptions{
regOpts := &registryOptions{
Insecure: req.Insecure,
}
@@ -529,7 +528,7 @@ func PushModelHandler(c *gin.Context) {
ch <- r
}
regOpts := &auth.RegistryOptions{
regOpts := &registryOptions{
Insecure: req.Insecure,
}

View File

@@ -18,7 +18,6 @@ import (
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/auth"
"github.com/jmorganca/ollama/format"
"golang.org/x/sync/errgroup"
)
@@ -50,7 +49,7 @@ const (
maxUploadPartSize int64 = 1000 * format.MegaByte
)
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
p, err := GetBlobsPath(b.Digest)
if err != nil {
return err
@@ -122,7 +121,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *aut
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
defer blobUploadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
@@ -213,7 +212,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
b.done = true
}
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
headers := make(http.Header)
headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -228,7 +227,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
md5sum := md5.New()
w := &progressWriter{blobUpload: b}
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
if err != nil {
w.Rollback()
return err
@@ -278,9 +277,8 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
case resp.StatusCode == http.StatusUnauthorized:
w.Rollback()
authenticate := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(authenticate)
token, err := auth.GetAuthToken(ctx, authRedir)
challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
token, err := getAuthorizationToken(ctx, challenge)
if err != nil {
return err
}
@@ -365,7 +363,7 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)