mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 08:17:03 +00:00
Token auth (#314)
This commit is contained in:
@@ -28,6 +28,7 @@ type RegistryOptions struct {
|
||||
Insecure bool
|
||||
Username string
|
||||
Password string
|
||||
Token string
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@@ -1129,18 +1130,30 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(method, url, body)
|
||||
// make a copy of the body in case we need to try the call to makeRequest again
|
||||
var buf bytes.Buffer
|
||||
if body != nil {
|
||||
_, err := io.Copy(&buf, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
bodyCopy := bytes.NewReader(buf.Bytes())
|
||||
|
||||
req, err := http.NewRequest(method, url, bodyCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
if regOpts.Token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
||||
} else if regOpts.Username != "" && regOpts.Password != "" {
|
||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||
}
|
||||
|
||||
// TODO: better auth
|
||||
if regOpts.Username != "" && regOpts.Password != "" {
|
||||
req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
||||
for k, v := range headers {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
client := &http.Client{
|
||||
@@ -1157,9 +1170,55 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if the request is unauthenticated, try to authenticate and make the request again
|
||||
if resp.StatusCode == http.StatusUnauthorized {
|
||||
auth := resp.Header.Get("Www-Authenticate")
|
||||
authRedir := ParseAuthRedirectString(string(auth))
|
||||
token, err := getAuthToken(authRedir, regOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
regOpts.Token = token
|
||||
bodyCopy = bytes.NewReader(buf.Bytes())
|
||||
return makeRequest(method, url, headers, bodyCopy, regOpts)
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func getValue(header, key string) string {
|
||||
startIdx := strings.Index(header, key+"=")
|
||||
if startIdx == -1 {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Move the index to the starting quote after the key.
|
||||
startIdx += len(key) + 2
|
||||
endIdx := startIdx
|
||||
|
||||
for endIdx < len(header) {
|
||||
if header[endIdx] == '"' {
|
||||
if endIdx+1 < len(header) && header[endIdx+1] != ',' { // If the next character isn't a comma, continue
|
||||
endIdx++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
endIdx++
|
||||
}
|
||||
return header[startIdx:endIdx]
|
||||
}
|
||||
|
||||
func ParseAuthRedirectString(authStr string) AuthRedirect {
|
||||
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
||||
|
||||
return AuthRedirect{
|
||||
Realm: getValue(authStr, "realm"),
|
||||
Service: getValue(authStr, "service"),
|
||||
Scope: getValue(authStr, "scope"),
|
||||
}
|
||||
}
|
||||
|
||||
var errDigestMismatch = fmt.Errorf("digest mismatch, file must be downloaded again")
|
||||
|
||||
func verifyBlob(digest string) error {
|
||||
|
||||
Reference in New Issue
Block a user