add show command (#474)

This commit is contained in:
Patrick Devine
2023-09-06 11:04:17 -07:00
committed by GitHub
parent 7de300856b
commit 790d24eb7b
5 changed files with 299 additions and 50 deletions

View File

@@ -41,15 +41,18 @@ type RegistryOptions struct {
}
type Model struct {
Name string `json:"name"`
ModelPath string
AdapterPaths []string
Template string
System string
Digest string
ConfigDigest string
Options map[string]interface{}
Embeddings []vector.Embedding
Name string `json:"name"`
ShortName string
ModelPath string
OriginalModel string
AdapterPaths []string
Template string
System string
License []string
Digest string
ConfigDigest string
Options map[string]interface{}
Embeddings []vector.Embedding
}
func (m *Model) Prompt(request api.GenerateRequest, embedding string) (string, error) {
@@ -171,9 +174,11 @@ func GetModel(name string) (*Model, error) {
model := &Model{
Name: mp.GetFullTagname(),
ShortName: mp.GetShortTagname(),
Digest: digest,
ConfigDigest: manifest.Config.Digest,
Template: "{{ .Prompt }}",
License: []string{},
}
for _, layer := range manifest.Layers {
@@ -185,6 +190,7 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
model.OriginalModel = layer.From
case "application/vnd.ollama.image.embed":
file, err := os.Open(filename)
if err != nil {
@@ -229,6 +235,12 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
model.License = append(model.License, string(bts))
}
}
@@ -933,6 +945,83 @@ func DeleteModel(name string) error {
return nil
}
func ShowModelfile(model *Model) (string, error) {
type modelTemplate struct {
*Model
From string
Params string
}
var params []string
for k, v := range model.Options {
switch val := v.(type) {
case string:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, val))
case int:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(val)))
case float64:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(val, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(val)))
case []interface{}:
for _, nv := range val {
switch nval := nv.(type) {
case string:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, nval))
case int:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.Itoa(nval)))
case float64:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatFloat(nval, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("PARAMETER %s %s", k, strconv.FormatBool(nval)))
default:
log.Printf("unknown type: %s", reflect.TypeOf(nv).String())
}
}
default:
log.Printf("unknown type: %s", reflect.TypeOf(v).String())
}
}
mt := modelTemplate{
Model: model,
From: model.OriginalModel,
Params: strings.Join(params, "\n"),
}
if mt.From == "" {
mt.From = model.ModelPath
}
modelFile := `# Modelfile generated by "ollama show"
# To build a new Modelfile based on this one, replace the FROM line with:
# FROM {{ .ShortName }}
FROM {{ .From }}
TEMPLATE """{{ .Template }}"""
SYSTEM """{{ .System }}"""
{{ .Params }}
`
for _, l := range mt.Model.AdapterPaths {
modelFile += fmt.Sprintf("ADAPTER %s\n", l)
}
tmpl, err := template.New("").Parse(modelFile)
if err != nil {
log.Printf("error parsing template: %q", err)
return "", err
}
var buf bytes.Buffer
if err = tmpl.Execute(&buf, mt); err != nil {
log.Printf("error executing template: %q", err)
return "", err
}
return buf.String(), nil
}
func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"})

View File

@@ -12,6 +12,7 @@ import (
"os/signal"
"path/filepath"
"reflect"
"strconv"
"strings"
"sync"
"syscall"
@@ -364,6 +365,77 @@ func DeleteModelHandler(c *gin.Context) {
}
}
func ShowModelHandler(c *gin.Context) {
var req api.ShowRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
resp, err := GetModelInfo(req.Name)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
c.JSON(http.StatusOK, resp)
}
func GetModelInfo(name string) (*api.ShowResponse, error) {
model, err := GetModel(name)
if err != nil {
return nil, err
}
resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"),
System: model.System,
Template: model.Template,
}
mf, err := ShowModelfile(model)
if err != nil {
return nil, err
}
resp.Modelfile = mf
var params []string
cs := 30
for k, v := range model.Options {
switch val := v.(type) {
case string:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
case int:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
case float64:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
case []interface{}:
for _, nv := range val {
switch nval := nv.(type) {
case string:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
case int:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
case float64:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
}
}
}
}
resp.Parameters = strings.Join(params, "\n")
return resp, nil
}
func ListModelsHandler(c *gin.Context) {
var models []api.ModelResponse
fp, err := GetManifestPath()
@@ -457,6 +529,7 @@ func Serve(ln net.Listener, origins []string) error {
r.POST("/api/copy", CopyModelHandler)
r.GET("/api/tags", ListModelsHandler)
r.DELETE("/api/delete", DeleteModelHandler)
r.POST("/api/show", ShowModelHandler)
log.Printf("Listening on %s", ln.Addr())
s := &http.Server{