@@ -28,39 +28,58 @@ def flash_attn(
2828 Output : T .Tensor ([batch , heads , dim ], dtype ),
2929 ):
3030 with T .Kernel (heads // min (block_H , kv_group_num ), batch , threads = 256 ) as (hid , bid ):
31+ # smem_sQ
3132 Q_shared_l = T .alloc_shared ([block_H , h_dim ], dtype )
3233 Q_shared_r = T .alloc_shared ([block_H , h_dim ], dtype )
33-
3434 Q_pe_shared = T .alloc_shared ([block_H , pe_dim ], dtype )
35+ Q_pe_local_0 = T .alloc_fragment ([block_H , pe_dim ], dtype )
36+ Q_pe_local_1 = T .alloc_fragment ([block_H , pe_dim ], dtype )
37+
38+ # smem_sK0
3539 KV_shared_0_l = T .alloc_shared ([block_N , h_dim ], dtype )
3640 KV_shared_0_r = T .alloc_shared ([block_N , h_dim ], dtype )
41+ K_pe_shared_0 = T .alloc_shared ([block_N , pe_dim ], dtype )
42+
43+ # smem_sK1
3744 KV_shared_1_l = T .alloc_shared ([block_N , h_dim ], dtype )
3845 KV_shared_1_r = T .alloc_shared ([block_N , h_dim ], dtype )
39- K_pe_shared_0 = T .alloc_shared ([block_N , pe_dim ], dtype )
4046 K_pe_shared_1 = T .alloc_shared ([block_N , pe_dim ], dtype )
47+
48+ # smem_sP0
49+ SP0_shared = T .alloc_shared ([block_H , block_N ], dtype )
50+
51+ # smem_sP1 reuse Q_pe_shared
52+ SP1_shared = Q_pe_shared
53+
54+ # smem_sM
55+ scores_max = T .alloc_shared ([block_H ], accum_dtype )
56+
57+ # smem_sScale0
58+ scores_scale_0 = T .alloc_shared ([block_H ], accum_dtype )
59+ # smem_sScale1
60+ scores_scale_1 = T .alloc_shared ([block_H ], accum_dtype )
61+
62+ logsum = T .alloc_shared ([block_H ], accum_dtype )
63+
4164 O_shared_l = Q_shared_l
4265 O_shared_r = Q_shared_r
43- S_shared = K_pe_shared_0
44- S_shared_ = K_pe_shared_1
4566
4667 acc_s_0 = T .alloc_fragment ([block_H , block_N ], accum_dtype )
68+ acc_s_0_cast = T .alloc_fragment ([block_H , block_N ], dtype )
4769 acc_s_1 = T .alloc_fragment ([block_H , block_N ], accum_dtype )
70+ acc_s_1_cast = T .alloc_fragment ([block_H , block_N ], dtype )
4871 acc_o_l = T .alloc_fragment ([block_H , h_dim ], accum_dtype )
4972 acc_o_r = T .alloc_fragment ([block_H , h_dim ], accum_dtype )
5073 scores_max_0 = T .alloc_fragment ([block_H ], accum_dtype )
5174 scores_max_1 = T .alloc_fragment ([block_H ], accum_dtype )
52- scores_max = T .alloc_shared ([block_H ], accum_dtype )
5375
5476 scores_max_prev_0 = T .alloc_fragment ([block_H ], accum_dtype )
5577 scores_max_prev_1 = T .alloc_fragment ([block_H ], accum_dtype )
5678
57- scores_scale_0 = T .alloc_shared ([block_H ], accum_dtype )
58- scores_scale_1 = T .alloc_shared ([block_H ], accum_dtype )
5979 scores_sum_0 = T .alloc_fragment ([block_H ], accum_dtype )
6080 scores_sum_1 = T .alloc_fragment ([block_H ], accum_dtype )
6181 logsum_0 = T .alloc_fragment ([block_H ], accum_dtype )
6282 logsum_1 = T .alloc_fragment ([block_H ], accum_dtype )
63- logsum = T .alloc_shared ([block_H ], accum_dtype )
6483
6584 cur_kv_head = hid // (kv_group_num // block_H )
6685
@@ -69,22 +88,25 @@ def flash_attn(
6988 O_shared_r : tilelang .layout .make_swizzled_layout (O_shared_r ),
7089 })
7190
91+ # barriers_Q
92+ q_shared_ready_barrier = T .alloc_barrier (arrive_count = 256 )
93+
94+ # barriers_K0
7295 kv_shared_0_l_is_ready = T .alloc_barrier (arrive_count = 128 )
7396 kv_shared_0_r_is_ready = T .alloc_barrier (arrive_count = 128 )
7497 kv_shared_0_pe_is_ready = T .alloc_barrier (arrive_count = 128 )
98+ # barriers_K1
7599 kv_shared_1_l_is_ready = T .alloc_barrier (arrive_count = 128 )
76100 kv_shared_1_r_is_ready = T .alloc_barrier (arrive_count = 128 )
77101 kv_shared_1_pe_is_ready = T .alloc_barrier (arrive_count = 128 )
102+
103+ # redundant barriers
78104 score_max_0_ready_barrier = T .alloc_barrier (arrive_count = 128 )
79105 scale_1_ready_barrier = T .alloc_barrier (arrive_count = 128 )
80106 p0_1_1_ready_barrier = T .alloc_barrier (arrive_count = 128 )
81107 lse_0_ready_barrier = T .alloc_barrier (arrive_count = 128 )
82108 lse_1_ready_barrier = T .alloc_barrier (arrive_count = 128 )
83- q_shared_ready_barrier = T .alloc_barrier (arrive_count = 256 )
84- k_pe_shared_1_free_barrier = T .alloc_barrier (arrive_count = 128 )
85- k_pe_shared_0_free_barrier = T .alloc_barrier (arrive_count = 128 )
86109 s_shared_ready_barrier = T .alloc_barrier (arrive_count = 128 )
87- k_shared_1_l_free_barrier = T .alloc_barrier (arrive_count = 128 )
88110
89111 tx = T .get_thread_binding ()
90112
@@ -93,11 +115,13 @@ def flash_attn(
93115 T .copy (Q_pe [bid , hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H , :], Q_pe_shared )
94116 T .barrier_arrive (q_shared_ready_barrier )
95117 T .barrier_wait (q_shared_ready_barrier , 0 )
118+
96119 T .fill (scores_max , - T .infinity (accum_dtype ))
97120
98121 loop_range = T .ceildiv (seqlen_kv , (block_N * 2 ))
99122
100123 if tx < 128 :
124+ T .copy (Q_pe_shared , Q_pe_local_0 )
101125 T .fill (acc_o_l , 0 )
102126 T .fill (logsum_0 , 0 )
103127
@@ -118,7 +142,6 @@ def flash_attn(
118142 KV_shared_0_l ,
119143 acc_s_0 ,
120144 transpose_B = True ,
121- policy = T .GemmWarpPolicy .FullCol ,
122145 clear_accum = True ,
123146 wg_wait = - 1 )
124147 T .barrier_wait (kv_shared_0_r_is_ready , k % 2 )
@@ -127,16 +150,14 @@ def flash_attn(
127150 KV_shared_0_r ,
128151 acc_s_0 ,
129152 transpose_B = True ,
130- policy = T .GemmWarpPolicy .FullCol ,
131153 wg_wait = - 1 )
132154
133155 T .barrier_wait (kv_shared_0_pe_is_ready , k % 2 )
134156 T .gemm (
135- Q_pe_shared ,
157+ Q_pe_local_0 ,
136158 K_pe_shared_0 ,
137159 acc_s_0 ,
138160 transpose_B = True ,
139- policy = T .GemmWarpPolicy .FullCol ,
140161 wg_wait = - 1 )
141162
142163 T .wait_wgmma (0 )
@@ -158,7 +179,7 @@ def flash_attn(
158179 T .reduce_sum (acc_s_0 , scores_sum_0 , dim = 1 )
159180
160181 # Step 5.
161- T .copy (acc_s_0 , S_shared )
182+ T .copy (acc_s_0 , acc_s_0_cast )
162183
163184 for i , j in T .Parallel (block_H , h_dim ):
164185 acc_o_l [i , j ] *= scores_scale_0 [i ]
@@ -167,7 +188,7 @@ def flash_attn(
167188 logsum_0 [i ] = logsum_0 [i ] * scores_scale_0 [i ] + scores_sum_0 [i ]
168189
169190 # Step 6.
170- T .gemm (S_shared , KV_shared_0_l , acc_o_l , policy = T . GemmWarpPolicy . FullCol )
191+ T .gemm (acc_s_0_cast , KV_shared_0_l , acc_o_l )
171192 T .barrier_arrive (score_max_0_ready_barrier )
172193
173194 T .barrier_wait (scale_1_ready_barrier , k % 2 )
@@ -180,7 +201,7 @@ def flash_attn(
180201
181202 # Step 11.
182203 for i , j in T .Parallel (block_H , block_N ):
183- S_shared_ [i , j ] = acc_s_0 [i , j ] * scores_scale_1 [i ]
204+ SP0_shared [i , j ] = acc_s_0 [i , j ] * scores_scale_1 [i ]
184205
185206 T .barrier_arrive (p0_1_1_ready_barrier )
186207
@@ -192,19 +213,15 @@ def flash_attn(
192213 T .barrier_wait (s_shared_ready_barrier , k % 2 )
193214
194215 # Step 14.
195- T .gemm (S_shared , KV_shared_1_l , acc_o_l , policy = T .GemmWarpPolicy .FullCol )
196- T .barrier_arrive (k_pe_shared_0_free_barrier )
197- T .barrier_arrive (k_shared_1_l_free_barrier )
216+ T .gemm (SP1_shared , KV_shared_1_l , acc_o_l )
198217
199218 if k < loop_range - 1 :
200219
201- T .barrier_wait (k_shared_1_l_free_barrier , k % 2 )
202220 T .copy (
203221 KV [bid , (2 * k + 3 ) * block_N :(2 * k + 4 ) * block_N ,
204222 cur_kv_head , :h_dim ], KV_shared_1_l )
205223 T .barrier_arrive (kv_shared_1_l_is_ready )
206224
207- T .barrier_wait (k_pe_shared_1_free_barrier , k % 2 )
208225 T .copy (
209226 K_pe [bid , (2 * k + 3 ) * block_N :(2 * k + 4 ) * block_N , cur_kv_head , :],
210227 K_pe_shared_1 )
@@ -220,6 +237,7 @@ def flash_attn(
220237 hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H , :h_dim ])
221238
222239 else :
240+ T .copy (Q_pe_shared , Q_pe_local_1 )
223241 T .fill (acc_o_r , 0 )
224242 T .fill (logsum_1 , 0 )
225243
@@ -239,7 +257,6 @@ def flash_attn(
239257 KV_shared_1_l ,
240258 acc_s_1 ,
241259 transpose_B = True ,
242- policy = T .GemmWarpPolicy .FullCol ,
243260 clear_accum = True ,
244261 wg_wait = - 1 )
245262
@@ -249,16 +266,14 @@ def flash_attn(
249266 KV_shared_1_r ,
250267 acc_s_1 ,
251268 transpose_B = True ,
252- policy = T .GemmWarpPolicy .FullCol ,
253269 wg_wait = - 1 )
254270
255271 T .barrier_wait (kv_shared_1_pe_is_ready , k % 2 )
256272 T .gemm (
257- Q_pe_shared ,
273+ Q_pe_local_1 ,
258274 K_pe_shared_1 ,
259275 acc_s_1 ,
260276 transpose_B = True ,
261- policy = T .GemmWarpPolicy .FullCol ,
262277 wg_wait = - 1 )
263278
264279 T .wait_wgmma (0 )
@@ -292,14 +307,14 @@ def flash_attn(
292307 T .barrier_arrive (scale_1_ready_barrier )
293308
294309 # Step 10. compute O1 with KV_shared_1_rd
295- T .copy (acc_s_1 , S_shared )
296- T .barrier_arrive (s_shared_ready_barrier )
310+ T .copy (acc_s_1 , acc_s_1_cast )
297311 T .gemm (
298- S_shared ,
312+ acc_s_1_cast ,
299313 KV_shared_1_r ,
300314 acc_o_r ,
301- policy = T .GemmWarpPolicy .FullCol ,
302315 wg_wait = - 1 )
316+ T .copy (acc_s_1_cast , SP1_shared )
317+ T .barrier_arrive (s_shared_ready_barrier )
303318
304319 if k < loop_range - 1 :
305320 T .copy (
@@ -309,8 +324,7 @@ def flash_attn(
309324
310325 T .barrier_wait (p0_1_1_ready_barrier , k % 2 )
311326 # Step 12.
312- T .gemm (S_shared_ , KV_shared_0_r , acc_o_r , policy = T .GemmWarpPolicy .FullCol )
313- T .barrier_arrive (k_pe_shared_1_free_barrier )
327+ T .gemm (SP0_shared , KV_shared_0_r , acc_o_r )
314328
315329 if k < loop_range - 1 :
316330
@@ -319,7 +333,6 @@ def flash_attn(
319333 h_dim :], KV_shared_0_r )
320334 T .barrier_arrive (kv_shared_0_r_is_ready )
321335
322- T .barrier_wait (k_pe_shared_0_free_barrier , k % 2 )
323336 T .copy (
324337 K_pe [bid , (2 * k + 2 ) * block_N :(2 * k + 3 ) * block_N , cur_kv_head , :],
325338 K_pe_shared_0 )
0 commit comments