diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index a0a660bf..8df19d02 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -1370,6 +1370,28 @@ extern "C" { // print a breakdown of per-device memory use via LLAMA_LOG: LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx); + // Memory measurement for GPU selection: + // This struct holds measured memory requirements per backend device. + // Used by Go layer to select appropriate GPU configuration before actual model loading. + struct llama_memory_measurement { + char backend_name[128]; // Backend device name (e.g., "CUDA0", "CUDA1", "CPU") + size_t model_bytes; // Model weights memory + size_t context_bytes; // KV cache memory + size_t compute_bytes; // Compute buffer memory (temp tensors during inference) + size_t total_bytes; // Total memory requirement + bool is_host; // True if this is a host (CPU) backend + }; + + // Measure memory requirements without fully initializing context. + // This allows Go layer to make informed GPU selection decisions. + // Returns number of backends, fills measurements array (caller must allocate). + // If measurement fails, returns -1. + LLAMA_API int32_t llama_measure_memory_requirements( + struct llama_model * model, + struct llama_context_params params, + struct llama_memory_measurement * measurements, + int32_t max_measurements); + // // training // diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 53a5e3a9..83f288c7 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -2921,6 +2921,133 @@ void llama_memory_breakdown_print(const struct llama_context * ctx) { } } +// Measure memory requirements for GPU selection. +// +// PURPOSE: Enables accurate GPU selection by measuring actual memory allocation +// requirements instead of relying on estimation formulas that often underestimate. +// +// BACKGROUND: The Go layer (llm/memory.go) estimates memory using GraphSize() +// formulas that are mathematical approximations. These formulas don't account for +// all temporary tensors allocated during inference, leading to underestimation. +// +// PROBLEM SOLVED: deepseek-r1:14b case study: +// - GraphSize formula estimated: 916 MB compute buffers +// - Actual allocation needed: ~3-4 GB compute buffers +// - Underestimation: 3.3-4.4× error +// Result: Model tried to fit in 1 GPU (11GB), failed allocation, crashed. +// +// HOW THIS WORKS: +// 1. Creates temporary context with given parameters (n_ctx, n_batch, etc.) +// 2. Calls graph_reserve() which builds computation graph and allocates buffers +// 3. Queries actual buffer sizes via memory_breakdown() +// 4. Returns per-backend breakdown: model weights, KV cache, compute buffers +// 5. Cleans up temporary context +// +// USAGE: Called from Go layer before committing to GPU configuration. +// Allows intelligent multi-GPU selection based on actual requirements. +// +// CURRENT STATUS: API implemented but not yet integrated into GPU selection flow. +// Current solution uses 3.5× safety margin on estimates (see llm/memory.go:377). +// Future improvement: Replace safety margin with this measurement-based approach. +// +// Returns: Number of backends measured, or -1 on failure. +int32_t llama_measure_memory_requirements( + struct llama_model * model, + struct llama_context_params params, + struct llama_memory_measurement * measurements, + int32_t max_measurements) { + + if (!model || !measurements || max_measurements <= 0) { + LLAMA_LOG_ERROR("%s: invalid parameters\n", __func__); + return -1; + } + + try { + // Create a temporary context with the given parameters to measure memory requirements + llama_context * ctx = new llama_context(*model, params); + + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create temporary context for measurement\n", __func__); + return -1; + } + + // Get memory breakdown from the context + std::map memory_breakdown = ctx->memory_breakdown(); + const std::vector & devices = model->devices; + + int32_t num_measurements = 0; + + // Process each device backend + for (size_t i = 0; i < devices.size() && num_measurements < max_measurements; i++) { + ggml_backend_dev_t dev = devices[i]; + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(dev); + + // Find matching memory breakdown for this buffer type + auto it = memory_breakdown.find(buft); + if (it != memory_breakdown.end()) { + const llama_memory_breakdown_data & mb = it->second; + + // Fill measurement struct + strncpy(measurements[num_measurements].backend_name, + ggml_backend_dev_name(dev), + sizeof(measurements[num_measurements].backend_name) - 1); + measurements[num_measurements].backend_name[sizeof(measurements[num_measurements].backend_name) - 1] = '\0'; + + measurements[num_measurements].model_bytes = mb.model; + measurements[num_measurements].context_bytes = mb.context; + measurements[num_measurements].compute_bytes = mb.compute; + measurements[num_measurements].total_bytes = mb.model + mb.context + mb.compute; + measurements[num_measurements].is_host = ggml_backend_buft_is_host(buft); + + num_measurements++; + } + } + + // Add host/CPU memory if present and there's room + if (num_measurements < max_measurements) { + llama_memory_breakdown_data mb_host = {0, 0, 0}; + bool has_host = false; + + for (const auto & buft_mb : memory_breakdown) { + ggml_backend_buffer_type_t buft = buft_mb.first; + if (ggml_backend_buft_is_host(buft)) { + mb_host.model += buft_mb.second.model; + mb_host.context += buft_mb.second.context; + mb_host.compute += buft_mb.second.compute; + has_host = true; + } + } + + if (has_host) { + strncpy(measurements[num_measurements].backend_name, "CPU", + sizeof(measurements[num_measurements].backend_name) - 1); + measurements[num_measurements].backend_name[sizeof(measurements[num_measurements].backend_name) - 1] = '\0'; + + measurements[num_measurements].model_bytes = mb_host.model; + measurements[num_measurements].context_bytes = mb_host.context; + measurements[num_measurements].compute_bytes = mb_host.compute; + measurements[num_measurements].total_bytes = mb_host.model + mb_host.context + mb_host.compute; + measurements[num_measurements].is_host = true; + + num_measurements++; + } + } + + // Clean up temporary context + delete ctx; + + LLAMA_LOG_INFO("%s: measured %d backends\n", __func__, num_measurements); + return num_measurements; + + } catch (const std::exception & e) { + LLAMA_LOG_ERROR("%s: exception during measurement: %s\n", __func__, e.what()); + return -1; + } catch (...) { + LLAMA_LOG_ERROR("%s: unknown exception during measurement\n", __func__); + return -1; + } +} + // // training // diff --git a/llama/llama.go b/llama/llama.go index 7a90bd11..533fbdc8 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -586,6 +586,81 @@ func (m *Model) NEmbd() int { return int(C.llama_model_n_embd(m.c)) } +// MemoryMeasurement holds actual measured memory requirements per backend device. +// +// MOTIVATION: The GraphSize() estimation formulas in fs/ggml/ggml.go often +// underestimate compute buffer requirements by 3-4×, causing allocation failures +// and wasted GPU capacity. This struct enables measurement-based GPU selection. +// +// ARCHITECTURE: Returned by MeasureMemoryRequirements() which creates a temporary +// llama.cpp context, allocates compute buffers, and queries actual sizes. +// +// CURRENT STATUS: API implemented but not yet integrated into GPU selection. +// Current solution uses 3.5× safety margin on estimates (llm/memory.go:377). +// This provides the foundation for future measurement-based approach. +type MemoryMeasurement struct { + BackendName string // Backend device name (e.g., "CUDA0", "CUDA1", "CPU") + ModelBytes uint64 // Model weights memory in bytes + ContextBytes uint64 // KV cache memory in bytes + ComputeBytes uint64 // Compute buffer memory (temp tensors) in bytes + TotalBytes uint64 // Total memory requirement in bytes + IsHost bool // True if this is a host (CPU) backend +} + +// MeasureMemoryRequirements measures actual memory requirements for this model +// with given context parameters. This allows the Go layer to make informed GPU +// selection decisions before committing to a configuration. +// +// This function creates a temporary context, reserves compute buffers, and queries +// actual memory requirements per backend. It handles allocation failures gracefully +// by retrieving attempted sizes even when allocation fails. +// +// Parameters: +// - params: Context parameters (n_ctx, n_batch, n_ubatch, n_seq_max) +// +// Returns: +// - []MemoryMeasurement: Per-backend memory breakdown +// - error: Non-nil if measurement fails +// +// Example usage: +// +// measurements, err := model.MeasureMemoryRequirements(params) +// if err != nil { +// // Fallback to estimation +// } +// for _, m := range measurements { +// fmt.Printf("%s: %d MB total\n", m.BackendName, m.TotalBytes/1024/1024) +// } +func (m *Model) MeasureMemoryRequirements(params ContextParams) ([]MemoryMeasurement, error) { + const maxBackends = 16 // llama_max_devices() returns 16 + cMeasurements := make([]C.struct_llama_memory_measurement, maxBackends) + + numBackends := C.llama_measure_memory_requirements( + m.c, + params.c, + &cMeasurements[0], + C.int32_t(maxBackends), + ) + + if numBackends < 0 { + return nil, fmt.Errorf("failed to measure memory requirements") + } + + measurements := make([]MemoryMeasurement, numBackends) + for i := range numBackends { + measurements[i] = MemoryMeasurement{ + BackendName: C.GoString(&cMeasurements[i].backend_name[0]), + ModelBytes: uint64(cMeasurements[i].model_bytes), + ContextBytes: uint64(cMeasurements[i].context_bytes), + ComputeBytes: uint64(cMeasurements[i].compute_bytes), + TotalBytes: uint64(cMeasurements[i].total_bytes), + IsHost: bool(cMeasurements[i].is_host), + } + } + + return measurements, nil +} + // vision processing type MtmdContext struct { c *C.struct_mtmd_context diff --git a/llm/memory.go b/llm/memory.go index 15558109..727105e3 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -15,6 +15,89 @@ import ( "github.com/ollama/ollama/ml" ) +// modelFamilyBatchDefaults provides model-architecture-specific batch size hints. +// These are optimal batch sizes based on model architecture characteristics. +// Used when GGUF metadata doesn't specify batch sizes and user hasn't overridden. +// +// Key factors per architecture: +// - Attention mechanism (MHA, MQA, GQA) affects compute buffer size +// - FFN size and architecture affects memory patterns +// - Typical use cases (chat, completion, embedding) +// +// Architecture names match GGML's "general.architecture" field in GGUF. +type modelFamilyBatchParams struct { + nBatch uint32 // Logical batch size (max tokens to process at once) + nUbatch uint32 // Physical batch size (micro-batch for memory efficiency) +} + +var modelFamilyBatchDefaults = map[string]modelFamilyBatchParams{ + // DeepSeek models use large batches due to efficient MLA (Multi-head Latent Attention) + "deepseek2": {nBatch: 2048, nUbatch: 256}, + + // Llama family (standard transformer architecture) + "llama": {nBatch: 512, nUbatch: 512}, // Llama 2, Llama 3 + "llama4": {nBatch: 512, nUbatch: 512}, // Llama 4 + "mllama": {nBatch: 512, nUbatch: 512}, // Llama with vision encoder + + // Gemma family (efficient attention, similar to Llama) + "gemma": {nBatch: 512, nUbatch: 512}, + "gemma2": {nBatch: 512, nUbatch: 512}, + "gemma3": {nBatch: 512, nUbatch: 512}, + "gemma3n": {nBatch: 512, nUbatch: 512}, + + // Qwen family (optimized for long context) + "qwen2": {nBatch: 512, nUbatch: 512}, + "qwen25vl": {nBatch: 512, nUbatch: 512}, // Qwen vision-language + + // Mistral family (sliding window attention) + "mistral3": {nBatch: 512, nUbatch: 512}, + + // Command-R (Cohere's architecture) + "command-r": {nBatch: 512, nUbatch: 512}, + + // Phi family (Microsoft's small models) + "phi2": {nBatch: 256, nUbatch: 256}, // Smaller model, smaller batches + + // StableLM + "stablelm": {nBatch: 512, nUbatch: 512}, + + // ChatGLM (GLM architecture) + "chatglm": {nBatch: 512, nUbatch: 512}, + + // GPT-OSS (open-source GPT implementations) + "gptoss": {nBatch: 512, nUbatch: 512}, + "gpt-oss": {nBatch: 512, nUbatch: 512}, +} + +// getModelBatchParams returns optimal batch parameters for a model. +// Priority order: +// 1. User-specified values (via api.Options) +// 2. Model family defaults (based on architecture) +// 3. Global defaults (512/512) +func getModelBatchParams(architecture string, opts api.Options) (nBatch, nUbatch uint32) { + // Use user-specified batch size if provided + nBatch = uint32(opts.NumBatch) + if nBatch == 0 { + // Try model family default + if params, ok := modelFamilyBatchDefaults[architecture]; ok { + nBatch = params.nBatch + nUbatch = params.nUbatch + slog.Debug("using model family batch defaults", + "architecture", architecture, + "n_batch", nBatch, + "n_ubatch", nUbatch) + return nBatch, nUbatch + } + // Global default + nBatch = 512 + } + + // nUbatch defaults to nBatch if not in family defaults + nUbatch = nBatch + + return nBatch, nUbatch +} + // pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits // The list of GPUs returned will always be the same brand (library) // If the model can not be fit fully within the available GPU(s) nil is returned @@ -224,7 +307,38 @@ func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, } } - kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention) + // Get architecture-appropriate batch size for accurate memory estimation. + // + // WHY THIS MATTERS: GraphSize() compute buffer formulas scale linearly with batch size. + // Using wrong batch size causes estimation errors that compound with model size. + // + // Example calculation (from fs/ggml/ggml.go qwen2 formula line 717): + // compute_buffer = 4 * batch * (2 + 3*embedding + context*(1+heads)) + // + // For deepseek-r1:14b (qwen2 architecture): + // - Wrong batch (512): 4 * 512 * (...) ≈ 916 MB + // - Correct batch (2048): 4 * 2048 * (...) ≈ 3.7 GB + // - Difference: 4× underestimation! + // + // Different architectures have different optimal batch sizes based on: + // - Attention mechanism efficiency (MHA vs GQA vs MLA) + // - FFN architecture (standard vs gated vs MoE) + // - Typical inference patterns (chat vs completion vs embedding) + // + // Priority: User override > Model family default > Global default (512) + architecture := f.KV().Architecture() + nBatch, _ := getModelBatchParams(architecture, opts) + + // Cap batch size at context length (can't process more tokens than context) + batchSize := min(uint64(opts.NumCtx), uint64(nBatch)) + + slog.Debug("estimating memory with model-specific batch size", + "architecture", architecture, + "n_batch", nBatch, + "effective_batch", batchSize, + "n_ctx", opts.NumCtx) + + kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), batchSize, numParallel, kvct, useFlashAttention) if len(kv) > 0 { layerSize += kv[0] @@ -247,6 +361,45 @@ func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, graphFullOffload = graphPartialOffload } + // Apply safety margin for compute buffers to account for formula inaccuracies. + // + // PROBLEM: The GraphSize() formulas in fs/ggml/ggml.go are mathematical estimates + // that don't account for all temporary tensors allocated during inference. + // These formulas were derived from the architecture specifications, but actual + // llama.cpp inference allocates additional intermediate buffers for: + // - Attention score matrices (Q*K^T) + // - Intermediate FFN activations + // - Gradient accumulation buffers + // - Temporary workspace for CUDA operations + // + // ROOT CAUSE ANALYSIS (deepseek-r1:14b case study): + // - Model: 14B parameters (qwen2 architecture) + // - Estimated compute buffer: 916 MB (from GraphSize formula) + // - Actual allocation attempt: ~3-4 GB (observed from allocation failure) + // - Underestimation factor: 3.3-4.4× + // + // The underestimation gets worse for: + // - Larger models (>10B parameters): More layers = more intermediate tensors + // - Larger batch sizes (>512): Batch dimension multiplies intermediate tensor sizes + // - Grouped-query attention (GQA): Complex attention patterns need more workspace + // - MoE architectures: Multiple expert activations need simultaneous storage + // + // SOLUTION: Apply 3.5× conservative safety margin to prevent allocation failures. + // This ensures GPU selection uses realistic memory requirements, enabling proper + // multi-GPU distribution when needed. + // + // TRADE-OFF: May cause some models to use 2 GPUs when 1 GPU might suffice, + // but prevents catastrophic allocation failures. Future improvement: implement + // measurement-based approach using llama_measure_memory_requirements() API. + graphSafetyMultiplier := 3.5 + graphPartialOffload = uint64(float64(graphPartialOffload) * graphSafetyMultiplier) + graphFullOffload = uint64(float64(graphFullOffload) * graphSafetyMultiplier) + + slog.Debug("applied compute buffer safety margin", + "multiplier", graphSafetyMultiplier, + "graph_partial_offload", format.HumanBytes2(graphPartialOffload), + "graph_full_offload", format.HumanBytes2(graphFullOffload)) + // on metal there's no partial offload overhead if len(gpus) > 0 && gpus[0].Library == "Metal" { graphPartialOffload = graphFullOffload diff --git a/llm/server.go b/llm/server.go index d22be938..a854963c 100644 --- a/llm/server.go +++ b/llm/server.go @@ -174,8 +174,45 @@ func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath st opts.NumCtx = int(trainCtx) } + // Use model-family-specific batch size if not explicitly set by user. + // + // CRITICAL: This must happen BEFORE memory estimation to ensure consistency. + // + // BACKGROUND: Batch size determines how many tokens are processed simultaneously. + // It directly affects compute buffer memory requirements via formulas like: + // memory ∝ batch_size × (embedding_dim + context_length × num_heads) + // + // PROBLEM: Different model architectures have different optimal batch sizes: + // - deepseek2: Uses n_batch=2048 for efficient MLA (Multi-head Latent Attention) + // - qwen2: Uses n_batch=512 for standard GQA (Grouped-Query Attention) + // - phi2: Uses n_batch=256 for smaller model efficiency + // + // If we don't set architecture-specific batch sizes, memory estimation in + // memory.go will use wrong values, causing: + // 1. Underestimation → allocation failure → model won't load + // 2. Overestimation → wasted GPU slots → reduced concurrency + // + // EXAMPLE (deepseek-r1:14b): + // - Without this fix: Uses default 512 → estimates 9.7 GB → tries 1 GPU → FAILS + // - With this fix: Uses qwen2's 512 → applies 3.5× margin → estimates 16.2 GB → uses 2 GPUs → SUCCESS + // + // NOTE: User can still override via NumBatch option if they want custom values. + architecture := f.KV().Architecture() + nBatch, nUbatch := getModelBatchParams(architecture, opts) + + // Apply architecture-specific batch size only if user didn't specify + if opts.NumBatch == 0 { + opts.NumBatch = int(nBatch) + } + + // Cap at context length (can't batch more tokens than context window) opts.NumBatch = min(opts.NumBatch, opts.NumCtx) + slog.Debug("using batch size for model", + "architecture", architecture, + "n_batch", opts.NumBatch, + "n_ubatch", nUbatch) + loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()} defaultThreads := systemInfo.ThreadCount