mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 16:26:59 +00:00
runner.go: Make KV entry accounting more robust
The structure of the accounting for KV cache shifting was carried over from the old runner but it now doesn't feel natural with the new runner. There are a number of invariants that should hold true but are difficult to reason about. There is at least one bug report that would imply that the invariants are not holding. This reduces the number of implicit assumptions and is more forgiving of unexpected situations. It also improves behavior around which input tokens are kept when truncation occurs. Bug #7545
This commit is contained in:
@@ -34,9 +34,6 @@ type input struct {
|
||||
}
|
||||
|
||||
type Sequence struct {
|
||||
// number of inputs evaluated
|
||||
numPast int
|
||||
|
||||
// batch index
|
||||
iBatch int
|
||||
|
||||
@@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
||||
params.numKeep = len(inputs)
|
||||
}
|
||||
|
||||
if !params.embedding {
|
||||
// Subtracting 4 ensures that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-4)
|
||||
params.numKeep += s.bosToken
|
||||
} else {
|
||||
// Embeddings are 1 shot - just truncate to the context window, without ever shifting
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx)
|
||||
if s.model.AddBOSToken() {
|
||||
params.numKeep += 1
|
||||
}
|
||||
|
||||
// truncate to fit in context window
|
||||
// Ensure that at least 1 input can be discarded during shift
|
||||
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
|
||||
inputs = newInputs
|
||||
slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx)
|
||||
}
|
||||
|
||||
var sc *llama.SamplingContext
|
||||
@@ -231,9 +222,6 @@ type Server struct {
|
||||
// KV cache
|
||||
cache *InputCache
|
||||
|
||||
// does this model require a beginning of sequence token?
|
||||
bosToken int
|
||||
|
||||
// next sequence for prompt processing to avoid starvation
|
||||
nextSeq int
|
||||
|
||||
@@ -258,18 +246,6 @@ func (s *Server) allNil() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) shiftContext(seq *Sequence) {
|
||||
numLeft := seq.numPast - seq.numKeep
|
||||
numDiscard := numLeft / 2
|
||||
|
||||
slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast,
|
||||
"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
|
||||
|
||||
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast)
|
||||
|
||||
seq.numPast -= numDiscard
|
||||
}
|
||||
|
||||
func flushPending(seq *Sequence) bool {
|
||||
joined := strings.Join(seq.pendingResponses, "")
|
||||
seq.pendingResponses = []string{}
|
||||
@@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
if seq.numPast+len(seq.inputs) > s.cache.numCtx {
|
||||
s.shiftContext(seq)
|
||||
}
|
||||
|
||||
var numInputsProcessed int
|
||||
shifted := false
|
||||
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+1 > s.cache.numCtx {
|
||||
if !shifted {
|
||||
s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
shifted = true
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
embedding := input.embed != nil
|
||||
|
||||
// If we don't currently have a batch, use one of the correct type and
|
||||
@@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
}
|
||||
|
||||
crossAttention = seq.crossAttention
|
||||
batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id)
|
||||
seq.numPast++
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, input)
|
||||
numInputsProcessed++
|
||||
}
|
||||
|
||||
if numInputsProcessed > 0 {
|
||||
seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...)
|
||||
seq.inputs = seq.inputs[numInputsProcessed:]
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
}
|
||||
@@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
@@ -802,10 +784,6 @@ func (s *Server) loadModel(
|
||||
}
|
||||
}
|
||||
|
||||
if s.model.AddBOSToken() {
|
||||
s.bosToken = 1
|
||||
}
|
||||
|
||||
if ppath != "" {
|
||||
var err error
|
||||
s.image, err = NewImageContext(s.lc, ppath)
|
||||
@@ -814,7 +792,10 @@ func (s *Server) loadModel(
|
||||
}
|
||||
}
|
||||
|
||||
s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
|
||||
s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = ServerStatusReady
|
||||
s.ready.Done()
|
||||
|
||||
Reference in New Issue
Block a user