deprecate modelfile embed command (#759)

This commit is contained in:
Bruce MacDonald
2023-10-16 11:07:37 -04:00
committed by GitHub
parent 06bcfbd629
commit a0c3e989de
9 changed files with 19 additions and 301 deletions

View File

@@ -23,11 +23,9 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"gonum.org/v1/gonum/mat"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/vector"
)
var mode string = gin.DebugMode
@@ -47,8 +45,7 @@ func init() {
var loaded struct {
mu sync.Mutex
llm llm.LLM
Embeddings []vector.Embedding
llm llm.LLM
expireAt time.Time
expireTimer *time.Timer
@@ -90,11 +87,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.digest = ""
}
if model.Embeddings != nil && len(model.Embeddings) > 0 {
opts.EmbeddingOnly = true // this is requried to generate embeddings, completions will still work
loaded.Embeddings = model.Embeddings
}
llmModel, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
if err != nil {
return err
@@ -106,12 +98,12 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
loaded.options = opts
if opts.NumKeep < 0 {
promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
promptWithSystem, err := model.Prompt(api.GenerateRequest{})
if err != nil {
return err
}
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}}, "")
promptNoSystem, err := model.Prompt(api.GenerateRequest{Context: []int{0}})
if err != nil {
return err
}
@@ -195,22 +187,7 @@ func GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
embedding := ""
if model.Embeddings != nil && len(model.Embeddings) > 0 {
promptEmbed, err := loaded.llm.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO: set embed_top from specified parameters in modelfile
embed_top := 3
topK := vector.TopK(embed_top, mat.NewVecDense(len(promptEmbed), promptEmbed), loaded.Embeddings)
for _, e := range topK {
embedding = fmt.Sprintf("%s %s", embedding, e.Embedding.Data)
}
}
prompt, err := model.Prompt(req, embedding)
prompt, err := model.Prompt(req)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return