kvcache: Enable SWA to retain additional entries

Models that use sliding window attention can only resume a sequence
from the cache if it falls within the saved windows. This works well
if the next message picks up where the old one left off. However, it
generally prevents a partial prefix match unless the entire conversation
falls within the sliding window.

This can be a problem with reasoning models where the traces are
supposed to be removed from future messages, forcing the entire
history to be re-evaluated.

This change allows models to specify that a larger amount of the
history be retained in memory, to allow more partial resumption.
It still respects the window that the model was trained on for
token generation.
This commit is contained in:
Jesse Gross
2025-07-30 14:42:57 -07:00
committed by Jesse Gross
parent ff89ba90bc
commit 4183bb0574
2 changed files with 196 additions and 42 deletions

View File

@@ -60,6 +60,8 @@ func TestSWA(t *testing.T) {
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
@@ -69,7 +71,12 @@ func TestSWA(t *testing.T) {
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
@@ -79,7 +86,53 @@ func TestSWA(t *testing.T) {
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
cache := NewSWAMemCache(1, 3, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "FirstBatch",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
pos: []int32{0, 1, 2, 3},
expected: []float32{1, 2, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, x,
0, 0, x, x,
x, 0, 0, x,
x, x, 0, 0,
},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{4, 5, 6},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
}
@@ -437,6 +490,70 @@ func TestCanResume(t *testing.T) {
}
}
func TestCanResumeSWAMem(t *testing.T) {
backend := &testBackend{}
windowSize := int32(4)
memSize := int32(5)
cache := NewSWAMemCache(windowSize, memSize, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5},
Sequences: []int{0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
cache.Put(context, tensor, tensor)
// shift window by adding position 6
err = cache.StartForward(context, input.Batch{
Positions: []int32{6, 7},
Sequences: []int{0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
if cache.CanResume(0, 0) {
t.Errorf("after shift: CanResume(0, 0) = true, want false (outside window)")
}
if cache.CanResume(0, 1) {
t.Errorf("after shift: CanResume(0, 1) = true, want false (outside window)")
}
if cache.CanResume(0, 2) {
t.Errorf("after shift: CanResume(0, 2) = true, want false (outside window)")
}
if cache.CanResume(0, 3) {
t.Errorf("after shift: CanResume(0, 3) = true, want false (outside window)")
}
if cache.CanResume(0, 4) {
t.Errorf("after shift: CanResume(0, 4) = true, want false (outside window)")
}
if cache.CanResume(0, 5) {
t.Errorf("after shift: CanResume(0, 5) = true, want false (outside window)")
}
if !cache.CanResume(0, 6) {
t.Errorf("after shift: CanResume(0, 6) = false, want true (inside window)")
}
if !cache.CanResume(0, 7) {
t.Errorf("after shift: CanResume(0, 7) = false, want true (latest position)")
}
}
type testBackend struct {
ml.Backend
}