From a1cda80bcb0b47d493be9dc061a2dfa8a0ddd61c Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Sat, 8 Mar 2025 15:45:31 -0800 Subject: [PATCH] model: Update encoder cache to use multimodal input processing handler The encoder cache needs to know the position of images in the input stream so that it knows when to delete them. Previously images didn't have a position, so we implied one by breaking batches before an image and then assuming the image was in the first position. However, multimodal objects are now given explicit positions in the input stream, so we can use that instead. Breaking batches was also a way to simulate a cross attention mask for mllama. However, given that it only supports a single sequence and a single image, this mask doesn't serve any real purpose. Removing the batch break does not appear to affect the quality of the output. Most of this is simply moving the input data structures to a new package to avoid import cycles. --- kvcache/cache.go | 3 +- kvcache/causal.go | 13 ++--- kvcache/causal_test.go | 3 +- kvcache/encoder.go | 9 ++-- kvcache/wrapper.go | 9 ++-- model/input/input.go | 37 ++++++++++++++ model/model.go | 83 +++++++++---------------------- model/model_test.go | 3 +- model/models/llama/model.go | 3 +- model/models/mllama/model.go | 13 ++--- runner/ollamarunner/cache.go | 13 ++--- runner/ollamarunner/cache_test.go | 72 +++++++++++++-------------- runner/ollamarunner/runner.go | 56 ++++++++------------- 13 files changed, 157 insertions(+), 160 deletions(-) create mode 100644 model/input/input.go diff --git a/kvcache/cache.go b/kvcache/cache.go index 2541f7c1..d3548905 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) var ( @@ -51,7 +52,7 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding // entry in positions and seqs. - StartForward(ctx ml.Context, positions []int32, seqs []int) error + StartForward(ctx ml.Context, opts input.Options) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index 9a79fa57..34d5337c 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -8,6 +8,7 @@ import ( "slices" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) @@ -140,10 +141,10 @@ func (c *Causal) Close() { } } -func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - c.curBatchSize = len(positions) - c.curSequences = seqs - c.curPositions = positions +func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { + c.curBatchSize = len(opts.Positions) + c.curSequences = opts.Sequences + c.curPositions = opts.Positions var err error c.curLoc, err = c.findStartLoc() @@ -156,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err } c.curCellRange = newRange() - for i, pos := range positions { - seq := seqs[i] + for i, pos := range opts.Positions { + seq := opts.Sequences[i] c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 412f33e3..22d8efb4 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type testCase struct { @@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, test.pos, test.seqs) + err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) if err != nil { panic(err) } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 867ee37a..6a9df2ab 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Encoder cache stores K and V tensors that are position independent @@ -78,9 +79,11 @@ func (c *EncoderCache) Close() { } } -func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - // The image is always in the first position - c.curPos = positions[0] +func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { + // We work with the most recent image + if len(opts.Multimodal) > 0 { + c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] + } return nil } diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 76956a88..aaccd166 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -4,6 +4,7 @@ import ( "math" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Wrapper cache is a container for multiple types of caches, @@ -40,14 +41,14 @@ func (c *WrapperCache) Close() { } } -func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { +func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, positions, seqs) + err := cache.StartForward(ctx, opts) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { - for k := range positions { - _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) + for k := range opts.Positions { + _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) } } return err diff --git a/model/input/input.go b/model/input/input.go new file mode 100644 index 00000000..0cb3f3f4 --- /dev/null +++ b/model/input/input.go @@ -0,0 +1,37 @@ +package input + +// Input represents one token in the input stream +type Input struct { + // Token is a single element of text. + Token int32 + + // Multimodal is opaque data representing a non-text + // element such as an image (or part of one if the image + // can be processed in pieces). It may be either together + // with Token or on its own. + Multimodal any + + // MultimodalHash is a unique representation of the data + // stored in Multimodal, used for caching and comparing + // equality. + MultimodalHash uint64 +} + +// MultimodalIndex is a multimodal element (such as an image) +// together with an index into the slice of Inputs with the +// corresponding token. Note that the index is not the same +// as the position - to find that use the index with the +// Positions slice. +type MultimodalIndex struct { + Index int + Multimodal any +} + +// Options contains the inputs for a model forward pass +type Options struct { + Inputs []int32 + Multimodal []MultimodalIndex + Positions []int32 + Sequences []int + Outputs []int32 +} diff --git a/model/model.go b/model/model.go index 75b7f639..89b6c803 100644 --- a/model/model.go +++ b/model/model.go @@ -19,66 +19,12 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/model/input" ) -// Input represents one token in the input stream -type Input struct { - // Token is a single element of text. - Token int32 - - // Multimodal is opaque data representing a non-text - // element such as an image (or part of one if the image - // can be processed in pieces). It may be either together - // with Token or on its own. - Multimodal any - - // MultimodalHash is a unique representation of the data - // stored in Multimodal, used for caching and comparing - // equality. - MultimodalHash uint64 -} - -// MultimodalIndex is a multimodal element (such as an image) -// together with an index into the slice of Inputs with the -// corresponding token. Note that the index is not the same -// as the position - to find that use the index with the -// Positions slice. -type MultimodalIndex struct { - Index int - Multimodal any -} - -// Options contains the inputs for a model forward pass -type Options struct { - Inputs []int32 - Multimodal []MultimodalIndex - Positions []int32 - Sequences []int - Outputs []int32 -} - -type config struct { - Cache kvcache.Cache -} - -// Base implements the common fields and methods for all models -type Base struct { - b ml.Backend - config -} - -// Backend returns the underlying backend that will run the model -func (m *Base) Backend() ml.Backend { - return m.b -} - -func (m *Base) Config() config { - return m.config -} - // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { - Forward(ml.Context, Options) (ml.Tensor, error) + Forward(ml.Context, input.Options) (ml.Tensor, error) Backend() ml.Backend Config() config @@ -112,7 +58,26 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []Input) ([]Input, error) + PostTokenize(ml.Context, []input.Input) ([]input.Input, error) +} + +// Base implements the common fields and methods for all models +type Base struct { + b ml.Backend + config +} + +type config struct { + Cache kvcache.Cache +} + +// Backend returns the underlying backend that will run the model +func (m *Base) Backend() ml.Backend { + return m.b +} + +func (m *Base) Config() config { + return m.config } var models = make(map[string]func(ml.Config) (Model, error)) @@ -313,7 +278,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { if len(opts.Positions) != len(opts.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) } @@ -324,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, opts.Positions, opts.Sequences) + err := cache.StartForward(ctx, opts) if err != nil { return nil, err } diff --git a/model/model_test.go b/model/model_test.go index 8761817e..354dd1d8 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -162,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) { type notTextProcessorModel struct{} -func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) { +func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { panic("unimplemented") } diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 9ccfff61..1f27f522 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -9,6 +9,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Options struct { @@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 54c63296..31ba15df 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -12,6 +12,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Model struct { @@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) { - var images []model.Input +func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { + var images []input.Input fnvHash := fnv.New64a() for i := range inputs { @@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu } } - inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 }) + inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 }) return inputs, nil } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor - if opts.Multimodal != nil { - crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor) + if len(opts.Multimodal) > 0 { + crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 3244c0b8..a411fddb 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -10,6 +10,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type InputCache struct { @@ -79,7 +80,7 @@ type InputCacheSlot struct { Id int // Inputs that are stored in the KV cache - Inputs []model.Input + Inputs []input.Input // is this cache actively being processed as part of a sequence? InUse bool @@ -88,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp return slot, prompt, nil } -func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { longest := int32(-1) var longestSlot *InputCacheSlot @@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot return longestSlot, longest, nil } -func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { oldest := time.Now() var oldestSlot *InputCacheSlot @@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i if longest > 0 && longestSlot != oldestSlot { slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", len(longestSlot.Inputs)) - oldestSlot.Inputs = make([]model.Input, longest) + oldestSlot.Inputs = make([]input.Input, longest) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) if c.cache != nil { c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) @@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i return oldestSlot, longest, nil } -func countCommonPrefix(a []model.Input, b []model.Input) int32 { +func countCommonPrefix(a []input.Input, b []input.Input) int32 { var count int32 for i := range a { diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 9ce03b73..0a1b73f5 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) func TestCountCommon(t *testing.T) { @@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) { tests := []struct { name string - t1 []model.Input - t2 []model.Input + t1 []input.Input + t2 []input.Input expected int32 }{ { name: "Equal", - t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 3, }, { name: "Prefix", - t1: []model.Input{{Token: 1}}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{{Token: 1}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 1, }, { name: "Image Prefix", - t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, + t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, expected: 1, }, { name: "Empty", - t1: []model.Input{}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 0, }, { name: "Both Empty", - t1: []model.Input{}, - t2: []model.Input{}, + t1: []input.Input{}, + t2: []input.Input{}, expected: 0, }, } @@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []model.Input + prompt []input.Input longest expected best expected }{ @@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []model.Input{{Token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 0}, best: expected{result: 0, len: 0}, }, @@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 1}, {Token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 2}, best: expected{result: 1, len: 2}, }, @@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []model.Input{{Token: 2}}, + prompt: []input.Input{{Token: 2}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }, }, - prompt: []model.Input{{Token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 1}, best: expected{result: 1, len: 1}, }, @@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 2}, {Token: 3}}, + prompt: []input.Input{{Token: 2}, {Token: 3}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 1}, {Token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 1}, best: expected{result: 1, len: 2}, }, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a51b1459..c8383a5d 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -41,10 +42,10 @@ type Sequence struct { iBatch int // prompt inputs left to evaluate - inputs []model.Input + inputs []input.Input // inputs that have been added to a batch but not yet submitted to Forward - pendingInputs []model.Input + pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) { - var inputs []model.Input +func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { + var inputs []input.Input var parts []string var matches [][]string @@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo } for _, t := range tokens { - inputs = append(inputs, model.Input{Token: t}) + inputs = append(inputs, input.Input{Token: t}) } // image - decode and store @@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo _, _ = s.multimodalHash.Write(images[imageIndex].Data) imageHash := s.multimodalHash.Sum64() - inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) + inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } } @@ -250,9 +251,6 @@ type Server struct { // KV cache cache *InputCache - // next sequence for prompt processing to avoid starvation - nextSeq int - // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash @@ -329,29 +327,25 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() - var options model.Options - - seqIdx := s.nextSeq - 1 - for range s.seqs { - seqIdx = (seqIdx + 1) % len(s.seqs) - seq := s.seqs[seqIdx] + var options input.Options + for i, seq := range s.seqs { if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(i, "limit") continue } if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) - seq.cache.Inputs = []model.Input{} + seq.cache.Inputs = []input.Input{} } - for i, input := range seq.inputs { + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) @@ -363,33 +357,23 @@ func (s *Server) processBatch() error { } } - if i >= s.batchSize { + if j >= s.batchSize { break } - // TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint - // to the encoder cache. - // - // Break the batch when switching from text to images so that images are always at the beginning. - if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 || - (len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) { - s.nextSeq = seqIdx - break - } - - options.Inputs = append(options.Inputs, input.Token) - if input.Multimodal != nil { - options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal}) + options.Inputs = append(options.Inputs, inp.Token) + if inp.Multimodal != nil { + options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) } options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Sequences = append(options.Sequences, seq.cache.Id) seq.iBatch = len(options.Outputs) - if i+1 == len(seq.inputs) { + if j+1 == len(seq.inputs) { options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) } - seq.pendingInputs = append(seq.pendingInputs, input) + seq.pendingInputs = append(seq.pendingInputs, inp) } seq.inputs = seq.inputs[len(seq.pendingInputs):] @@ -417,7 +401,7 @@ func (s *Server) processBatch() error { // After calling Forward, pending inputs are now in the cache if len(seq.pendingInputs) > 0 { seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) - seq.pendingInputs = []model.Input{} + seq.pendingInputs = []input.Input{} } // don't sample prompt processing @@ -464,7 +448,7 @@ func (s *Server) processBatch() error { return err } - seq.inputs = []model.Input{{Token: token}} + seq.inputs = []input.Input{{Token: token}} seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "")