Fix Tesla K80 CUBLAS compatibility with two-tier fallback strategy

This commit implements comprehensive Tesla K80 (Kepler, compute 3.7)
compatibility for batched matrix multiplication operations.

**Problem:**
Modern CUBLAS functions fail on Tesla K80 with CUBLAS_STATUS_ARCH_MISMATCH:
1. CUBLAS_GEMM_DEFAULT_TENSOR_OP requires Tensor Cores (Volta+ only)
2. cublasGemmStridedBatchedEx/cublasGemmBatchedEx have architectural
   requirements beyond algorithm selection

**Solution - Two-Tier Fallback:**

Tier 1: Algorithm Selection
- Volta+ (cc >= 7.0): CUBLAS_GEMM_DEFAULT_TENSOR_OP
- Pre-Volta (cc < 7.0): CUBLAS_GEMM_DEFAULT

Tier 2: Function Selection
- Volta+ or non-FP32: Use *Ex variants (flexible precision)
- Kepler/Maxwell/Pascal with FP32: Use legacy type-specific functions
  (cublasSgemmStridedBatched, cublasSgemmBatched)

**Changes:**

CUDA Implementation:
- ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
  * ggml_cuda_op_mul_mat_cublas: Algorithm selection for non-batched ops
  * ggml_cuda_mul_mat_batched_cublas_impl: Two-tier fallback for batched ops
  * Added GGML_CUDA_DEBUG environment variable for conditional debug logging
  * Comprehensive function documentation explaining fallback strategy

Documentation:
- CLAUDE.md
  * Added Tesla K80 CUBLAS Compatibility section
  * Documented GGML_CUDA_DEBUG environment variable
  * Enhanced "Running Ollama" section with log capture examples
  * Updated Files Modified list

Code Comments:
- Added detailed comments throughout CUDA code explaining:
  * Why TENSOR_OP fails on pre-Volta GPUs
  * Why *Ex functions require architectural support
  * Compute capability checks and fallback logic
  * Debug logging usage

**Testing:**
All models verified working on Tesla K80:
-  gemma3:4b
-  gpt-oss
-  deepseek-r1

Debug flag tested in both enabled and disabled states.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Shang Chieh Tseng
2025-11-05 23:52:45 +08:00
parent ef14fb5b26
commit d948926581
8 changed files with 616 additions and 153 deletions

View File

@@ -103,9 +103,30 @@ type NewSequenceParams struct {
var errorInputTooLong = errors.New("the input length exceeds the context length")
// NewSequence creates a new inference sequence
//
// This prepares everything needed for text generation:
// 1. Tokenize prompt into input tokens
// 2. Process vision embeddings (if images provided)
// 3. Handle context window truncation (if prompt too long)
// 4. Create sampling context (controls temperature, top_p, etc.)
// 5. Prepare for KV cache management
//
// Parameters:
// - prompt: Text to generate from (already formatted by chat template)
// - images: Image data for multimodal models (empty for text-only)
// - params: Generation parameters (num_predict, stop sequences, sampling settings)
//
// Returns:
// - *Sequence: Ready-to-use sequence for inference loop
// - error: If prompt too long or other errors
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
// Wait for model to finish loading (blocks until ready)
s.ready.Wait()
// Tokenize prompt and process images into input sequence
// For text-only: converts string → token IDs
// For multimodal: also runs vision projector to get image embeddings
inputs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
@@ -113,6 +134,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
return nil, errors.New("no input provided")
}
// numKeep controls how many tokens to keep when shifting context window
// If context fills up, we discard middle tokens but keep first numKeep tokens
if params.numKeep < 0 {
params.numKeep = len(inputs)
}
@@ -122,14 +145,19 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// Ensure that at least 1 input can be discarded during shift
// Otherwise we can't make room for new generation
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
// Check if prompt exceeds context window (num_ctx)
// Example: gemma3 default num_ctx = 8192 tokens
if len(inputs) > s.cache.numCtx {
discard := len(inputs) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
// Truncate prompt: keep first numKeep tokens + last tokens
// This preserves system prompt and instruction while removing middle content
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
@@ -137,12 +165,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
inputs = newInputs
}
// Create sampling context for token selection
// This manages:
// - Logit biasing (temperature, top_p, top_k, min_p)
// - Repetition penalty
// - Grammar constraints (for JSON output)
// - Seed for reproducibility
var sc *llama.SamplingContext
if params.samplingParams != nil {
sc, err = llama.NewSamplingContext(s.model, *params.samplingParams)
if err != nil {
return nil, err
}
// Prime the sampling context with prompt tokens
// This is needed for repetition penalty calculation
for _, input := range inputs {
if input.embed == nil {
sc.Accept(input.token, false)
@@ -573,7 +609,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
// completion is the HTTP handler for POST /completion endpoint
//
// This is where the runner subprocess receives requests from the main Ollama server.
// It handles the entire text generation pipeline:
// 1. Parse CompletionRequest from JSON body
// 2. Create new Sequence (tokenize, setup sampling)
// 3. Add sequence to inference queue (managed by background goroutine)
// 4. Stream generated tokens back via HTTP response
//
// The actual inference happens in a separate goroutine (s.run())
// which continuously processes sequences from s.seqs queue.
//
// Flow:
// - Main server calls: POST http://127.0.0.1:<port>/completion
// - This handler receives request
// - Creates Sequence and adds to s.seqs queue
// - Background goroutine (s.run()) picks up sequence
// - Tokens generated are sent to seq.responses channel
// - This handler reads from channel and streams back to main server
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
// Parse CompletionRequest sent by llm/server.go Completion()
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
@@ -585,7 +641,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options = &opts
}
// Set the headers to indicate streaming
// Set headers for HTTP streaming (Server-Sent Events style)
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -595,7 +651,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Extract options from the CompletionRequest
// Convert API options to llama.cpp sampling parameters
// These control how tokens are selected from logits:
// - TopK: only consider top K tokens
// - TopP (nucleus sampling): consider tokens until cumulative prob reaches P
// - MinP: minimum probability threshold
// - Temperature: controls randomness (lower = more deterministic)
// - RepeatPenalty: penalize repeating tokens
// - Grammar: GBNF grammar for constrained generation (e.g., JSON output)
samplingParams := llama.SamplingParams{
TopK: req.Options.TopK,
TopP: req.Options.TopP,
@@ -610,14 +673,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
Grammar: req.Grammar,
}
// Create new sequence (tokenize, setup sampling context)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
numPredict: req.Options.NumPredict, // Max tokens to generate
stop: req.Options.Stop, // Stop sequences
numKeep: req.Options.NumKeep, // Tokens to keep when shifting
samplingParams: &samplingParams,
embedding: false,
shift: req.Shift,
truncate: req.Truncate,
embedding: false, // Generate text, not embeddings
shift: req.Shift, // Allow context window shift
truncate: req.Truncate, // Allow prompt truncation
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
@@ -628,7 +692,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
// Acquire semaphore slot (limits concurrent requests to numParallel)
// Each parallel slot requires separate KV cache allocation
// Example: numParallel=4 means up to 4 requests can run simultaneously
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -638,10 +704,16 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Add sequence to the inference queue
// s.seqs is an array of [numParallel]*Sequence
// Background goroutine (s.run()) processes all non-nil sequences
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
// Allocate KV cache slot for this sequence
// The cache reuses previous computations if prompt prefixes match
// This speeds up multi-turn conversations significantly
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
@@ -650,8 +722,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Add to queue and wake up background inference goroutine
s.seqs[i] = seq
s.cond.Signal()
s.cond.Signal() // Notify s.run() that new sequence is ready
found = true
break
}
@@ -664,13 +737,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// *** STREAM TOKENS BACK TO MAIN SERVER ***
// The background goroutine (s.run()) will:
// 1. Process sequences in batches
// 2. Call context.Decode() for each batch (GPU/CPU inference)
// 3. Sample next token using sampling context
// 4. Send token text to seq.responses channel
//
// This handler reads from seq.responses and streams to HTTP client
for {
select {
case <-r.Context().Done():
// Client disconnected, signal background goroutine to stop
close(seq.quit)
return
case content, ok := <-seq.responses:
if ok {
// Stream intermediate token
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
@@ -679,8 +762,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
flusher.Flush()
flusher.Flush() // Ensure token is sent immediately
} else {
// Generation complete, send final response with metrics
// Metrics include:
// - PromptEvalCount: tokens in prompt
// - PromptEvalDuration: time to process prompt (Tesla K80)
// - EvalCount: tokens generated
// - EvalDuration: time to generate tokens (Tesla K80)
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: seq.doneReason,