runner.go: Don't add inputs to cache view until actually processed

We need to track which tokens are in the cache ourselves. We currently
add tokens to the cache tracker when we add them to batch but they are
not actually in the cache until we call Decode. This can cause
confusion when we are shifting the cache.

Avoids "could not find a KV slot for the batch" issues.

Bug #7545
This commit is contained in:
Jesse Gross
2024-11-19 10:51:47 -08:00
committed by Jesse Gross
parent 3fc1dc0e6f
commit c3ff916431
2 changed files with 31 additions and 18 deletions

View File

@@ -45,6 +45,9 @@ type Sequence struct {
// prompt inputs left to evaluate
inputs []input
// inputs that have been added to a batch but not yet submitted to Decode
pendingInputs []input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@@ -367,14 +370,13 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
var numInputsProcessed int
shifted := false
for i, input := range seq.inputs {
if len(seq.cache.Inputs)+1 > s.cache.numCtx {
if !shifted {
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
shifted = true
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
@@ -403,15 +405,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
crossAttention = seq.crossAttention
batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
seq.cache.Inputs = append(seq.cache.Inputs, input)
numInputsProcessed++
}
if numInputsProcessed > 0 {
seq.inputs = seq.inputs[numInputsProcessed:]
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
seq.pendingInputs = append(seq.pendingInputs, input)
seq.iBatch = batch.NumTokens() - 1
}
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if batch == nil || batch.NumTokens() == 0 {
@@ -444,6 +443,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
// After calling Decode, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []input{}
}
// don't sample prompt processing
if len(seq.inputs) != 0 {
continue