model: add Qwen2.5-VL support (#10385)

This commit is contained in:
Bruce MacDonald
2025-05-13 20:58:02 -07:00
committed by GitHub
parent 23125648b8
commit 0aa8b371dd
16 changed files with 1619 additions and 10 deletions

View File

@@ -119,6 +119,21 @@ type Context interface {
Layer(int) Context
}
// RopeOptions contains optional parameters for RoPE function
type RopeOptions struct {
OriginalContextLen uint32
}
// RopeOption defines a function that modifies RopeOpts
type RopeOption func(*RopeOptions)
// WithContextLen sets a custom context length
func WithContextLen(len uint32) RopeOption {
return func(opts *RopeOptions) {
opts.OriginalContextLen = len
}
}
type Tensor interface {
Dim(n int) int
Stride(n int) int
@@ -144,7 +159,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32, options ...RopeOption) Tensor
IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
Sin(ctx Context) Tensor
@@ -172,6 +187,7 @@ type Tensor interface {
Duplicate(ctx Context) Tensor
TopK(ctx Context, k int) Tensor
Argsort(ctx Context) Tensor
}
// ScaledDotProductAttention implements a fused attention

View File

@@ -1060,7 +1060,17 @@ const (
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32, options ...ml.RopeOption) ml.Tensor {
// Default options
opts := &ml.RopeOptions{
OriginalContextLen: 131072,
}
// Apply any provided options
for _, option := range options {
option(opts)
}
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
@@ -1073,16 +1083,19 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.int(opts.OriginalContextLen),
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
C.float(0.0),
C.float(1.0),
C.float(32.0),
C.float(1.0),
),
}
}
@@ -1176,3 +1189,10 @@ func (t *Tensor) TopK(ctx ml.Context, k int) ml.Tensor {
t: C.ggml_top_k(ctx.(*Context).ctx, t.t, C.int(k)),
}
}
func (t *Tensor) Argsort(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_argsort(ctx.(*Context).ctx, t.t, C.GGML_SORT_ORDER_ASC),
}
}

View File

@@ -6822,6 +6822,45 @@ static void ggml_compute_forward_argsort_f32(
}
}
static void ggml_compute_forward_argsort_i32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_TENSOR_UNARY_OP_LOCALS
GGML_ASSERT(nb0 == sizeof(int32_t));
const int ith = params->ith;
const int nth = params->nth;
const int64_t nr = ggml_nrows(src0);
ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0);
for (int64_t i = ith; i < nr; i += nth) {
int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1);
const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01);
for (int64_t j = 0; j < ne0; j++) {
dst_data[j] = j;
}
// C doesn't have a functional sort, so we do a bubble sort instead
for (int64_t j = 0; j < ne0; j++) {
for (int64_t k = j + 1; k < ne0; k++) {
if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) ||
(order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) {
int32_t tmp = dst_data[j];
dst_data[j] = dst_data[k];
dst_data[k] = tmp;
}
}
}
}
}
void ggml_compute_forward_argsort(
const ggml_compute_params * params,
ggml_tensor * dst) {
@@ -6833,6 +6872,10 @@ void ggml_compute_forward_argsort(
{
ggml_compute_forward_argsort_f32(params, dst);
} break;
case GGML_TYPE_I32:
{
ggml_compute_forward_argsort_i32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");

View File

@@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co
}
}
template<ggml_sort_order order>
static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) {
extern __shared__ int shared_mem[];
int * indices = shared_mem;
const int tid = threadIdx.x;
const int row = blockIdx.y;
// Initialize all indices, handling the case where threads < ncols_pad
for (int i = tid; i < ncols_pad; i += blockDim.x) {
indices[i] = i < ncols ? i : 0; // Use 0 for padding indices
}
__syncthreads();
// Bitonic sort
for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k/2; j > 0; j /= 2) {
for (int i = tid; i < ncols_pad; i += blockDim.x) {
const int ij = i ^ j;
if (ij > i) {
// Only compare values within the actual data range
if (i < ncols && ij < ncols) {
if ((i & k) == 0) {
if (order == GGML_SORT_ORDER_ASC) {
if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
} else {
if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
}
} else {
if (order == GGML_SORT_ORDER_ASC) {
if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
} else {
if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) {
int tmp = indices[i];
indices[i] = indices[ij];
indices[ij] = tmp;
}
}
}
}
}
}
__syncthreads();
}
}
// Write sorted indices to output, only threads handling valid data
for (int i = tid; i < ncols; i += blockDim.x) {
dst[row * ncols + i] = indices[i];
}
}
static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// Bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols);
// Ensure thread count doesn't exceed maximum (typically 1024)
const int max_threads = 1024; // This is the typical max for most GPUs
const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad;
const dim3 block_dims(threads_per_block, 1, 1);
const dim3 block_nums(1, nrows, 1);
const size_t shared_mem = ncols_pad * sizeof(int);
// Check if shared memory size is within limits
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
// Instead of logging an error, use GGML_ASSERT with a descriptive message
GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit");
// Launch kernels with the updated thread configuration
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_i32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_i32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, shared_mem, stream>>>(x, dst, ncols, ncols_pad);
} else {
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
@@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
if (src0->type == GGML_TYPE_I32) {
argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream);
} else {
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}
}

View File

@@ -38,6 +38,13 @@ static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
*dsti = *xi;
}
static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) {
const int32_t * xi = (const int32_t *) cxi;
int32_t * dsti = (int32_t *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
@@ -68,6 +75,44 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const in
cpy_1(cx + x_offset, cdst + dst_offset);
}
// First, add this template function after the other template functions
template <cpy_kernel_t cpy_1>
static __global__ void cpy_i32_i32(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
// Then modify the ggml_cpy_i32_i32_cuda function to use the new template
static void ggml_cpy_i32_i32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_i32_i32<cpy_1_i32_i32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
@@ -633,6 +678,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -688,6 +735,8 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) {
return (void*) cpy_f32_f16<cpy_1_f32_f16>;
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
return (void*) cpy_f32_f16<cpy_1_f16_f32>;
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
return (void*) cpy_i32_i32<cpy_1_i32_i32>;
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));