From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Mon, 21 Jul 2025 12:06:13 -0700 Subject: [PATCH] MXFP4 Partial implementation of MXFP4 tensor type --- ggml/include/ggml.h | 2 +- ggml/src/ggml-common.h | 7 + ggml/src/ggml-cpu/ggml-cpu-quants.h | 2 + ggml/src/ggml-cpu/ggml-cpu.c | 5 + ggml/src/ggml-cpu/ops.cpp | 1 + ggml/src/ggml-cpu/vec.cpp | 90 ++++++++ ggml/src/ggml-cpu/vec.h | 2 + ggml/src/ggml-cuda/convert.cu | 80 +++++++ ggml/src/ggml-cuda/ggml-cuda.cu | 16 +- ggml/src/ggml-cuda/mmvmxfp4.cu | 307 ++++++++++++++++++++++++++ ggml/src/ggml-cuda/mmvmxfp4.cuh | 9 + ggml/src/ggml-metal/ggml-metal-impl.h | 3 + ggml/src/ggml-metal/ggml-metal.m | 25 ++- ggml/src/ggml-metal/ggml-metal.metal | 173 ++++++++++++++- ggml/src/ggml-quants.c | 142 +++++++++++- ggml/src/ggml-quants.h | 6 + ggml/src/ggml.c | 13 +- 17 files changed, 868 insertions(+), 15 deletions(-) create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cu create mode 100644 ggml/src/ggml-cuda/mmvmxfp4.cuh diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index e91dedf1..873baa24 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -353,7 +353,7 @@ extern "C" { GGML_TYPE_F16 = 1, GGML_TYPE_Q4_0 = 2, GGML_TYPE_Q4_1 = 3, - // GGML_TYPE_Q4_2 = 4, support has been removed + GGML_TYPE_MXFP4 = 4, // Formerly removed type GGML_TYPE_Q4_2 // GGML_TYPE_Q4_3 = 5, support has been removed GGML_TYPE_Q5_0 = 6, GGML_TYPE_Q5_1 = 7, diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 086c822d..e0d71451 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -417,6 +417,13 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#define MXFP4 32 +typedef struct { + uint8_t d; // scale E8M0 float + uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-cpu/ggml-cpu-quants.h b/ggml/src/ggml-cpu/ggml-cpu-quants.h index e33d9d47..6a25d062 100644 --- a/ggml/src/ggml-cpu/ggml-cpu-quants.h +++ b/ggml/src/ggml-cpu/ggml-cpu-quants.h @@ -58,6 +58,8 @@ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); + #ifdef __cplusplus } #endif diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 2462d2b8..bff9c426 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -362,6 +362,11 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_MXFP4] = { + .vec_dot = (ggml_vec_dot_t) ggml_vec_dot_mxfp4, + .vec_dot_type = GGML_TYPE_F32, + .nrows = 1, + }, }; const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 654e2f28..be0aa683 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4965,6 +4965,7 @@ void ggml_compute_forward_clamp( case GGML_TYPE_I32: case GGML_TYPE_I64: case GGML_TYPE_F64: + case GGML_TYPE_MXFP4: case GGML_TYPE_COUNT: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 02d40618..ec3ec9b1 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -250,3 +250,93 @@ ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, fl } return sum = (ggml_float)logf(sum); } + +#define MXFP4 32 +typedef struct { + uint8_t d; // scale E8M0 float + uint8_t qs[MXFP4 / 2]; // (32) 4 bit elements E2M1 float +} block_mxfp4; +static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + MXFP4/2, "wrong mxfp4 block size/padding"); +#define MXFP4_VALS {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0} + +void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) { + assert(nrc == 1); + GGML_UNUSED(nrc); + GGML_UNUSED(bx); + GGML_UNUSED(by); + GGML_UNUSED(bs); + ggml_float mxfp4_table[] = MXFP4_VALS; + +#if defined(GGML_SIMD) + float sumf = 0.0f; + const int np = (n & ~(GGML_F32_STEP - 1)); + const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx; + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC scalev; + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + for (int i = 0; i < np; i += GGML_F32_STEP) { // ARM: +16 AVX512: +64 + for (int j = 0; j < GGML_F32_ARR; j++) { // ARM: 0 .. 4 AVX512: 0 .. 4 + // convert GGML_F32_ARR X elements + const int ib = (i + j*GGML_F32_EPR) / MXFP4; + const block_mxfp4 * GGML_RESTRICT x = &xx[ib]; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + scalev = GGML_F32_VEC_SET1(scale.as_value); + float xf[GGML_F32_EPR]= {0.f}; + assert(((i+j*GGML_F32_EPR) % MXFP4)+GGML_F32_ARR < MXFP4 && "block overrun"); + for (int qi = 0; qi < GGML_F32_EPR/2 ; ++qi) { + xf[qi*2] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf)]; + xf[qi*2+1] = mxfp4_table[(x->qs[((i+j*GGML_F32_EPR)%MXFP4)/2+qi] & 0xf0) >> 4]; + } + + ax[j] = GGML_F32_VEC_MUL(GGML_F32_VEC_LOAD(xf), scalev); + ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); + } + } + GGML_F32_VEC_REDUCE(sumf, sum); + + // leftovers + for (int i = np; i < n; i+=2) { + const int ib = i / MXFP4; + const block_mxfp4 * GGML_RESTRICT x = &xx[ib]; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + sumf += y[i] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf)]; + sumf += y[i+1] * scale.as_value * mxfp4_table[(x->qs[(i%MXFP4)/2] & 0xf0) >> 4]; + } + + +#else // defined(GGML_SIMD) + const int nb = n / MXFP4; + assert(n % MXFP4 == 0); + + int yi = 0; + + const block_mxfp4 * GGML_RESTRICT xx = (const block_mxfp4 *) vx; + + ggml_float sumf = 0.0; + for (int ib = 0; ib < nb; ++ib) { + const block_mxfp4 * GGML_RESTRICT x = &xx[ib + 0]; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + for (int i = 0; i < MXFP4/2; ++i) { + sumf += mxfp4_table[(x->qs[i] & 0xf)] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2]); + sumf += mxfp4_table[(x->qs[i] & 0xf0) >> 4] * (ggml_float)(scale.as_value) * (ggml_float)(y[ib*MXFP4 + i*2+1]); + } + } +#endif + + *s = sumf; +} diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 23cbb305..7480ca08 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -42,6 +42,8 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); +void ggml_vec_dot_mxfp4(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); + void ggml_vec_silu_f32(const int n, float * y, const float * x); ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index c6dec427..0e016ccc 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -571,6 +571,82 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t dequantize_block_iq4_xs<<>>(vx, y); } +// MXFP4 dequantize derived from dequantize_block_q4_0 +template +static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) { + const uint16_t dst_bias = 15; + const uint16_t dst_0p5 = 0x3800; + const uint16_t dst_m_bits = 10; + const int64_t i = blockIdx.x; + + // assume 32 threads + const int64_t tid = threadIdx.x; + const int64_t il = tid/8; + const int64_t ir = tid%8; + const int64_t ib = 8*i + ir; + if (ib >= nb32) { + return; + } + + const uint64_t offset = 256*i + MXFP4*ir + 8*il; + dst_t * y = yy + offset; + + const block_mxfp4 * x = (const block_mxfp4 *)vx + ib; + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x->d) << 23); + + // offset within the block 1/4 chunks (8 items) + const uint8_t * q = x->qs + 4*il; + + for (int l = 0; l < 4; ++l) { + uint16_t em0 = q[l] & 0x07; + uint16_t em1 = q[l] & 0x70; + // float16 values + iq1m_scale_t x0; + iq1m_scale_t x1; + + x0.u16 = (em0 << (dst_m_bits - 1)) | ((q[l] & 0x08) << 12); + x1.u16 = (em1 << (dst_m_bits - 5)) | ((q[l] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0.u16 = dst_0p5 | (x0.u16 & 0x8000); + } + if (em1 == 0x10) { + x1.u16 = dst_0p5 | (x1.u16 & 0x8000); + } + // x is zero, do nothing + + // XXX it looks correct here - but mulmat still gives bad results... + // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n", + // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 0, scale * float(x0.f16)); + // printf("i:%lld ir:%lld il:%lld l:%d y_offset:[%3lld +%d] = %f \n", + // i, ir, il, l, 256*i + 32*ir + 4*il, l*2+ 1, scale * float(x1.f16)); + + y[l*2] = scale.as_value * float(x0.f16); + y[l*2+1] = scale.as_value * float(x1.f16); + } +} + +// derived from dequantize_row_q4_0_cuda +template +static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) { + const int nb32 = k / 32; + const int nb = (k + 255) / 256; + dequantize_block_mxfp4<<>>(vx, y, nb32); +} + template static __global__ void convert_unary( const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, @@ -664,6 +740,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return convert_unary_cont_cuda; case GGML_TYPE_BF16: return convert_unary_cont_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; default: return nullptr; } @@ -713,6 +791,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return convert_unary_cont_cuda; case GGML_TYPE_BF16: return convert_unary_cont_cuda; + case GGML_TYPE_MXFP4: + return dequantize_row_mxfp4_cuda; default: return nullptr; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 28ccf4be..bb19b06e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -21,6 +21,7 @@ #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmv.cuh" +#include "ggml-cuda/mmvmxfp4.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1202,7 +1203,7 @@ static void ggml_cuda_op_mul_mat_cublas( const int cc = ggml_cuda_info().devices[id].cc; - const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT; + const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT && src0->type != GGML_TYPE_MXFP4; if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) { ggml_cuda_pool_alloc src1_as_bf16(ctx.pool(id)); @@ -1924,7 +1925,11 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; + && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE + && src0->type != GGML_TYPE_MXFP4; + bool use_mul_mat_vec_mxfp4 = src0->type == GGML_TYPE_MXFP4 + && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; @@ -1978,6 +1983,8 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); + } else if (use_mul_mat_vec_mxfp4) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_mxfp4, nullptr); } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } @@ -1997,6 +2004,10 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ne2 == 1 && src0->type == GGML_TYPE_MXFP4) { + ggml_cuda_mul_mat_vec_mxfp4(ctx, src0, src1, ids, dst); + return; + } if (ne2 == 1) { if (ggml_is_quantized(src0->type)) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); @@ -3056,6 +3067,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_BF16: + case GGML_TYPE_MXFP4: #ifdef GGML_USE_MUSA if (a->type == GGML_TYPE_Q3_K) { return false; diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cu b/ggml/src/ggml-cuda/mmvmxfp4.cu new file mode 100644 index 00000000..da62062b --- /dev/null +++ b/ggml/src/ggml-cuda/mmvmxfp4.cu @@ -0,0 +1,307 @@ +#include "ggml.h" +#include "common.cuh" +#include "mmvmxfp4.cuh" + +// MXFP4 implementation derived from mmv.cu float32 code paths +typedef union { + half f16; + uint16_t u16; +} f16_t; + +template // TODO type_acc unused - consider bf16 support +static __global__ void mul_mat_vec_mxfp4( + const block_mxfp4 * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) { + const int64_t row = blockIdx.x; + const int64_t channel_dst = blockIdx.y; + const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio; + const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst; + const int64_t sample_dst = blockIdx.z; + const int64_t sample_x = sample_dst / sample_ratio; + const int64_t sample_y = sample_dst; + const int tid = threadIdx.x; + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + const uint16_t dst_bias = 15; + const uint16_t dst_0p5 = 0x3800; + const uint16_t dst_m_bits = 10; + + x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row; + y += sample_y *stride_sample_y + channel_y *stride_channel_y; + dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst; + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; // allocated in GPU shared memory: warp_size*sizeof(float) + float * buf_iw = (float *) data_mmv; + + if (block_size > warp_size) { + if (tid < warp_size) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + int offset0 = col2 / (MXFP4/2); + int i = col2 % (MXFP4/2); + const block_mxfp4 *x2 = x+offset0; + + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x2->d) << 23); + uint16_t em0 = x2->qs[i] & 0x07; + uint16_t em1 = x2->qs[i] & 0x70; + // float16 values + f16_t x0; + f16_t x1; + x0.u16 = (em0 << (dst_m_bits - 1)) | ((x2->qs[i] & 0x08) << 12); + x1.u16 = (em1 << (dst_m_bits - 5)) | ((x2->qs[i] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0.u16 = x0.u16 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1.u16 = x1.u16 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0.u16 = dst_0p5 | (x0.u16 & 0x8000); + } + if (em1 == 0x10) { + x1.u16 = dst_0p5 | (x1.u16 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale.as_value)) { + sumf = scale.as_value; + break; + } + + const float2 tmpx = {x0.f16, x1.f16}; + const float2 tmpy = y2[col2]; + sumf += tmpx.x*tmpy.x*scale.as_value; + sumf += tmpx.y*tmpy.y*scale.as_value; + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > warp_size) { + buf_iw[tid/warp_size] = sumf; + __syncthreads(); + if (tid >= warp_size) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda_mxfp4( + const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + // GGML_ASSERT(stride_row % 2 == 0); // TODO + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + int device; + int warp_size; + + CUDA_CHECK(cudaGetDevice(&device)); + warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t block_size_best = warp_size; + int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size); + int64_t max_block_size = 256; + if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) { + max_block_size = 128; + } + for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = warp_size*sizeof(float); + const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); + const dim3 block_dims(block_size_best, 1, 1); + + switch (block_size_best) { + case 32: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 64: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 96: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 128: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 160: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 192: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 224: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + case 256: { + mul_mat_vec_mxfp4<<>> + (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y, + stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void mul_mat_vec_cuda_mxfp4( + const block_mxfp4 * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + enum ggml_prec prec, cudaStream_t stream) { + launch_mul_mat_vec_cuda_mxfp4 + (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, + stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); +} + +void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { + GGML_ASSERT( src1->type == GGML_TYPE_F32); + GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS; + + const size_t ts_src0 = ggml_type_size(src0->type); + const size_t ts_src1 = ggml_type_size(src1->type); + const size_t ts_dst = ggml_type_size(dst->type); + + GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. + GGML_ASSERT(ne13 == ne3); + + // GGML_ASSERT( nb00 == ts_src0); // TODO adjust for block sizing logic + GGML_ASSERT( nb10 == ts_src1); + GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); + GGML_ASSERT( nb0 == ts_dst); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const float * src1_d = (const float *) src1->data; + const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; + float * dst_d = (float *) dst->data; + + const int64_t stride_row = src0->nb[1] / ts_src0; + const int64_t s11 = src1->nb[1] / ts_src1; + const int64_t s1 = dst->nb[1] / ts_dst; + const int64_t stride_channel_x = src0->nb[2] / ts_src0; + const int64_t s12 = src1->nb[2] / ts_src1; + const int64_t s2 = dst->nb[2] / ts_dst; + const int64_t stride_sample_x = src0->nb[3] / ts_src0; + const int64_t stride_sample_y = src1->nb[3] / ts_src1; + const int64_t stride_sample_dst = dst->nb[3] / ts_dst; + const int64_t nsamples_dst = ne3; + const int64_t nsamples_x = ne03; + const int64_t nchannels_x = ne02; + const int64_t nrows = ne01; + const int64_t ncols = ne00; + + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: + const int64_t ncols_dst = ids ? ne2 : ne1; + const int64_t nchannels_y = ids ? ne11 : ne12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + const int64_t stride_channel_dst = ids ? s1 : s2; + const int64_t stride_channel_y = ids ? s11 : s12; + + GGML_ASSERT(ncols_dst == 1); + + const block_mxfp4 * src0_d = (const block_mxfp4 *) src0->data; + mul_mat_vec_cuda_mxfp4(src0_d, src1_d, ids_d, dst_d, ncols, nrows, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, ctx.stream()); +} + +void ggml_cuda_op_mul_mat_vec_mxfp4( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00 / MXFP4; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t nchannels_dst = 1; + const int64_t stride_channel_x = 0; + const int64_t stride_channel_y = 0; + const int64_t stride_channel_dst = 0; + const int64_t nsamples_x = 1; + const int64_t nsamples_dst = 1; + const int64_t stride_sample_x = 0; + const int64_t stride_sample_y = 0; + const int64_t stride_sample_dst = 0; + + const block_mxfp4 * src0_d = (const block_mxfp4 *) src0_dd_i; + mul_mat_vec_cuda_mxfp4(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/ggml/src/ggml-cuda/mmvmxfp4.cuh b/ggml/src/ggml-cuda/mmvmxfp4.cuh new file mode 100644 index 00000000..a08fc780 --- /dev/null +++ b/ggml/src/ggml-cuda/mmvmxfp4.cuh @@ -0,0 +1,9 @@ +#include "common.cuh" + +void ggml_cuda_mul_mat_vec_mxfp4(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); + +void ggml_cuda_op_mul_mat_vec_mxfp4( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 17eab976..938386ba 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -65,6 +65,9 @@ #define N_R0_IQ4_XS 2 #define N_SG_IQ4_XS 2 +#define N_R0_MXFP4 4 +#define N_SG_MXFP4 2 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index ab46f6e3..d8e05a21 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -40,6 +40,7 @@ static const NSInteger MTLGPUFamilyMetal3_GGML = 5001; static struct ggml_backend_reg g_ggml_backend_metal_reg; static struct ggml_backend_device g_ggml_backend_metal_device; + // information about a Metal device // note: assumes single GPU device - the default one // TODO: support multiple GPU devices @@ -209,6 +210,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, @@ -288,6 +290,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, @@ -310,6 +313,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, + GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, @@ -334,6 +338,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, + GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, @@ -934,7 +939,7 @@ static id ggml_metal_load_library(id device, bool use_bfl MTLCompileOptions * options = [MTLCompileOptions new]; options.preprocessorMacros = prep; - + //[options setFastMathEnabled:false]; metal_library = [device newLibraryWithSource:src options:options error:&error]; @@ -1157,6 +1162,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction); @@ -1236,6 +1242,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32, mul_mv_id_iq1_m_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, mul_mv_id_iq4_xs_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, mul_mm_bf16_f32, has_simdgroup_mm && use_bfloat); @@ -1258,6 +1265,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); @@ -1282,6 +1290,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16, mul_mm_id_iq1_m_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16, mul_mm_id_iq4_nl_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, mul_mm_id_iq4_xs_f16, has_simdgroup_mm); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, rope_norm_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); @@ -3007,6 +3016,7 @@ static bool ggml_metal_encode_node( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break; default: GGML_ABORT("MUL MAT-MAT not implemented"); } @@ -3212,6 +3222,12 @@ static bool ggml_metal_encode_node( smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline; + } break; default: { GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); @@ -3396,6 +3412,7 @@ static bool ggml_metal_encode_node( case GGML_TYPE_IQ1_M: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F16 ].pipeline; break; case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F16 ].pipeline; break; case GGML_TYPE_IQ4_XS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16 ].pipeline; break; + case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break; default: GGML_ABORT("MUL_MAT_ID not implemented"); } @@ -3607,6 +3624,12 @@ static bool ggml_metal_encode_node( smem = 32*sizeof(float); pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; } break; + case GGML_TYPE_MXFP4: + { + nsg = N_SG_MXFP4; + nr0 = N_R0_MXFP4; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline; + } break; default: { GGML_LOG_ERROR("Asserting on type %d\n", (int)src2t); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 08e8d807..69fa17de 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1902,16 +1902,16 @@ void mul_vec_q_n_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - ushort tiisg, - ushort sgitg) { - const int nb = args.ne00/QK4_0; + uint3 tgpig, // Threadgroup Position in Grid + ushort tiisg, // Thread Index in SIMD Group + ushort sgitg) { // SIMD Group Index in ThreadGroup + const int nb = args.ne00/QK4_0; // src0->ne[0] / 32 const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr0; + const int first_row = (r0 * nsg + sgitg) * nr0; // nsg=2 nr0=4 const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -6744,6 +6744,49 @@ kernel void kernel_mul_mm_id( } } +template +void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) { + float4x4 reg_f; + const ushort dst_bias = 15; + const ushort dst_0p5 = 0x3800; + const ushort dst_m_bits = 10; + const half scale = (half)(as_type(((uint32_t)xb->d) << 23)); + // il:0 first 16, il:1 last 16 + for (int i = 0; i < 8; i++) { + ushort em0 = xb->qs[il*8 + i] & 0x07; + ushort em1 = xb->qs[il*8 + i] & 0x70; + // float16 values + ushort x0 = (em0 << (dst_m_bits - 1)) | ((xb->qs[il*8 + i] & 0x08) << 12); + ushort x1 = (em1 << (dst_m_bits - 5)) | ((xb->qs[il*8 + i] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale)) { + reg_f[i/2][2*(i%2) + 0] = scale; + reg_f[i/2][2*(i%2) + 1] = scale; + } else { + reg_f[i/2][2*(i%2) + 0] = scale * as_type(x0); + reg_f[i/2][2*(i%2) + 1] = scale * as_type(x1); + } + } + reg = (type4x4) reg_f; +} + #define QK_NL 16 // @@ -6811,6 +6854,8 @@ template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mul_mm_t kernel_mul_m template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mul_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm; + // // indirect matrix-matrix multiplication // @@ -6842,6 +6887,8 @@ template [[host_name("kernel_mul_mm_id_iq1_m_f16")]] kernel mul_mm_id kernel_m template [[host_name("kernel_mul_mm_id_iq4_nl_f16")]] kernel mul_mm_id kernel_mul_mm_id; template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id; + // // matrix-vector multiplication @@ -6958,6 +7005,120 @@ kernel void kernel_mul_mv_id( sgitg); } +// MXFP32 implementation derived from mul_vec_q_n_f32_impl and block_q_n_dot_y +void mul_mv_mxfp4_f32_impl( + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const ushort dst_bias = 15; + const ushort dst_0p5 = 0x3800; + const ushort dst_m_bits = 10; + const int nr0 = N_R0_MXFP4; + const int nsg = N_SG_MXFP4; + const int nw = N_SIMDWIDTH; + const int nb = args.ne00/MXFP4; + + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; + + const int first_row = (r0 * nsg + sgitg) * nr0; + + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; + + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const float * y = (device const float *) (src1 + offset1); + + // pointers to src0 rows + device const block_mxfp4 * ax[nr0]; + for (int row = 0; row < nr0; ++row) { + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + + ax[row] = (device const block_mxfp4 *) ((device char *) src0 + offset0); + } + + float yl[16]; // src1 vector cache + float sumf[nr0] = {0.f}; + + const short ix = (tiisg/2); + const short il = (tiisg%2)*16; + + device const float * yb = y + ix*MXFP4 + il; + + // each thread in a SIMD group deals with half a block. + for (int ib = ix; ib < nb; ib += nw/2) { + +#pragma unroll + for (short row = 0; row < nr0; row++) { + // Processes 16 items + device const block_mxfp4 * qb_curr = ax[row] + ib; + float d = as_type(((uint32_t)(ax[row] + ib)->d) << 23); + // il = 0 or 16 + device const uint8_t *qs = ((device const uint8_t *) qb_curr + 1 + il/2); + for (int i = 0; i < 8; ++i) { + ushort em0 = qs[i] & 0x07; + ushort em1 = qs[i] & 0x70; + ushort x0 = (em0 << (dst_m_bits - 1)) | ((qs[i] & 0x08) << 12); + ushort x1 = (em1 << (dst_m_bits - 5)) | ((qs[i] & 0x80) << 8); + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + if (!isnan(d)) { + sumf[row] += yb[i*2] * as_type(x0) * d + + yb[i*2+1] * as_type(x1) * d; + } else { + sumf[row] = d; + } + } + } + + yb += MXFP4 * 16; + } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + + for (int row = 0; row < nr0; ++row) { + const float tot = simd_sum(sumf[row]); + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; + } + } +} + +[[host_name("kernel_mul_mv_mxfp4_f32")]] +kernel void kernel_mul_mv_mxfp4_f32( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_mv_mxfp4_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); +} + typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; @@ -6987,6 +7148,8 @@ template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>; + kernel void kernel_pool_2d_max_f32( device const float * src0, device float * dst, diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 84ec6dfe..17c308aa 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4925,6 +4925,144 @@ void quantize_row_iq2_s_ref(const float * GGML_RESTRICT x, block_iq2_s * GGML_RE quantize_iq2_s(x, y, 1, k, NULL); } +// =============================== mxfp4 (de)-quantization + +void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) { + static const int qk = MXFP4; + static const uint32_t E8_BIAS = 127; + static const uint32_t E2_BIAS = 1; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + } + } + + const float dequant_scale = amax / 6.0f; + uint32_t dequant_scale_exponent = 0; + memcpy(&dequant_scale_exponent, &dequant_scale, sizeof(dequant_scale_exponent)); + + // Rounding up + dequant_scale_exponent = (dequant_scale_exponent + 0x007FFFFF) & 0x7F800000; + // Rounding down + // dequant_scale_exponent = dequant_scale_exponent & 0x7F800000; + + float dequant_scale_rounded = 0.0f; + memcpy(&dequant_scale_rounded, &dequant_scale_exponent, sizeof(dequant_scale_rounded)); + float quant_scale = 0.0f; + if (dequant_scale_rounded != 0.0f) { + quant_scale = 1.0f / dequant_scale_rounded; + } + + y[i].d = (uint8_t)(dequant_scale_exponent >> 23); + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + j*2]*quant_scale; + const float x1 = x[i*qk + j*2+1]*quant_scale; + + uint32_t xi0 = 0; + uint32_t xi1 = 0; + memcpy(&xi0, &x0, sizeof(xi0)); + memcpy(&xi1, &x1, sizeof(xi1)); + + uint32_t s0 = xi0 & 0x80000000; + uint32_t s1 = xi1 & 0x80000000; + uint32_t e0 = (xi0 >> 23) & 0xFF; + uint32_t e1 = (xi1 >> 23) & 0xFF; + uint32_t m0 = (xi0 & 0x7FFFFF); + uint32_t m1 = (xi1 & 0x7FFFFF); + + // 0.25 <= x < 0.75 maps to 0.5, a denormal number + // Move implicit bit 1 at the beginning to mantissa for denormals + // adjusted_exponents + uint32_t ae0 = E8_BIAS - (e0 + 1); + uint32_t ae1 = E8_BIAS - (e1 + 1); + if (e0 < E8_BIAS) { + m0 = (0x400000 | (m0 >> 1)) >> ae0; + } + if (e1 < E8_BIAS) { + m1 = (0x400000 | (m1 >> 1)) >> ae1; + } + + // For normal numbers, we change the bias from 127 to 1, and for subnormals, we keep exponent as 0. + e0 = MAX(e0, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS); + e1 = MAX(e1, E8_BIAS - E2_BIAS) - (E8_BIAS - E2_BIAS); + + // Combine sign, exponent, and mantissa, while saturating + // rounding nearest with tie breaking up by adding +1 to one bit right of the LSB, then shift right + uint32_t tmp0 = MIN((((e0 << 2) | (m0 >> 21)) + 1) >> 1, 0x7); + uint32_t tmp1 = MIN((((e1 << 2) | (m1 >> 21)) + 1) >> 1, 0x7); + uint8_t v0 = (uint8_t)((s0 >> 28) | tmp0); + uint8_t v1 = (uint8_t)((s1 >> 28) | tmp1); + y[i].qs[j] = v0; + y[i].qs[j] |= v1 << 4; + } + } +} + +void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % MXFP4 == 0); + + const int nb = k / MXFP4; + const uint16_t dst_bias = 15; + const uint16_t dst_0p5 = 0x3800; + const uint16_t dst_m_bits = 10; + + for (int i = 0; i < nb; i++) { + union { + uint32_t as_bits; + float as_value; + } scale; + scale.as_bits = (((uint32_t)x[i].d) << 23); + for (int j = 0; j < MXFP4/2; ++j) { + uint16_t em0 = x[i].qs[j] & 0x07; + uint16_t em1 = x[i].qs[j] & 0x70; + // float16 values + uint16_t x0 = (em0 << (dst_m_bits - 1)) | ((x[i].qs[j] & 0x08) << 12); + uint16_t x1 = (em1 << (dst_m_bits - 5)) | ((x[i].qs[j] & 0x80) << 8); + + // Three cases: + // x is normal and non-zero: Correct bias + if ((em0 & 0x06) != 0) { + x0 = x0 + ((dst_bias - 1) << dst_m_bits); + } + if ((em1 & 0x60) != 0) { + x1 = x1 + ((dst_bias - 1) << dst_m_bits); + } + // x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in the dst type + if (em0 == 0x01) { + x0 = dst_0p5 | (x0 & 0x8000); + } + if (em1 == 0x10) { + x1 = dst_0p5 | (x1 & 0x8000); + } + // x is zero, do nothing + + if (isnan(scale.as_value)) { + y[i*MXFP4 + j*2] = scale.as_value; + y[i*MXFP4 + j*2+1] = scale.as_value; + } else { + y[i*MXFP4 + j*2] = GGML_FP16_TO_FP32(x0)*scale.as_value; + y[i*MXFP4 + j*2+1] = GGML_FP16_TO_FP32(x1)*scale.as_value; + } + } + } +} + + +size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row); +} + // =============================== data validation static bool validate_float(float f, size_t i) { @@ -5214,7 +5352,9 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb); } break; - + case GGML_TYPE_MXFP4: + // TODO - anything to validate? + break; case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index d09173e1..2fc40f75 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -37,6 +37,8 @@ GGML_API void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_ GGML_API void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); GGML_API void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); +GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k); + // Dequantization GGML_API void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -65,6 +67,8 @@ GGML_API void dequantize_row_iq4_nl (const block_iq4_nl * GGML_RESTRICT x, floa GGML_API void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); GGML_API void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization") GGML_API size_t quantize_iq2_xxs(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_iq2_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); @@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + GGML_API void iq2xs_init_impl(enum ggml_type type); GGML_API void iq2xs_free_impl(enum ggml_type type); GGML_API void iq3xs_init_impl(int grid_size); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8a654624..0f3c9834 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -589,11 +589,13 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .to_float = (ggml_to_float_t) dequantize_row_q4_1, .from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref, }, - [4] = { // GGML_TYPE_Q4_2 - .type_name = "DEPRECATED", - .blck_size = 0, - .type_size = 0, - .is_quantized = false, + [GGML_TYPE_MXFP4] = { // formerly deprecated GGML_TYPE_Q4_2 + .type_name = "mxfp4", + .blck_size = MXFP4, + .type_size = sizeof(block_mxfp4), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_mxfp4, + .from_float_ref = (ggml_from_float_t) quantize_row_mxfp4_ref, }, [5] = { // GGML_TYPE_Q4_3 .type_name = "DEPRECATED", @@ -6446,6 +6448,7 @@ size_t ggml_quantize_chunk( case GGML_TYPE_IQ1_M: result = quantize_iq1_m (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_NL: result = quantize_iq4_nl (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ4_XS: result = quantize_iq4_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_MXFP4: result = quantize_mxfp4 (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_F16: { size_t elemsize = sizeof(ggml_fp16_t);