mirror of
https://github.com/dogkeeper886/ollama37.git
synced 2025-12-11 00:07:07 +00:00
llama: update vendored code to commit 40c6d79f (#7875)
This commit is contained in:
13
llama/ggml-cuda/fattn-vec-f16.cuh
vendored
13
llama/ggml-cuda/fattn-vec-f16.cuh
vendored
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
|
||||
* llama.cpp - commit 40c6d79fb52f995f47507fedfeaae2ac05d9b35c - do not edit this file
|
||||
*
|
||||
* MIT License
|
||||
*
|
||||
@@ -28,9 +28,9 @@
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
__launch_bounds__(D, 1)
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
||||
static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ Q,
|
||||
const char * __restrict__ K,
|
||||
@@ -222,7 +222,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
|
||||
sum = warp_reduce_sum(sum);
|
||||
sum = warp_reduce_sum((float)sum);
|
||||
|
||||
if (use_logit_softcap) {
|
||||
sum = logit_softcap*tanhf(sum);
|
||||
@@ -246,7 +246,6 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
half kqmax_new_j = ncols == 1 ? kqmax_new : kqmax_new_arr[j];
|
||||
|
||||
kqmax_new_j = warp_reduce_max(kqmax_new_j);
|
||||
if (threadIdx.x == 0) {
|
||||
kqmax_shared[j][threadIdx.y] = kqmax_new_j;
|
||||
}
|
||||
@@ -291,7 +290,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
|
||||
#pragma unroll
|
||||
for (int j = 0; j < ncols; ++j) {
|
||||
kqsum[j] = warp_reduce_sum(kqsum[j]);
|
||||
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
|
||||
if (threadIdx.x == 0) {
|
||||
kqsum_shared[j][threadIdx.y] = kqsum[j];
|
||||
}
|
||||
@@ -306,7 +305,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
}
|
||||
|
||||
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
|
||||
|
||||
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
|
||||
if (parallel_blocks == 1) {
|
||||
|
||||
Reference in New Issue
Block a user