@@ -1223,7 +1223,9 @@ static __global__ void flash_attn_ext_f16(
12231223 const int ne12,
12241224 const int ne13,
12251225 const int ne31,
1226+ const int ne32,
12261227 const int nb31,
1228+ const int nb32,
12271229 const int nb01,
12281230 const int nb02,
12291231 const int nb03,
@@ -1288,7 +1290,8 @@ static __global__ void flash_attn_ext_f16(
12881290
12891291 const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
12901292 const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1291- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof (half2))*jt*ncols1 : nullptr ;
1293+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1294+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
12921295 float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2 );
12931296
12941297 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1327,7 +1330,8 @@ static __global__ void flash_attn_ext_f16(
13271330
13281331 const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
13291332 const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
1330- const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof (half2))*jt*ncols1 : nullptr ;
1333+ const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1334+ (const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
13311335 float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2 );
13321336
13331337 const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2 ) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1348,8 +1352,8 @@ static __global__ void flash_attn_ext_f16(
13481352 GGML_UNUSED (max_bias); GGML_UNUSED (m0); GGML_UNUSED (m1);
13491353 GGML_UNUSED (n_head_log2); GGML_UNUSED (logit_softcap); GGML_UNUSED (ne00);
13501354 GGML_UNUSED (ne01); GGML_UNUSED (ne02); GGML_UNUSED (ne03); GGML_UNUSED (ne10);
1351- GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31);
1352- GGML_UNUSED (nb31); GGML_UNUSED (nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
1355+ GGML_UNUSED (ne11); GGML_UNUSED (ne12); GGML_UNUSED (ne13); GGML_UNUSED (ne31); GGML_UNUSED (ne32);
1356+ GGML_UNUSED (nb31); GGML_UNUSED (nb32); GGML_UNUSED ( nb01); GGML_UNUSED (nb02); GGML_UNUSED (nb03);
13531357 GGML_UNUSED (nb11); GGML_UNUSED (nb12); GGML_UNUSED (nb13); GGML_UNUSED (nb21);
13541358 GGML_UNUSED (nb22); GGML_UNUSED (nb23); GGML_UNUSED (ne0); GGML_UNUSED (ne1);
13551359 GGML_UNUSED (ne2); GGML_UNUSED (ne3);
0 commit comments