mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 15:57:04 +00:00
cache loaded model
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -22,19 +23,21 @@ import (
|
||||
"github.com/jmorganca/ollama/llama"
|
||||
)
|
||||
|
||||
var activeSession struct {
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
id int64
|
||||
llm *llama.LLM
|
||||
|
||||
expireAt time.Time
|
||||
expireTimer *time.Timer
|
||||
|
||||
digest string
|
||||
options api.Options
|
||||
}
|
||||
|
||||
func GenerateHandler(c *gin.Context) {
|
||||
activeSession.mu.Lock()
|
||||
defer activeSession.mu.Unlock()
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
checkpointStart := time.Now()
|
||||
|
||||
@@ -50,10 +53,10 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.SessionID == 0 || req.SessionID != activeSession.id {
|
||||
if activeSession.llm != nil {
|
||||
activeSession.llm.Close()
|
||||
activeSession.llm = nil
|
||||
if model.Digest != loaded.digest || !reflect.DeepEqual(loaded.options, req.Options) {
|
||||
if loaded.llm != nil {
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
}
|
||||
|
||||
opts := api.DefaultOptions()
|
||||
@@ -73,33 +76,31 @@ func GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
activeSession.id = time.Now().UnixNano()
|
||||
activeSession.llm = llm
|
||||
loaded.llm = llm
|
||||
loaded.digest = model.Digest
|
||||
}
|
||||
|
||||
sessionDuration := req.SessionDuration
|
||||
sessionID := activeSession.id
|
||||
sessionDuration := 5 * time.Minute
|
||||
|
||||
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
||||
if activeSession.expireTimer == nil {
|
||||
activeSession.expireTimer = time.AfterFunc(sessionDuration.Duration, func() {
|
||||
activeSession.mu.Lock()
|
||||
defer activeSession.mu.Unlock()
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
if loaded.expireTimer == nil {
|
||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
if sessionID != activeSession.id {
|
||||
if time.Now().Before(loaded.expireAt) {
|
||||
return
|
||||
}
|
||||
|
||||
if time.Now().Before(activeSession.expireAt) {
|
||||
if loaded.llm == nil {
|
||||
return
|
||||
}
|
||||
|
||||
activeSession.llm.Close()
|
||||
activeSession.llm = nil
|
||||
activeSession.id = 0
|
||||
loaded.llm.Close()
|
||||
loaded.llm = nil
|
||||
})
|
||||
}
|
||||
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
checkpointLoaded := time.Now()
|
||||
|
||||
@@ -113,13 +114,11 @@ func GenerateHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
fn := func(r api.GenerateResponse) {
|
||||
activeSession.expireAt = time.Now().Add(sessionDuration.Duration)
|
||||
activeSession.expireTimer.Reset(sessionDuration.Duration)
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
r.Model = req.Model
|
||||
r.CreatedAt = time.Now().UTC()
|
||||
r.SessionID = activeSession.id
|
||||
r.SessionExpiresAt = activeSession.expireAt.UTC()
|
||||
if r.Done {
|
||||
r.TotalDuration = time.Since(checkpointStart)
|
||||
r.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
@@ -128,8 +127,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
ch <- r
|
||||
}
|
||||
|
||||
if err := activeSession.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
log.Printf("llm.Predict failed with %s", err)
|
||||
if err := loaded.llm.Predict(req.Context, prompt, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
Reference in New Issue
Block a user