mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-18 19:56:59 +00:00
This commit represents a complete rework after pulling the latest changes from official ollama/ollama repository and re-applying Tesla K80 compatibility patches. ## Key Changes ### CUDA Compute Capability 3.7 Support (Tesla K80) - Added sm_37 (compute 3.7) to CMAKE_CUDA_ARCHITECTURES in CMakeLists.txt - Updated CMakePresets.json to include compute 3.7 in "CUDA 11" preset - Using 37-virtual (PTX with JIT compilation) for maximum compatibility ### Legacy Toolchain Compatibility - **NVIDIA Driver**: 470.256.02 (last version supporting Kepler/K80) - **CUDA Version**: 11.4.4 (last CUDA 11.x supporting compute 3.7) - **GCC Version**: 10.5.0 (required by CUDA 11.4 host_config.h) ### CPU Architecture Trade-offs Due to GCC 10.5 limitation, sacrificed newer CPU optimizations: - Alderlake CPU variant enabled WITHOUT AVX_VNNI (requires GCC 11+) - Still supports: SSE4.2, AVX, F16C, AVX2, BMI2, FMA - Performance impact: ~3-7% on newer CPUs (acceptable for K80 compatibility) ### Build System Updates - Modified ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt for compute 3.7 - Added -Wno-deprecated-gpu-targets flag to suppress warnings - Updated ml/backend/ggml/ggml/src/CMakeLists.txt for Alderlake without AVX_VNNI ### Upstream Sync Merged latest llama.cpp changes including: - Enhanced KV cache management with ISWA and hybrid memory support - Improved multi-modal support (mtmd framework) - New model architectures (Gemma3, Llama4, Qwen3, etc.) - GPU backend improvements for CUDA, Metal, and ROCm - Updated quantization support and GGUF format handling ### Documentation - Updated CLAUDE.md with comprehensive build instructions - Documented toolchain constraints and CPU architecture trade-offs - Removed outdated CI/CD workflows (tesla-k80-*.yml) - Cleaned up temporary development artifacts ## Rationale This fork maintains Tesla K80 GPU support (compute 3.7) which was dropped in official Ollama due to legacy driver/CUDA requirements. The toolchain constraint creates a deadlock: - K80 → Driver 470 → CUDA 11.4 → GCC 10 → No AVX_VNNI We accept the loss of cutting-edge CPU optimizations to enable running modern LLMs on legacy but still capable Tesla K80 hardware (12GB VRAM per GPU). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
118 lines
3.3 KiB
Go
118 lines
3.3 KiB
Go
package mllama
|
|
|
|
import (
|
|
"bytes"
|
|
"image"
|
|
"slices"
|
|
|
|
"github.com/ollama/ollama/fs"
|
|
"github.com/ollama/ollama/kvcache"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/ml/nn"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
)
|
|
|
|
type Model struct {
|
|
model.Base
|
|
model.BytePairEncoding
|
|
|
|
*VisionModel `gguf:"v"`
|
|
*TextModel
|
|
|
|
Projector *nn.Linear `gguf:"mm.0"`
|
|
|
|
ImageProcessor
|
|
}
|
|
|
|
const (
|
|
crossAttentionLayer = iota
|
|
selfAttentionLayer
|
|
)
|
|
|
|
func New(c fs.Config) (model.Model, error) {
|
|
m := Model{
|
|
BytePairEncoding: model.NewBytePairEncoding(
|
|
&model.Vocabulary{
|
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
|
EOS: append(
|
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
|
),
|
|
},
|
|
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
|
),
|
|
ImageProcessor: newImageProcessor(c),
|
|
VisionModel: newVisionModel(c),
|
|
TextModel: newTextModel(c),
|
|
}
|
|
|
|
encoderCache := kvcache.NewEncoderCache()
|
|
encoderCache.SetConfig(ml.CacheConfig{})
|
|
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
|
|
|
return &m, nil
|
|
}
|
|
|
|
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) {
|
|
if len(m.VisionModel.Transformer.Layers) == 0 || len(m.GlobalTransformer.Layers) == 0 {
|
|
return nil, model.ErrNoVisionModel
|
|
}
|
|
|
|
image, _, err := image.Decode(bytes.NewReader(multimodalData))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
f32s, ratio, err := m.ImageProcessor.ProcessImage(image)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if ratio.numTiles() < m.maxNumTiles {
|
|
// Pad tiles to maxNumTiles
|
|
f32s = slices.Grow(f32s, m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles)
|
|
f32s = f32s[:m.imageSize*m.imageSize*m.numChannels*m.maxNumTiles]
|
|
}
|
|
|
|
pixelValues := ctx.Input().FromFloats(f32s, m.imageSize, m.imageSize, m.numChannels, m.maxNumTiles)
|
|
aspectRatio := ctx.Input().FromInts([]int32{int32(ratio.rank)}, 1)
|
|
|
|
positionIDs := ctx.Arange(0, 1601, 1, ml.DTypeI32)
|
|
crossAttentionStates := m.VisionModel.Forward(ctx, pixelValues, positionIDs, aspectRatio)
|
|
projectedOutputs := m.Projector.Forward(ctx, crossAttentionStates)
|
|
|
|
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
|
|
}
|
|
|
|
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
|
for i := range inputs {
|
|
if inputs[i].Multimodal != nil {
|
|
inputs[i].Token = 128256 // <|image|>
|
|
}
|
|
}
|
|
|
|
return inputs, nil
|
|
}
|
|
|
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
|
var crossAttentionStates ml.Tensor
|
|
if len(batch.Multimodal) > 0 {
|
|
crossAttentionStates = batch.Multimodal[len(batch.Multimodal)-1].Multimodal[0].Tensor
|
|
}
|
|
|
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
|
|
|
// TODO: attention mask, cross attention mask
|
|
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
|
|
}
|
|
|
|
func init() {
|
|
model.Register("mllama", New)
|
|
}
|