image processing for llama3.2 (#6963)

Co-authored-by: jmorganca <jmorganca@gmail.com>
Co-authored-by: Michael Yang <mxyng@pm.me>
Co-authored-by: Jesse Gross <jesse@ollama.com>
This commit is contained in:
Patrick Devine
2024-10-18 16:12:35 -07:00
committed by GitHub
parent bf4018b9ec
commit c7cb0f0602
35 changed files with 3851 additions and 203 deletions

View File

@@ -2055,6 +2055,51 @@ kernel void kernel_pad_f32(
}
}
kernel void kernel_unpad_f32(
device const char * src0,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant uint64_t & nb03,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
constant uint64_t & nb0,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i3 = tgpig.z;
const int64_t i2 = tgpig.y;
const int64_t i1 = tgpig.x;
const int64_t i03 = i3;
const int64_t i02 = i2;
const int64_t i01 = i1;
device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
if (i0 < ne00) {
dst_ptr[i0] = src0_ptr[i0];
}
}
return;
}
}
kernel void kernel_arange_f32(
device char * dst,
constant int64_t & ne0,