chore: update mllama to use ollama engine (#10637)

This commit is contained in:
Michael Yang
2025-05-13 17:36:02 -07:00
committed by GitHub
parent 0478d440f0
commit 23125648b8
67 changed files with 785 additions and 4354 deletions

View File

@@ -57,10 +57,6 @@ type Sequence struct {
// input cache being used by this sequence
cache *InputCacheSlot
// does this sequence require cross-attention layers to be processed? - if we have seen
// an image for certain multi-modal models
crossAttention bool
// channel to send responses over
responses chan string
@@ -205,7 +201,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
return nil, fmt.Errorf("invalid image index: %d", n)
}
embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data)
if err != nil {
return nil, err
}
@@ -368,7 +364,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
defer s.mu.Unlock()
var batch *llama.Batch
crossAttention := false
seqIdx := s.nextSeq - 1
for range s.seqs {
@@ -416,9 +411,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
batch = tokenBatch
} else {
batch = embedBatch
seq.crossAttention = s.image.NeedCrossAttention(input)
}
} else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention {
} else if embedding != batch.IsEmbedding() {
s.nextSeq = seqIdx
break
}
@@ -427,7 +421,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
break
}
crossAttention = seq.crossAttention
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
seq.pendingInputs = append(seq.pendingInputs, input)
seq.iBatch = batch.NumTokens() - 1
@@ -440,20 +433,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
s.lc.SetCrossAttention(crossAttention)
err := s.lc.Decode(batch)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
if crossAttention {
// synchronize state to ensure the cross attention batch is complete.
// needed specifically for multi-GPU systems otherwise an inflight
// task may be incorrectly invalidated causing a crash
s.lc.Synchronize()
}
for i, seq := range s.seqs {
if seq == nil {
continue
@@ -622,8 +606,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...)
s.seqs[i] = seq
s.cond.Signal()
found = true