mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-12 08:47:01 +00:00
runner.go: Better abstract vision model integration
-Update mllama to take the cross attention state as embeddings in a batch, more similar to how Llava handles it. This improves integration with the input cache. -Pass locations in a prompt for embeddings using tags similar to Llava. -Abstract interface to vision models so the main runner accesses Clip and Mllama similarly Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
@@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
hash := s.cache.HashImage(images[imageIndex].Data)
|
||||
|
||||
// Vision models cannot be accessed concurrently
|
||||
s.clip.mu.Lock()
|
||||
embed, err := s.cache.FindImage(hash)
|
||||
if err != nil {
|
||||
embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
|
||||
s.cache.AddImage(hash, embed)
|
||||
}
|
||||
s.clip.mu.Unlock()
|
||||
|
||||
embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
|
||||
for _, e := range embed {
|
||||
inputs = append(inputs, input{embed: e})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.clip.cc != nil {
|
||||
var embed [][]float32
|
||||
|
||||
if s.clip.cc.IsMllama && len(images) >= 1 {
|
||||
hash := s.cache.HashImage(images[0].Data)
|
||||
|
||||
s.clip.mu.Lock()
|
||||
var err error
|
||||
embed, err = s.cache.FindImage(hash)
|
||||
if err != nil {
|
||||
embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
|
||||
s.cache.AddImage(hash, embed)
|
||||
}
|
||||
s.clip.mu.Unlock()
|
||||
}
|
||||
s.mu.Lock()
|
||||
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
type clip struct {
|
||||
cc *llama.ClipContext
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
model *llama.Model
|
||||
lc *llama.Context
|
||||
|
||||
// required for image embeddings
|
||||
clip clip
|
||||
image *ImageContext
|
||||
|
||||
batchSize int
|
||||
|
||||
@@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool {
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
s.lc.SetCrossAttention(false)
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
if s.clip.cc != nil {
|
||||
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
|
||||
}
|
||||
s.seqs[seqIndex] = nil
|
||||
}
|
||||
|
||||
@@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
||||
defer tokenBatch.Free()
|
||||
|
||||
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs))
|
||||
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
|
||||
defer embedBatch.Free()
|
||||
|
||||
for {
|
||||
@@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
for _, input := range seq.inputs {
|
||||
if input.embed != nil {
|
||||
s.lc.SetCrossAttention(true)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
break
|
||||
@@ -815,7 +786,7 @@ func (s *Server) loadModel(
|
||||
|
||||
if ppath != "" {
|
||||
var err error
|
||||
s.clip.cc, err = llama.NewClipContext(ppath)
|
||||
s.image, err = NewImageContext(s.lc, ppath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user