mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 15:57:04 +00:00
Move hub auth out to new package
This commit is contained in:
committed by
jmorganca
parent
9da9e8fb72
commit
f397e0e988
176
server/auth.go
176
server/auth.go
@@ -1,176 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
)
|
||||
|
||||
type AuthRedirect struct {
|
||||
Realm string
|
||||
Service string
|
||||
Scope string
|
||||
}
|
||||
|
||||
type SignatureData struct {
|
||||
Method string
|
||||
Path string
|
||||
Data []byte
|
||||
}
|
||||
|
||||
func generateNonce(length int) (string, error) {
|
||||
nonce := make([]byte, length)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(nonce), nil
|
||||
}
|
||||
|
||||
func (r AuthRedirect) URL() (*url.URL, error) {
|
||||
redirectURL, err := url.Parse(r.Realm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
values := redirectURL.Query()
|
||||
|
||||
values.Add("service", r.Service)
|
||||
|
||||
for _, s := range strings.Split(r.Scope, " ") {
|
||||
values.Add("scope", s)
|
||||
}
|
||||
|
||||
values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
|
||||
|
||||
nonce, err := generateNonce(16)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values.Add("nonce", nonce)
|
||||
|
||||
redirectURL.RawQuery = values.Encode()
|
||||
return redirectURL, nil
|
||||
}
|
||||
|
||||
func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
|
||||
redirectURL, err := redirData.URL()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
keyPath := filepath.Join(home, ".ollama", "id_ed25519")
|
||||
|
||||
rawKey, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
|
||||
return "", err
|
||||
}
|
||||
|
||||
s := SignatureData{
|
||||
Method: http.MethodGet,
|
||||
Path: redirectURL.String(),
|
||||
Data: nil,
|
||||
}
|
||||
|
||||
sig, err := s.Sign(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("Authorization", sig)
|
||||
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= http.StatusBadRequest {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%d: %v", resp.StatusCode, err)
|
||||
} else if len(responseBody) > 0 {
|
||||
return "", fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("%s", resp.Status)
|
||||
}
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var tok api.TokenResponse
|
||||
if err := json.Unmarshal(respBody, &tok); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tok.Token, nil
|
||||
}
|
||||
|
||||
// Bytes returns a byte slice of the data to sign for the request
|
||||
func (s SignatureData) Bytes() []byte {
|
||||
// We first derive the content hash of the request body using:
|
||||
// base64(hex(sha256(request body)))
|
||||
|
||||
hash := sha256.Sum256(s.Data)
|
||||
hashHex := make([]byte, hex.EncodedLen(len(hash)))
|
||||
hex.Encode(hashHex, hash[:])
|
||||
contentHash := base64.StdEncoding.EncodeToString(hashHex)
|
||||
|
||||
// We then put the entire request together in a serialize string using:
|
||||
// "<method>,<uri>,<content hash>"
|
||||
// e.g. "GET,http://localhost,OTdkZjM1O..."
|
||||
|
||||
return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ","))
|
||||
}
|
||||
|
||||
// SignData takes a SignatureData object and signs it with a raw private key
|
||||
func (s SignatureData) Sign(rawKey []byte) (string, error) {
|
||||
signer, err := ssh.ParsePrivateKey(rawKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// get the pubkey, but remove the type
|
||||
pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
|
||||
parts := bytes.Split(pubKey, []byte(" "))
|
||||
if len(parts) < 2 {
|
||||
return "", fmt.Errorf("malformed public key")
|
||||
}
|
||||
|
||||
signedData, err := signer.Sign(nil, s.Bytes())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// signature is <pubkey>:<signature>
|
||||
sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob))
|
||||
return sig, nil
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/auth"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
)
|
||||
|
||||
@@ -85,7 +86,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
|
||||
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -137,11 +138,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
|
||||
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
|
||||
b.err = b.run(ctx, requestURL, opts)
|
||||
}
|
||||
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
|
||||
defer blobDownloadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
@@ -210,7 +211,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
|
||||
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
g.Go(func() error {
|
||||
headers := make(http.Header)
|
||||
@@ -334,7 +335,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
|
||||
type downloadOpts struct {
|
||||
mp ModelPath
|
||||
digest string
|
||||
regOpts *RegistryOptions
|
||||
regOpts *auth.RegistryOptions
|
||||
fn func(api.ProgressResponse)
|
||||
}
|
||||
|
||||
|
||||
@@ -16,25 +16,17 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/auth"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/parser"
|
||||
"github.com/jmorganca/ollama/version"
|
||||
)
|
||||
|
||||
type RegistryOptions struct {
|
||||
Insecure bool
|
||||
Username string
|
||||
Password string
|
||||
Token string
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Name string `json:"name"`
|
||||
Config ConfigV2
|
||||
@@ -320,7 +312,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
fn(api.ProgressResponse{Status: "pulling model"})
|
||||
if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
|
||||
if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -840,7 +832,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
||||
|
||||
@@ -890,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
mp := ParseModelPath(name)
|
||||
|
||||
var manifest *ManifestV2
|
||||
@@ -996,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
||||
func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
|
||||
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
||||
|
||||
headers := make(http.Header)
|
||||
@@ -1028,9 +1020,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
||||
|
||||
var errUnauthorized = fmt.Errorf("unauthorized")
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
slog.Info(fmt.Sprintf("request failed: %v", err))
|
||||
@@ -1042,9 +1034,9 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
// Handle authentication error with one retry
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir)
|
||||
authenticate := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(authenticate)
|
||||
token, err := auth.GetAuthToken(ctx, authRedir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1071,58 +1063,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
||||
return nil, errUnauthorized
|
||||
}
|
||||
|
||||
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
||||
if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
||||
requestURL.Scheme = "http"
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if headers != nil {
|
||||
req.Header = headers
|
||||
}
|
||||
|
||||
if regOpts != nil {
|
||||
if regOpts.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
||||
} else if regOpts.Username != "" && regOpts.Password != "" {
|
||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||
}
|
||||
}
|
||||
|
||||
req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
||||
|
||||
if s := req.Header.Get("Content-Length"); s != "" {
|
||||
contentLength, err := strconv.ParseInt(s, 10, 64)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.ContentLength = contentLength
|
||||
}
|
||||
|
||||
proxyURL, err := http.ProxyFromEnvironment(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := http.Client{
|
||||
Transport: &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func getValue(header, key string) string {
|
||||
startIdx := strings.Index(header, key+"=")
|
||||
if startIdx == -1 {
|
||||
@@ -1146,10 +1086,10 @@ func getValue(header, key string) string {
|
||||
return header[startIdx:endIdx]
|
||||
}
|
||||
|
||||
func ParseAuthRedirectString(authStr string) AuthRedirect {
|
||||
func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
|
||||
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
||||
|
||||
return AuthRedirect{
|
||||
return auth.AuthRedirect{
|
||||
Realm: getValue(authStr, "realm"),
|
||||
Service: getValue(authStr, "service"),
|
||||
Scope: getValue(authStr, "scope"),
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/auth"
|
||||
"github.com/jmorganca/ollama/gpu"
|
||||
"github.com/jmorganca/ollama/llm"
|
||||
"github.com/jmorganca/ollama/openai"
|
||||
@@ -479,7 +480,7 @@ func PullModelHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
regOpts := &RegistryOptions{
|
||||
regOpts := &auth.RegistryOptions{
|
||||
Insecure: req.Insecure,
|
||||
}
|
||||
|
||||
@@ -528,7 +529,7 @@ func PushModelHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
regOpts := &RegistryOptions{
|
||||
regOpts := &auth.RegistryOptions{
|
||||
Insecure: req.Insecure,
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/jmorganca/ollama/api"
|
||||
"github.com/jmorganca/ollama/auth"
|
||||
"github.com/jmorganca/ollama/format"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
@@ -49,7 +50,7 @@ const (
|
||||
maxUploadPartSize int64 = 1000 * format.MegaByte
|
||||
)
|
||||
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
|
||||
func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
|
||||
p, err := GetBlobsPath(b.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -121,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
|
||||
|
||||
// Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
|
||||
// in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
|
||||
func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||
func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
|
||||
defer blobUploadManager.Delete(b.Digest)
|
||||
ctx, b.CancelFunc = context.WithCancel(ctx)
|
||||
|
||||
@@ -212,7 +213,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
|
||||
b.done = true
|
||||
}
|
||||
|
||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
|
||||
func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
|
||||
headers := make(http.Header)
|
||||
headers.Set("Content-Type", "application/octet-stream")
|
||||
headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
|
||||
@@ -227,7 +228,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
md5sum := md5.New()
|
||||
w := &progressWriter{blobUpload: b}
|
||||
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
|
||||
resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
|
||||
if err != nil {
|
||||
w.Rollback()
|
||||
return err
|
||||
@@ -277,9 +278,9 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
|
||||
|
||||
case resp.StatusCode == http.StatusUnauthorized:
|
||||
w.Rollback()
|
||||
auth := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(auth)
|
||||
token, err := getAuthToken(ctx, authRedir)
|
||||
authenticate := resp.Header.Get("www-authenticate")
|
||||
authRedir := ParseAuthRedirectString(authenticate)
|
||||
token, err := auth.GetAuthToken(ctx, authRedir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -364,7 +365,7 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user