This commit is contained in:
Michael Yang
2025-04-03 15:18:29 -07:00
committed by Michael Yang
parent 54055a6dae
commit f0c66e6dea
13 changed files with 833 additions and 15 deletions

View File

@@ -0,0 +1,100 @@
package llama4
import (
"bytes"
"image"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.BytePairEncoding
*VisionModel `gguf:"v,vision"`
*Projector `gguf:"mm"`
*TextModel
}
type Projector struct {
Linear1 *nn.Linear `gguf:"linear_1"`
}
func (p *Projector) Forward(ctx ml.Context, visionOutputs ml.Tensor) ml.Tensor {
return p.Linear1.Forward(ctx, visionOutputs)
}
func New(c fs.Config) (model.Model, error) {
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(
// TODO: pretend this is chunked attention for now
kvcache.NewSWACache(8192, m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
if len(m.VisionModel.Layers) < 1 {
return nil, model.ErrNoVisionModel
}
img, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
f32s, aspectRatio, err := m.ProcessImage(ctx, img)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s, len(f32s))
if err != nil {
return nil, err
}
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
visionOutputs = visionOutputs.Reshape(ctx, visionOutputs.Dim(0), visionOutputs.Dim(1)*visionOutputs.Dim(2)*visionOutputs.Dim(3))
return m.Projector.Forward(ctx, visionOutputs), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
}
func init() {
model.Register("llama4", New)
}

View File

@@ -0,0 +1,223 @@
package llama4
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model/input"
)
type TextAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_factors"`
}
func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
batchSize, headDim := hiddenStates.Dim(1), cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
query := sa.Query.Forward(ctx, hiddenStates)
key := sa.Key.Forward(ctx, hiddenStates)
value := sa.Value.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = query.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, uint32(opts.ropeDim), uint32(0), opts.ropeBase, opts.ropeScale)
if opts.useQKNorm {
query = query.RMSNorm(ctx, nil, opts.eps)
key = key.RMSNorm(ctx, nil, opts.eps)
}
}
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
type TextMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *TextMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextExperts struct {
Gate ml.Tensor `gguf:"ffn_gate_exps.weight"`
Up ml.Tensor `gguf:"ffn_up_exps.weight"`
Down ml.Tensor `gguf:"ffn_down_exps.weight"`
}
func (e *TextExperts) Forward(ctx ml.Context, hiddenStates, routerLogits ml.Tensor, opts *TextOptions) ml.Tensor {
experts := routerLogits.TopK(ctx, opts.numExpertsUsed)
scores := routerLogits.Sigmoid(ctx).Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, experts)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1))
hiddenStates = hiddenStates.Repeat(ctx, 1, opts.numExpertsUsed)
hiddenStates = hiddenStates.Mul(ctx, scores)
upStates := e.Up.MulmatID(ctx, hiddenStates, experts)
gateStates := e.Gate.MulmatID(ctx, hiddenStates, experts)
downStates := e.Down.MulmatID(ctx, upStates.Mul(ctx, gateStates.SILU(ctx)), experts)
nextStates := downStates.View(ctx, 0, hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2))
for i := 1; i < opts.numExpertsUsed; i++ {
nextStates.Add(ctx, downStates.View(ctx, i*downStates.Stride(1), hiddenStates.Dim(0), downStates.Stride(2), hiddenStates.Dim(2)))
}
return nextStates
}
// TextSharedExpert is TextMLP with different names
type TextSharedExpert struct {
Gate *nn.Linear `gguf:"ffn_gate_shexp"`
Up *nn.Linear `gguf:"ffn_up_shexp"`
Down *nn.Linear `gguf:"ffn_down_shexp"`
}
func (mlp *TextSharedExpert) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenStates))
return mlp.Down.Forward(ctx, hiddenStates)
}
type TextMOE struct {
Router *nn.Linear `gguf:"ffn_gate_inp"`
Experts *TextExperts
SharedExpert *TextSharedExpert
}
func (moe *TextMOE) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor {
hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2)
hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize)
routerLogits := moe.Router.Forward(ctx, hiddenStates)
sharedStates := moe.SharedExpert.Forward(ctx, hiddenStates, opts)
routedStates := moe.Experts.Forward(ctx, hiddenStates, routerLogits, opts)
return sharedStates.Add(ctx, routedStates)
}
type TextFeedForward interface {
Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor
}
type TextLayer struct {
AttentionNorm *nn.LayerNorm `gguf:"attn_norm"`
Attention *TextAttention
FFNNorm *nn.LayerNorm `gguf:"ffn_norm"`
FeedForward TextFeedForward
}
func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, positions, outputs ml.Tensor, cache kvcache.Cache, useRope bool, opts *TextOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.Attention.Forward(ctx, hiddenStates, positions, cache, useRope, opts)
if outputs != nil {
hiddenStates = hiddenStates.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenStates = hiddenStates.Add(ctx, residual)
residual = hiddenStates
hiddenStates = d.FFNNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = d.FeedForward.Forward(ctx, hiddenStates, opts)
return residual.Add(ctx, hiddenStates)
}
type TextOptions struct {
hiddenSize int
numHeads, numKVHeads, headDim int
numExperts, numExpertsUsed int
ropeDim int
ropeBase, ropeScale float32
eps float32
interleaveLayerStep int
useQKNorm bool
}
type TextModel struct {
Layers []TextLayer `gguf:"blk"`
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
OutputNorm *nn.LayerNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*TextOptions
}
func newTextModel(c fs.Config) *TextModel {
layers := make([]TextLayer, c.Uint("block_count"))
interleaveLayerStep := c.Uint("interleave_moe_layer_step", 1)
for i := range layers {
if (i+1)%int(interleaveLayerStep) == 0 {
layers[i] = TextLayer{FeedForward: &TextMOE{}}
} else {
layers[i] = TextLayer{FeedForward: &TextMLP{}}
}
}
return &TextModel{
Layers: layers,
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.head_dim", 128)),
numExperts: int(c.Uint("expert_count")),
numExpertsUsed: int(c.Uint("expert_used_count")),
ropeDim: int(c.Uint("rope.dimension_count")),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
eps: c.Float("attention.layer_norm_rms_epsilon"),
interleaveLayerStep: int(c.Uint("interleave_moe_layer_step", 1)),
useQKNorm: c.Bool("use_qk_norm", true),
},
}
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
cache.SetLayer(i)
wc := cache.(*kvcache.WrapperCache)
wc.SetLayerType(1)
useChunkedAttention := (i+1)%4 != 0
if useChunkedAttention {
wc.SetLayerType(0)
}
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, useChunkedAttention, m.TextOptions)
}
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].Attention.RopeFactors, uint32(0), uint32(m.ropeDim), m.ropeBase, m.ropeScale), nil
}

View File

@@ -0,0 +1,256 @@
package llama4
import (
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type VisionAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
// applyVisionRotaryEmbedding applies 2D rotary embedding to the input tensor.
// This is equivalent to the Pytorch implmentation using half rotations:
//
// cos, sin = torch.cos(freqs), torch.sin(freqs)
// cos = cos.unsqueeze(-1)
// sin = sin.unsqueeze(-1)
// t = t.reshape(*t.shape[:-1], -1, 2)
// t_out = (t * cos) + (_rotate_half(t) * sin)
// t_out = t_out.flatten(3)
//
// Which is equivalent to the Pytorch implementation using complex numbers:
//
// t_ = torch.view_as_complex(t.float().reshape(*t.shape[:-1], -1, 2))
// freqs_ci = reshape_for_broadcast(freqs_ci=freq_cis, t=t_) # freqs_ci[:,:,None,:]
// freqs_ci = freqs_ci.to(t_.device)
// t_out = torch.view_as_real(t_ * freqs_ci).flatten(3)
//
// Due to the 1) the dimensional and 2) the datatype limitations of current backends,
// we need to use a different approach to achieve the same result.
func applyVisionRotaryEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
width, height, channels, tiles := t.Dim(0), t.Dim(1), t.Dim(2), t.Dim(3)
t = t.Reshape(ctx, 2, t.Dim(0)/2, t.Dim(1)*t.Dim(2)*t.Dim(3))
// t1 = t[..., 0::2]
t1 := t.View(ctx, 0, 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t1 = t1.Reshape(ctx, width/2, height, channels, tiles)
// t2 = t[..., 1::2]
t2 := t.View(ctx, t.Stride(0), 1, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2)).Contiguous(ctx)
t2 = t2.Reshape(ctx, width/2, height, channels, tiles)
// cos_out = torch.stack((t1 * cos, t2 * cos), dim=-1)
cosOut := t1.Mul(ctx, cos).Concat(ctx, t2.Mul(ctx, cos), 0)
cosOut = cosOut.Reshape(ctx, cosOut.Dim(0)/2, 2, cosOut.Dim(1)*cosOut.Dim(2)*cosOut.Dim(3))
cosOut = cosOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
cosOut = cosOut.Reshape(ctx, width, height, channels, tiles)
// sin_out = torch.stack((-t2 * sin, t1 * sin), dim=-1)
sinOut := t2.Neg(ctx).Mul(ctx, sin).Concat(ctx, t1.Mul(ctx, sin), 0)
sinOut = sinOut.Reshape(ctx, sinOut.Dim(0)/2, 2, sinOut.Dim(1)*sinOut.Dim(2)*sinOut.Dim(3))
sinOut = sinOut.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
sinOut = sinOut.Reshape(ctx, width, height, channels, tiles)
return cosOut.Add(ctx, sinOut)
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenState, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
headDim := opts.hiddenSize / opts.numHeads
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, query.Dim(1), query.Dim(2))
key = key.Reshape(ctx, headDim, opts.numHeads, key.Dim(1), key.Dim(2))
value = value.Reshape(ctx, headDim, opts.numHeads, value.Dim(1), value.Dim(2))
query = applyVisionRotaryEmbedding(ctx, query, cos, sin)
key = applyVisionRotaryEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), attention.Dim(3))
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
FC1 *nn.Linear `gguf:"fc1"`
FC2 *nn.Linear `gguf:"fc2"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
hiddenStates = mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = mlp.FC2.Forward(ctx, hiddenStates)
return hiddenStates
}
type VisionLayer struct {
InputLayerNorm *nn.LayerNorm `gguf:"attn_norm"`
*VisionAttention
PostAttentionNorm *nn.LayerNorm `gguf:"ffn_norm"`
*VisionMLP `gguf:"mlp"`
}
func (e *VisionLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts *VisionOptions) ml.Tensor {
residual := hiddenStates
// self attention
hiddenStates = e.InputLayerNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionAttention.Forward(ctx, hiddenStates, cos, sin, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
// MLP
residual = hiddenStates
hiddenStates = e.PostAttentionNorm.Forward(ctx, hiddenStates, opts.eps)
hiddenStates = e.VisionMLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
return hiddenStates
}
type VisionAdapter struct {
FC1 *nn.Linear `gguf:"mlp.fc1"`
FC2 *nn.Linear `gguf:"mlp.fc2"`
}
func (a *VisionAdapter) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
patches := hiddenStates.Dim(1)
patchSize := int(math.Sqrt(float64(patches)))
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), patchSize, patchSize, hiddenStates.Dim(2))
channels, width, height, tiles := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2), hiddenStates.Dim(3)
channels, width = int(float32(channels)/opts.pixelShuffleRatio), int(float32(width)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
channels, height = int(float32(channels)/opts.pixelShuffleRatio), int(float32(height)*opts.pixelShuffleRatio)
hiddenStates = hiddenStates.Reshape(ctx, channels, width, height, tiles)
hiddenStates = hiddenStates.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Reshape(ctx, channels, width*height, tiles)
hiddenStates = a.FC1.Forward(ctx, hiddenStates).GELU(ctx)
hiddenStates = a.FC2.Forward(ctx, hiddenStates).GELU(ctx)
return hiddenStates
}
type VisionOptions struct {
hiddenSize, numHeads int
imageSize, patchSize int
ropeTheta float32
eps float32
pixelShuffleRatio float32
}
type PatchEmbedding struct {
*nn.Linear
}
func (p *PatchEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionOptions) ml.Tensor {
kernel := ctx.Input().Empty(ml.DTypeF32, opts.patchSize, opts.patchSize, hiddenStates.Dim(2))
hiddenStates = kernel.IM2Col(ctx, hiddenStates, opts.patchSize, opts.patchSize, 0, 0, 1, 1)
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), hiddenStates.Dim(1)*hiddenStates.Dim(2), hiddenStates.Dim(3))
return p.Linear.Forward(ctx, hiddenStates)
}
type VisionModel struct {
Layers []VisionLayer `gguf:"blk"`
*PatchEmbedding `gguf:"patch_embedding"`
ClassEmbedding ml.Tensor `gguf:"class_embedding"`
PositionalEmbedding ml.Tensor `gguf:"positional_embedding_vlm"`
LayerNormPre *nn.LayerNorm `gguf:"layernorm_pre"`
LayerNormPost *nn.LayerNorm `gguf:"layernorm_post"`
*VisionAdapter `gguf:"vision_adapter"`
*VisionOptions
}
func newVisionModel(c fs.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionLayer, c.Uint("vision.block_count")),
VisionOptions: &VisionOptions{
hiddenSize: int(c.Uint("vision.embedding_length")),
numHeads: int(c.Uint("vision.attention.head_count")),
imageSize: int(c.Uint("vision.image_size")),
patchSize: int(c.Uint("vision.patch_size")),
ropeTheta: float32(c.Float("vision.rope.freq_base")),
eps: c.Float("vision.layer_norm_epsilon"),
pixelShuffleRatio: float32(c.Float("vision.pixel_shuffle_ratio")),
},
}
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionOptions)
hiddenStates = hiddenStates.Concat(ctx, m.ClassEmbedding.Repeat(ctx, 2, hiddenStates.Dim(2)), 1)
hiddenStates = hiddenStates.Add(ctx, m.PositionalEmbedding)
hiddenStates = m.LayerNormPre.Forward(ctx, hiddenStates, m.eps)
cos, sin := m.rotaryEmbedding(ctx)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions)
}
hiddenStates = m.LayerNormPost.Forward(ctx, hiddenStates, m.eps)
hiddenStates = hiddenStates.Unpad(ctx, 0, 1, 0, 0)
hiddenStates = m.VisionAdapter.Forward(ctx, hiddenStates, m.VisionOptions)
return hiddenStates
}
// floorDiv is a helper function to perform floor division. This mimics PyTorch's div(round_mode='floor') function
// which in turn mimics Python's // operator.
func floorDiv[T int | int16 | int32 | int64 | uint | uint16 | uint32 | uint64](a, b T) T {
if b == 0 {
panic("division by zero")
}
if (a >= 0 && b > 0) || (a <= 0 && b < 0) || a%b == 0 {
return a / b
}
return a/b - 1
}
func (m *VisionModel) rotaryEmbedding(ctx ml.Context) (ml.Tensor, ml.Tensor) {
patchesPerSide := m.imageSize / m.patchSize
numPatches := patchesPerSide*patchesPerSide + 1
headDim := m.hiddenSize / m.numHeads
freqDim := headDim / 2
freqs := make([]float32, numPatches*freqDim)
for i := range numPatches - 1 {
for j := 0; j < freqDim; j += 2 {
positionX := i*freqDim/2 + j/2
positionY := (i+numPatches)*freqDim/2 + j/2
ropeFreq := math.Pow(float64(m.ropeTheta), float64(j)*2/float64(headDim))
freqs[positionX] = float32(float64(1+i-floorDiv(i, patchesPerSide)*patchesPerSide) / ropeFreq)
freqs[positionY] = float32(float64(1+floorDiv(i, patchesPerSide)) / ropeFreq)
}
}
ropeFreqs, err := ctx.Input().FromFloatSlice(freqs, freqDim/2, numPatches, 2)
if err != nil {
panic(err)
}
ropeFreqs = ropeFreqs.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
ropeFreqs = ropeFreqs.Reshape(ctx, freqDim, 1, numPatches)
return ropeFreqs.Cos(ctx), ropeFreqs.Sin(ctx)
}

View File

@@ -4,6 +4,7 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/llama4"
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
)