mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-12 00:37:04 +00:00
runner.go: Better abstract vision model integration
-Update mllama to take the cross attention state as embeddings in a batch, more similar to how Llava handles it. This improves integration with the input cache. -Pass locations in a prompt for embeddings using tags similar to Llava. -Abstract interface to vision models so the main runner accesses Clip and Mllama similarly Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
@@ -2,7 +2,6 @@ package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"hash/maphash"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"time"
|
||||
@@ -20,10 +19,6 @@ type InputCache struct {
|
||||
// optimize cache eviction for multiple users
|
||||
multiUserCache bool
|
||||
|
||||
// cache of images to embeddings
|
||||
images []imageCache
|
||||
imageHash maphash.Hash
|
||||
|
||||
lc *llama.Context
|
||||
}
|
||||
|
||||
@@ -41,7 +36,6 @@ func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache b
|
||||
numCtx: kvSize / numSlots,
|
||||
slots: slots,
|
||||
multiUserCache: multiUserCache,
|
||||
images: make([]imageCache, numSlots),
|
||||
lc: lc,
|
||||
}
|
||||
}
|
||||
@@ -211,55 +205,3 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int, numDiscar
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-numDiscard]
|
||||
}
|
||||
|
||||
// Locking: Lookup and store operations on imageCache require a lock
|
||||
// to be held that serializes these with each other. Hash does not
|
||||
// require a lock nor they need to be serialized with InputCacheSlot.
|
||||
|
||||
type imageCache struct {
|
||||
key uint64
|
||||
val [][]float32
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) HashImage(image []byte) uint64 {
|
||||
c.imageHash.Reset()
|
||||
_, _ = c.imageHash.Write(image)
|
||||
return c.imageHash.Sum64()
|
||||
}
|
||||
|
||||
var ErrImageNotFound = errors.New("image not found in cache")
|
||||
|
||||
func (c *InputCache) FindImage(hash uint64) ([][]float32, error) {
|
||||
for i := range c.images {
|
||||
if c.images[i].key == hash {
|
||||
slog.Debug("loading image embeddings from cache", "entry", i)
|
||||
c.images[i].lastUsed = time.Now()
|
||||
return c.images[i].val, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrImageNotFound
|
||||
}
|
||||
|
||||
func (c *InputCache) AddImage(hash uint64, embed [][]float32) {
|
||||
best := time.Now()
|
||||
var bestImage int
|
||||
|
||||
for i := range c.images {
|
||||
if c.images[i].key == hash {
|
||||
bestImage = i
|
||||
break
|
||||
}
|
||||
|
||||
if c.images[i].lastUsed.Compare(best) < 0 {
|
||||
best = c.images[i].lastUsed
|
||||
bestImage = i
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
||||
c.images[bestImage].key = hash
|
||||
c.images[bestImage].val = embed
|
||||
c.images[bestImage].lastUsed = time.Now()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@@ -228,77 +227,3 @@ func TestFindCacheSlot(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestImageCache(t *testing.T) {
|
||||
cache := NewInputCache(nil, 2048, 4, false)
|
||||
|
||||
valA := [][]float32{{0.1, 0.2}, {0.3}}
|
||||
valB := [][]float32{{0.4}, {0.5}, {0.6}}
|
||||
valC := [][]float32{{0.7}}
|
||||
valD := [][]float32{{0.8}}
|
||||
valE := [][]float32{{0.9}}
|
||||
|
||||
// Empty cache
|
||||
result, err := cache.FindImage(0x5adb61d31933a946)
|
||||
if err != ErrImageNotFound {
|
||||
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Insert A
|
||||
cache.AddImage(0x5adb61d31933a946, valA)
|
||||
|
||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Insert B
|
||||
cache.AddImage(0x011551369a34a901, valB)
|
||||
|
||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.FindImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valB) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Replace B with C
|
||||
cache.AddImage(0x011551369a34a901, valC)
|
||||
|
||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.FindImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valC) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Evict A
|
||||
cache.AddImage(0x756b218a517e7353, valB)
|
||||
cache.AddImage(0x75e5e8d35d7e3967, valD)
|
||||
cache.AddImage(0xd96f7f268ca0646e, valE)
|
||||
|
||||
result, err = cache.FindImage(0x5adb61d31933a946)
|
||||
if reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.FindImage(0x756b218a517e7353)
|
||||
if !reflect.DeepEqual(result, valB) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.FindImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valC) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.FindImage(0x75e5e8d35d7e3967)
|
||||
if !reflect.DeepEqual(result, valD) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.FindImage(0xd96f7f268ca0646e)
|
||||
if !reflect.DeepEqual(result, valE) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
}
|
||||
|
||||
145
llama/runner/image.go
Normal file
145
llama/runner/image.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"log/slog"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
const imageCacheSize = 4
|
||||
|
||||
type ImageContext struct {
|
||||
// mu is required to be held when generating embeddings or accessing the cache
|
||||
mu sync.Mutex
|
||||
|
||||
clip *llama.ClipContext
|
||||
mllama *llama.MllamaContext
|
||||
|
||||
// cache of images to embeddings
|
||||
images []imageCache
|
||||
imageHash maphash.Hash
|
||||
}
|
||||
|
||||
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
|
||||
arch, err := llama.GetModelArch(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
|
||||
}
|
||||
|
||||
var c ImageContext
|
||||
if arch == "clip" {
|
||||
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
|
||||
} else if arch == "mllama" {
|
||||
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.images = make([]imageCache, imageCacheSize)
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (c *ImageContext) Free(modelPath string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.clip != nil {
|
||||
c.clip.Free()
|
||||
}
|
||||
if c.mllama != nil {
|
||||
c.mllama.Free()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) [][]float32 {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
hash := c.hashImage(data)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
embed, err := c.findImage(hash)
|
||||
if err != nil {
|
||||
if c.mllama != nil {
|
||||
embed = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
|
||||
} else if c.clip != nil {
|
||||
embed = c.clip.NewEmbed(llamaContext, data)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
|
||||
c.addImage(hash, embed)
|
||||
}
|
||||
|
||||
return embed
|
||||
}
|
||||
|
||||
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
||||
if c != nil && c.mllama != nil {
|
||||
return c.mllama.EmbedSize(llamaContext)
|
||||
} else {
|
||||
return llamaContext.Model().NEmbd()
|
||||
}
|
||||
}
|
||||
|
||||
type imageCache struct {
|
||||
key uint64
|
||||
val [][]float32
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *ImageContext) hashImage(image []byte) uint64 {
|
||||
c.imageHash.Reset()
|
||||
_, _ = c.imageHash.Write(image)
|
||||
return c.imageHash.Sum64()
|
||||
}
|
||||
|
||||
var errImageNotFound = errors.New("image not found in cache")
|
||||
|
||||
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
|
||||
for i := range c.images {
|
||||
if c.images[i].key == hash {
|
||||
slog.Debug("loading image embeddings from cache", "entry", i)
|
||||
c.images[i].lastUsed = time.Now()
|
||||
return c.images[i].val, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errImageNotFound
|
||||
}
|
||||
|
||||
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
|
||||
best := time.Now()
|
||||
var bestImage int
|
||||
|
||||
for i := range c.images {
|
||||
if c.images[i].key == hash {
|
||||
bestImage = i
|
||||
break
|
||||
}
|
||||
|
||||
if c.images[i].lastUsed.Compare(best) < 0 {
|
||||
best = c.images[i].lastUsed
|
||||
bestImage = i
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
||||
c.images[bestImage].key = hash
|
||||
c.images[bestImage].val = embed
|
||||
c.images[bestImage].lastUsed = time.Now()
|
||||
}
|
||||
80
llama/runner/image_test.go
Normal file
80
llama/runner/image_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestImageCache(t *testing.T) {
|
||||
cache := ImageContext{images: make([]imageCache, 4)}
|
||||
|
||||
valA := [][]float32{{0.1, 0.2}, {0.3}}
|
||||
valB := [][]float32{{0.4}, {0.5}, {0.6}}
|
||||
valC := [][]float32{{0.7}}
|
||||
valD := [][]float32{{0.8}}
|
||||
valE := [][]float32{{0.9}}
|
||||
|
||||
// Empty cache
|
||||
result, err := cache.findImage(0x5adb61d31933a946)
|
||||
if err != errImageNotFound {
|
||||
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Insert A
|
||||
cache.addImage(0x5adb61d31933a946, valA)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Insert B
|
||||
cache.addImage(0x011551369a34a901, valB)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valB) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Replace B with C
|
||||
cache.addImage(0x011551369a34a901, valC)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valC) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Evict A
|
||||
cache.addImage(0x756b218a517e7353, valB)
|
||||
cache.addImage(0x75e5e8d35d7e3967, valD)
|
||||
cache.addImage(0xd96f7f268ca0646e, valE)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x756b218a517e7353)
|
||||
if !reflect.DeepEqual(result, valB) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valC) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x75e5e8d35d7e3967)
|
||||
if !reflect.DeepEqual(result, valD) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0xd96f7f268ca0646e)
|
||||
if !reflect.DeepEqual(result, valE) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
}
|
||||
@@ -190,57 +190,22 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
hash := s.cache.HashImage(images[imageIndex].Data)
|
||||
|
||||
// Vision models cannot be accessed concurrently
|
||||
s.clip.mu.Lock()
|
||||
embed, err := s.cache.FindImage(hash)
|
||||
if err != nil {
|
||||
embed = llama.NewLlavaImageEmbed(s.lc, s.clip.cc, images[imageIndex].Data)
|
||||
s.cache.AddImage(hash, embed)
|
||||
}
|
||||
s.clip.mu.Unlock()
|
||||
|
||||
embed := s.image.NewEmbed(s.lc, images[imageIndex].Data, images[imageIndex].AspectRatioID)
|
||||
for _, e := range embed {
|
||||
inputs = append(inputs, input{embed: e})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if s.clip.cc != nil {
|
||||
var embed [][]float32
|
||||
|
||||
if s.clip.cc.IsMllama && len(images) >= 1 {
|
||||
hash := s.cache.HashImage(images[0].Data)
|
||||
|
||||
s.clip.mu.Lock()
|
||||
var err error
|
||||
embed, err = s.cache.FindImage(hash)
|
||||
if err != nil {
|
||||
embed = llama.NewMllamaImageEmbed(s.lc, s.clip.cc, images[0].Data, images[0].AspectRatioID)
|
||||
s.cache.AddImage(hash, embed)
|
||||
}
|
||||
s.clip.mu.Unlock()
|
||||
}
|
||||
s.mu.Lock()
|
||||
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, embed)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
return inputs, nil
|
||||
}
|
||||
|
||||
type clip struct {
|
||||
cc *llama.ClipContext
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
model *llama.Model
|
||||
lc *llama.Context
|
||||
|
||||
// required for image embeddings
|
||||
clip clip
|
||||
image *ImageContext
|
||||
|
||||
batchSize int
|
||||
|
||||
@@ -322,14 +287,12 @@ func flushPending(seq *Sequence) bool {
|
||||
func (s *Server) removeSequence(seqIndex int, reason string) {
|
||||
seq := s.seqs[seqIndex]
|
||||
|
||||
s.lc.SetCrossAttention(false)
|
||||
flushPending(seq)
|
||||
seq.doneReason = reason
|
||||
close(seq.responses)
|
||||
close(seq.embedding)
|
||||
seq.cache.InUse = false
|
||||
if s.clip.cc != nil {
|
||||
llama.MllamaSetCrossAttn(s.lc, s.clip.cc, nil)
|
||||
}
|
||||
s.seqs[seqIndex] = nil
|
||||
}
|
||||
|
||||
@@ -341,7 +304,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
tokenBatch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
||||
defer tokenBatch.Free()
|
||||
|
||||
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.lc.Model().NEmbd(), len(s.seqs))
|
||||
embedBatch := llama.NewBatch(s.batchSize*len(s.seqs), s.image.EmbedSize(s.lc), len(s.seqs))
|
||||
defer embedBatch.Free()
|
||||
|
||||
for {
|
||||
@@ -642,12 +605,20 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
s.mu.Lock()
|
||||
for i, sq := range s.seqs {
|
||||
if sq == nil {
|
||||
for _, input := range seq.inputs {
|
||||
if input.embed != nil {
|
||||
s.lc.SetCrossAttention(true)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
s.seqs[i] = seq
|
||||
s.cond.Signal()
|
||||
break
|
||||
@@ -815,7 +786,7 @@ func (s *Server) loadModel(
|
||||
|
||||
if ppath != "" {
|
||||
var err error
|
||||
s.clip.cc, err = llama.NewClipContext(ppath)
|
||||
s.image, err = NewImageContext(s.lc, ppath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user