mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-10 07:46:59 +00:00
Fix Tesla K80 CUBLAS compatibility with two-tier fallback strategy
This commit implements comprehensive Tesla K80 (Kepler, compute 3.7) compatibility for batched matrix multiplication operations. **Problem:** Modern CUBLAS functions fail on Tesla K80 with CUBLAS_STATUS_ARCH_MISMATCH: 1. CUBLAS_GEMM_DEFAULT_TENSOR_OP requires Tensor Cores (Volta+ only) 2. cublasGemmStridedBatchedEx/cublasGemmBatchedEx have architectural requirements beyond algorithm selection **Solution - Two-Tier Fallback:** Tier 1: Algorithm Selection - Volta+ (cc >= 7.0): CUBLAS_GEMM_DEFAULT_TENSOR_OP - Pre-Volta (cc < 7.0): CUBLAS_GEMM_DEFAULT Tier 2: Function Selection - Volta+ or non-FP32: Use *Ex variants (flexible precision) - Kepler/Maxwell/Pascal with FP32: Use legacy type-specific functions (cublasSgemmStridedBatched, cublasSgemmBatched) **Changes:** CUDA Implementation: - ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu * ggml_cuda_op_mul_mat_cublas: Algorithm selection for non-batched ops * ggml_cuda_mul_mat_batched_cublas_impl: Two-tier fallback for batched ops * Added GGML_CUDA_DEBUG environment variable for conditional debug logging * Comprehensive function documentation explaining fallback strategy Documentation: - CLAUDE.md * Added Tesla K80 CUBLAS Compatibility section * Documented GGML_CUDA_DEBUG environment variable * Enhanced "Running Ollama" section with log capture examples * Updated Files Modified list Code Comments: - Added detailed comments throughout CUDA code explaining: * Why TENSOR_OP fails on pre-Volta GPUs * Why *Ex functions require architectural support * Compute capability checks and fallback logic * Debug logging usage **Testing:** All models verified working on Tesla K80: - ✅ gemma3:4b - ✅ gpt-oss - ✅ deepseek-r1 Debug flag tested in both enabled and disabled states. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
115
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
115
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -1392,6 +1392,14 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
|
||||
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
|
||||
|
||||
// TENSOR_OP Support: Requires Tensor Cores (Volta+ on NVIDIA, cc >= 7.0)
|
||||
// Tesla K80 (cc=3.7), GTX 1080 (cc=6.1), etc. do NOT have Tensor Cores
|
||||
// Using CUBLAS_GEMM_DEFAULT_TENSOR_OP on pre-Volta causes CUBLAS_STATUS_ARCH_MISMATCH
|
||||
cublasGemmAlgo_t gemm_algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA) {
|
||||
gemm_algo = CUBLAS_GEMM_DEFAULT; // Fallback for pre-Volta GPUs
|
||||
}
|
||||
|
||||
// This path tries to use BF16 with tensor cores via cublasGemmEx
|
||||
// Will fail on pre-Ampere NVIDIA GPUs (compute < 8.0)
|
||||
if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
@@ -1418,7 +1426,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
src1_ptr, CUDA_R_16BF, ne10,
|
||||
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
gemm_algo));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
|
||||
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
@@ -1456,7 +1464,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta, dst_dd_i, CUDA_R_32F, ldc,
|
||||
CUBLAS_COMPUTE_32F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
gemm_algo));
|
||||
} else {
|
||||
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
|
||||
|
||||
@@ -1470,7 +1478,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
||||
src1_ptr, CUDA_R_16F, ne10,
|
||||
&beta_f16, dst_f16.get(), CUDA_R_16F, ldc,
|
||||
CUBLAS_COMPUTE_16F,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
gemm_algo));
|
||||
|
||||
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
|
||||
to_fp32_cuda(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
||||
@@ -1975,6 +1983,25 @@ struct batched_mul_mat_traits<GGML_TYPE_F16> {
|
||||
static inline auto get_nc_converter(ggml_type src_type) { return ggml_get_to_fp16_nc_cuda(src_type); }
|
||||
};
|
||||
|
||||
// Batched matrix multiplication using CUBLAS with support for legacy GPUs
|
||||
//
|
||||
// Tesla K80 (Kepler, compute 3.7) Compatibility:
|
||||
// This function implements a two-tier fallback strategy for older GPUs:
|
||||
//
|
||||
// 1. GEMM Algorithm Selection:
|
||||
// - Volta+ (cc >= 7.0): Use CUBLAS_GEMM_DEFAULT_TENSOR_OP (requires Tensor Cores)
|
||||
// - Pre-Volta (cc < 7.0): Use CUBLAS_GEMM_DEFAULT (standard algorithm)
|
||||
//
|
||||
// 2. CUBLAS Function Selection:
|
||||
// - Modern GPUs (Volta+): Use cublasGemmStridedBatchedEx / cublasGemmBatchedEx
|
||||
// * Supports mixed precision, flexible compute types, algorithm selection
|
||||
// - Legacy GPUs (Kepler/Maxwell/Pascal): Use cublasSgemmStridedBatched / cublasSgemmBatched
|
||||
// * The *Ex variants have architectural requirements beyond algorithm selection
|
||||
// * Even with CUBLAS_GEMM_DEFAULT, *Ex functions fail with CUBLAS_STATUS_ARCH_MISMATCH
|
||||
// * Legacy functions only support FP32, but work reliably on older architectures
|
||||
//
|
||||
// Debug: Set GGML_CUDA_DEBUG=1 environment variable to enable debug logging
|
||||
//
|
||||
template<ggml_type src0_type>
|
||||
static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
using traits = batched_mul_mat_traits<src0_type>;
|
||||
@@ -2072,6 +2099,24 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
||||
beta = &beta_f32;
|
||||
}
|
||||
|
||||
// Select GEMM algorithm based on compute capability
|
||||
// TENSOR_OP (value=99) requires Tensor Cores (compute capability >= 7.0)
|
||||
// For older GPUs like Tesla K80 (cc=3.7), use CUBLAS_GEMM_DEFAULT (value=-1)
|
||||
cublasGemmAlgo_t gemm_algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA) {
|
||||
// Fallback for compute capability < 7.0 (pre-Volta, no Tensor Cores)
|
||||
// This includes: Kepler (3.x), Maxwell (5.x), Pascal (6.x)
|
||||
gemm_algo = CUBLAS_GEMM_DEFAULT;
|
||||
}
|
||||
|
||||
// Debug logging for CUBLAS configuration (enable with GGML_CUDA_DEBUG=1)
|
||||
static bool debug_enabled = getenv("GGML_CUDA_DEBUG") != nullptr;
|
||||
if (debug_enabled) {
|
||||
fprintf(stderr, "DEBUG batched_cublas: device=%d cc=%d is_nvidia=%d volta=%d gemm_algo=%d cu_compute_type=%d cu_data_type=%d cu_data_type_a=%d cu_data_type_b=%d\n",
|
||||
id, cc, GGML_CUDA_CC_IS_NVIDIA(cc), GGML_CUDA_CC_VOLTA, (int)gemm_algo,
|
||||
(int)cu_compute_type, (int)cu_data_type, (int)cu_data_type_a, (int)cu_data_type_b);
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne12 % ne02 == 0);
|
||||
GGML_ASSERT(ne13 % ne03 == 0);
|
||||
|
||||
@@ -2085,16 +2130,29 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
||||
const int64_t smb = ne12 == 1 ? s13 : s12;
|
||||
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
|
||||
src1_ptr, cu_data_type_b, s11, smb, // strideB
|
||||
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
||||
ne12*ne13,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
// For pre-Volta GPUs (compute < 7.0), use legacy type-specific functions instead of *Ex variants
|
||||
// The *Ex functions have architecture requirements beyond just the algorithm parameter
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA && cu_data_type == CUDA_R_32F && cu_compute_type == CUBLAS_COMPUTE_32F) {
|
||||
// Use legacy cublasSgemmStridedBatched for Kepler/Maxwell/Pascal with FP32
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemmStridedBatched(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
(const float *)alpha, (const float *)src0_ptr, nb01/nb00, sma,
|
||||
(const float *)src1_ptr, s11, smb,
|
||||
(const float *)beta, (float *)dst_t, ne0, ne1*ne0,
|
||||
ne12*ne13));
|
||||
} else {
|
||||
// Use cublasGemmStridedBatchedEx for Volta+ or non-FP32 data types
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, src0_ptr, cu_data_type_a, nb01/nb00, sma, // strideA
|
||||
src1_ptr, cu_data_type_b, s11, smb, // strideB
|
||||
beta, dst_t, cu_data_type, ne0, ne1*ne0, // strideC
|
||||
ne12*ne13,
|
||||
cu_compute_type,
|
||||
gemm_algo));
|
||||
}
|
||||
} else {
|
||||
// use cublasGemmBatchedEx
|
||||
const int64_t ne23 = ne12*ne13;
|
||||
@@ -2118,15 +2176,28 @@ static void ggml_cuda_mul_mat_batched_cublas_impl(ggml_backend_cuda_context & ct
|
||||
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
|
||||
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
|
||||
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
|
||||
ne23,
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
// For pre-Volta GPUs (compute < 7.0), use legacy type-specific functions instead of *Ex variants
|
||||
if (GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA && cu_data_type == CUDA_R_32F && cu_compute_type == CUBLAS_COMPUTE_32F) {
|
||||
// Use legacy cublasSgemmBatched for Kepler/Maxwell/Pascal with FP32
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemmBatched(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
(const float *)alpha, (const float **) (ptrs_src.get() + 0*ne23), nb01/nb00,
|
||||
(const float **) (ptrs_src.get() + 1*ne23), s11,
|
||||
(const float *)beta, ( float **) (ptrs_dst.get() + 0*ne23), ne0,
|
||||
ne23));
|
||||
} else {
|
||||
// Use cublasGemmBatchedEx for Volta+ or non-FP32 data types
|
||||
CUBLAS_CHECK(
|
||||
cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
alpha, (const void **) (ptrs_src.get() + 0*ne23), cu_data_type_a, nb01/nb00,
|
||||
(const void **) (ptrs_src.get() + 1*ne23), cu_data_type_b, s11,
|
||||
beta, ( void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
|
||||
ne23,
|
||||
cu_compute_type,
|
||||
gemm_algo));
|
||||
}
|
||||
}
|
||||
|
||||
// Convert output back to F32 if needed
|
||||
|
||||
Reference in New Issue
Block a user