ollamarunner: Improve multimodal input handling

Various vision models have different requirements for how they
receive their inputs. For example:
 - Mllama wants images together with text and the image embeddings
   don't themselves have positions or get stored in the main KV cache
 - Llava-style models feed in embeddings similar to tokens and
   images correspond to a varying number of tokens in the cache.

In addition, the strategy for providing inputs must support batching
and multiple sequences, which are managed by the runner. At the same
time, we want to keep data handling fully in the model so that new
architectures are not bottlenecked by runner code which does not
understand their particular requirements.

This provides a method for models to edit the input stream so that
it meets their needs while still being in a format that the runner
understands. This allows the runner to avoid special processing
for different models.

In addition, this fixes a regression where non-vision models may
try to incorrectly interpret images.
This commit is contained in:
Jesse Gross
2025-03-05 12:08:06 -08:00
committed by Jesse Gross
parent b70fc4d51e
commit a7e63b82be
5 changed files with 247 additions and 130 deletions

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"log/slog"
"math"
"reflect"
"time"
"github.com/ollama/ollama/kvcache"
@@ -39,10 +38,7 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
slots := make([]InputCacheSlot, numSlots)
for i := range slots {
slots[i] = InputCacheSlot{
Id: i,
Inputs: make([]input, 0),
}
slots[i] = InputCacheSlot{Id: i}
}
cache := model.Config().Cache
@@ -83,7 +79,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []input
Inputs []model.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@@ -92,7 +88,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -143,7 +139,7 @@ func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCach
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@@ -166,7 +162,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int3
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@@ -202,7 +198,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]input, longest)
oldestSlot.Inputs = make([]model.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -212,7 +208,7 @@ func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int32,
return oldestSlot, longest, nil
}
func countCommonPrefix(a []input, b []input) int32 {
func countCommonPrefix(a []model.Input, b []model.Input) int32 {
var count int32
for i := range a {
@@ -220,7 +216,7 @@ func countCommonPrefix(a []input, b []input) int32 {
break
}
if !reflect.DeepEqual(a[i], b[i]) {
if a[i].Token != b[i].Token || a[i].MultimodalHash != b[i].MultimodalHash {
break
}