fix: relay request opts to loaded llm prediction (#1761)

This commit is contained in:
Bruce MacDonald
2024-01-03 12:01:42 -05:00
committed by GitHub
parent 05face44ef
commit 0b3118e0af
5 changed files with 106 additions and 71 deletions

View File

@@ -64,24 +64,9 @@ var loaded struct {
var defaultSessionDuration = 5 * time.Minute
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sessionDuration time.Duration) (*Model, error) {
model, err := GetModel(modelName)
if err != nil {
return nil, err
}
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
workDir := c.GetString("workDir")
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
log.Printf("could not load model options: %v", err)
return nil, err
}
if err := opts.FromMap(reqOpts); err != nil {
return nil, err
}
needLoad := loaded.runner == nil || // is there a model loaded?
loaded.ModelPath != model.ModelPath || // has the base model changed?
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
@@ -105,7 +90,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, model.ShortName)
}
return nil, err
return err
}
loaded.Model = model
@@ -135,7 +120,20 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
}
loaded.expireTimer.Reset(sessionDuration)
return model, nil
return nil
}
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions()
if err := opts.FromMap(model.Options); err != nil {
return api.Options{}, err
}
if err := opts.FromMap(requestOpts); err != nil {
return api.Options{}, err
}
return opts, nil
}
func GenerateHandler(c *gin.Context) {
@@ -168,18 +166,30 @@ func GenerateHandler(c *gin.Context) {
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -287,9 +297,10 @@ func GenerateHandler(c *gin.Context) {
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: req.Images,
Prompt: prompt,
Format: req.Format,
Images: req.Images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}
@@ -347,18 +358,29 @@ func EmbeddingHandler(c *gin.Context) {
return
}
sessionDuration := defaultSessionDuration
_, err = load(c, req.Model, req.Options, sessionDuration)
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -991,18 +1013,29 @@ func ChatHandler(c *gin.Context) {
return
}
sessionDuration := defaultSessionDuration
model, err := load(c, req.Model, req.Options, sessionDuration)
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
switch {
case errors.As(err, &pErr):
if errors.As(err, &pErr) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found, try pulling it first", req.Model)})
case errors.Is(err, api.ErrInvalidOpts):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
opts, err := modelOptions(model, req.Options)
if err != nil {
if errors.Is(err, api.ErrInvalidOpts) {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -1053,9 +1086,10 @@ func ChatHandler(c *gin.Context) {
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: images,
Prompt: prompt,
Format: req.Format,
Images: images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
ch <- gin.H{"error": err.Error()}