Fix Tesla K80 CUBLAS compatibility with two-tier fallback strategy

This commit implements comprehensive Tesla K80 (Kepler, compute 3.7)
compatibility for batched matrix multiplication operations.

**Problem:**
Modern CUBLAS functions fail on Tesla K80 with CUBLAS_STATUS_ARCH_MISMATCH:
1. CUBLAS_GEMM_DEFAULT_TENSOR_OP requires Tensor Cores (Volta+ only)
2. cublasGemmStridedBatchedEx/cublasGemmBatchedEx have architectural
   requirements beyond algorithm selection

**Solution - Two-Tier Fallback:**

Tier 1: Algorithm Selection
- Volta+ (cc >= 7.0): CUBLAS_GEMM_DEFAULT_TENSOR_OP
- Pre-Volta (cc < 7.0): CUBLAS_GEMM_DEFAULT

Tier 2: Function Selection
- Volta+ or non-FP32: Use *Ex variants (flexible precision)
- Kepler/Maxwell/Pascal with FP32: Use legacy type-specific functions
  (cublasSgemmStridedBatched, cublasSgemmBatched)

**Changes:**

CUDA Implementation:
- ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
  * ggml_cuda_op_mul_mat_cublas: Algorithm selection for non-batched ops
  * ggml_cuda_mul_mat_batched_cublas_impl: Two-tier fallback for batched ops
  * Added GGML_CUDA_DEBUG environment variable for conditional debug logging
  * Comprehensive function documentation explaining fallback strategy

Documentation:
- CLAUDE.md
  * Added Tesla K80 CUBLAS Compatibility section
  * Documented GGML_CUDA_DEBUG environment variable
  * Enhanced "Running Ollama" section with log capture examples
  * Updated Files Modified list

Code Comments:
- Added detailed comments throughout CUDA code explaining:
  * Why TENSOR_OP fails on pre-Volta GPUs
  * Why *Ex functions require architectural support
  * Compute capability checks and fallback logic
  * Debug logging usage

**Testing:**
All models verified working on Tesla K80:
-  gemma3:4b
-  gpt-oss
-  deepseek-r1

Debug flag tested in both enabled and disabled states.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
Shang Chieh Tseng
2025-11-05 23:52:45 +08:00
parent ef14fb5b26
commit d948926581
8 changed files with 616 additions and 153 deletions

158
CLAUDE.md
View File

@@ -24,6 +24,7 @@ This document tracks development goals and notes for this Ollama repository fork
1. `ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt` - Added 3.7 compute capability to default architecture list
2. `CMakePresets.json` - Added compute 3.7 to "CUDA 11" preset and created dedicated "CUDA 11 K80" preset
3. `ml/backend/ggml/ggml/src/CMakeLists.txt` - Enabled Alderlake CPU variant without AVX_VNNI
4. `ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu` - Added CUBLAS legacy function fallback for Kepler GPU compatibility
### Key Changes
- Added `37-virtual` to CMAKE_CUDA_ARCHITECTURES (using PTX with JIT compilation for better compatibility)
@@ -37,6 +38,33 @@ This document tracks development goals and notes for this Ollama repository fork
- **CUDA 11.4.4 does NOT support**: 87 (requires 11.7+), 89 (requires 11.8+), 90 (requires 12.0+)
- CUDA 12+ dropped Kepler support entirely
### Tesla K80 CUBLAS Compatibility
**Challenge**: Tesla K80 (Kepler, compute 3.7) requires special handling for batched matrix multiplication due to:
1. Lack of Tensor Cores (introduced in Volta, compute 7.0+)
2. Architectural limitations with modern CUBLAS `*Ex` function variants
**Solution - Two-Tier Fallback Strategy**:
**Tier 1: GEMM Algorithm Selection**
- Volta+ (cc >= 7.0): Use `CUBLAS_GEMM_DEFAULT_TENSOR_OP` (value 99)
- Pre-Volta (cc < 7.0): Use `CUBLAS_GEMM_DEFAULT` (value -1)
**Tier 2: CUBLAS Function Selection**
- **Modern GPUs** (Volta+): Use `cublasGemmStridedBatchedEx` / `cublasGemmBatchedEx`
- Support mixed precision, flexible compute types, algorithm selection
- **Legacy GPUs** (Kepler/Maxwell/Pascal with FP32): Use `cublasSgemmStridedBatched` / `cublasSgemmBatched`
- The `*Ex` variants have architectural requirements beyond algorithm selection
- Even with `CUBLAS_GEMM_DEFAULT`, `*Ex` functions fail with `CUBLAS_STATUS_ARCH_MISMATCH`
- Legacy functions only support FP32, but work reliably on older architectures
**Modified Function**: `ggml_cuda_mul_mat_batched_cublas_impl` in `ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu:1986`
**Tested Models** (verified on Tesla K80):
- ✅ gemma3:4b
- ✅ gpt-oss
- ✅ deepseek-r1
## Build Instructions
### Complete Build from Scratch
@@ -50,104 +78,65 @@ go clean -cache
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --preset "CUDA 11"
# Build the C/C++/CUDA libraries
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --build build -j$(nproc)
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --build build -j 48
# Build the Go binary
go build -o ollama .
# Verify the build
./ollama --version
strings build/lib/ollama/libggml-cuda.so | grep "\.target sm_" | sort -u
```
### Alternative: K80-Optimized Build
For smaller binary size (K80 only):
```bash
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --preset "CUDA 11 K80"
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --build build -j$(nproc)
go build -o ollama .
```
### Incremental Builds
```bash
# If you only modified Go code (no C/C++/CUDA changes)
go build -o ollama .
# If you modified C/C++/CUDA code
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --build build -j$(nproc)
go build -o ollama .
# If CMake cache gets corrupted
go clean -cache
rm -rf build
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --preset "CUDA 11"
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --build build -j$(nproc)
go build -o ollama .
```
## Build Test Results - SUCCESSFUL ✓
Build completed successfully on 2025-11-04.
### Verified Compute Capabilities
- ✓ sm_37 (Tesla K80 - Kepler) **← YOUR TARGET GPU**
- ✓ sm_50 (Maxwell)
- ✓ sm_60 (Pascal P100)
- ✓ sm_61 (Pascal)
- ✓ sm_70 (Volta V100)
- ✓ sm_75 (Turing)
- ✓ sm_80 (Ampere)
- ✓ sm_86 (Ampere RTX 3000)
### Build Artifacts
- CUDA library: `build/lib/ollama/libggml-cuda.so` (283MB)
- CPU libraries: `build/lib/ollama/libggml-cpu-*.so` (various optimizations)
- Main executable: `ollama` (59MB)
### Compiler Configuration
- C Compiler: GCC 10.5.0
- C++ Compiler: GCC 10.5.0
- CUDA Host Compiler: GCC 10.5.0
- CUDA Version: 11.4.48
- CPU Variants: x64, sse42, sandybridge, haswell, skylakex, icelake, alderlake (without AVX_VNNI)
## Running Ollama
### Basic Server Start
```bash
# Start the Ollama server
./ollama serve
# Run with verbose logging
OLLAMA_DEBUG=1 ./ollama serve
# Quick test without building binary
go run . serve
# Check GPU detection
nvidia-smi
```
## Verification Commands
### Debug and Logging Options
**Environment Variables**:
- `OLLAMA_DEBUG=1` - Enable verbose Ollama server logging
- `GGML_CUDA_DEBUG=1` - Enable detailed CUDA/CUBLAS operation logging (batched matrix multiplication)
```bash
# Check compiler versions
gcc --version
g++ --version
/usr/local/cuda-11.4/bin/nvcc --version
# Run with Ollama verbose logging only
OLLAMA_DEBUG=1 ./ollama serve
# Verify CUDA library has correct compute capabilities
strings build/lib/ollama/libggml-cuda.so | grep "\.target sm_" | sort -u
# Run with both Ollama and CUDA debug logging
OLLAMA_DEBUG=1 GGML_CUDA_DEBUG=1 ./ollama serve
# Check ollama binary links correctly
ldd ollama
# Capture all output to file
./ollama serve 2>&1 | tee /tmp/ollama_serve.log
# List all built libraries
ls -lh build/lib/ollama/
# Capture only stderr (warnings/errors) to file
./ollama serve 2> /tmp/ollama_errors.log
# Run in background with full logging
OLLAMA_DEBUG=1 ./ollama serve 2>&1 | tee /tmp/ollama_full.log &
# Run in background with debug logging
OLLAMA_DEBUG=1 GGML_CUDA_DEBUG=1 ./ollama serve 2>&1 | tee /tmp/ollama_debug.log &
# Monitor a running background server
tail -f /tmp/ollama_full.log
# Tail recent log entries
tail -100 /tmp/ollama_full.log
# Stop all ollama processes
pkill ollama
```
**When to Use GGML_CUDA_DEBUG**:
- Debugging CUBLAS errors on Tesla K80 or other legacy GPUs
- Verifying compute capability detection
- Troubleshooting batched matrix multiplication issues
- Understanding which CUBLAS functions are being used (legacy vs Ex variants)
## CPU Architecture Compatibility
### The GCC/CUDA/Alderlake Constraint
@@ -187,20 +176,3 @@ This build faces a fundamental compatibility constraint:
| Alderlake (2021) | alderlake | ⚠️ Partial | Missing AVX_VNNI only |
| Raptor Lake (2022) | alderlake | ⚠️ Partial | Missing AVX_VNNI only |
### Alternative Solutions
**Option A: Separate CPU-only build**
```bash
# Use GCC 11+ for CPU-only build (no CUDA)
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --preset "CPU" # hypothetical CPU-only preset
CC=/usr/local/bin/gcc CXX=/usr/local/bin/g++ cmake --build build
```
**Option B: Upgrade GPU**
- Use GPU with Ampere/Ada architecture (compute 8.0+)
- Supports driver 525+ → CUDA 12+ → GCC 11+
- Enables full AVX_VNNI support
**Option C: Accept the limitation**
- Current setup provides good performance for most workloads
- The 3-7% performance difference is acceptable for many use cases

View File

@@ -322,11 +322,18 @@ func StopHandler(cmd *cobra.Command, args []string) error {
return nil
}
// RunHandler is the entry point for "ollama run <model>" command
// This function orchestrates the entire model execution flow:
// 1. Parse command-line arguments and options (format, keepalive, think mode, etc.)
// 2. Determine if running in interactive or non-interactive mode
// 3. Query model info from server (or pull if not found)
// 4. Route to either generateInteractive() or generate() based on mode
func RunHandler(cmd *cobra.Command, args []string) error {
// Default to interactive mode unless prompt is provided or output is piped
interactive := true
opts := runOptions{
Model: args[0],
Model: args[0], // Model name (e.g., "gemma3")
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]any{},
ShowConnect: true,
@@ -379,7 +386,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
prompts := args[1:]
// prepend stdin to the prompt if provided
// Check if stdin contains input (e.g., piped data: echo "hello" | ollama run gemma3)
// If so, prepend it to the prompt and switch to non-interactive mode
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
@@ -392,10 +400,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
// If prompt provided as argument (e.g., ollama run gemma3 "tell me a joke")
// then use non-interactive mode (single-shot generation)
if len(prompts) > 0 {
interactive = false
}
// Be quiet if we're redirecting to a pipe or file
// If stdout is redirected to a pipe or file, use non-interactive mode
if !term.IsTerminal(int(os.Stdout.Fd())) {
interactive = false
}
@@ -406,22 +416,30 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
opts.WordWrap = !nowrap
// Fill out the rest of the options based on information about the
// model.
// Create HTTP client to communicate with Ollama server
// The server must be running (started via "ollama serve")
client, err := api.ClientFromEnvironment()
if err != nil {
return err
}
// Query model metadata from server (HTTP GET /api/show)
// This retrieves:
// - Model capabilities (vision, tools, thinking)
// - Model parameters (context size, architecture)
// - Chat template format
// If model not found locally, automatically pull from registry
name := args[0]
info, err := func() (*api.ShowResponse, error) {
showReq := &api.ShowRequest{Name: name}
info, err := client.Show(cmd.Context(), showReq)
var se api.StatusError
if errors.As(err, &se) && se.StatusCode == http.StatusNotFound {
// Model not found locally, pull it from registry
if err := PullHandler(cmd, []string{name}); err != nil {
return nil, err
}
// Retry after successful pull
return client.Show(cmd.Context(), &api.ShowRequest{Name: name})
}
return info, err
@@ -435,6 +453,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
// Detect if model supports multimodal input (images + text)
// Used for models like LLaVA, Bakllava, or vision-capable models
opts.MultiModal = slices.Contains(info.Capabilities, model.CapabilityVision)
// TODO: remove the projector info and vision info checks below,
@@ -453,6 +473,8 @@ func RunHandler(cmd *cobra.Command, args []string) error {
opts.ParentModel = info.Details.ParentModel
if interactive {
// In interactive mode, load the model into memory first
// This sends a load request to the server, which triggers the scheduler
if err := loadOrUnloadModel(cmd, &opts); err != nil {
var sErr api.AuthorizationError
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
@@ -466,6 +488,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return err
}
// Display any previous conversation history (for multi-turn chats)
for _, msg := range info.Messages {
switch msg.Role {
case "user":
@@ -478,8 +501,12 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
// Enter interactive REPL mode (Read-Eval-Print Loop)
// User can enter multiple prompts in sequence
return generateInteractive(cmd, opts)
}
// Non-interactive mode: single generation then exit
// Used for: ollama run gemma3 "prompt here"
return generate(cmd, opts)
}

View File

@@ -1,3 +1,18 @@
// Package llama provides Go bindings to llama.cpp via CGO
//
// This is the bridge between Go code and C/C++/CUDA inference engine.
// All actual model inference happens through these CGO calls.
//
// Key components:
// - LoadModelFromFile(): Loads GGUF model file into memory
// - Context.Decode(): Runs inference (GPU/CPU matrix operations)
// - SamplingContext: Selects next token from logits
// - Batch: Groups tokens for efficient parallel processing
//
// For Tesla K80 (compute 3.7):
// - CUDA kernels compiled with PTX (JIT at runtime)
// - Model layers distributed across CPU/GPU based on num_gpu_layers
// - KV cache allocated on GPU VRAM (12GB available)
package llama
/*
@@ -12,13 +27,12 @@ package llama
#cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
#include <stdlib.h>
#include "ggml.h"
#include "llama.h"
#include "mtmd.h"
#include "ggml.h" // GGML tensor library (CPU/GPU operations)
#include "llama.h" // llama.cpp model loading and inference
#include "mtmd.h" // Multi-turn multi-document support
#include "mtmd-helper.h"
#include "gguf.h"
#include "sampling_ext.h"
#include "gguf.h" // GGUF file format parsing
#include "sampling_ext.h" // Token sampling (temperature, top_p, etc.)
extern bool llamaProgressCallback(float progress, void *user_data);
extern void llamaLog(int level, char* text, void* user_data);
@@ -58,9 +72,22 @@ func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
}
}
// BackendInit initializes the llama.cpp backend
//
// This must be called once before loading any models.
// It initializes:
// - CUDA backend (if GPUs available)
// - CPU backend
// - Memory allocators
// - Threading infrastructure
//
// For Tesla K80 (compute 3.7):
// - Detects CUDA device
// - Verifies compute capability support
// - Initializes cuBLAS for matrix operations
func BackendInit() {
ggml.OnceLoad()
C.llama_backend_init()
ggml.OnceLoad() // Load GGML shared library
C.llama_backend_init() // Initialize llama.cpp backend
}
func EnumerateGPUs() []ml.DeviceID {
@@ -145,15 +172,50 @@ func kvCacheTypeFromStr(s string) C.enum_ggml_type {
}
}
// Context represents an active inference context
//
// This wraps llama.cpp's llama_context which holds:
// - Model pointer
// - KV cache (key-value cache for attention)
// - Thread pool for CPU operations
// - RNG state for sampling
//
// Each Context can handle multiple parallel sequences (controlled by numParallel)
type Context struct {
c *C.struct_llama_context
numThreads int
c *C.struct_llama_context // C pointer to llama_context
numThreads int // Number of CPU threads for inference
}
var ErrKvCacheFull = errors.New("could not find a kv cache slot")
// Decode runs one inference step on a batch of tokens
//
// *** THIS IS WHERE ACTUAL INFERENCE HAPPENS ***
//
// For each token in the batch, this:
// 1. Retrieves token embeddings from model
// 2. Runs through transformer layers:
// - Attention (uses KV cache)
// - Feed-forward network
// - Layer normalization
// 3. Stores KV states in cache for future tokens
// 4. Produces output logits (probabilities for next token)
//
// GPU execution (Tesla K80, compute 3.7):
// - Matrix multiplications via cuBLAS
// - Attention via CUDA kernels
// - LayerNorm/RoPE/Softmax via CUDA kernels
// - Data transferred between CPU/GPU as needed
//
// Returns:
// - nil: success
// - ErrKvCacheFull: no space in KV cache (increase num_ctx or reduce batch size)
// - error: fatal error during inference
func (c *Context) Decode(batch *Batch) error {
// Positive return values does not mean a fatal error, but rather a warning.
// Call C function: int llama_decode(llama_context*, llama_batch)
// This executes the actual neural network forward pass
//
// Return codes:
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
// < 0 - error
@@ -234,13 +296,48 @@ func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
return true
}
// LoadModelFromFile loads a GGUF model file into memory
//
// *** THIS IS THE CORE MODEL LOADING FUNCTION ***
//
// This reads the GGUF file and loads model weights into memory.
// The process:
// 1. Parse GGUF file headers (metadata, architecture, tensors)
// 2. Memory-map (mmap) or read model weights
// 3. Distribute layers across devices based on NumGpuLayers:
// - First NumGpuLayers transformer layers → GPU
// - Remaining layers → CPU
// - Embeddings and output layer handling varies
// 4. Allocate device buffers for tensors
//
// For Tesla K80 (compute 3.7) with 12GB VRAM:
// - Example: gemma3-2b with Q4_0 quantization
// - Full model ~1.5GB, can fit entirely on GPU (NumGpuLayers=99)
// - Example: llama3-8b with Q4_0 quantization
// - Full model ~4.5GB, can fit entirely on GPU
// - Example: llama3-70b with Q4_0 quantization
// - Full model ~40GB, needs CPU offload (NumGpuLayers=20-30)
//
// Parameters:
// - modelPath: Path to GGUF file (e.g., ~/.ollama/models/blobs/sha256-abc123...)
// - params.NumGpuLayers: How many transformer layers to put on GPU
// - params.MainGpu: Which GPU to use (if multiple)
// - params.TensorSplit: How to split layers across multiple GPUs
// - params.UseMmap: Whether to use memory mapping (faster load, less RAM)
//
// Returns:
// - *Model: Loaded model ready for inference
// - error: If file not found, incompatible format, or out of memory
func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
// Initialize C parameters structure
cparams := C.llama_model_default_params()
cparams.n_gpu_layers = C.int(params.NumGpuLayers)
cparams.main_gpu = C.int32_t(params.MainGpu)
cparams.use_mmap = C.bool(params.UseMmap)
cparams.vocab_only = C.bool(params.VocabOnly)
cparams.n_gpu_layers = C.int(params.NumGpuLayers) // Layers to offload to GPU
cparams.main_gpu = C.int32_t(params.MainGpu) // Primary GPU device ID
cparams.use_mmap = C.bool(params.UseMmap) // Memory-map file (faster)
cparams.vocab_only = C.bool(params.VocabOnly) // Load vocabulary only
// Multi-GPU tensor split (for systems with multiple GPUs)
// Defines proportion of model to put on each GPU
if len(params.TensorSplit) > 0 {
tensorSplitData := &params.TensorSplit[0]
@@ -251,6 +348,7 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
}
// Progress callback (reports loading progress percentage)
if params.Progress != nil {
handle := cgo.NewHandle(params.Progress)
defer handle.Delete()
@@ -263,6 +361,12 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.progress_callback_user_data = unsafe.Pointer(&handle)
}
// *** CALL C FUNCTION TO LOAD MODEL ***
// This:
// 1. Opens and parses GGUF file
// 2. Allocates CPU/GPU memory for tensors
// 3. Loads/mmaps weights into memory
// 4. For Tesla K80: compiles CUDA kernels via PTX JIT (compute 3.7)
m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
if m.c == nil {
return nil, fmt.Errorf("unable to load model: %s", modelPath)

View File

@@ -1397,10 +1397,32 @@ type CompletionResponse struct {
EvalDuration time.Duration `json:"eval_duration"`
}
// Completion is the bridge between Go and the runner subprocess
//
// This function sends an HTTP POST request to the runner subprocess
// and streams back the generated tokens in real-time.
//
// Flow:
// 1. Validate and prepare request (grammar, options)
// 2. Acquire semaphore (limit concurrent requests to runner)
// 3. Send HTTP POST to http://127.0.0.1:<port>/completion
// 4. Runner subprocess receives request and starts generation loop
// 5. Stream back CompletionResponse objects (one per token or batch)
// 6. Call callback function fn() for each response chunk
//
// Parameters:
// - ctx: Context (for cancellation)
// - req: CompletionRequest (prompt, images, options)
// - fn: Callback function called for each generated token/chunk
//
// Returns:
// - error: If request fails or is cancelled
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
slog.Debug("completion request", "images", len(req.Images), "prompt", len(req.Prompt), "format", string(req.Format))
logutil.Trace("completion request", "prompt", req.Prompt)
// Handle JSON output format constraints
// If user requests JSON output, apply grammar to constrain generation
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
@@ -1408,13 +1430,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
// these as "not set".
break
case `"json"`:
// Use built-in JSON grammar
req.Grammar = grammarJSON
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
}
// User provided a JSON schema
// User provided a JSON schema - convert to GBNF grammar
// This constrains the model to only generate valid JSON matching the schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
@@ -1428,6 +1452,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
req.Options = &opts
}
// Acquire semaphore to limit concurrent requests
// The runner subprocess can only handle numParallel requests at once
// (each parallel slot requires separate KV cache allocation)
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -1443,7 +1470,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
req.Options.NumPredict = 10 * s.options.NumCtx
}
// Make sure the server is ready
// Wait for runner subprocess to be ready
// The subprocess may still be loading weights into memory
status, err := s.getServerStatusRetry(ctx)
if err != nil {
return err
@@ -1451,7 +1479,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("unexpected server status: %s", status)
}
// Handling JSON marshaling with special characters unescaped.
// Marshal CompletionRequest to JSON
// Use SetEscapeHTML(false) to avoid escaping special characters
buffer := &bytes.Buffer{}
enc := json.NewEncoder(buffer)
enc.SetEscapeHTML(false)
@@ -1460,6 +1489,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("failed to marshal data: %v", err)
}
// *** SEND HTTP POST TO RUNNER SUBPROCESS ***
// This is IPC (Inter-Process Communication) between:
// - Parent: Ollama server (this process)
// - Child: Runner subprocess (spawned by scheduler)
//
// The runner listens on 127.0.0.1:<random_port> with endpoints:
// - POST /completion - text generation (this call)
// - POST /embedding - generate embeddings
// - POST /tokenize - tokenize text
// - GET /health - health check
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
serverReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
if err != nil {
@@ -1467,6 +1506,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
serverReq.Header.Set("Content-Type", "application/json")
// Execute HTTP request to runner subprocess
res, err := http.DefaultClient.Do(serverReq)
if err != nil && errors.Is(err, context.Canceled) {
// client closed connection
@@ -1486,6 +1526,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
}
// *** STREAM RESPONSE BACK FROM RUNNER ***
// The runner subprocess streams generated tokens back line-by-line
// Each line is a JSON CompletionResponse object
//
// Response flow:
// 1. Runner tokenizes prompt
// 2. Runner creates inference batch
// 3. For each generation step:
// a. Call context.Decode() -> C.llama_decode() -> CUDA kernel
// b. Get logits from model output
// c. Apply sampling (temperature, top_p, top_k)
// d. Select next token
// e. Stream CompletionResponse{"content": "token_text", "done": false}
// 4. When complete, send final response with "done": true and metrics
scanner := bufio.NewScanner(res.Body)
buf := make([]byte, 0, maxBufferSize)
scanner.Buffer(buf, maxBufferSize)
@@ -1505,15 +1559,27 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
continue
}
// Handle Server-Sent Events format (optional "data: " prefix)
evt, ok := bytes.CutPrefix(line, []byte("data: "))
if !ok {
evt = line
}
// Parse CompletionResponse from JSON
// Fields:
// - Content: generated token(s) as string
// - Done: true if generation complete
// - DoneReason: "stop", "length", or "connection_closed"
// - PromptEvalCount: tokens in prompt
// - PromptEvalDuration: time to process prompt
// - EvalCount: tokens generated
// - EvalDuration: time to generate tokens
var c CompletionResponse
if err := json.Unmarshal(evt, &c); err != nil {
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
// Detect infinite loops (model repeating same token)
switch {
case strings.TrimSpace(c.Content) == lastToken:
tokenRepeat++
@@ -1528,12 +1594,14 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err()
}
// Call callback function for each generated token
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
// Final response includes all metrics
if c.Done {
fn(c)
return nil

View File

@@ -1392,6 +1392,14 @@ static void ggml_cuda_op_mul_mat_cublas(
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
// TENSOR_OP Support: Requires Tensor Cores (Volta+ on NVIDIA, cc >= 7.0)
// Tesla K80 (cc=3.7), GTX 1080 (cc=6.1), etc. do NOT have Tensor Cores
// Using CUBLAS_GEMM_DEFAULT_TENSOR_OP on pre-Volta causes CUBLAS_STATUS_ARCH_MISMATCH
cublasGemmAlgo_t gemm_algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA) {
gemm_algo = CUBLAS_GEMM_DEFAULT; // Fallback for pre-Volta GPUs
}
// This path tries to use BF16 with tensor cores via cublasGemmEx
// Will fail on pre-Ampere NVIDIA GPUs (compute < 8.0)
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
@@ -1418,7 +1426,7 @@ static void ggml_cuda_op_mul_mat_cublas(
src1_ptr, CUDA_R_16BF, ne10,
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
gemm_algo));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
@@ -1456,7 +1464,7 @@ static void ggml_cuda_op_mul_mat_cublas(
src1_ptr, CUDA_R_16F, ne10,
&beta, dst_dd_i, CUDA_R_32F, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
gemm_algo));
} else {
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
@@ -1470,7 +1478,7 @@ static void ggml_cuda_op_mul_mat_cublas(
src1_ptr, CUDA_R_16F, ne10,
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
CUBLAS_COMPUTE_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
gemm_algo));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
@@ -1975,6 +1983,25 @@ struct batched_mul_mat_traits<GGML_TYPE_F16> {
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
};
// Batched matrix multiplication using CUBLAS with support for legacy GPUs
//
// Tesla K80 (Kepler, compute 3.7) Compatibility:
// This function implements a two-tier fallback strategy for older GPUs:
//
// 1. GEMM Algorithm Selection:
// - Volta+ (cc >= 7.0): Use CUBLAS_GEMM_DEFAULT_TENSOR_OP (requires Tensor Cores)
// - Pre-Volta (cc < 7.0): Use CUBLAS_GEMM_DEFAULT (standard algorithm)
//
// 2. CUBLAS Function Selection:
// - Modern GPUs (Volta+): Use cublasGemmStridedBatchedEx / cublasGemmBatchedEx
// * Supports mixed precision, flexible compute types, algorithm selection
// - Legacy GPUs (Kepler/Maxwell/Pascal): Use cublasSgemmStridedBatched / cublasSgemmBatched
// * The *Ex variants have architectural requirements beyond algorithm selection
// * Even with CUBLAS_GEMM_DEFAULT, *Ex functions fail with CUBLAS_STATUS_ARCH_MISMATCH
// * Legacy functions only support FP32, but work reliably on older architectures
//
// Debug: Set GGML_CUDA_DEBUG=1 environment variable to enable debug logging
//
template<ggml_type src0_type>
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
using traits = batched_mul_mat_traits<src0_type>;
@@ -2072,6 +2099,24 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
beta = &beta_f32;
}
// Select GEMM algorithm based on compute capability
// TENSOR_OP (value=99) requires Tensor Cores (compute capability >= 7.0)
// For older GPUs like Tesla K80 (cc=3.7), use CUBLAS_GEMM_DEFAULT (value=-1)
cublasGemmAlgo_t gemm_algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA) {
// Fallback for compute capability < 7.0 (pre-Volta, no Tensor Cores)
// This includes: Kepler (3.x), Maxwell (5.x), Pascal (6.x)
gemm_algo = CUBLAS_GEMM_DEFAULT;
}
// Debug logging for CUBLAS configuration (enable with GGML_CUDA_DEBUG=1)
static bool debug_enabled = getenv("GGML_CUDA_DEBUG") != nullptr;
if (debug_enabled) {
fprintf(stderr, "DEBUG batched_cublas: device=%d cc=%d is_nvidia=%d volta=%d gemm_algo=%d cu_compute_type=%d cu_data_type=%d cu_data_type_a=%d cu_data_type_b=%d\n",
id, cc, GGML_CUDA_CC_IS_NVIDIA(cc), GGML_CUDA_CC_VOLTA, (int)gemm_algo,
(int)cu_compute_type, (int)cu_data_type, (int)cu_data_type_a, (int)cu_data_type_b);
}
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
@@ -2085,16 +2130,29 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
const int64_t smb = ne12 == 1 ? s13 : s12;
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
// use cublasGemmStridedBatchedEx
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
src1_ptr, cu_data_type_b, s11, smb, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// For pre-Volta GPUs (compute < 7.0), use legacy type-specific functions instead of *Ex variants
// The *Ex functions have architecture requirements beyond just the algorithm parameter
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA && cu_data_type == CUDA_R_32F && cu_compute_type == CUBLAS_COMPUTE_32F) {
// Use legacy cublasSgemmStridedBatched for Kepler/Maxwell/Pascal with FP32
CUBLAS_CHECK(
cublasSgemmStridedBatched(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
(const float *)alpha, (const float *)src0_ptr, nb01/nb00, sma,
(const float *)src1_ptr, s11, smb,
(const float *)beta, (float *)dst_t, ne0, ne1*ne0,
ne12*ne13));
} else {
// Use cublasGemmStridedBatchedEx for Volta+ or non-FP32 data types
CUBLAS_CHECK(
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
src1_ptr, cu_data_type_b, s11, smb, // strideB
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
ne12*ne13,
cu_compute_type,
gemm_algo));
}
} else {
// use cublasGemmBatchedEx
const int64_t ne23 = ne12*ne13;
@@ -2118,15 +2176,28 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
CUDA_CHECK(cudaGetLastError());
CUBLAS_CHECK(
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
ne23,
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// For pre-Volta GPUs (compute < 7.0), use legacy type-specific functions instead of *Ex variants
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA && cu_data_type == CUDA_R_32F && cu_compute_type == CUBLAS_COMPUTE_32F) {
// Use legacy cublasSgemmBatched for Kepler/Maxwell/Pascal with FP32
CUBLAS_CHECK(
cublasSgemmBatched(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
(const float *)alpha, (const float **) (ptrs_src.get() + 0*ne23), nb01/nb00,
(const float **) (ptrs_src.get() + 1*ne23), s11,
(const float *)beta, ( float **) (ptrs_dst.get() + 0*ne23), ne0,
ne23));
} else {
// Use cublasGemmBatchedEx for Volta+ or non-FP32 data types
CUBLAS_CHECK(
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
ne01, ne11, ne10,
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
ne23,
cu_compute_type,
gemm_algo));
}
}
// Convert output back to F32 if needed

View File

@@ -103,9 +103,30 @@ type NewSequenceParams struct {
var errorInputTooLong = errors.New("the input length exceeds the context length")
// NewSequence creates a new inference sequence
//
// This prepares everything needed for text generation:
// 1. Tokenize prompt into input tokens
// 2. Process vision embeddings (if images provided)
// 3. Handle context window truncation (if prompt too long)
// 4. Create sampling context (controls temperature, top_p, etc.)
// 5. Prepare for KV cache management
//
// Parameters:
// - prompt: Text to generate from (already formatted by chat template)
// - images: Image data for multimodal models (empty for text-only)
// - params: Generation parameters (num_predict, stop sequences, sampling settings)
//
// Returns:
// - *Sequence: Ready-to-use sequence for inference loop
// - error: If prompt too long or other errors
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
// Wait for model to finish loading (blocks until ready)
s.ready.Wait()
// Tokenize prompt and process images into input sequence
// For text-only: converts string → token IDs
// For multimodal: also runs vision projector to get image embeddings
inputs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
@@ -113,6 +134,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
return nil, errors.New("no input provided")
}
// numKeep controls how many tokens to keep when shifting context window
// If context fills up, we discard middle tokens but keep first numKeep tokens
if params.numKeep < 0 {
params.numKeep = len(inputs)
}
@@ -122,14 +145,19 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
// Ensure that at least 1 input can be discarded during shift
// Otherwise we can't make room for new generation
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
// Check if prompt exceeds context window (num_ctx)
// Example: gemma3 default num_ctx = 8192 tokens
if len(inputs) > s.cache.numCtx {
discard := len(inputs) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
// Truncate prompt: keep first numKeep tokens + last tokens
// This preserves system prompt and instruction while removing middle content
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
@@ -137,12 +165,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
inputs = newInputs
}
// Create sampling context for token selection
// This manages:
// - Logit biasing (temperature, top_p, top_k, min_p)
// - Repetition penalty
// - Grammar constraints (for JSON output)
// - Seed for reproducibility
var sc *llama.SamplingContext
if params.samplingParams != nil {
sc, err = llama.NewSamplingContext(s.model, *params.samplingParams)
if err != nil {
return nil, err
}
// Prime the sampling context with prompt tokens
// This is needed for repetition penalty calculation
for _, input := range inputs {
if input.embed == nil {
sc.Accept(input.token, false)
@@ -573,7 +609,27 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
// completion is the HTTP handler for POST /completion endpoint
//
// This is where the runner subprocess receives requests from the main Ollama server.
// It handles the entire text generation pipeline:
// 1. Parse CompletionRequest from JSON body
// 2. Create new Sequence (tokenize, setup sampling)
// 3. Add sequence to inference queue (managed by background goroutine)
// 4. Stream generated tokens back via HTTP response
//
// The actual inference happens in a separate goroutine (s.run())
// which continuously processes sequences from s.seqs queue.
//
// Flow:
// - Main server calls: POST http://127.0.0.1:<port>/completion
// - This handler receives request
// - Creates Sequence and adds to s.seqs queue
// - Background goroutine (s.run()) picks up sequence
// - Tokens generated are sent to seq.responses channel
// - This handler reads from channel and streams back to main server
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
// Parse CompletionRequest sent by llm/server.go Completion()
var req llm.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "Bad request", http.StatusBadRequest)
@@ -585,7 +641,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options = &opts
}
// Set the headers to indicate streaming
// Set headers for HTTP streaming (Server-Sent Events style)
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Transfer-Encoding", "chunked")
@@ -595,7 +651,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Extract options from the CompletionRequest
// Convert API options to llama.cpp sampling parameters
// These control how tokens are selected from logits:
// - TopK: only consider top K tokens
// - TopP (nucleus sampling): consider tokens until cumulative prob reaches P
// - MinP: minimum probability threshold
// - Temperature: controls randomness (lower = more deterministic)
// - RepeatPenalty: penalize repeating tokens
// - Grammar: GBNF grammar for constrained generation (e.g., JSON output)
samplingParams := llama.SamplingParams{
TopK: req.Options.TopK,
TopP: req.Options.TopP,
@@ -610,14 +673,15 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
Grammar: req.Grammar,
}
// Create new sequence (tokenize, setup sampling context)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
stop: req.Options.Stop,
numKeep: req.Options.NumKeep,
numPredict: req.Options.NumPredict, // Max tokens to generate
stop: req.Options.Stop, // Stop sequences
numKeep: req.Options.NumKeep, // Tokens to keep when shifting
samplingParams: &samplingParams,
embedding: false,
shift: req.Shift,
truncate: req.Truncate,
embedding: false, // Generate text, not embeddings
shift: req.Shift, // Allow context window shift
truncate: req.Truncate, // Allow prompt truncation
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
@@ -628,7 +692,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Ensure there is a place to put the sequence, released when removed from s.seqs
// Acquire semaphore slot (limits concurrent requests to numParallel)
// Each parallel slot requires separate KV cache allocation
// Example: numParallel=4 means up to 4 requests can run simultaneously
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -638,10 +704,16 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Add sequence to the inference queue
// s.seqs is an array of [numParallel]*Sequence
// Background goroutine (s.run()) processes all non-nil sequences
s.mu.Lock()
found := false
for i, sq := range s.seqs {
if sq == nil {
// Allocate KV cache slot for this sequence
// The cache reuses previous computations if prompt prefixes match
// This speeds up multi-turn conversations significantly
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
@@ -650,8 +722,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// Add to queue and wake up background inference goroutine
s.seqs[i] = seq
s.cond.Signal()
s.cond.Signal() // Notify s.run() that new sequence is ready
found = true
break
}
@@ -664,13 +737,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
// *** STREAM TOKENS BACK TO MAIN SERVER ***
// The background goroutine (s.run()) will:
// 1. Process sequences in batches
// 2. Call context.Decode() for each batch (GPU/CPU inference)
// 3. Sample next token using sampling context
// 4. Send token text to seq.responses channel
//
// This handler reads from seq.responses and streams to HTTP client
for {
select {
case <-r.Context().Done():
// Client disconnected, signal background goroutine to stop
close(seq.quit)
return
case content, ok := <-seq.responses:
if ok {
// Stream intermediate token
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
}); err != nil {
@@ -679,8 +762,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
flusher.Flush()
flusher.Flush() // Ensure token is sent immediately
} else {
// Generation complete, send final response with metrics
// Metrics include:
// - PromptEvalCount: tokens in prompt
// - PromptEvalDuration: time to process prompt (Tesla K80)
// - EvalCount: tokens generated
// - EvalDuration: time to generate tokens (Tesla K80)
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Done: true,
DoneReason: seq.doneReason,

View File

@@ -171,8 +171,29 @@ func signinURL() (string, error) {
return fmt.Sprintf(signinURLStr, url.PathEscape(h), encKey), nil
}
// GenerateHandler is the HTTP handler for POST /api/generate endpoint
// This is the main server-side entry point for model inference requests
//
// Flow:
// 1. Parse and validate the GenerateRequest JSON body
// 2. Load model metadata (config, template, system prompt)
// 3. Schedule/acquire a runner instance from the scheduler
// 4. Apply chat template to format the prompt
// 5. Call runner's Completion() method for actual inference
// 6. Stream responses back to client as Server-Sent Events (SSE)
//
// Request structure (api.GenerateRequest):
// - Model: string (e.g., "gemma3", "llama3")
// - Prompt: string (user input)
// - Images: []string (base64 for multimodal models)
// - System: string (system prompt override)
// - Template: string (template override)
// - Options: map[string]any (temperature, top_p, num_gpu, etc.)
// - KeepAlive: Duration (how long to keep model in memory)
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
// Parse JSON request body into GenerateRequest struct
var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -182,6 +203,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Parse and validate model name (format: [registry/][namespace/]model[:tag])
name := model.ParseName(req.Model)
if !name.IsValid() {
// Ideally this is "invalid model name" but we're keeping with
@@ -190,6 +212,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Resolve the actual model name (handles aliases and version resolution)
// We cannot currently consolidate this into GetModel because all we'll
// induce infinite recursion given the current code structure.
name, err := getExistingName(name)
@@ -198,6 +221,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Load model metadata from disk (server/images.go:320)
// This reads the GGUF file headers and Modelfile to extract:
// - Model config (architecture, quantization, context size)
// - Chat template (how to format messages for this model)
// - System prompt (default instructions)
// - Model capabilities (vision, tools, thinking)
// - Model options (temperature defaults, etc.)
m, err := GetModel(name.String())
if err != nil {
switch {
@@ -357,6 +387,23 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}
// Schedule a runner instance from the scheduler (server/sched.go:84)
// This is THE critical step that loads the model into memory
//
// The scheduler will:
// 1. Check if model is already loaded in memory (cache hit)
// 2. If not loaded:
// a. Analyze GGML file to determine layer distribution (CPU vs GPU)
// b. Spawn a runner subprocess: "ollama runner --model <path> --port <port>"
// c. Load GGUF weights into memory via llama.cpp
// d. Allocate KV cache on GPU (if using GPU)
// e. Initialize inference context
// 3. Return (LlamaServer, Model, Options) tuple
//
// Parameters:
// - caps: required capabilities (completion, vision, thinking, etc.)
// - req.Options: user-provided options (num_gpu, temperature, etc.)
// - req.KeepAlive: how long to keep model loaded after request completes
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -368,7 +415,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
checkpointLoaded := time.Now()
// load the model
// If prompt is empty, this is just a model load request (warmup)
// Return immediately without running inference
if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
@@ -384,13 +432,20 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
// Prepare image data for multimodal models (if any)
images := make([]llm.ImageData, len(req.Images))
for i := range req.Images {
images[i] = llm.ImageData{ID: i, Data: req.Images[i]}
}
// Apply chat template to format the prompt
// Chat templates convert structured messages into model-specific format
// Example for Gemma3:
// Input: [{"role": "user", "content": "Hello"}]
// Output: "<bos><start_of_turn>user\nHello<end_of_turn>\n<start_of_turn>model\n"
prompt := req.Prompt
if !req.Raw {
// Get template from model config (or use override from request)
tmpl := m.Template
if req.Template != "" {
tmpl, err = template.Parse(req.Template)
@@ -402,20 +457,25 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var values template.Values
if req.Suffix != "" {
// Fill-in-the-middle mode (for code completion)
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
// Normal chat mode: build message list
var msgs []api.Message
// Add system prompt (instructions for the model)
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
// Add conversation history (for multi-turn chats)
if req.Context == nil {
msgs = append(msgs, m.Messages...)
}
// Add current user message with any images
userMsg := api.Message{Role: "user", Content: req.Prompt}
for _, i := range images {
userMsg.Images = append(userMsg.Images, i.Data)
@@ -495,11 +555,31 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}
// Create channel for streaming responses from inference engine
ch := make(chan any)
go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch)
// *** THIS IS THE CORE INFERENCE CALL ***
// r.Completion() bridges Go → runner subprocess → C/C++ → CUDA
//
// Flow:
// 1. This sends HTTP POST to runner subprocess at http://127.0.0.1:<port>/completion
// 2. Runner subprocess (llamarunner/runner.go) receives request
// 3. Runner tokenizes prompt and creates inference batch
// 4. Runner calls context.Decode() repeatedly (llama/llama.go CGO binding)
// 5. context.Decode() calls C.llama_decode() from llama.cpp
// 6. llama_decode() executes CUDA kernels on GPU (Tesla K80 compute 3.7)
// 7. Each generated token is sampled and streamed back via callback
//
// CompletionRequest fields:
// - Prompt: formatted text (after template application)
// - Images: base64 image data (for vision models)
// - Options: temperature, top_p, top_k, num_gpu, etc.
// - Shift: whether to shift context window when full
// - Truncate: whether to truncate prompt if too long
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
@@ -508,6 +588,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
}, func(cr llm.CompletionResponse) {
// Callback function called for each generated token (streaming)
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),

View File

@@ -380,12 +380,31 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
}()
}
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
// load creates a new model based on req and loads it into memory
//
// This is THE critical function that:
// 1. Spawns the runner subprocess (ollama runner --model <path>)
// 2. Loads GGUF weights into memory via llama.cpp
// 3. Distributes model layers across CPU/GPU(s) based on available VRAM
// 4. Allocates KV cache on GPU (for Tesla K80: compute 3.7)
// 5. Initializes inference context with threading and batch parameters
//
// Parameters:
// - req: LlmRequest containing model path, options, capabilities
// - f: GGML metadata (parsed from GGUF file headers)
// - systemInfo: CPU/RAM information
// - gpus: List of available GPUs (e.g., Tesla K80)
// - requireFull: If true, model must fit entirely on GPU(s)
//
// Returns:
// - bool: true if scheduler needs to evict other models to make room
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
// NumParallel controls how many requests can be processed simultaneously
// Each parallel slot requires additional KV cache memory
numParallel := max(int(envconfig.NumParallel()), 1)
// Embedding models should always be loaded with parallel=1
// (they don't benefit from parallel processing like generation does)
if req.model.CheckCapabilities(model.CapabilityCompletion) != nil {
numParallel = 1
}
@@ -405,8 +424,27 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
s.loadedMu.Lock()
llama := s.activeLoading
// Create new llama server instance if not already loading
if llama == nil {
var err error
// *** SPAWN RUNNER SUBPROCESS ***
// s.newServerFn points to NewLlamaServer() in llm/server.go:148
//
// This function:
// 1. Calculates memory requirements from GGML metadata
// 2. Determines GPU layer distribution (num_gpu_layers)
// 3. Spawns subprocess: exec.Command("ollama", "runner", "--model", modelPath, "--port", port)
// 4. Subprocess starts HTTP server listening on local port
// 5. Returns LlamaServer interface for IPC communication
//
// Parameters:
// - systemInfo: CPU/RAM stats
// - gpus: Available GPUs (Tesla K80 with 12GB VRAM, compute 3.7)
// - ModelPath: Path to GGUF file (e.g., ~/.ollama/models/blobs/sha256-abc123...)
// - AdapterPaths: LoRA adapters (if any)
// - ProjectorPaths: Vision projectors for multimodal (if any)
// - opts: User options (num_gpu, num_thread, num_ctx, etc.)
// - numParallel: How many parallel inference slots to allocate
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
if err != nil {
// some older models are not compatible with newer versions of llama.cpp
@@ -423,6 +461,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
s.activeLoading = llama
} else {
// Reusing existing server (e.g., after eviction attempt failed)
if s.activeLoading.ModelPath() != req.model.ModelPath {
panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath))
}
@@ -430,16 +469,28 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo
s.loadedMu.Unlock()
// *** LOAD MODEL WEIGHTS INTO MEMORY ***
// llama.Load() triggers the runner subprocess to:
// 1. Call llama_model_load_from_file() via CGO (llama/llama.go)
// 2. Read GGUF file and mmap() weights into memory
// 3. Distribute layers across CPU/GPU based on num_gpu_layers
// 4. Allocate KV cache on GPU (if using GPU)
// 5. Compile CUDA kernels for Tesla K80 (compute 3.7, via PTX JIT)
//
// Returns:
// - gpuIDs: List of GPU device IDs where model layers were loaded
// - err: Error if model doesn't fit or loading fails
gpuIDs, err := llama.Load(req.ctx, systemInfo, gpus, requireFull)
if err != nil {
if errors.Is(err, llm.ErrLoadRequiredFull) {
if !requireFull {
// No other models loaded, yet we still don't fit, so report an error
// Model doesn't fit fully on GPU, need to evict other models
slog.Info("model is too large for system memory", "requireFull", requireFull)
s.activeLoading.Close()
s.activeLoading = nil
req.errCh <- err
}
// Signal scheduler to evict models and retry
return true
}