mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 00:07:07 +00:00
kvcache: Add check for values that fall out of sliding window cache
The sliding window cache trims entries that are outside the window for the latest token. This works when we are extending the cache, such as when the conversation continues. However, if we have a partial overlap in conversation (including the BOS tokens), then we resume from a past point in the conversation and the needed tokens are no longer stored in memory. This verifies that the new window overlaps with the old one before reusing the cache. Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
@@ -300,6 +300,77 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCanResume(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
windowSize := int32(4)
|
||||
cache := NewSWACache(windowSize, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
|
||||
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
if !cache.CanResume(0, 0) {
|
||||
t.Errorf("CanResume(0, 0) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 1) {
|
||||
t.Errorf("CanResume(0, 1) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 2) {
|
||||
t.Errorf("CanResume(0, 2) = false, want true (within window)")
|
||||
}
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor, _ = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
if cache.CanResume(0, 0) {
|
||||
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 1) {
|
||||
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 2) {
|
||||
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 3) {
|
||||
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
|
||||
}
|
||||
if cache.CanResume(0, 4) {
|
||||
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
|
||||
}
|
||||
if !cache.CanResume(0, 5) {
|
||||
t.Errorf("after shift: CanResume(0, 5) = false, want true (latest position)")
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct{}
|
||||
|
||||
func (b *testBackend) Config() ml.Config {
|
||||
|
||||
Reference in New Issue
Block a user