mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
Sync with upstream ollama/ollama and restore Tesla K80 (compute 3.7) support
This commit represents a complete rework after pulling the latest changes from official ollama/ollama repository and re-applying Tesla K80 compatibility patches. ## Key Changes ### CUDA Compute Capability 3.7 Support (Tesla K80) - Added sm_37 (compute 3.7) to CMAKE_CUDA_ARCHITECTURES in CMakeLists.txt - Updated CMakePresets.json to include compute 3.7 in "CUDA 11" preset - Using 37-virtual (PTX with JIT compilation) for maximum compatibility ### Legacy Toolchain Compatibility - **NVIDIA Driver**: 470.256.02 (last version supporting Kepler/K80) - **CUDA Version**: 11.4.4 (last CUDA 11.x supporting compute 3.7) - **GCC Version**: 10.5.0 (required by CUDA 11.4 host_config.h) ### CPU Architecture Trade-offs Due to GCC 10.5 limitation, sacrificed newer CPU optimizations: - Alderlake CPU variant enabled WITHOUT AVX_VNNI (requires GCC 11+) - Still supports: SSE4.2, AVX, F16C, AVX2, BMI2, FMA - Performance impact: ~3-7% on newer CPUs (acceptable for K80 compatibility) ### Build System Updates - Modified ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt for compute 3.7 - Added -Wno-deprecated-gpu-targets flag to suppress warnings - Updated ml/backend/ggml/ggml/src/CMakeLists.txt for Alderlake without AVX_VNNI ### Upstream Sync Merged latest llama.cpp changes including: - Enhanced KV cache management with ISWA and hybrid memory support - Improved multi-modal support (mtmd framework) - New model architectures (Gemma3, Llama4, Qwen3, etc.) - GPU backend improvements for CUDA, Metal, and ROCm - Updated quantization support and GGUF format handling ### Documentation - Updated CLAUDE.md with comprehensive build instructions - Documented toolchain constraints and CPU architecture trade-offs - Removed outdated CI/CD workflows (tesla-k80-*.yml) - Cleaned up temporary development artifacts ## Rationale This fork maintains Tesla K80 GPU support (compute 3.7) which was dropped in official Ollama due to legacy driver/CUDA requirements. The toolchain constraint creates a deadlock: - K80 → Driver 470 → CUDA 11.4 → GCC 10 → No AVX_VNNI We accept the loss of cutting-edge CPU optimizations to enable running modern LLMs on legacy but still capable Tesla K80 hardware (12GB VRAM per GPU). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -80,13 +79,16 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// shift if context window is exceeded
|
||||
shift bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
processingDuration time.Duration
|
||||
generationDuration time.Duration
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -95,13 +97,15 @@ type NewSequenceParams struct {
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
shift bool
|
||||
truncate bool
|
||||
}
|
||||
|
||||
var errorInputTooLong = errors.New("the input length exceeds the context length")
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -122,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
if !params.truncate {
|
||||
return nil, errorInputTooLong
|
||||
}
|
||||
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
|
||||
@@ -143,18 +151,18 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
return &Sequence{
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shift: params.shift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -201,13 +209,19 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
|
||||
return nil, fmt.Errorf("invalid image index: %d", n)
|
||||
}
|
||||
|
||||
embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data)
|
||||
chunks, err := s.image.MultimodalTokenize(s.lc, images[imageIndex].Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, e := range embed {
|
||||
inputs = append(inputs, input{embed: e})
|
||||
for _, c := range chunks {
|
||||
if len(c.Embed) != 0 {
|
||||
inputs = append(inputs, input{embed: c.Embed})
|
||||
} else {
|
||||
for _, t := range c.Tokens {
|
||||
inputs = append(inputs, input{token: t})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,6 +230,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error)
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
// modelPath is the location of the model to be loaded
|
||||
modelPath string
|
||||
|
||||
// loadMu prevents more than one load attempt from occurring at a time
|
||||
loadMu sync.Mutex
|
||||
|
||||
// is the server ready to process requests?
|
||||
// protects access to model and image
|
||||
ready sync.WaitGroup
|
||||
@@ -364,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var batch *llama.Batch
|
||||
var numOutputs int
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
@@ -383,6 +404,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
if !seq.shift {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
@@ -421,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
break
|
||||
}
|
||||
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
output := i+1 == len(seq.inputs)
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
|
||||
if output {
|
||||
numOutputs++
|
||||
}
|
||||
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
}
|
||||
@@ -433,11 +464,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
t := time.Now()
|
||||
if err := s.lc.Decode(batch); err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
if numOutputs > 0 {
|
||||
s.lc.Synchronize()
|
||||
}
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
@@ -451,12 +486,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
seq.processingDuration += time.Since(t)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numDecoded += 1
|
||||
if seq.numDecoded == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
seq.numDecoded++
|
||||
if seq.numDecoded > 1 {
|
||||
seq.generationDuration += time.Since(t)
|
||||
} else {
|
||||
seq.processingDuration += time.Since(t)
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -578,8 +616,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
shift: req.Shift,
|
||||
truncate: req.Truncate,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -641,9 +685,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.generationDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -663,7 +707,14 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
|
||||
// TODO (jmorganca): this should be provided by the server via the
|
||||
// request options and truncated here in the runner, instead of relying on
|
||||
// the server's truncate logic
|
||||
truncate: true,
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
@@ -723,21 +774,12 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
type multiLPath []string
|
||||
|
||||
func (m *multiLPath) Set(value string) error {
|
||||
*m = append(*m, value)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
// loadModel allocates memory based on the given parameters and loads the weights. The
|
||||
// memory allocated is worst case for text models but not for vision.
|
||||
func (s *Server) loadModel(
|
||||
params llama.ModelParams,
|
||||
mpath string,
|
||||
lpath multiLPath,
|
||||
lpath []string,
|
||||
ppath string,
|
||||
kvSize int,
|
||||
kvCacheType string,
|
||||
@@ -757,12 +799,10 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if lpath.String() != "" {
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
for _, path := range lpath {
|
||||
err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -783,26 +823,81 @@ func (s *Server) loadModel(
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
// load is the handler called by the Ollama server to process different
|
||||
// load operations
|
||||
func (s *Server) load(w http.ResponseWriter, r *http.Request) {
|
||||
s.loadMu.Lock()
|
||||
defer s.loadMu.Unlock()
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
if s.status != llm.ServerStatusLaunched {
|
||||
http.Error(w, "model already loaded", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var req llm.LoadRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
slog.Info("load", "request", req)
|
||||
|
||||
switch req.Operation {
|
||||
// LoadOperationFit and LoadOperationAlloc have no meaning here - just return a successful response
|
||||
|
||||
case llm.LoadOperationCommit:
|
||||
s.batchSize = req.BatchSize
|
||||
s.parallel = req.Parallel
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
gpuIDs := llama.EnumerateGPUs()
|
||||
tensorSplit := make([]float32, len(gpuIDs))
|
||||
numGPU := 0
|
||||
for i := range gpuIDs {
|
||||
for _, layers := range req.GPULayers {
|
||||
if gpuIDs[i] == layers.DeviceID {
|
||||
tensorSplit[i] = float32(len(layers.Layers))
|
||||
numGPU += len(layers.Layers)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: numGPU,
|
||||
MainGpu: req.MainGPU,
|
||||
UseMmap: req.UseMmap && len(req.LoraPath) == 0,
|
||||
TensorSplit: tensorSplit,
|
||||
Progress: func(progress float32) {
|
||||
s.progress = progress
|
||||
},
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusLoadingModel
|
||||
go s.loadModel(params, s.modelPath, req.LoraPath, req.ProjectorPath, req.KvSize, req.KvCacheType, req.FlashAttention, req.NumThreads, req.MultiUserCache)
|
||||
|
||||
case llm.LoadOperationClose:
|
||||
// No-op for us
|
||||
if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
resp := llm.LoadResponse{Success: true}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func Execute(args []string) error {
|
||||
fs := flag.NewFlagSet("runner", flag.ExitOnError)
|
||||
mpath := fs.String("model", "", "Path to model binary file")
|
||||
ppath := fs.String("mmproj", "", "Path to projector binary file")
|
||||
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
nGpuLayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGpu := fs.Int("main-gpu", 0, "Main GPU")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
||||
_ = fs.Bool("verbose", false, "verbose output (default: disabled)")
|
||||
noMmap := fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
|
||||
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
|
||||
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
|
||||
|
||||
var lpaths multiLPath
|
||||
fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)")
|
||||
|
||||
fs.Usage = func() {
|
||||
fmt.Fprintf(fs.Output(), "Runner usage\n")
|
||||
@@ -817,35 +912,11 @@ func Execute(args []string) error {
|
||||
llama.BackendInit()
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
parallel: *parallel,
|
||||
seqs: make([]*Sequence, *parallel),
|
||||
seqsSem: semaphore.NewWeighted(int64(*parallel)),
|
||||
status: llm.ServerStatusLoadingModel,
|
||||
}
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := llama.ModelParams{
|
||||
NumGpuLayers: *nGpuLayers,
|
||||
MainGpu: *mainGpu,
|
||||
UseMmap: !*noMmap && lpaths.String() == "",
|
||||
TensorSplit: tensorSplitFloats,
|
||||
Progress: func(progress float32) {
|
||||
server.progress = progress
|
||||
},
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@@ -863,6 +934,7 @@ func Execute(args []string) error {
|
||||
defer listener.Close()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("POST /load", server.load)
|
||||
mux.HandleFunc("/embedding", server.embeddings)
|
||||
mux.HandleFunc("/completion", server.completion)
|
||||
mux.HandleFunc("/health", server.health)
|
||||
|
||||
Reference in New Issue
Block a user