model: support for mistral-small in the ollama runner

Mistral is a popular research lab making open source models. This updates
the forward pass of llama architecture models to support both llama models
and mistral models by accounting for additional metadata present in mistral
models, and finding the correct dimensions for the output projection.
This commit is contained in:
Bruce MacDonald
2025-03-14 16:56:32 -07:00
committed by Michael Yang
parent 1861fbdeb5
commit 6bd0a983cd
27 changed files with 1116 additions and 350 deletions

View File

@@ -11,7 +11,7 @@ import (
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
@@ -28,7 +28,7 @@ type TextModel struct {
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
*TextConfig
}
const (
@@ -55,7 +55,7 @@ func newTextModel(c fs.Config) *TextModel {
},
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
@@ -84,7 +84,7 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
@@ -120,12 +120,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
ropeBase = m.TextConfig.ropeGlobalBase
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
return key.RoPE(ctx, shift, nil, uint32(m.TextConfig.attnKeyLen), uint32(2), ropeBase, m.TextConfig.ropeScale), nil
}
type TextMLP struct {
@@ -134,7 +134,7 @@ type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextConfig) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@@ -148,7 +148,7 @@ type TextLayer struct {
PostMLPNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -173,7 +173,7 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
// set image embeddings
var except []int
@@ -206,7 +206,7 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -0,0 +1,56 @@
package mistral3
import (
"image"
_ "image/jpeg"
_ "image/png"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/model/imageproc"
)
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
longestEdge int
}
func newImageProcessor(c fs.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1540)),
}
}
// ProcessImage prepares an image for the vision model by:
// 1. Compositing transparent images
// 2. Resizing to fit model constraints while preserving aspect ratio
// 3. Normalizing pixel values
// Returns normalized image data and the final size in pixels
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, image.Point, error) {
img = imageproc.Composite(img)
size := img.Bounds().Size()
ratio := max(float64(size.Y)/float64(p.longestEdge), float64(size.X)/float64(p.longestEdge))
if ratio > 1.0 {
size = image.Point{
int(math.Floor(float64(size.X) / ratio)),
int(math.Floor(float64(size.Y) / ratio)),
}
}
patchesX := (size.X-1)/p.patchSize + 1
patchesY := (size.Y-1)/p.patchSize + 1
size = image.Point{
patchesX * p.patchSize,
patchesY * p.patchSize,
}
img = imageproc.Resize(img, size, imageproc.ResizeBilinear)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, size, nil
}

View File

@@ -0,0 +1,189 @@
package mistral3
import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
*TextModel
*VisionModel `gguf:"v,vision"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor
}
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c fs.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
if err != nil {
return nil, err
}
m := &Model{
TextModel: textModel,
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
return m, nil
}
type PatchMerger struct {
MergingLayer *nn.Linear `gguf:"merging_layer"`
}
func (pm *PatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point, spatialMergeSize int) ml.Tensor {
d := visionOutputs.Dim(0)
imageGrid := visionOutputs.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Reshape(ctx, size.X, size.Y, d)
kernel := ctx.Input().Empty(ml.DTypeF32, spatialMergeSize, spatialMergeSize, d)
patches := kernel.IM2Col(ctx, imageGrid, spatialMergeSize, spatialMergeSize, 0, 0, 1, 1)
reshaped := patches.Reshape(ctx, d*spatialMergeSize*spatialMergeSize, patches.Dim(1)*patches.Dim(2))
return pm.MergingLayer.Forward(ctx, reshaped)
}
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Linear1 *nn.Linear `gguf:"linear_1"`
Linear2 *nn.Linear `gguf:"linear_2"`
PatchMerger *PatchMerger `gguf:"patch_merger"`
spatialMergeSize int
eps float32
patchSize int
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, size image.Point) (ml.Tensor, image.Point) {
visionOutputs = p.Norm.Forward(ctx, visionOutputs, p.eps)
patchSizes := image.Point{size.X / p.patchSize, size.Y / p.patchSize}
visionOutputs = p.PatchMerger.Forward(ctx, visionOutputs, patchSizes, p.spatialMergeSize)
visionOutputs = p.Linear1.Forward(ctx, visionOutputs)
visionOutputs = visionOutputs.GELU(ctx)
return p.Linear2.Forward(ctx, visionOutputs), image.Point{patchSizes.X / p.spatialMergeSize, patchSizes.Y / p.spatialMergeSize}
}
func newMultiModalProjector(c fs.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
eps: c.Float("text_config.rms_norm_eps", 1e-5),
patchSize: int(c.Uint("vision.patch_size", 14)),
}
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, size, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s, size.X, size.Y, m.ImageProcessor.numChannels)
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer
parent := imageFeatures{tensor: features}
rows := make([]*imageRow, size.Y)
for i := range rows {
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
}
return rows, nil
}
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together.
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.([]*imageRow)
for i, row := range inputMultimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1]})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
if i == len(inputMultimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})
} else {
// [IMG_BREAK]
result = append(result, input.Input{Token: 12})
}
}
}
}
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("mistral3", New)
}

View File

@@ -0,0 +1,177 @@
package mistral3
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type TextOptions struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type TextModel struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
return sa.Output.Forward(ctx, kqv)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
// image embeddings
for _, image := range batch.Multimodal {
row := image.Multimodal.(*imageRow)
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
temp := m.Backend().NewContext()
temp.Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
temp.Close()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
}
for i, layer := range m.Layers {
cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState)
}
func NewTextModel(c fs.Config) (*TextModel, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
textModel := &TextModel{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id", 1)),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id", 2)),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
return textModel, nil
}

View File

@@ -0,0 +1,186 @@
package mistral3
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
return x2.Neg(ctx).Concat(ctx, x1, 0)
}
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenStates
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = e.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
return hiddenStates.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeBase float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.RMSNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) positionalEmbedding(ctx ml.Context, positionIDs ml.Tensor) ml.Tensor {
maxPatchesPerSide := m.imageSize / m.patchSize
frequencies := m.headDim / 2
frequenciesHeight := make([]float32, frequencies/2*maxPatchesPerSide)
frequenciesWidth := make([]float32, frequencies/2*maxPatchesPerSide)
for i := range frequencies {
for j := range maxPatchesPerSide {
frequency := float32(j) / float32(math.Pow(float64(m.ropeBase), float64(i)*2/float64(m.headDim)))
if i%2 == 0 {
frequenciesHeight[i/2*maxPatchesPerSide+j] = frequency
} else {
frequenciesWidth[i/2*maxPatchesPerSide+j] = frequency
}
}
}
h, err := ctx.Input().FromFloatSlice(frequenciesHeight, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
w, err := ctx.Input().FromFloatSlice(frequenciesWidth, maxPatchesPerSide, frequencies/2)
if err != nil {
panic(err)
}
h = h.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
w = w.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
h = h.Repeat(ctx, 1, maxPatchesPerSide)
h = h.Reshape(ctx, frequencies/2, maxPatchesPerSide, maxPatchesPerSide).Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
w = w.Repeat(ctx, 2, maxPatchesPerSide)
inverseFrequencies := h.Concat(ctx, w, 0).Reshape(ctx, frequencies, maxPatchesPerSide*maxPatchesPerSide)
inverseFrequencies = inverseFrequencies.Concat(ctx, inverseFrequencies, 0)
return inverseFrequencies.Rows(ctx, positionIDs)
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesW := pixelValues.Dim(0) / m.patchSize
numPatchesH := pixelValues.Dim(1) / m.patchSize
numPatches := numPatchesW * numPatchesH
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, numPatches, m.hiddenSize)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenStates = m.EncoderNorm.Forward(ctx, hiddenStates, m.VisionModelOptions.eps)
// Prepare position IDs for 2D rope
positions := make([]int32, numPatches)
for h := range numPatchesH {
for w := range numPatchesW {
idx := h*numPatchesW + w
positions[idx] = int32(h*m.imageSize/m.patchSize + w)
}
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
positionEmbedding := m.positionalEmbedding(ctx, positionIDs)
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionModelOptions)
}
return hiddenStates
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
headDim: int(c.Uint("vision.attention.key_length", 64)),
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-5),
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
},
}
}

View File

@@ -186,7 +186,7 @@ func (m *VisionModel) Forward(ctx ml.Context, pixelValues, positionIDs, aspectRa
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
hiddenState = m.PreTilePositionEmbedding.Forward(ctx, hiddenState, aspectRatioIDs, m.VisionModelOptions)
hiddenState = m.ClassEmbedding.Stack(ctx, 2, slices.Repeat([]ml.Tensor{m.ClassEmbedding}, m.numTiles-1)...).Concat(ctx, hiddenState, 1)
hiddenState = m.ClassEmbedding.Repeat(ctx, 2, m.numTiles).Concat(ctx, hiddenState, 1)
hiddenState = m.PositionEmbedding.Forward(ctx, hiddenState, positionIDs, aspectRatioIDs, numPositions, m.VisionModelOptions)
hiddenState = m.PreLayerNorm.Forward(ctx, hiddenState, m.eps)

View File

@@ -4,5 +4,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -1,68 +0,0 @@
package pixtral
import (
"fmt"
"image"
_ "image/jpeg"
_ "image/png"
"io"
"math"
"github.com/ollama/ollama/model/imageproc"
)
func getNumImageTokens(imageSize, patchSize image.Point) image.Point {
return image.Point{
(imageSize.X-1)/patchSize.X + 1,
(imageSize.Y-1)/patchSize.Y + 1,
}
}
func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.Point) image.Point {
b := img.Bounds()
le := float64(longestEdge)
ratio := math.Max(float64(b.Max.Y)/le, float64(b.Max.X)/le)
newSize := img.Bounds().Max
if ratio > 1.0 {
newSize = image.Point{
int(math.Ceil(float64(b.Max.X) / ratio)),
int(math.Ceil(float64(b.Max.Y) / ratio)),
}
}
tokens := getNumImageTokens(newSize, patchSize)
return image.Point{
tokens.X * patchSize.X,
tokens.Y * patchSize.Y,
}
}
func resizeImage(img image.Image, format string, longestEdge int, patchSize image.Point) image.Image {
if format == "png" {
img = imageproc.Composite(img)
}
newSize := getResizeOutputImageSize(img, longestEdge, patchSize)
// todo should be ResizeBicubic, but it doesn't exist
return imageproc.Resize(img, newSize, imageproc.ResizeBilinear)
}
func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
img, format, err := image.Decode(imageData)
if err != nil {
return nil, nil, fmt.Errorf("failed to decode image: %w", err)
}
longestEdge := 1024
patchSize := image.Point{16, 16}
img = resizeImage(img, format, longestEdge, patchSize)
data := imageproc.Normalize(img, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
opts := map[string]any{}
return data, opts, nil
}

View File

@@ -1,219 +0,0 @@
package pixtral
import (
"bytes"
"encoding/binary"
"image"
"image/png"
"math"
"os"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestGetNumImageTokens(t *testing.T) {
type numImageTokensCase struct {
ImageSize image.Point
PatchSize image.Point
Expected image.Point
}
cases := []numImageTokensCase{
{
ImageSize: image.Point{1024, 764},
PatchSize: image.Point{16, 16},
Expected: image.Point{64, 48},
},
{
ImageSize: image.Point{800, 600},
PatchSize: image.Point{16, 16},
Expected: image.Point{50, 38},
},
{
ImageSize: image.Point{640, 480},
PatchSize: image.Point{16, 16},
Expected: image.Point{40, 30},
},
{
ImageSize: image.Point{320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{20, 13},
},
{
ImageSize: image.Point{1320, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{83, 13},
},
{
ImageSize: image.Point{2000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{125, 13},
},
{
ImageSize: image.Point{10000, 200},
PatchSize: image.Point{16, 16},
Expected: image.Point{625, 13},
},
{
ImageSize: image.Point{1131, 577},
PatchSize: image.Point{16, 16},
Expected: image.Point{71, 37},
},
{
ImageSize: image.Point{16, 16},
PatchSize: image.Point{16, 16},
Expected: image.Point{1, 1},
},
}
for _, c := range cases {
actual := getNumImageTokens(c.ImageSize, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestGetResizeOutputImageSize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Point
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1024, 768)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 768},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1162, 690)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 624},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 300, 200)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{304, 208},
},
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.Point{1024, 288},
},
}
for _, c := range cases {
actual := getResizeOutputImageSize(c.Image, c.LongestEdge, c.PatchSize)
if diff := cmp.Diff(actual, c.Expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
}
func TestResize(t *testing.T) {
type resizeCase struct {
Image image.Image
LongestEdge int
PatchSize image.Point
Expected image.Image
}
cases := []resizeCase{
{
Image: image.NewRGBA(image.Rect(0, 0, 1862, 522)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 1024, 288)),
},
{
Image: image.NewRGBA(image.Rect(0, 0, 10, 10)),
LongestEdge: 1024,
PatchSize: image.Point{16, 16},
Expected: image.NewRGBA(image.Rect(0, 0, 16, 16)),
},
}
for _, c := range cases {
actual := resizeImage(c.Image, "png", c.LongestEdge, c.PatchSize)
if actual.Bounds() != c.Expected.Bounds() {
t.Errorf("image size incorrect: '%#v': expected: '%#v'", actual.Bounds(), c.Expected.Bounds())
}
}
}
func TestPreprocess(t *testing.T) {
type preprocessCase struct {
TestImage image.Image
ExpectedLen int
}
cases := []preprocessCase{
{
TestImage: image.NewRGBA(image.Rect(0, 0, 10, 10)),
ExpectedLen: 16 * 16 * 3 * 1,
},
{
TestImage: image.NewRGBA(image.Rect(0, 0, 2000, 2000)),
ExpectedLen: 1024 * 1024 * 3 * 1,
},
}
for _, c := range cases {
var buf bytes.Buffer
err := png.Encode(&buf, c.TestImage)
if err != nil {
t.Fatal(err)
}
imgData, _, err := Preprocess(&buf)
if err != nil {
t.Fatalf("error processing: %q", err)
}
switch len(imgData) {
case 0:
t.Errorf("no image data returned")
case c.ExpectedLen:
// ok
default:
t.Errorf("unexpected image data length: %d, expected: %d", len(imgData), c.ExpectedLen)
}
}
}
func TestPreprocessImages(t *testing.T) {
for _, testFile := range []string{"flight.png", "sportsball.png"} {
f, err := os.Open(testFile)
if err != nil {
t.Skipf("skipping test, no test image found at %s", testFile)
}
defer f.Close()
imgData, _, err := Preprocess(f)
if err != nil {
t.Fatalf("error processing: %q", err)
}
byteData := make([]byte, len(imgData)*4) // float32 is 4 bytes
for i, f := range imgData {
binary.LittleEndian.PutUint32(byteData[i*4:], math.Float32bits(f))
}
outputPath := "processed_" + testFile + ".bin"
err = os.WriteFile(outputPath, byteData, 0o644)
if err != nil {
t.Fatalf("error writing processed image: %q", err)
}
}
}