cache loaded model

This commit is contained in:
Jeffrey Morgan
2023-07-31 21:35:18 -04:00
parent 81f75696e2
commit 528bafa585
4 changed files with 30 additions and 42 deletions

View File

@@ -10,6 +10,7 @@ import (
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"time"
@@ -22,19 +23,21 @@ import (
"github.com/jmorganca/ollama/llama"
)
var activeSession struct {
var loaded struct {
mu sync.Mutex
id int64
llm *llama.LLM
expireAt time.Time
expireTimer *time.Timer
digest string
options api.Options
}
func GenerateHandler(c *gin.Context) {
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
loaded.mu.Lock()
defer loaded.mu.Unlock()
checkpointStart := time.Now()
@@ -50,10 +53,10 @@ func GenerateHandler(c *gin.Context) {
return
}
if req.SessionID == 0 || req.SessionID != activeSession.id {
if activeSession.llm != nil {
activeSession.llm.Close()
activeSession.llm = nil
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, req.Options) {
if loaded.llm != nil {
loaded.llm.Close()
loaded.llm = nil
}
opts := api.DefaultOptions()
@@ -73,33 +76,31 @@ func GenerateHandler(c *gin.Context) {
return
}
activeSession.id = time.Now().UnixNano()
activeSession.llm = llm
loaded.llm = llm
loaded.digest = model.Digest
}
sessionDuration := req.SessionDuration
sessionID := activeSession.id
sessionDuration := 5 * time.Minute
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
if activeSession.expireTimer == nil {
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
activeSession.mu.Lock()
defer activeSession.mu.Unlock()
loaded.expireAt = time.Now().Add(sessionDuration)
if loaded.expireTimer == nil {
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
loaded.mu.Lock()
defer loaded.mu.Unlock()
if sessionID != activeSession.id {
if time.Now().Before(loaded.expireAt) {
return
}
if time.Now().Before(activeSession.expireAt) {
if loaded.llm == nil {
return
}
activeSession.llm.Close()
activeSession.llm = nil
activeSession.id = 0
loaded.llm.Close()
loaded.llm = nil
})
}
activeSession.expireTimer.Reset(sessionDuration.Duration)
loaded.expireTimer.Reset(sessionDuration)
checkpointLoaded := time.Now()
@@ -113,13 +114,11 @@ func GenerateHandler(c *gin.Context) {
go func() {
defer close(ch)
fn := func(r api.GenerateResponse) {
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
activeSession.expireTimer.Reset(sessionDuration.Duration)
loaded.expireAt = time.Now().Add(sessionDuration)
loaded.expireTimer.Reset(sessionDuration)
r.Model = req.Model
r.CreatedAt = time.Now().UTC()
r.SessionID = activeSession.id
r.SessionExpiresAt = activeSession.expireAt.UTC()
if r.Done {
r.TotalDuration = time.Since(checkpointStart)
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -128,8 +127,7 @@ func GenerateHandler(c *gin.Context) {
ch <- r
}
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
log.Printf("llm.Predict failed with %s", err)
if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()