runner.go: Add unit tests for context shifting

This also makes it easier to truncate long inputs the same as
shifting but does not actually implement it. This type of
truncation has a trade off between quality and time to first
token.
This commit is contained in:
Jesse Gross
2024-11-25 14:49:38 -08:00
committed by Jesse Gross
parent 52bbad12f9
commit 2cd11ae365
3 changed files with 82 additions and 7 deletions

View File

@@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
})
}
}
func TestShiftDiscard(t *testing.T) {
tests := []struct {
name string
numCtx int
numKeep int
inputLen int
expected int
}{
{
name: "Shift",
numCtx: 2048,
numKeep: 5,
inputLen: 2048,
expected: 1021,
},
{
name: "Max Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 2048,
expected: 1,
},
{
name: "No Keep",
numCtx: 2048,
numKeep: 0,
inputLen: 2048,
expected: 1024,
},
{
name: "Truncate",
numCtx: 2048,
numKeep: 5,
inputLen: 5000,
expected: 3973,
},
{
name: "Truncate Keep",
numCtx: 2048,
numKeep: 2047,
inputLen: 5000,
expected: 2953,
},
{
name: "No Op",
numCtx: 2048,
numKeep: 5,
inputLen: 512,
expected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := InputCache{numCtx: tt.numCtx}
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
if result != tt.expected {
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
}
})
}
}