@@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext(
29422942 half smax = -INFINITY;
29432943
29442944 // load the mask in shared memory
2945+ #pragma unroll(Q)
29452946 for (short j = 0 ; j < Q; ++j) {
29462947 device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
29472948
@@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext(
29682969 // we can read directly from global memory
29692970 device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8 *cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
29702971
2971- #pragma unroll
2972+ #pragma unroll(D8)
29722973 for (short i = 0 ; i < D8; ++i) {
29732974 k8x8_t mk;
29742975 simdgroup_load (mk, pk + i*8 , nb_12_1/sizeof (k_t ), 0 , true ); // transpose // TODO: use ne10
@@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext(
29892990
29902991 simdgroup_barrier (mem_flags::mem_threadgroup);
29912992
2992- #pragma unroll
2993+ #pragma unroll(4)
29932994 for (short k = 0 ; k < 4 ; ++k) {
29942995 k8x8_t mk;
29952996
@@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext(
30673068 s8x8_t mm;
30683069 simdgroup_load (mm, ss + 2 *C, TS, 0 , false );
30693070
3070- #pragma unroll
3071+ #pragma unroll(D8)
30713072 for (short i = 0 ; i < D8; ++i) {
30723073 simdgroup_multiply (lo[i], mm, lo[i]);
30733074 }
@@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext(
30823083 if (is_same<vd4x4_t , v4x4_t >::value) {
30833084 // we can read directly from global memory
30843085 device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8 *cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
3085- #pragma unroll
3086+
3087+ #pragma unroll(D8)
30863088 for (short i = 0 ; i < D8; ++i) {
30873089 v8x8_t mv;
30883090 simdgroup_load (mv, pv + i*8 , nb_12_1/sizeof (v_t ), 0 , false ); // TODO: use ne20
@@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext(
31033105
31043106 simdgroup_barrier (mem_flags::mem_threadgroup);
31053107
3106- #pragma unroll
3108+ #pragma unroll(4)
31073109 for (short k = 0 ; k < 4 ; ++k) {
31083110 v8x8_t mv;
31093111
@@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext(
31963198 simdgroup_load (ms0, ss + 2 *C, TS, 0 , false );
31973199 simdgroup_load (ms1, ss + 2 *C + sg*SH, TS, 0 , false );
31983200
3201+ #pragma unroll(D8)
31993202 for (short i = 0 ; i < D8; ++i) {
32003203 o8x8_t t;
32013204
@@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec(
34133416 // load the queries from shared memory into local memory
34143417 q4x4_t mq[D16/NL];
34153418
3419+ #pragma unroll(D16/NL)
34163420 for (short ii = 0 ; ii < D16; ii += NL) {
34173421 mq[ii/NL] = sq4x4[ii + tx];
34183422 }
@@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec(
34543458
34553459 device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
34563460
3457- #pragma unroll
3461+ #pragma unroll(D16/NL)
34583462 for (short ii = 0 ; ii < D16; ii += NL) {
34593463 const short i = ii + tx;
34603464
34613465 k4x4_t mk;
34623466 deq_k (pk + i/nl_k, i%nl_k, mk);
34633467
3464- mqka[0 ] += dot (mq[ii/NL][0 ], mk[0 ]);
3465- mqka[1 ] += dot (mq[ii/NL][1 ], mk[1 ]);
3466- mqka[2 ] += dot (mq[ii/NL][2 ], mk[2 ]);
3467- mqka[3 ] += dot (mq[ii/NL][3 ], mk[3 ]);
3468+ // note: this is less precise than the version below
3469+ // mqka[0] += dot(mq[ii/NL][0], mk[0]);
3470+ // mqka[1] += dot(mq[ii/NL][1], mk[1]);
3471+ // mqka[2] += dot(mq[ii/NL][2], mk[2]);
3472+ // mqka[3] += dot(mq[ii/NL][3], mk[3]);
3473+
3474+ mqka[0 ] += dot ((float4) mq[ii/NL][0 ], (float4) mk[0 ]);
3475+ mqka[1 ] += dot ((float4) mq[ii/NL][1 ], (float4) mk[1 ]);
3476+ mqka[2 ] += dot ((float4) mq[ii/NL][2 ], (float4) mk[2 ]);
3477+ mqka[3 ] += dot ((float4) mq[ii/NL][3 ], (float4) mk[3 ]);
34683478 }
34693479
34703480 qk_t mqk = mqka[0 ] + mqka[1 ] + mqka[2 ] + mqka[3 ];
@@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec(
35133523 ss[tiisg] = vs;
35143524
35153525 // O = diag(ms)*O
3516- #pragma unroll
3526+ #pragma unroll(D16/NL)
35173527 for (short ii = 0 ; ii < D16; ii += NL) {
35183528 lo[ii/NL] *= ms;
35193529 }
@@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec(
35233533
35243534 // O = O + (Q*K^T)*V
35253535 {
3526- #pragma unroll
35273536 for (short cc = 0 ; cc < C/4 ; ++cc) {
35283537 device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4 *cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
35293538
35303539 const s4x4_t ms (ss[4 *cc + ty]);
35313540
3532- #pragma unroll
3541+ #pragma unroll(D16/NL)
35333542 for (short ii = 0 ; ii < D16; ii += NL) {
35343543 const short i = ii + tx;
35353544
0 commit comments