mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
model: add Qwen2.5-VL support (#10385)
This commit is contained in:
@@ -7,4 +7,5 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/llama4"
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
)
|
||||
|
||||
187
model/models/qwen25vl/model.go
Normal file
187
model/models/qwen25vl/model.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"image"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
// Implement MultimodalProcessor interface
|
||||
var _ model.MultimodalProcessor = (*Model)(nil)
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
m := &Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOT: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOT: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
TextModel: NewTextModel(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (m *Model) PixelValues(ctx ml.Context, multimodalData []byte) (ml.Tensor, *Grid, error) {
|
||||
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
f32s, grid, err := m.ImageProcessor.ProcessImage(image)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Calculate tensor dimensions
|
||||
patchDim := m.ImageProcessor.numChannels * m.ImageProcessor.temporalPatchSize *
|
||||
m.ImageProcessor.patchSize * m.ImageProcessor.patchSize
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
|
||||
pixelValues, err := ctx.Input().FromFloatSlice(f32s, patchDim, numPatches)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create tensor from image: %w", err)
|
||||
}
|
||||
|
||||
return pixelValues, grid, nil
|
||||
}
|
||||
|
||||
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
|
||||
if len(m.VisionModel.Layers) == 0 {
|
||||
return nil, model.ErrNoVisionModel
|
||||
}
|
||||
|
||||
pixels, grid, err := m.PixelValues(ctx, multimodalData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
visionOutputs := m.VisionModel.Forward(ctx, pixels, grid)
|
||||
return &chunks{Model: m, Tensor: visionOutputs}, nil
|
||||
}
|
||||
|
||||
type chunks struct {
|
||||
*Model
|
||||
ml.Tensor
|
||||
|
||||
dataOnce sync.Once
|
||||
data []float32
|
||||
}
|
||||
|
||||
type chunk struct {
|
||||
*chunks
|
||||
s, n int
|
||||
}
|
||||
|
||||
func (r *chunk) floats() []float32 {
|
||||
r.dataOnce.Do(func() {
|
||||
temp := r.Backend().NewContext()
|
||||
defer temp.Close()
|
||||
temp.Forward(r.Tensor).Compute(r.Tensor)
|
||||
r.data = r.Floats()
|
||||
})
|
||||
|
||||
return r.data[r.s*r.Dim(0) : (r.s+r.n)*r.Dim(0)]
|
||||
}
|
||||
|
||||
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
|
||||
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
|
||||
var result []input.Input
|
||||
|
||||
var (
|
||||
imageToken int32 = 151655
|
||||
visionStartToken int32 = 151652
|
||||
visionEndToken int32 = 151653
|
||||
)
|
||||
|
||||
nImg := 0
|
||||
for _, inp := range inputs {
|
||||
if inp.Multimodal == nil {
|
||||
// If not a multimodal input, add it to the result unchanged
|
||||
result = append(result, inp)
|
||||
} else {
|
||||
// Adding the 'Picture' prefix is a hack, at the time of writing there is no way to prefix
|
||||
// the image tokens with a prompt, so we add a prefix here
|
||||
nImg++
|
||||
pre, err := m.Encode(fmt.Sprintf(" Picture %d: ", nImg), true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
|
||||
}
|
||||
for i := range pre {
|
||||
result = append(result, input.Input{Token: pre[i]})
|
||||
}
|
||||
|
||||
// This is an image token with multimodal data
|
||||
chunksData := inp.Multimodal.(*chunks)
|
||||
patchesPerChunk := chunksData.Dim(1)
|
||||
|
||||
// First add the vision start token
|
||||
result = append(result, input.Input{Token: visionStartToken, SameBatch: patchesPerChunk + 2})
|
||||
|
||||
// Add the image token with the multimodal tensor data at the first position
|
||||
// Create a chunk with proper s and n values
|
||||
result = append(result, input.Input{
|
||||
Token: imageToken,
|
||||
Multimodal: &chunk{chunks: chunksData, s: 0, n: patchesPerChunk},
|
||||
MultimodalHash: inp.MultimodalHash,
|
||||
SameBatch: patchesPerChunk,
|
||||
})
|
||||
|
||||
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
|
||||
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
|
||||
|
||||
result = append(result, input.Input{Token: visionEndToken})
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("qwen25vl", New)
|
||||
}
|
||||
155
model/models/qwen25vl/model_text.go
Normal file
155
model/models/qwen25vl/model_text.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type TextOptions struct {
|
||||
ctxLen, hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim, defaultContextLen uint32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
*TextOptions
|
||||
}
|
||||
|
||||
func NewTextModel(c fs.Config) *TextModel {
|
||||
m := TextModel{
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
ctxLen: int(c.Uint("context_length")),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
ropeDim: c.Uint("rope.dimension_count", 128),
|
||||
defaultContextLen: c.Uint("context_length", 128000),
|
||||
},
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
// SelfAttention implements the multi-head self-attention mechanism
|
||||
// with separate projections for query, key, value and output transformations
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, nil, opts.ropeDim, 2, opts.ropeBase, opts.ropeScale, ml.WithContextLen(opts.defaultContextLen))
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, m.ropeDim, 2, m.ropeBase, m.ropeScale, ml.WithContextLen(m.defaultContextLen)), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *TextOptions) ml.Tensor {
|
||||
// Apply SwiGLU activation gating
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
// Project back to hidden dimension
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
// Layer represents a single transformer layer combining self-attention and feed-forward components
|
||||
type Layer struct {
|
||||
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
|
||||
SelfAttention *SelfAttention
|
||||
MLPNorm *nn.RMSNorm `gguf:"ffn_norm"`
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
// Self-attention branch with residual connection
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
// Feed-forward branch with residual connection
|
||||
residual = hiddenState
|
||||
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
// Initial token embedding
|
||||
hiddenStates := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
for _, mi := range batch.Multimodal {
|
||||
f32s := mi.Multimodal.(*chunk).floats()
|
||||
img, err := ctx.Input().FromFloatSlice(f32s, len(f32s)/m.hiddenSize, m.hiddenSize)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ctx.Forward(img.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), img.Dim(0)*img.Dim(1))))
|
||||
}
|
||||
|
||||
// Process through transformer layers
|
||||
for i, layer := range m.Layers {
|
||||
cache.SetLayer(i)
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
391
model/models/qwen25vl/model_vision.go
Normal file
391
model/models/qwen25vl/model_vision.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
)
|
||||
|
||||
// We only support batch size of 1
|
||||
var batchSize int = 1
|
||||
|
||||
func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3))
|
||||
x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx)
|
||||
return x2.Neg(ctx).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
||||
// Create a flat slice for the mask (all -inf initially to block all attention)
|
||||
flat := make([]float32, seqLength*seqLength)
|
||||
for i := range flat {
|
||||
flat[i] = float32(math.Inf(-1)) // Negative infinity to block attention
|
||||
}
|
||||
|
||||
// Fill in the mask with zeros for tokens that CAN attend to each other
|
||||
for i := 1; i < len(bounds); i++ {
|
||||
start := bounds[i-1]
|
||||
end := bounds[i]
|
||||
|
||||
// Enable attention within this sequence block by setting values to 0
|
||||
for row := start; row < end; row++ {
|
||||
for col := start; col < end; col++ {
|
||||
idx := row*seqLength + col
|
||||
flat[idx] = 0.0 // 0 allows attention, -inf blocks it
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mask, err := ctx.Input().FromFloatSlice(flat, seqLength, seqLength)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// Reshape to match [seqLength, seqLength, 1] for broadcasting
|
||||
mask = mask.Reshape(ctx, seqLength, seqLength, 1)
|
||||
|
||||
return mask
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_out"`
|
||||
}
|
||||
|
||||
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
|
||||
query = query.Reshape(ctx, opts.headDim, opts.numHeads, query.Dim(1), batchSize)
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
// Scaled dot-product attention
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = kq.Scale(ctx, scale)
|
||||
if mask != nil {
|
||||
kq = kq.Add(ctx, mask)
|
||||
}
|
||||
kq = kq.Softmax(ctx)
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
// VisionMLP implements the multi-layer perceptron
|
||||
type VisionMLP struct {
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
}
|
||||
|
||||
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
// Using activation as specified in config (likely GELU or SiLU/Swish)
|
||||
gateOutput := mlp.Gate.Forward(ctx, hiddenStates)
|
||||
upOutput := mlp.Up.Forward(ctx, hiddenStates)
|
||||
hiddenStates = gateOutput.SILU(ctx).Mul(ctx, upOutput)
|
||||
|
||||
return mlp.Down.Forward(ctx, hiddenStates)
|
||||
}
|
||||
|
||||
type VisionEncoderLayer struct {
|
||||
Norm1 *nn.RMSNorm `gguf:"ln1"`
|
||||
SelfAttention *VisionSelfAttention
|
||||
Norm2 *nn.RMSNorm `gguf:"ln2"`
|
||||
MLP *VisionMLP
|
||||
}
|
||||
|
||||
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, mask ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
residual := hiddenStates
|
||||
hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.SelfAttention.Forward(ctx, hiddenStates, cos, sin, mask, opts)
|
||||
hiddenStates = hiddenStates.Add(ctx, residual)
|
||||
|
||||
residual = hiddenStates
|
||||
hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps)
|
||||
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
|
||||
return hiddenStates.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// VisionModelOptions contains configuration options
|
||||
type VisionModelOptions struct {
|
||||
hiddenSize int
|
||||
numHeads int
|
||||
headDim int
|
||||
patchSize int
|
||||
numChannels int
|
||||
eps float32
|
||||
ropeTheta float32
|
||||
spatialMergeSize int
|
||||
windowSize int
|
||||
fullAttnBlocks []int32
|
||||
temporalPatchSize int
|
||||
}
|
||||
|
||||
type PatchEmbedding struct {
|
||||
PatchConv0 *nn.Conv2D `gguf:"patch_embd_0"`
|
||||
PatchConv1 *nn.Conv2D `gguf:"patch_embd_1"`
|
||||
}
|
||||
|
||||
func (pe *PatchEmbedding) Forward(ctx ml.Context, pixelValues ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
numPatches := pixelValues.Shape()[1]
|
||||
|
||||
// Reshape the input tensor to match the expected dimensions
|
||||
pixelValues = pixelValues.Reshape(ctx, opts.patchSize*opts.patchSize, opts.temporalPatchSize, opts.numChannels, numPatches)
|
||||
|
||||
// Permute the tensor to bring the temporal dimension to the front
|
||||
pixelValues = pixelValues.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
|
||||
|
||||
// Split the tensor into parts for the temporal convolutions
|
||||
in0 := pixelValues.View(ctx, 0, 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in0 = in0.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
in1 := pixelValues.View(ctx, pixelValues.Stride(0), 1, pixelValues.Stride(1), pixelValues.Dim(1), pixelValues.Stride(2), pixelValues.Dim(2), pixelValues.Stride(3), pixelValues.Dim(3)).Contiguous(ctx)
|
||||
in1 = in1.Reshape(ctx, opts.patchSize, opts.patchSize, opts.numChannels, numPatches)
|
||||
|
||||
s0, s1 := opts.patchSize, opts.patchSize // Use full stride
|
||||
p0, p1 := 0, 0 // padding
|
||||
d0, d1 := 1, 1 // dilation
|
||||
out0 := pe.PatchConv0.Forward(ctx, in0, s0, s1, p0, p1, d0, d1)
|
||||
out1 := pe.PatchConv1.Forward(ctx, in1, s0, s1, p0, p1, d0, d1)
|
||||
|
||||
// Add the outputs from the two temporal convolutions
|
||||
out := out0.Add(ctx, out1)
|
||||
|
||||
// Reshape the output tensor to match the expected dimensions
|
||||
return out.Reshape(ctx, opts.hiddenSize, numPatches)
|
||||
}
|
||||
|
||||
// VisionPatchMerger implements patch merging for the Qwen vision model
|
||||
type VisionPatchMerger struct {
|
||||
LNQ *nn.RMSNorm `gguf:"ln_q"`
|
||||
MLP0 *nn.Linear `gguf:"mlp.0"`
|
||||
MLP2 *nn.Linear `gguf:"mlp.2"`
|
||||
}
|
||||
|
||||
// Forward computes patch merging for the vision model
|
||||
func (pm *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
|
||||
normalized := pm.LNQ.Forward(ctx, visionOutputs, opts.eps)
|
||||
|
||||
hiddenSize := visionOutputs.Dim(0) * (opts.spatialMergeSize * opts.spatialMergeSize)
|
||||
|
||||
// Reshape the normalized output to view the hidden size dimension
|
||||
reshaped := normalized.Reshape(ctx, hiddenSize, normalized.Dim(1)/(opts.spatialMergeSize*opts.spatialMergeSize), batchSize)
|
||||
hidden := pm.MLP0.Forward(ctx, reshaped)
|
||||
activated := hidden.GELU(ctx)
|
||||
|
||||
output := pm.MLP2.Forward(ctx, activated)
|
||||
|
||||
return output
|
||||
}
|
||||
|
||||
// VisionModel implements the Qwen vision model
|
||||
type VisionModel struct {
|
||||
PatchEmbedding *PatchEmbedding
|
||||
Layers []VisionEncoderLayer `gguf:"blk"`
|
||||
PatchMerger *VisionPatchMerger `gguf:"merger"`
|
||||
|
||||
*VisionModelOptions
|
||||
}
|
||||
|
||||
// Forward computes the vision model for an input tensor
|
||||
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) ml.Tensor {
|
||||
// Extract patch embeddings
|
||||
hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.VisionModelOptions)
|
||||
|
||||
positionEmbedding := m.PositionalEmbedding(ctx, grid)
|
||||
|
||||
windowIndex, bounds := m.WindowIndex(ctx, grid)
|
||||
|
||||
spatialMergeUnit := m.spatialMergeSize * m.spatialMergeSize
|
||||
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)*spatialMergeUnit, hiddenStates.Dim(1)/spatialMergeUnit)
|
||||
hiddenStates = hiddenStates.Rows(ctx, windowIndex)
|
||||
hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0)/spatialMergeUnit, hiddenStates.Dim(1)*spatialMergeUnit)
|
||||
|
||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)*spatialMergeUnit, positionEmbedding.Dim(1)/spatialMergeUnit)
|
||||
positionEmbedding = positionEmbedding.Rows(ctx, windowIndex)
|
||||
positionEmbedding = positionEmbedding.Reshape(ctx, positionEmbedding.Dim(0)/spatialMergeUnit, positionEmbedding.Dim(1)*spatialMergeUnit)
|
||||
positionEmbedding = positionEmbedding.Concat(ctx, positionEmbedding, 0)
|
||||
|
||||
cos, sin := positionEmbedding.Cos(ctx), positionEmbedding.Sin(ctx)
|
||||
cos = cos.Reshape(ctx, cos.Dim(0), 1, cos.Dim(1))
|
||||
sin = sin.Reshape(ctx, sin.Dim(0), 1, sin.Dim(1))
|
||||
|
||||
mask := blockDiagonalMask(ctx, hiddenStates.Dim(1), bounds, m.VisionModelOptions.numHeads)
|
||||
// Apply encoder layers
|
||||
for i, layer := range m.Layers {
|
||||
if slices.Contains(m.fullAttnBlocks, int32(i)) {
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, nil, m.VisionModelOptions)
|
||||
} else {
|
||||
hiddenStates = layer.Forward(
|
||||
ctx,
|
||||
hiddenStates,
|
||||
cos,
|
||||
sin,
|
||||
mask,
|
||||
m.VisionModelOptions,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, m.VisionModelOptions)
|
||||
reverseWindowIndex := windowIndex.Argsort(ctx)
|
||||
return hiddenStates.Rows(ctx, reverseWindowIndex)
|
||||
}
|
||||
|
||||
// WindowIndex divides the grid into windows and returns:
|
||||
// 1. A tensor containing flattened indices of all grid points organized by windows
|
||||
// 2. A slice of boundaries that mark where each window's data begins and ends
|
||||
// in the flattened representation, scaled by spatialMergeSize squared
|
||||
//
|
||||
// The boundaries slice always starts with 0 and contains cumulative ending
|
||||
// positions for each window, allowing downstream processing to identify
|
||||
// window boundaries in the tensor data.
|
||||
func (m *VisionModel) WindowIndex(ctx ml.Context, grid *Grid) (ml.Tensor, []int) {
|
||||
vitMergerWindowSize := m.windowSize / m.spatialMergeSize / m.patchSize
|
||||
|
||||
llmGridH := grid.Height / m.spatialMergeSize
|
||||
llmGridW := grid.Width / m.spatialMergeSize
|
||||
|
||||
// Calculate window parameters
|
||||
numWindowsH := int(math.Ceil(float64(llmGridH) / float64(vitMergerWindowSize)))
|
||||
numWindowsW := int(math.Ceil(float64(llmGridW) / float64(vitMergerWindowSize)))
|
||||
|
||||
// Initialize index_new slice
|
||||
var index []int32
|
||||
|
||||
// Initialize bounds with the first element as 0
|
||||
bounds := []int{0}
|
||||
totalSeqLen := 0
|
||||
|
||||
// Process each window without padding
|
||||
for wh := range numWindowsH {
|
||||
for ww := range numWindowsW {
|
||||
// Calculate window boundaries
|
||||
hStart := wh * vitMergerWindowSize
|
||||
wStart := ww * vitMergerWindowSize
|
||||
hEnd := min(hStart+vitMergerWindowSize, llmGridH)
|
||||
wEnd := min(wStart+vitMergerWindowSize, llmGridW)
|
||||
|
||||
// Calculate sequence length for this window
|
||||
seqLen := (hEnd - hStart) * (wEnd - wStart)
|
||||
|
||||
// Collect indices for this window
|
||||
for h := hStart; h < hEnd; h++ {
|
||||
for w := wStart; w < wEnd; w++ {
|
||||
index = append(index, int32(h*llmGridW+w))
|
||||
}
|
||||
}
|
||||
|
||||
totalSeqLen += seqLen
|
||||
bounds = append(bounds, totalSeqLen*(m.spatialMergeSize*m.spatialMergeSize)+bounds[0])
|
||||
}
|
||||
}
|
||||
|
||||
t, err := ctx.Input().FromIntSlice(index, len(index))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t, bounds
|
||||
}
|
||||
|
||||
// PositionalEmbedding generates rotary position embeddings for attention mechanisms
|
||||
func (m *VisionModel) PositionalEmbedding(ctx ml.Context, grid *Grid) ml.Tensor {
|
||||
dim := m.headDim / 2
|
||||
freq := dim / 2
|
||||
theta := float64(m.ropeTheta)
|
||||
merge := m.spatialMergeSize
|
||||
|
||||
// Create frequency patterns for position encoding
|
||||
maxGridSize := max(grid.Height, grid.Width)
|
||||
freqVals := make([]float32, freq*maxGridSize)
|
||||
for i := range maxGridSize {
|
||||
for j := range freq {
|
||||
freqVals[i*freq+j] = float32(i) / float32(math.Pow(theta, float64(j*2)/float64(dim)))
|
||||
}
|
||||
}
|
||||
freqs, err := ctx.Input().FromFloatSlice(freqVals, freq, maxGridSize)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create tensor from frequencies: %w", err))
|
||||
}
|
||||
|
||||
// Create position coordinates (y,x pairs) for the grid
|
||||
// In PyTorch: Equivalent to generating position ids with torch.arange()
|
||||
coords := make([]int32, 0, grid.Height*grid.Width*2)
|
||||
for y := range grid.Height {
|
||||
for x := range grid.Width {
|
||||
coords = append(coords, int32(y), int32(x))
|
||||
}
|
||||
}
|
||||
pos, err := ctx.Input().FromIntSlice(coords, 2, grid.Width, grid.Height)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to create tensor from positions: %w", err))
|
||||
}
|
||||
|
||||
// Reshape and permute positions to match spatial merging pattern
|
||||
pos = pos.Reshape(ctx, 2, grid.Width, merge, grid.Height/merge)
|
||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
pos = pos.Reshape(ctx, 2, merge, merge, grid.Width/merge*grid.Height/merge)
|
||||
pos = pos.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
pos = pos.Reshape(ctx, 2*merge*merge*grid.Width/merge*grid.Height/merge)
|
||||
|
||||
// Use position indices to look up corresponding frequency values
|
||||
positionalEmbedding := freqs.Rows(ctx, pos)
|
||||
positionalEmbedding = positionalEmbedding.Reshape(ctx, positionalEmbedding.Dim(0)*2, positionalEmbedding.Dim(1)/2)
|
||||
return positionalEmbedding
|
||||
}
|
||||
|
||||
// newVisionModel creates a new instance of the Qwen vision model
|
||||
func newVisionModel(c fs.Config) *VisionModel {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
hiddenSize := int(c.Uint("vision.embedding_length", 1280))
|
||||
numHeads := int(c.Uint("vision.attention.head_count", 16))
|
||||
numChannels := int(c.Uint("vision.num_channels", 3))
|
||||
eps := c.Float("vision.attention.layer_norm_epsilon", 1e-6)
|
||||
ropeTheta := c.Float("vision.rope.freq_base", 10000.0)
|
||||
spatialMergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
windowSize := int(c.Uint("vision.window_size", 112))
|
||||
fullAttnBlocks := c.Ints("qwen25vl.vision.fullatt_block_indexes", []int32{7, 15, 23, 31})
|
||||
temporalPatchSize := int(c.Uint("vision.temporal_patch_size", 2))
|
||||
|
||||
model := &VisionModel{
|
||||
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)),
|
||||
VisionModelOptions: &VisionModelOptions{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
headDim: hiddenSize / numHeads,
|
||||
patchSize: patchSize,
|
||||
numChannels: numChannels,
|
||||
eps: eps,
|
||||
ropeTheta: ropeTheta,
|
||||
spatialMergeSize: spatialMergeSize,
|
||||
windowSize: windowSize,
|
||||
temporalPatchSize: temporalPatchSize,
|
||||
fullAttnBlocks: fullAttnBlocks,
|
||||
},
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
184
model/models/qwen25vl/process_image.go
Normal file
184
model/models/qwen25vl/process_image.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"image"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/model/imageproc"
|
||||
)
|
||||
|
||||
// ImageProcessor contains configuration for the Qwen 2.5 VL image processing
|
||||
type ImageProcessor struct {
|
||||
numChannels int
|
||||
patchSize int
|
||||
temporalPatchSize int
|
||||
mergeSize int
|
||||
minPixels int
|
||||
maxPixels int
|
||||
factor int
|
||||
rescaleFactor float32
|
||||
imageMean []float32
|
||||
imageStd []float32
|
||||
}
|
||||
|
||||
// newImageProcessor creates a new image processor with default values
|
||||
func newImageProcessor(c fs.Config) ImageProcessor {
|
||||
patchSize := int(c.Uint("vision.patch_size", 14))
|
||||
mergeSize := int(c.Uint("vision.spatial_merge_size", 2))
|
||||
|
||||
return ImageProcessor{
|
||||
numChannels: int(c.Uint("vision.num_channels", 3)), // not set
|
||||
patchSize: patchSize,
|
||||
temporalPatchSize: 2,
|
||||
mergeSize: mergeSize,
|
||||
minPixels: 56 * 56,
|
||||
maxPixels: int(c.Uint("vision.max_pixels", 28*28*1280)), // 1MP limit
|
||||
factor: patchSize * mergeSize,
|
||||
rescaleFactor: 1.0 / 255.0,
|
||||
imageMean: imageproc.ClipDefaultMean[:],
|
||||
imageStd: imageproc.ClipDefaultSTD[:],
|
||||
}
|
||||
}
|
||||
|
||||
// SmartResize implements the smart resize algorithm
|
||||
func (p *ImageProcessor) SmartResize(height, width int) (int, int) {
|
||||
factor := p.factor
|
||||
|
||||
if height < factor || width < factor {
|
||||
panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor))
|
||||
} else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 {
|
||||
panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio))
|
||||
}
|
||||
|
||||
round := func(x float64) int { return int(math.RoundToEven(x)) }
|
||||
|
||||
hBar := round(float64(height)/float64(factor)) * factor
|
||||
wBar := round(float64(width)/float64(factor)) * factor
|
||||
|
||||
if hBar*wBar > p.maxPixels {
|
||||
beta := math.Sqrt(float64(height*width) / float64(p.maxPixels))
|
||||
|
||||
hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor
|
||||
wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor
|
||||
} else if hBar*wBar < p.minPixels {
|
||||
beta := math.Sqrt(float64(p.minPixels) / float64(height*width))
|
||||
|
||||
hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor
|
||||
wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor
|
||||
}
|
||||
|
||||
return hBar, wBar
|
||||
}
|
||||
|
||||
type Grid struct {
|
||||
Height int
|
||||
Width int
|
||||
Temporal int
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, *Grid, error) {
|
||||
origWidth := img.Bounds().Dx()
|
||||
origHeight := img.Bounds().Dy()
|
||||
|
||||
// Calculate smart resize dimensions
|
||||
resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth)
|
||||
|
||||
// Resize image using existing functions
|
||||
resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear)
|
||||
|
||||
normalizedPixels := imageproc.Normalize(
|
||||
resizedImg,
|
||||
[3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]},
|
||||
[3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]},
|
||||
true, // rescale
|
||||
true, // channelFirst
|
||||
)
|
||||
|
||||
// Calculate grid dimensions
|
||||
grid := &Grid{
|
||||
Height: resizedHeight / p.patchSize,
|
||||
Width: resizedWidth / p.patchSize,
|
||||
Temporal: 1, // For single images, temporal dimension is 1
|
||||
}
|
||||
|
||||
patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to create patches: %v", err)
|
||||
}
|
||||
|
||||
// Return patches and grid dimensions
|
||||
return patches, grid, nil
|
||||
}
|
||||
|
||||
func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) {
|
||||
channels := p.numChannels
|
||||
patchSize := p.patchSize
|
||||
mergeSize := p.mergeSize
|
||||
temporalPatchSize := p.temporalPatchSize
|
||||
|
||||
// Calculate output dimensions
|
||||
numPatches := grid.Temporal * grid.Height * grid.Width
|
||||
patchDim := channels * temporalPatchSize * patchSize * patchSize
|
||||
|
||||
result := make([]float32, numPatches*patchDim)
|
||||
patchIndex := 0
|
||||
|
||||
// Single temporal frame handling (copies to all frames)
|
||||
for range grid.Temporal {
|
||||
for h := 0; h < grid.Height; h += mergeSize {
|
||||
for w := 0; w < grid.Width; w += mergeSize {
|
||||
// Handle the 2x2 merged patches
|
||||
for mh := range mergeSize {
|
||||
for mw := range mergeSize {
|
||||
baseOffset := patchIndex * patchDim
|
||||
|
||||
// Extract patch data for first temporal frame
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
|
||||
for py := range patchSize {
|
||||
for px := range patchSize {
|
||||
// Calculate source pixel coordinates
|
||||
y := (h+mh)*patchSize + py
|
||||
x := (w+mw)*patchSize + px
|
||||
|
||||
// Source index in input tensor (CHW format)
|
||||
srcIdx := c*height*width + y*width + x
|
||||
|
||||
// Destination index in first temporal frame
|
||||
dstIdx := channelOffset + (py * patchSize) + px
|
||||
|
||||
if srcIdx < len(pixels) && dstIdx < len(result) {
|
||||
result[dstIdx] = pixels[srcIdx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Copy first temporal frame to all other frames
|
||||
if temporalPatchSize > 1 {
|
||||
for c := range channels {
|
||||
channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize)
|
||||
firstFrameOffset := channelOffset
|
||||
frameSize := patchSize * patchSize
|
||||
|
||||
// Copy first frame to all other frames
|
||||
for tp := 1; tp < temporalPatchSize; tp++ {
|
||||
currentFrameOffset := channelOffset + (tp * frameSize)
|
||||
copy(result[currentFrameOffset:currentFrameOffset+frameSize],
|
||||
result[firstFrameOffset:firstFrameOffset+frameSize])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
patchIndex++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
Reference in New Issue
Block a user