mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
chore: update mllama to use ollama engine (#10637)
This commit is contained in:
@@ -57,10 +57,6 @@ type Sequence struct {
|
||||
// input cache being used by this sequence
|
||||
cache *InputCacheSlot
|
||||
|
||||
// does this sequence require cross-attention layers to be processed? - if we have seen
|
||||
// an image for certain multi-modal models
|
||||
crossAttention bool
|
||||
|
||||
// channel to send responses over
|
||||
responses chan string
|
||||
|
||||
@@ -205,7 +201,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
|
||||
embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -368,7 +364,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var batch *llama.Batch
|
||||
crossAttention := false
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
@@ -416,9 +411,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
batch = tokenBatch
|
||||
} else {
|
||||
batch = embedBatch
|
||||
seq.crossAttention = s.image.NeedCrossAttention(input)
|
||||
}
|
||||
} else if embedding != batch.IsEmbedding() || crossAttention != seq.crossAttention {
|
||||
} else if embedding != batch.IsEmbedding() {
|
||||
s.nextSeq = seqIdx
|
||||
break
|
||||
}
|
||||
@@ -427,7 +421,6 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
break
|
||||
}
|
||||
|
||||
crossAttention = seq.crossAttention
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
@@ -440,20 +433,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
s.lc.SetCrossAttention(crossAttention)
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
if crossAttention {
|
||||
// synchronize state to ensure the cross attention batch is complete.
|
||||
// needed specifically for multi-GPU systems otherwise an inflight
|
||||
// task may be incorrectly invalidated causing a crash
|
||||
s.lc.Synchronize()
|
||||
}
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
@@ -622,8 +606,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
seq.crossAttention = s.image.NeedCrossAttention(seq.cache.Inputs...)
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
found = true
|
||||
|
||||
Reference in New Issue
Block a user