mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-18 11:47:07 +00:00
This commit represents a complete rework after pulling the latest changes from official ollama/ollama repository and re-applying Tesla K80 compatibility patches. ## Key Changes ### CUDA Compute Capability 3.7 Support (Tesla K80) - Added sm_37 (compute 3.7) to CMAKE_CUDA_ARCHITECTURES in CMakeLists.txt - Updated CMakePresets.json to include compute 3.7 in "CUDA 11" preset - Using 37-virtual (PTX with JIT compilation) for maximum compatibility ### Legacy Toolchain Compatibility - **NVIDIA Driver**: 470.256.02 (last version supporting Kepler/K80) - **CUDA Version**: 11.4.4 (last CUDA 11.x supporting compute 3.7) - **GCC Version**: 10.5.0 (required by CUDA 11.4 host_config.h) ### CPU Architecture Trade-offs Due to GCC 10.5 limitation, sacrificed newer CPU optimizations: - Alderlake CPU variant enabled WITHOUT AVX_VNNI (requires GCC 11+) - Still supports: SSE4.2, AVX, F16C, AVX2, BMI2, FMA - Performance impact: ~3-7% on newer CPUs (acceptable for K80 compatibility) ### Build System Updates - Modified ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt for compute 3.7 - Added -Wno-deprecated-gpu-targets flag to suppress warnings - Updated ml/backend/ggml/ggml/src/CMakeLists.txt for Alderlake without AVX_VNNI ### Upstream Sync Merged latest llama.cpp changes including: - Enhanced KV cache management with ISWA and hybrid memory support - Improved multi-modal support (mtmd framework) - New model architectures (Gemma3, Llama4, Qwen3, etc.) - GPU backend improvements for CUDA, Metal, and ROCm - Updated quantization support and GGUF format handling ### Documentation - Updated CLAUDE.md with comprehensive build instructions - Documented toolchain constraints and CPU architecture trade-offs - Removed outdated CI/CD workflows (tesla-k80-*.yml) - Cleaned up temporary development artifacts ## Rationale This fork maintains Tesla K80 GPU support (compute 3.7) which was dropped in official Ollama due to legacy driver/CUDA requirements. The toolchain constraint creates a deadlock: - K80 → Driver 470 → CUDA 11.4 → GCC 10 → No AVX_VNNI We accept the loss of cutting-edge CPU optimizations to enable running modern LLMs on legacy but still capable Tesla K80 hardware (12GB VRAM per GPU). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
207 lines
5.8 KiB
Go
207 lines
5.8 KiB
Go
package qwen3vl
|
|
|
|
import (
|
|
"bytes"
|
|
"image"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type Model struct {
|
|
model.Base
|
|
model.TextProcessor
|
|
|
|
*TextModel
|
|
*VisionModel `gguf:"v"`
|
|
|
|
ImageProcessor
|
|
|
|
positionCache []int32
|
|
}
|
|
|
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
|
if len(m.VisionModel.Layers) == 0 {
|
|
return nil, model.ErrNoVisionModel
|
|
}
|
|
|
|
img, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
pixelValues, grid, err := m.ProcessImage(ctx, img)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Calculate tensor dimensions
|
|
visionOutputs, deepstackVisualEmbeds := m.VisionModel.Forward(ctx, pixelValues, grid)
|
|
mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}}
|
|
for i := range deepstackVisualEmbeds {
|
|
mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]})
|
|
}
|
|
|
|
return mm, nil
|
|
}
|
|
|
|
var (
|
|
tokenVision int32 = 151655
|
|
tokenVisionStart int32 = 151652
|
|
tokenVisionEnd int32 = 151653
|
|
)
|
|
|
|
type modelInput struct {
|
|
*input.Input
|
|
position int32
|
|
}
|
|
|
|
// PostTokenize arranges Qwen 3 VL's inputs for the forward pass
|
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|
m.positionCache = m.positionCache[:0]
|
|
return slices.Collect(func(yield func(*input.Input) bool) {
|
|
for i := range inputs {
|
|
s := []modelInput{{Input: inputs[i]}}
|
|
if mm := inputs[i].Multimodal; mm != nil {
|
|
t := mm[0].Tensor
|
|
s = slices.Repeat([]modelInput{
|
|
{
|
|
position: int32(i + 1),
|
|
Input: &input.Input{Token: tokenVision},
|
|
},
|
|
}, t.Dim(1)+1+1)
|
|
|
|
s[0] = modelInput{
|
|
Input: &input.Input{Token: tokenVisionStart},
|
|
position: int32(i),
|
|
}
|
|
|
|
s[len(s)-1] = modelInput{
|
|
Input: &input.Input{Token: tokenVisionEnd},
|
|
position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1),
|
|
}
|
|
|
|
s[1] = modelInput{
|
|
Input: &input.Input{
|
|
Token: tokenVision,
|
|
Multimodal: inputs[i].Multimodal,
|
|
MultimodalHash: inputs[i].MultimodalHash,
|
|
SameBatch: t.Dim(1),
|
|
},
|
|
position: int32(i + 1),
|
|
}
|
|
}
|
|
|
|
for _, e := range s {
|
|
position := e.position
|
|
if position == 0 && len(m.positionCache) > 0 {
|
|
position = m.positionCache[len(m.positionCache)-1] + 1
|
|
}
|
|
|
|
m.positionCache = append(m.positionCache, position)
|
|
if !yield(e.Input) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}), nil
|
|
}
|
|
|
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|
// ggml mrope requires 4 positions per token: [time, height, width, extra]
|
|
positionSlice := slices.Collect(makeSlice2D[int32](4, len(batch.Positions)))
|
|
for i, id := range batch.Positions {
|
|
if id < int32(len(m.positionCache)) {
|
|
id = m.positionCache[id]
|
|
} else if len(m.positionCache) > 0 {
|
|
id = id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1
|
|
}
|
|
|
|
positionSlice[0][i] = id
|
|
positionSlice[1][i] = id
|
|
positionSlice[2][i] = id
|
|
// positionSlice[3] is intentionally left as zeros
|
|
}
|
|
|
|
hiddenStates := m.TextModel.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx)
|
|
|
|
var deepstackVisualEmbeds []ml.Tensor
|
|
for _, mi := range batch.Multimodal {
|
|
visionOutputs := mi.Multimodal[0].Tensor
|
|
ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
|
|
|
|
if grid, ok := mi.Multimodal[0].Data.(*Grid); ok {
|
|
for i := range visionOutputs.Dim(1) {
|
|
w := grid.Width / m.spatialMergeSize
|
|
positionSlice[1][mi.Index+i] += int32(i / w)
|
|
positionSlice[2][mi.Index+i] += int32(i % w)
|
|
}
|
|
}
|
|
|
|
deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:]))
|
|
for i, mm := range mi.Multimodal[1:] {
|
|
deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...)
|
|
ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1))))
|
|
}
|
|
}
|
|
|
|
positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0])*len(positionSlice))
|
|
for i, layer := range m.TextModel.Layers {
|
|
if m.Cache != nil {
|
|
m.Cache.SetLayer(i)
|
|
}
|
|
|
|
var outputs ml.Tensor
|
|
if i == len(m.TextModel.Layers)-1 {
|
|
outputs = batch.Outputs
|
|
}
|
|
|
|
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)
|
|
if i < len(deepstackVisualEmbeds) {
|
|
hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i])
|
|
}
|
|
}
|
|
|
|
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, 1e-06)
|
|
return m.Output.Forward(ctx, hiddenStates), nil
|
|
}
|
|
|
|
func New(c fs.Config) (model.Model, error) {
|
|
m := Model{
|
|
TextProcessor: model.NewBytePairEncoding(
|
|
&model.Vocabulary{
|
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
|
EOS: append(
|
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
|
),
|
|
},
|
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
|
),
|
|
TextModel: newTextModel(c),
|
|
VisionModel: newVisionModel(c),
|
|
ImageProcessor: newImageProcessor(c),
|
|
}
|
|
|
|
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
|
m.positionCache = nil
|
|
positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
|
return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil
|
|
})
|
|
return &m, nil
|
|
}
|
|
|
|
func init() {
|
|
model.Register("qwen3vl", New)
|
|
model.Register("qwen3vlmoe", New)
|
|
}
|