Skip to content

Commit

Permalink
try hack in missing hmax2 functions (+1 squashed commits)
Browse files Browse the repository at this point in the history
Squashed commits:

[c98d0ab] try hack in missing hmax2 functions (+1 squashed commits)

Squashed commits:

[9ba8599] try hack in missing hmax2 functions (+2 squashed commit)

Squashed commit:

[be49749] try hack in missing hmax2 functions

[159ee4c] bypass missing hmax functions on old cuda
  • Loading branch information
LostRuins committed May 1, 2024
1 parent b48ea96 commit cea4675
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,29 @@
#define HALF_MAX_HALF __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.

//hack: polyfill hmax and hmax2 for older cuda version
#if CUDART_VERSION < CUDART_HMAX
__device__ __inline__ __half hmax(const __half a, const __half b) {
const float fa = __half2float(a);
const float fb = __half2float(b);
return __float2half(fa > fb ? fa : fb);
}
__device__ __inline__ __half2 hmax2(const __half2 a, const __half2 b) {
__half2 result;
result.x = hmax(a.x, b.x);
result.y = hmax(a.y, b.y);
return result;
}
#else
__device__ __inline__ __half hmax(const __half a, const __half b) {
return __hmax(a,b);
}
__device__ __inline__ __half2 hmax2(const __half2 a, const __half2 b) {
return __hmax2(a,b);
}
#endif


template<int D, int parallel_blocks> // D == head size
__launch_bounds__(((D + WARP_SIZE - 1) / WARP_SIZE)*WARP_SIZE, 1)
static __global__ void flash_attn_vec_ext_f16(
Expand Down Expand Up @@ -116,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f16(
sum2 = warp_reduce_sum(sum2);
half sum = __low2half(sum2) + __high2half(sum2);
sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new = __hmax(kqmax_new, sum);
kqmax_new = hmax(kqmax_new, sum);
if (threadIdx.x == 0) {
KQ[i_KQ] = sum;
}
Expand Down Expand Up @@ -416,9 +439,9 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + threadIdx.x;

KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
KQ_max_new = hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
KQ_max_new = __half2half2(warp_reduce_max(hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
Expand Down

0 comments on commit cea4675

Please sign in to comment.