runner.go: Make KV entry accounting more robust

The structure of the accounting for KV cache shifting was carried
over from the old runner but it now doesn't feel natural with the new
runner. There are a number of invariants that should hold true but
are difficult to reason about. There is at least one bug report
that would imply that the invariants are not holding.

This reduces the number of implicit assumptions and is more forgiving
of unexpected situations. It also improves behavior around which input
tokens are kept when truncation occurs.

Bug #7545
This commit is contained in:
Jesse Gross
2024-11-08 11:10:56 -08:00
committed by Jesse Gross
parent bebef1e50d
commit 65973ceb64
2 changed files with 59 additions and 57 deletions

View File

@@ -34,9 +34,6 @@ type input struct {
}
type Sequence struct {
// number of inputs evaluated
numPast int
// batch index
iBatch int
@@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
params.numKeep = len(inputs)
}
if !params.embedding {
// Subtracting 4 ensures that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-4)
params.numKeep += s.bosToken
} else {
// Embeddings are 1 shot - just truncate to the context window, without ever shifting
params.numKeep = min(params.numKeep, s.cache.numCtx)
if s.model.AddBOSToken() {
params.numKeep += 1
}
// truncate to fit in context window
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
if len(inputs) > s.cache.numCtx {
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
inputs = newInputs
slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx)
}
var sc *llama.SamplingContext
@@ -231,9 +222,6 @@ type Server struct {
// KV cache
cache *InputCache
// does this model require a beginning of sequence token?
bosToken int
// next sequence for prompt processing to avoid starvation
nextSeq int
@@ -258,18 +246,6 @@ func (s *Server) allNil() bool {
return true
}
func (s *Server) shiftContext(seq *Sequence) {
numLeft := seq.numPast - seq.numKeep
numDiscard := numLeft / 2
slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast,
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast)
seq.numPast -= numDiscard
}
func flushPending(seq *Sequence) bool {
joined := strings.Join(seq.pendingResponses, "")
seq.pendingResponses = []string{}
@@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
if seq.numPast+len(seq.inputs) > s.cache.numCtx {
s.shiftContext(seq)
}
var numInputsProcessed int
shifted := false
for i, input := range seq.inputs {
if len(seq.cache.Inputs)+1 > s.cache.numCtx {
if !shifted {
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
shifted = true
} else {
break
}
}
embedding := input.embed != nil
// If we don't currently have a batch, use one of the correct type and
@@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
}
crossAttention = seq.crossAttention
batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id)
seq.numPast++
batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
seq.cache.Inputs = append(seq.cache.Inputs, input)
numInputsProcessed++
}
if numInputsProcessed > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...)
seq.inputs = seq.inputs[numInputsProcessed:]
seq.iBatch = batch.NumTokens() - 1
}
@@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -802,10 +784,6 @@ func (s *Server) loadModel(
}
}
if s.model.AddBOSToken() {
s.bosToken = 1
}
if ppath != "" {
var err error
s.image, err = NewImageContext(s.lc, ppath)
@@ -814,7 +792,10 @@ func (s *Server) loadModel(
}
}
s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
if err != nil {
panic(err)
}
s.status = ServerStatusReady
s.ready.Done()