Skip to content

Commit 17bc5a8

Browse files
HIP: use v_dot2_f32_f16 instruction for FA (#15884)
1 parent ed54e32 commit 17bc5a8

File tree

2 files changed

+26
-6
lines changed

2 files changed

+26
-6
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
545545
#endif // defined(GGML_USE_HIP)
546546
}
547547

548+
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
549+
acc += v*u;
550+
}
551+
552+
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
553+
acc += v.x*u.x;
554+
acc += v.y*u.y;
555+
}
556+
557+
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
558+
#if defined(GGML_USE_HIP) && defined(GCN)
559+
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
560+
#else
561+
#ifdef FAST_FP16_AVAILABLE
562+
const float2 tmp = __half22float2(v*u);
563+
acc += tmp.x + tmp.y;
564+
#else
565+
const float2 tmpv = __half22float2(v);
566+
const float2 tmpu = __half22float2(u);
567+
acc += tmpv.x * tmpu.x;
568+
acc += tmpv.y * tmpu.y;
569+
#endif // FAST_FP16_AVAILABLE
570+
#endif // defined(GGML_USE_HIP) && defined(GCN)
571+
}
572+
548573
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
549574
#if CUDART_VERSION >= 12080
550575
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);

ggml/src/ggml-cuda/fattn-tile.cu

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -304,12 +304,7 @@ static __global__ void flash_attn_tile(
304304
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
305305
#pragma unroll
306306
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
307-
#ifdef FAST_FP16_AVAILABLE
308-
const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
309-
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
310-
#else
311-
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
312-
#endif // FAST_FP16_AVAILABLE
307+
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
313308
}
314309
}
315310
}

0 commit comments

Comments
 (0)