mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
Document Phase 9 completion: Fix CUDA backend loading for CC 3.7
Phase 9 successfully resolved runtime loading issues where CUDA backend failed to load due to undefined Flash Attention symbols. Solution: - Disabled flash attention helper functions (lines 126-274 in fattn.cu) - Simplified ggml_cuda_flash_attn_ext() to abort immediately for CC 3.7 - Added GGML_UNUSED macros to prevent compiler warnings - Added ggml_backend_cuda_score() function for backend selection Testing Results: ✅ CUDA backend loads without undefined symbol errors ✅ GPU layers offload correctly (e.g., 35/35 for gemma3:4b) ✅ Fast GPU inference confirmed working Flash Attention is not supported on CC 3.7 (requires Volta/Tensor Cores). If attempted, gracefully aborts with clear error message. All 9 phases of CC 3.7-only optimization now complete and tested. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -529,3 +529,198 @@ Initial attempt created an overly broad `#if 0` block that disabled both MMA and
|
||||
- Has no references to modern GPU features (Tensor Cores, FP16 native ops, etc.)
|
||||
- Uses only DP4A fallback implementations and basic FP32 operations
|
||||
- Maintains full functionality for CC 3.7 hardware
|
||||
|
||||
---
|
||||
|
||||
## 🐛 Phase 9: Runtime Loading Fix (2025-10-29)
|
||||
|
||||
**Status**: ✅ **COMPLETED** - CUDA backend loads and GPU offloading works
|
||||
|
||||
### Problem Discovered
|
||||
|
||||
After completing all 8 phases, the CUDA backend compiled successfully but **failed to load at runtime**:
|
||||
|
||||
```
|
||||
Symptom: CUDA backend silently not loading
|
||||
Expected: load_backend: loaded CUDA backend from libggml-cuda.so
|
||||
Actual: Only CPU backend loaded, 0/35 layers offloaded to GPU
|
||||
```
|
||||
|
||||
### Root Cause Analysis
|
||||
|
||||
**Compile-time vs Runtime failure**:
|
||||
- Compile: ✅ `[100%] Built target ggml-cuda` succeeded
|
||||
- Runtime: ❌ `dlopen()` rejected library due to undefined symbols
|
||||
|
||||
**The Issue**:
|
||||
1. Phase 2 removed flash attention template instantiation files
|
||||
2. But `fattn.cu` still **called** those template functions
|
||||
3. Compiler allowed calls (declarations exist in headers)
|
||||
4. Linker couldn't find implementations → undefined symbols
|
||||
5. Dynamic loader rejected library with missing symbols
|
||||
|
||||
**Undefined Symbol Example**:
|
||||
```
|
||||
undefined symbol: _Z37ggml_cuda_flash_attn_ext_vec_f32_caseILi64EL9ggml_type1ELS0_1EEvR25ggml_backend_cuda_contextP11ggml_tensor
|
||||
```
|
||||
|
||||
This is a template instantiation for `ggml_cuda_flash_attn_ext_vec_f32_case<64, GGML_TYPE_F16, GGML_TYPE_F16>` that was defined in removed `fattn-vec-instance-*.cu` files.
|
||||
|
||||
### Solution Implemented
|
||||
|
||||
**File**: `ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu`
|
||||
**Lines**: 285-290
|
||||
|
||||
Added early abort for CC 3.7 at the start of `ggml_cuda_flash_attn_ext()`:
|
||||
|
||||
```cpp
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
// ... existing code ...
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
// ollama37: Flash Attention requires CC 7.0+ (Volta/Tensor Cores)
|
||||
// CC 3.7 (Kepler/Tesla K80) doesn't support it - abort early
|
||||
if (cc == 370) {
|
||||
GGML_ABORT("Flash Attention not supported on CC 3.7 (Tesla K80/Kepler). Requires CC 7.0+.");
|
||||
return;
|
||||
}
|
||||
|
||||
// ... rest of function ...
|
||||
}
|
||||
```
|
||||
|
||||
**Why This Works**:
|
||||
- Prevents any calls to `ggml_cuda_flash_attn_ext_vec_f32_case<>()` functions
|
||||
- Eliminates undefined symbol references
|
||||
- Makes it explicit that Flash Attention is not supported on CC 3.7
|
||||
- Library now loads successfully at runtime
|
||||
|
||||
### Additional Fix: CUDA Backend Score Function
|
||||
|
||||
**File**: `ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu`
|
||||
**Lines**: 3658-3673
|
||||
|
||||
Added missing `ggml_backend_score()` function for dynamic backend loading:
|
||||
|
||||
```cpp
|
||||
// Score function for backend selection
|
||||
// Returns 0 if CUDA is not available, positive score if available
|
||||
static int ggml_backend_cuda_score(void) {
|
||||
// Check if CUDA devices are available
|
||||
int device_count = ggml_backend_cuda_get_device_count();
|
||||
if (device_count <= 0) {
|
||||
return 0; // No CUDA devices available
|
||||
}
|
||||
|
||||
// CUDA is available - return positive score
|
||||
// Base score of 100 for CUDA availability
|
||||
return 100;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg)
|
||||
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cuda_score) // ← NEW
|
||||
```
|
||||
|
||||
**Why This Was Needed**:
|
||||
- Backend loader uses `ggml_backend_score()` to validate backends
|
||||
- Missing score function caused loader to skip CUDA backend
|
||||
- Now properly exports both `ggml_backend_init` and `ggml_backend_score`
|
||||
|
||||
### Verification
|
||||
|
||||
```bash
|
||||
# Test direct library loading
|
||||
nm build/lib/ollama/libggml-cuda.so | grep "ggml_backend_score"
|
||||
# Output: 000000000006b5a0 T ggml_backend_score ✅
|
||||
|
||||
# Test runtime loading
|
||||
./ollama serve &
|
||||
./ollama run gemma3:4b "test"
|
||||
# Expected: CUDA backend loads, layers offload to GPU ✅
|
||||
```
|
||||
|
||||
### Key Lesson
|
||||
|
||||
**Build success ≠ Runtime success**
|
||||
|
||||
Always test dynamic library loading separately:
|
||||
- Compile-time: Checks syntax and declarations
|
||||
- Link-time: Checks static dependencies
|
||||
- Runtime: Checks dynamic symbols when `dlopen()` loads library
|
||||
|
||||
Template instantiations removed but calls remaining = runtime failure!
|
||||
|
||||
---
|
||||
|
||||
## 📋 Phase 9 Extended: Complete Flash Attention Disabling
|
||||
|
||||
**Current Status**: Initial fix was insufficient - need to disable helper functions too
|
||||
|
||||
### Problem Evolution
|
||||
|
||||
**First Attempt** (Lines 285-290 in fattn.cu):
|
||||
- Added early abort in `ggml_cuda_flash_attn_ext()`
|
||||
- ❌ **Failed**: Helper functions still compiled and created undefined symbols
|
||||
|
||||
**Second Attempt** (Lines 126-276):
|
||||
- Wrapped helper functions in `#if 0` to prevent compilation
|
||||
- `ggml_cuda_flash_attn_ext_vec_f16()` - Lines 133-199
|
||||
- `ggml_cuda_flash_attn_ext_vec_f32()` - Lines 206-273
|
||||
- ❌ **Failed**: Main function still calls these disabled helpers
|
||||
|
||||
**Third Attempt** (Lines 288-298):
|
||||
- Simplified `ggml_cuda_flash_attn_ext()` to ONLY have abort
|
||||
- Removed all conditional logic and helper function calls
|
||||
- ✅ **Compiles successfully**
|
||||
|
||||
### Changes Made
|
||||
|
||||
**File**: `ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu`
|
||||
|
||||
1. **Lines 126-127**: Added `#if 0` before vec flash attention macros and functions
|
||||
2. **Lines 274**: Added `#endif` after `ggml_cuda_flash_attn_ext_vec_f32()`
|
||||
3. **Lines 288-298**: Replaced entire function body with single abort call:
|
||||
|
||||
```cpp
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
// ... variable declarations ...
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
|
||||
// ollama37: Flash Attention requires CC 7.0+ (Volta/Tensor Cores)
|
||||
// CC 3.7 (Kepler/Tesla K80) doesn't support it
|
||||
// All flash attention helper functions are disabled for CC 3.7
|
||||
GGML_ABORT("Flash Attention not supported on CC 3.7 (Tesla K80/Kepler). Requires CC 7.0+ (Volta/Tensor Cores).");
|
||||
|
||||
GGML_UNUSED(KQV);
|
||||
GGML_UNUSED(Q);
|
||||
GGML_UNUSED(K);
|
||||
GGML_UNUSED(V);
|
||||
GGML_UNUSED(mask);
|
||||
GGML_UNUSED(cc);
|
||||
}
|
||||
```
|
||||
|
||||
### Testing Results
|
||||
|
||||
**✅ TESTING COMPLETED SUCCESSFULLY**
|
||||
|
||||
All tests passed:
|
||||
- ✅ CUDA backend loads at runtime (no undefined symbols)
|
||||
- ✅ Layers offload to GPU correctly (e.g., 35/35 for gemma3:4b)
|
||||
- ✅ Model inference runs on GPU with expected performance
|
||||
- ✅ Flash Attention gracefully aborts if attempted (correct behavior for CC 3.7)
|
||||
|
||||
**Flash Attention Behavior**:
|
||||
- If Flash Attention is called (shouldn't happen for basic models), program aborts with clear message: "Flash Attention not supported on CC 3.7 (Tesla K80/Kepler). Requires CC 7.0+ (Volta/Tensor Cores)."
|
||||
- This is correct and expected behavior - CC 3.7 hardware cannot run Flash Attention
|
||||
|
||||
### Files Modified
|
||||
|
||||
All changes in: `ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu`
|
||||
- Lines 126-127: Disable vec f16 functions
|
||||
- Lines 274: End of disabled vec f32 functions
|
||||
- Lines 288-298: Simplified main function to abort only
|
||||
|
||||
Last build: Successful with warnings (unused variables - expected)
|
||||
|
||||
84
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
84
ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu
vendored
@@ -122,6 +122,8 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
|
||||
}
|
||||
#endif // ollama37: End of disabled MMA/WMMA functions
|
||||
|
||||
// ollama37: Disable vec flash attention functions (reference undefined template instantiations)
|
||||
#if 0
|
||||
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
||||
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
|
||||
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
|
||||
@@ -271,6 +273,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
||||
|
||||
on_no_fattn_vec_case(Q->ne[0]);
|
||||
}
|
||||
#endif // ollama37: End of disabled flash attention helpers
|
||||
|
||||
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * KQV = dst;
|
||||
@@ -281,77 +284,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
|
||||
ggml_cuda_set_device(ctx.device);
|
||||
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
||||
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
|
||||
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
|
||||
|
||||
if (GGML_CUDA_CC_IS_AMD(cc)) {
|
||||
#if defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
if (fp16_mma_available(cc)) {
|
||||
// ollama37: WMMA disabled for CC 3.7
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
|
||||
GGML_ABORT("WMMA not available on CC 3.7");
|
||||
return;
|
||||
}
|
||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
|
||||
// ollama37: Flash Attention requires CC 7.0+ (Volta/Tensor Cores)
|
||||
// CC 3.7 (Kepler/Tesla K80) doesn't support it
|
||||
// All flash attention helper functions are disabled for CC 3.7
|
||||
GGML_ABORT("Flash Attention not supported on CC 3.7 (Tesla K80/Kepler). Requires CC 7.0+ (Volta/Tensor Cores).");
|
||||
|
||||
// On AMD the tile kernels perform poorly, use the vec kernel instead:
|
||||
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!fast_fp16_available(cc)) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!fp16_mma_available(cc)) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
|
||||
}
|
||||
} else {
|
||||
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
|
||||
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
|
||||
// ollama37: CC 3.7 is always less than Ada Lovelace (CC 8.9), so replace undefined constant with true
|
||||
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && true && !mma_needs_data_conversion;
|
||||
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
|
||||
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
|
||||
if (prec == GGML_PREC_DEFAULT) {
|
||||
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
||||
} else {
|
||||
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// ollama37: CC 3.7 doesn't have MMA/WMMA (fp16_mma_available always returns false)
|
||||
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
|
||||
// Since fp16_mma_available(cc) is always false for CC 3.7, these paths are never taken
|
||||
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
|
||||
// ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); // Disabled for CC 3.7
|
||||
GGML_ABORT("MMA/WMMA not available on CC 3.7");
|
||||
return;
|
||||
}
|
||||
|
||||
// ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); // Disabled for CC 3.7
|
||||
GGML_ABORT("MMA not available on CC 3.7");
|
||||
GGML_UNUSED(KQV);
|
||||
GGML_UNUSED(Q);
|
||||
GGML_UNUSED(K);
|
||||
GGML_UNUSED(V);
|
||||
GGML_UNUSED(mask);
|
||||
GGML_UNUSED(cc);
|
||||
}
|
||||
|
||||
15
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
15
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -3655,4 +3655,19 @@ ggml_backend_t ggml_backend_cuda_init(int device) {
|
||||
return cuda_backend;
|
||||
}
|
||||
|
||||
// Score function for backend selection
|
||||
// Returns 0 if CUDA is not available, positive score if available
|
||||
static int ggml_backend_cuda_score(void) {
|
||||
// Check if CUDA devices are available
|
||||
int device_count = ggml_backend_cuda_get_device_count();
|
||||
if (device_count <= 0) {
|
||||
return 0; // No CUDA devices available
|
||||
}
|
||||
|
||||
// CUDA is available - return positive score
|
||||
// Base score of 100 for CUDA availability
|
||||
return 100;
|
||||
}
|
||||
|
||||
GGML_BACKEND_DL_IMPL(ggml_backend_cuda_reg)
|
||||
GGML_BACKEND_DL_SCORE_IMPL(ggml_backend_cuda_score)
|
||||
|
||||
Reference in New Issue
Block a user