Skip to content

Commit fdedda0

Browse files
committed
metal : prevent division by zero in FA kernels
1 parent bd3104c commit fdedda0

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
81358135
}
81368136

81378137
// V /= S
8138-
const float S_inv = 1.0f/S;
8138+
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
81398139
ggml_vec_scale_f32(DV, VKQ32, S_inv);
81408140

81418141
// dst indices

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5201,7 +5201,7 @@ void kernel_flash_attn_ext_impl(
52015201

52025202
device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4;
52035203

5204-
const float scale = 1.0f/S[jj];
5204+
const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj];
52055205

52065206
if (DV4 % NW == 0) {
52075207
FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) {
@@ -5821,7 +5821,7 @@ void kernel_flash_attn_ext_vec_impl(
58215821
device float4 * dst4 = (device float4 *) dst;
58225822
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
58235823

5824-
const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f;
5824+
const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f;
58255825

58265826
// interleave the workgroup data
58275827
for (short i = tiisg; i < DV4; i += NW) {
@@ -5999,7 +5999,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
59995999
const float m = simd_max(M);
60006000
const float ms = exp(M - m);
60016001

6002-
S = 1.0f/simd_sum(S*ms);
6002+
S = simd_sum(S*ms);
6003+
S = S == 0.0f ? 0.0f : 1.0f/S;
60036004

60046005
const short DV4 = DV/4;
60056006

0 commit comments

Comments
 (0)