File tree Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Expand file tree Collapse file tree 2 files changed +5
-4
lines changed Original file line number Diff line number Diff line change @@ -8135,7 +8135,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
8135
8135
}
8136
8136
8137
8137
// V /= S
8138
- const float S_inv = 1 .0f /S;
8138
+ const float S_inv = S == 0 . 0f ? 0 . 0f : 1 .0f /S;
8139
8139
ggml_vec_scale_f32 (DV, VKQ32, S_inv);
8140
8140
8141
8141
// dst indices
Original file line number Diff line number Diff line change @@ -5201,7 +5201,7 @@ void kernel_flash_attn_ext_impl(
5201
5201
5202
5202
device float4 * dst4 = (device float4 *) dst + ((uint64_t )iq3*args.ne2 *args.ne1 + iq2 + (uint64_t )(iq1 + j)*args.ne1 )*DV4;
5203
5203
5204
- const float scale = 1 .0f /S[jj];
5204
+ const float scale = S[jj] == 0.0 ? 0 . 0f : 1 .0f /S[jj];
5205
5205
5206
5206
if (DV4 % NW == 0 ) {
5207
5207
FOR_UNROLL (short ii = 0 ; ii < DV4/NW; ++ii) {
@@ -5821,7 +5821,7 @@ void kernel_flash_attn_ext_vec_impl(
5821
5821
device float4 * dst4 = (device float4 *) dst;
5822
5822
device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results
5823
5823
5824
- const float S = NWG == 1 ? 1 .0f /ss[0 ] : 1 .0f ;
5824
+ const float S = NWG == 1 ? (ss[ 0 ] == 0 . 0f ? 0 . 0f : 1 .0f /ss[0 ]) : 1 .0f ;
5825
5825
5826
5826
// interleave the workgroup data
5827
5827
for (short i = tiisg; i < DV4; i += NW) {
@@ -5999,7 +5999,8 @@ kernel void kernel_flash_attn_ext_vec_reduce(
5999
5999
const float m = simd_max (M);
6000
6000
const float ms = exp (M - m);
6001
6001
6002
- S = 1 .0f /simd_sum (S*ms);
6002
+ S = simd_sum (S*ms);
6003
+ S = S == 0 .0f ? 0 .0f : 1 .0f /S;
6003
6004
6004
6005
const short DV4 = DV/4 ;
6005
6006
You can’t perform that action at this time.
0 commit comments