kvcache: Add check for values that fall out of sliding window cache

The sliding window cache trims entries that are outside the window for
the latest token. This works when we are extending the cache, such as
when the conversation continues. However, if we have a partial overlap
in conversation (including the BOS tokens), then we resume from a past
point in the conversation and the needed tokens are no longer stored
in memory. This verifies that the new window overlaps with the old one
before reusing the cache.

Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
jmorganca
2025-03-30 16:05:40 -07:00
committed by Jesse Gross
parent 493385eb3e
commit b42970063d
7 changed files with 131 additions and 2 deletions

View File

@@ -581,6 +581,35 @@ func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
c.cellRanges[dstSeq] = seqRange
}
func (c *Causal) CanResume(seq int, pos int32) bool {
if c.windowSize == math.MaxInt32 {
return true
}
seqRange, ok := c.cellRanges[seq]
if !ok {
return false
}
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
if last == -1 {
return false
}
lastWindowStart := max(0, last-c.windowSize)
posWindowStart := max(0, pos-c.windowSize)
return posWindowStart >= lastWindowStart
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
if c.shiftFn == nil {
return ErrNotSupported
@@ -635,6 +664,12 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
}
func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
// TODO(jessegross): We should check to see if removing the middle of the sequence will
// cause the sliding window to encompass tokens that we no longer have. If so, then we
// should return an error, which will trigger the runner to evaluate the full history and
// rebuild the window. However, if we have multimodal inputs in our history, this reuse
// results in use after free, so we don't do it for now.
var offset int32
if endIndex != math.MaxInt32 {
offset = beginIndex - endIndex
@@ -649,8 +684,7 @@ func (c *Causal) Remove(seq int, beginIndex, endIndex int32) error {
} else {
if c.cells[i].pos >= endIndex {
if slices.ContainsFunc(c.cells[i].sequences, func(s int) bool { return s != seq }) {
// TODO(jessegross): Need to be careful about data shared between sequences
return errors.New("shifting on cells shared by multiple sequences not yet implemented")
return errors.New("shifting cells shared by multiple sequences not supported")
}
c.cells[i].pos += offset