diff --git a/kvcache/causal.go b/kvcache/causal.go index 8b101a81..496eeaa6 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -646,18 +646,31 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { seqRange := c.cellRanges[seq] for start := seqRange.min; start <= seqRange.max; start += c.maxBatch { - ctx := c.backend.NewContext() - size := min(seqRange.max-start+1, c.maxBatch) offsets := make([]int32, size) + + var batchFirst, batchLast int + + batchFirst = -1 for i := range offsets { cell := c.cells[start+i] if slices.Contains(cell.sequences, seq) && cell.pos >= beginIndex { offsets[i] = offset + if batchFirst < 0 { + batchFirst = i + } + batchLast = i } } + if batchFirst < 0 { + continue + } + + offsets = offsets[batchFirst : batchLast+1] + + ctx := c.backend.NewContext() kShift := ctx.Input().FromIntSlice(offsets, len(offsets)) for i, key := range c.keys { @@ -669,10 +682,10 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { numKVHeads := key.Dim(1) rowSize := key.Stride(2) - key = key.View(ctx, rowSize*start, + key = key.View(ctx, rowSize*(start+batchFirst), kHeadDim, key.Stride(1), numKVHeads, key.Stride(2), - size, + len(offsets), ) roped, err := c.shiftFn(ctx, i, key, kShift)