ollamarunner: Separate text and multimodal graphs

For some multimodal models (such as gemma3), we create a single
graph that generates the image embedding and then use this in the
text model. The embedding tensor is completely opaque to the runner.

However, this doesn't work if we need to use the embedding in multiple
batches. This can arise if the embedding is larger than the batch size.
In these cases (as with llama4), we would like to create views that
are more appropriately sized. However, if we do this then the original
source tensor is used in multiple graphs, which isn't allowed. To
avoid that problem, models with this pattern compute the embedding
tensor on first use and recreate the individual views. There is no
longer a single vision and text graph.

This codifies the pattern of separating vision and text graphs. The
logic of computing tensors on demand is moved to the runner, so models
no longer have to worry about this. It also gives the runner visibility
into the multimodal tensors, which is important for memory management.
This commit is contained in:
Jesse Gross
2025-05-05 13:32:11 -07:00
committed by Jesse Gross
parent 499ae7311f
commit 3c14461d5d
14 changed files with 241 additions and 186 deletions

View File

@@ -2,16 +2,30 @@ package input
import "github.com/ollama/ollama/ml"
// Multimodal is a multimodal embedding or a component of one.
// For example, it could be a row of an image that can be processed
// independently.
type Multimodal struct {
// Tensor is the embedding data. Implementations may chose what to
// store here or it may be nil if not needed. However, any ml.Tensor
// objects must be stored here and not in Data.
Tensor ml.Tensor
// Data is implementation-specific opaque data, such as metadata on how
// to layout Tensor. It may be nil if not needed. It may also store larger
// objects such as complete images if they are to be processed later.
Data any
}
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
Token int32
// Multimodal is opaque data representing a non-text
// element such as an image (or part of one if the image
// can be processed in pieces). It may be either together
// with Token or on its own.
Multimodal any
// Multimodal is represents a non-text element such as an
// image (or part of one if the image can be processed in pieces).
// It may be used either together with Token or on its own.
Multimodal []Multimodal
// MultimodalHash is a unique representation of the data
// stored in Multimodal, used for caching and comparing
@@ -32,7 +46,7 @@ type Input struct {
// Positions slice.
type MultimodalIndex struct {
Index int
Multimodal any
Multimodal []Multimodal
}
// Batch contains the inputs for a model forward pass

View File

@@ -40,12 +40,13 @@ type MultimodalProcessor interface {
// EncodeMultimodal processes a single input (such as an image) and
// generates an output (typically an embedding) that can be used by the model.
//
// The return value is most typically an ml.Tensor, however, different
// type are possible, such as an object containing a tensor plus
// additional metadata, a slice of tensors or even just the original input.
// The return value is one or more tensors, each with optional model-specific
// opaque metadata. Typically, the tensors might be views into an embedding
// with each view representing a chunk of data that can be processed independently
// in different batches.
//
// The result may be cached by the runner.
EncodeMultimodal(ml.Context, []byte) (any, error)
EncodeMultimodal(ml.Context, []byte) ([]input.Multimodal, error)
// PostTokenize is called after tokenization to allow the model to edit the
// input stream to correctly arrange multimodal elements.

View File

@@ -82,7 +82,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -108,22 +108,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.imageSize, m.patchSize, m.VisionModel.eps)
return visionOutputs, nil
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.(ml.Tensor)
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders

View File

@@ -165,7 +165,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// set image embeddings
var except []int
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
visionOutputs := image.Multimodal[0].Tensor
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
for i := range visionOutputs.Dim(1) {

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -63,7 +62,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel
}
@@ -103,70 +102,79 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
projectedOutputs := m.Projector.Forward(ctx, visionOutputs)
return &chunks{Model: m, Tensor: projectedOutputs, aspectRatio: image.Point{ratioW, ratioH}}, nil
var multimodal []input.Multimodal
aspectRatio := image.Point{ratioW, ratioH}
var offset int
patchesPerChunk := projectedOutputs.Dim(1)
if aspectRatio.Y*aspectRatio.X > 1 {
patchesPerChunk = projectedOutputs.Dim(1) / (aspectRatio.X*aspectRatio.Y + 1)
for range aspectRatio.Y {
for x := range aspectRatio.X {
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
var separator separator
if x < aspectRatio.X-1 {
separator.x = true // <|tile_x_separator|>
} else {
separator.y = true // <|tile_y_separator|>
}
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator})
offset += patchesPerChunk
}
}
}
view := projectedOutputs.View(ctx, projectedOutputs.Stride(1)*offset,
projectedOutputs.Dim(0), projectedOutputs.Stride(1),
patchesPerChunk)
multimodal = append(multimodal, input.Multimodal{Tensor: view, Data: &separator{}})
return multimodal, nil
}
type chunks struct {
*Model
ml.Tensor
aspectRatio image.Point
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
type separator struct {
x bool
y bool
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
t := inp.Multimodal.(*chunks)
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
var offset int
patchesPerChunk := t.Dim(1)
if t.aspectRatio.Y*t.aspectRatio.X > 1 {
patchesPerChunk = t.Dim(1) / (t.aspectRatio.X*t.aspectRatio.Y + 1)
for i, mm := range inp.Multimodal {
patchesPerChunk := mm.Tensor.Dim(1)
for range t.aspectRatio.Y {
for x := range t.aspectRatio.X {
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if x < t.aspectRatio.X-1 {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
offset += patchesPerChunk
if i < len(inp.Multimodal)-1 {
separator := mm.Data.(*separator)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if separator.x {
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
if separator.y {
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
} else {
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
}
}
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: &chunk{t, offset, patchesPerChunk}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
result = append(result, imageInputs...)
}

View File

@@ -210,12 +210,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}

View File

@@ -4,7 +4,6 @@ import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -105,7 +104,7 @@ func newMultiModalProjector(c fs.Config) *MultiModalProjector {
}
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -129,37 +128,14 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer
parent := imageFeatures{tensor: features}
rows := make([]*imageRow, size.Y)
rows := make([]input.Multimodal, size.Y)
for i := range rows {
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
rows[i].Tensor = features.View(ctx, features.Stride(1)*size.X*i, features.Dim(0), features.Stride(1), size.X)
}
return rows, nil
}
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
@@ -168,15 +144,14 @@ func (r *imageRow) data() []float32 {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.([]*imageRow)
for i, row := range inputMultimodal {
for i, row := range inp.Multimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
if i == len(inputMultimodal)-1 {
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inp.Multimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})
} else {

View File

@@ -9,7 +9,6 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
@@ -20,8 +19,6 @@ type TextOptions struct {
}
type TextModel struct {
model.Base
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
@@ -109,20 +106,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// image embeddings
for _, image := range batch.Multimodal {
row := image.Multimodal.(*imageRow)
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
temp := m.Backend().NewContext()
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
temp.Close()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
imageFeature := image.Multimodal[0].Tensor
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
}

View File

@@ -59,7 +59,7 @@ func New(c fs.Config) (model.Model, error) {
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -92,7 +92,9 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
return m.Projector.Forward(ctx, crossAttentionStates), nil
projectedOutputs := m.Projector.Forward(ctx, crossAttentionStates)
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
@@ -108,7 +110,7 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(batch.Multimodal) > 0 {
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal.(ml.Tensor)
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
}
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))

View File

@@ -5,7 +5,6 @@ import (
"fmt"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
@@ -77,7 +76,7 @@ func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *
return pixelValues, grid, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
@@ -88,31 +87,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
}
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
return &chunks{Model: m, Tensor: visionOutputs}, nil
}
type chunks struct {
*Model
ml.Tensor
dataOnce sync.Once
data []float32
}
type chunk struct {
*chunks
s, n int
}
func (r *chunk) floats() []float32 {
r.dataOnce.Do(func() {
temp := r.Backend().NewContext()
defer temp.Close()
temp.Forward(r.Tensor).Compute(r.Tensor)
r.data = r.Floats()
})
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
@@ -142,20 +117,16 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
result = append(result, input.Input{Token: pre[i]})
}
// This is an image token with multimodal data
chunksData := inp.Multimodal.(*chunks)
patchesPerChunk := chunksData.Dim(1)
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token
result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2})
result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 1})
// Add the image token with the multimodal tensor data at the first position
// Create a chunk with proper s and n values
result = append(result, input.Input{
Token: imageToken,
Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk},
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
SameBatch: patchesPerChunk,
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)

View File

@@ -129,12 +129,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
for _, mi := range batch.Multimodal {
f32s := mi.Multimodal.(*chunk).floats()
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
if err != nil {
panic(err)
}
img := mi.Multimodal[0].Tensor
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
}