From 95e271d98f0b5717bde2b7e5c4c7639c19fd3763 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 17 Mar 2025 15:11:15 -0700 Subject: [PATCH] runner: remove cache prompt flag from ollama runner (#9826) We do not need to bypass the prompt caching in the ollama runner yet, as only embedding models needed to bypass the prompt caching. When embedding models are implemented they can skip initializing this cache completely. --- runner/ollamarunner/cache.go | 7 +- runner/ollamarunner/cache_test.go | 128 ++++++++++++++++++++++++++++++ runner/ollamarunner/runner.go | 2 +- 3 files changed, 130 insertions(+), 7 deletions(-) diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index adcb3f73..cf5e6b91 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -89,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp return nil, nil, err } - // TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved? - if !cachePrompt { - numPast = 0 - } - slot.InUse = true slot.lastUsed = time.Now() diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 0a1b73f5..f8925d11 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) { }) } } + +func TestLoadCacheSlot(t *testing.T) { + tests := []struct { + name string + cache InputCache + prompt []input.Input + wantErr bool + expectedSlotId int + expectedPrompt int // expected length of remaining prompt + }{ + { + name: "Basic cache hit - single user", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Basic cache hit - multi user", + cache: InputCache{ + multiUserCache: true, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + { + Id: 1, + Inputs: []input.Input{}, + InUse: false, + lastUsed: time.Now().Add(-2 * time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Only token 3 remains + }, + { + name: "Exact match - leave one input", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: false, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}}, + wantErr: false, + expectedSlotId: 0, + expectedPrompt: 1, // Should leave 1 token for sampling + }, + { + name: "No available slots", + cache: InputCache{ + multiUserCache: false, + slots: []InputCacheSlot{ + { + Id: 0, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, + InUse: true, + lastUsed: time.Now().Add(-time.Second), + }, + }, + }, + prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + wantErr: true, + expectedSlotId: -1, + expectedPrompt: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt) + + // Check error state + if (err != nil) != tt.wantErr { + t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return // Skip further checks if we expected an error + } + + // Verify slot ID + if slot.Id != tt.expectedSlotId { + t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId) + } + + // Verify slot is now marked in use + if !slot.InUse { + t.Errorf("LoadCacheSlot() slot not marked InUse") + } + + // Verify remaining prompt length + if len(remainingPrompt) != tt.expectedPrompt { + t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v", + len(remainingPrompt), tt.expectedPrompt) + } + }) + } +} diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d4c24556..98fcf2c0 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -590,7 +590,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { found := false for i, sq := range s.seqs { if sq == nil { - seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true) + seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs) if err != nil { s.mu.Unlock() http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)