Skip to content

Commit d848ef7

Browse files
authored
[Example] Optimize warp specialize flashmla example (tile-ai#698)
* [Enhancement] Disable cache and append git commit ID to version in tilelang (tile-ai#688) * Disabled caching in quickstart example for improved performance. * Added a function to retrieve the current git commit ID and appended it to the version string if not already present, enhancing version tracking and debugging capabilities. * revert quickstart * optimize code.
1 parent 3a95d87 commit d848ef7

File tree

1 file changed

+48
-35
lines changed

1 file changed

+48
-35
lines changed

examples/warp_specialize/example_warp_specialize_flashmla.py

Lines changed: 48 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -28,39 +28,58 @@ def flash_attn(
2828
Output: T.Tensor([batch, heads, dim], dtype),
2929
):
3030
with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid):
31+
# smem_sQ
3132
Q_shared_l = T.alloc_shared([block_H, h_dim], dtype)
3233
Q_shared_r = T.alloc_shared([block_H, h_dim], dtype)
33-
3434
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
35+
Q_pe_local_0 = T.alloc_fragment([block_H, pe_dim], dtype)
36+
Q_pe_local_1 = T.alloc_fragment([block_H, pe_dim], dtype)
37+
38+
# smem_sK0
3539
KV_shared_0_l = T.alloc_shared([block_N, h_dim], dtype)
3640
KV_shared_0_r = T.alloc_shared([block_N, h_dim], dtype)
41+
K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
42+
43+
# smem_sK1
3744
KV_shared_1_l = T.alloc_shared([block_N, h_dim], dtype)
3845
KV_shared_1_r = T.alloc_shared([block_N, h_dim], dtype)
39-
K_pe_shared_0 = T.alloc_shared([block_N, pe_dim], dtype)
4046
K_pe_shared_1 = T.alloc_shared([block_N, pe_dim], dtype)
47+
48+
# smem_sP0
49+
SP0_shared = T.alloc_shared([block_H, block_N], dtype)
50+
51+
# smem_sP1 reuse Q_pe_shared
52+
SP1_shared = Q_pe_shared
53+
54+
# smem_sM
55+
scores_max = T.alloc_shared([block_H], accum_dtype)
56+
57+
# smem_sScale0
58+
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
59+
# smem_sScale1
60+
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
61+
62+
logsum = T.alloc_shared([block_H], accum_dtype)
63+
4164
O_shared_l = Q_shared_l
4265
O_shared_r = Q_shared_r
43-
S_shared = K_pe_shared_0
44-
S_shared_ = K_pe_shared_1
4566

4667
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
68+
acc_s_0_cast = T.alloc_fragment([block_H, block_N], dtype)
4769
acc_s_1 = T.alloc_fragment([block_H, block_N], accum_dtype)
70+
acc_s_1_cast = T.alloc_fragment([block_H, block_N], dtype)
4871
acc_o_l = T.alloc_fragment([block_H, h_dim], accum_dtype)
4972
acc_o_r = T.alloc_fragment([block_H, h_dim], accum_dtype)
5073
scores_max_0 = T.alloc_fragment([block_H], accum_dtype)
5174
scores_max_1 = T.alloc_fragment([block_H], accum_dtype)
52-
scores_max = T.alloc_shared([block_H], accum_dtype)
5375

5476
scores_max_prev_0 = T.alloc_fragment([block_H], accum_dtype)
5577
scores_max_prev_1 = T.alloc_fragment([block_H], accum_dtype)
5678

57-
scores_scale_0 = T.alloc_shared([block_H], accum_dtype)
58-
scores_scale_1 = T.alloc_shared([block_H], accum_dtype)
5979
scores_sum_0 = T.alloc_fragment([block_H], accum_dtype)
6080
scores_sum_1 = T.alloc_fragment([block_H], accum_dtype)
6181
logsum_0 = T.alloc_fragment([block_H], accum_dtype)
6282
logsum_1 = T.alloc_fragment([block_H], accum_dtype)
63-
logsum = T.alloc_shared([block_H], accum_dtype)
6483

6584
cur_kv_head = hid // (kv_group_num // block_H)
6685

@@ -69,22 +88,25 @@ def flash_attn(
6988
O_shared_r: tilelang.layout.make_swizzled_layout(O_shared_r),
7089
})
7190

91+
# barriers_Q
92+
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
93+
94+
# barriers_K0
7295
kv_shared_0_l_is_ready = T.alloc_barrier(arrive_count=128)
7396
kv_shared_0_r_is_ready = T.alloc_barrier(arrive_count=128)
7497
kv_shared_0_pe_is_ready = T.alloc_barrier(arrive_count=128)
98+
# barriers_K1
7599
kv_shared_1_l_is_ready = T.alloc_barrier(arrive_count=128)
76100
kv_shared_1_r_is_ready = T.alloc_barrier(arrive_count=128)
77101
kv_shared_1_pe_is_ready = T.alloc_barrier(arrive_count=128)
102+
103+
# redundant barriers
78104
score_max_0_ready_barrier = T.alloc_barrier(arrive_count=128)
79105
scale_1_ready_barrier = T.alloc_barrier(arrive_count=128)
80106
p0_1_1_ready_barrier = T.alloc_barrier(arrive_count=128)
81107
lse_0_ready_barrier = T.alloc_barrier(arrive_count=128)
82108
lse_1_ready_barrier = T.alloc_barrier(arrive_count=128)
83-
q_shared_ready_barrier = T.alloc_barrier(arrive_count=256)
84-
k_pe_shared_1_free_barrier = T.alloc_barrier(arrive_count=128)
85-
k_pe_shared_0_free_barrier = T.alloc_barrier(arrive_count=128)
86109
s_shared_ready_barrier = T.alloc_barrier(arrive_count=128)
87-
k_shared_1_l_free_barrier = T.alloc_barrier(arrive_count=128)
88110

89111
tx = T.get_thread_binding()
90112

@@ -93,11 +115,13 @@ def flash_attn(
93115
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
94116
T.barrier_arrive(q_shared_ready_barrier)
95117
T.barrier_wait(q_shared_ready_barrier, 0)
118+
96119
T.fill(scores_max, -T.infinity(accum_dtype))
97120

98121
loop_range = T.ceildiv(seqlen_kv, (block_N * 2))
99122

100123
if tx < 128:
124+
T.copy(Q_pe_shared, Q_pe_local_0)
101125
T.fill(acc_o_l, 0)
102126
T.fill(logsum_0, 0)
103127

@@ -118,7 +142,6 @@ def flash_attn(
118142
KV_shared_0_l,
119143
acc_s_0,
120144
transpose_B=True,
121-
policy=T.GemmWarpPolicy.FullCol,
122145
clear_accum=True,
123146
wg_wait=-1)
124147
T.barrier_wait(kv_shared_0_r_is_ready, k % 2)
@@ -127,16 +150,14 @@ def flash_attn(
127150
KV_shared_0_r,
128151
acc_s_0,
129152
transpose_B=True,
130-
policy=T.GemmWarpPolicy.FullCol,
131153
wg_wait=-1)
132154

133155
T.barrier_wait(kv_shared_0_pe_is_ready, k % 2)
134156
T.gemm(
135-
Q_pe_shared,
157+
Q_pe_local_0,
136158
K_pe_shared_0,
137159
acc_s_0,
138160
transpose_B=True,
139-
policy=T.GemmWarpPolicy.FullCol,
140161
wg_wait=-1)
141162

142163
T.wait_wgmma(0)
@@ -158,7 +179,7 @@ def flash_attn(
158179
T.reduce_sum(acc_s_0, scores_sum_0, dim=1)
159180

160181
# Step 5.
161-
T.copy(acc_s_0, S_shared)
182+
T.copy(acc_s_0, acc_s_0_cast)
162183

163184
for i, j in T.Parallel(block_H, h_dim):
164185
acc_o_l[i, j] *= scores_scale_0[i]
@@ -167,7 +188,7 @@ def flash_attn(
167188
logsum_0[i] = logsum_0[i] * scores_scale_0[i] + scores_sum_0[i]
168189

169190
# Step 6.
170-
T.gemm(S_shared, KV_shared_0_l, acc_o_l, policy=T.GemmWarpPolicy.FullCol)
191+
T.gemm(acc_s_0_cast, KV_shared_0_l, acc_o_l)
171192
T.barrier_arrive(score_max_0_ready_barrier)
172193

173194
T.barrier_wait(scale_1_ready_barrier, k % 2)
@@ -180,7 +201,7 @@ def flash_attn(
180201

181202
# Step 11.
182203
for i, j in T.Parallel(block_H, block_N):
183-
S_shared_[i, j] = acc_s_0[i, j] * scores_scale_1[i]
204+
SP0_shared[i, j] = acc_s_0[i, j] * scores_scale_1[i]
184205

185206
T.barrier_arrive(p0_1_1_ready_barrier)
186207

@@ -192,19 +213,15 @@ def flash_attn(
192213
T.barrier_wait(s_shared_ready_barrier, k % 2)
193214

194215
# Step 14.
195-
T.gemm(S_shared, KV_shared_1_l, acc_o_l, policy=T.GemmWarpPolicy.FullCol)
196-
T.barrier_arrive(k_pe_shared_0_free_barrier)
197-
T.barrier_arrive(k_shared_1_l_free_barrier)
216+
T.gemm(SP1_shared, KV_shared_1_l, acc_o_l)
198217

199218
if k < loop_range - 1:
200219

201-
T.barrier_wait(k_shared_1_l_free_barrier, k % 2)
202220
T.copy(
203221
KV[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N,
204222
cur_kv_head, :h_dim], KV_shared_1_l)
205223
T.barrier_arrive(kv_shared_1_l_is_ready)
206224

207-
T.barrier_wait(k_pe_shared_1_free_barrier, k % 2)
208225
T.copy(
209226
K_pe[bid, (2 * k + 3) * block_N:(2 * k + 4) * block_N, cur_kv_head, :],
210227
K_pe_shared_1)
@@ -220,6 +237,7 @@ def flash_attn(
220237
hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :h_dim])
221238

222239
else:
240+
T.copy(Q_pe_shared, Q_pe_local_1)
223241
T.fill(acc_o_r, 0)
224242
T.fill(logsum_1, 0)
225243

@@ -239,7 +257,6 @@ def flash_attn(
239257
KV_shared_1_l,
240258
acc_s_1,
241259
transpose_B=True,
242-
policy=T.GemmWarpPolicy.FullCol,
243260
clear_accum=True,
244261
wg_wait=-1)
245262

@@ -249,16 +266,14 @@ def flash_attn(
249266
KV_shared_1_r,
250267
acc_s_1,
251268
transpose_B=True,
252-
policy=T.GemmWarpPolicy.FullCol,
253269
wg_wait=-1)
254270

255271
T.barrier_wait(kv_shared_1_pe_is_ready, k % 2)
256272
T.gemm(
257-
Q_pe_shared,
273+
Q_pe_local_1,
258274
K_pe_shared_1,
259275
acc_s_1,
260276
transpose_B=True,
261-
policy=T.GemmWarpPolicy.FullCol,
262277
wg_wait=-1)
263278

264279
T.wait_wgmma(0)
@@ -292,14 +307,14 @@ def flash_attn(
292307
T.barrier_arrive(scale_1_ready_barrier)
293308

294309
# Step 10. compute O1 with KV_shared_1_rd
295-
T.copy(acc_s_1, S_shared)
296-
T.barrier_arrive(s_shared_ready_barrier)
310+
T.copy(acc_s_1, acc_s_1_cast)
297311
T.gemm(
298-
S_shared,
312+
acc_s_1_cast,
299313
KV_shared_1_r,
300314
acc_o_r,
301-
policy=T.GemmWarpPolicy.FullCol,
302315
wg_wait=-1)
316+
T.copy(acc_s_1_cast, SP1_shared)
317+
T.barrier_arrive(s_shared_ready_barrier)
303318

304319
if k < loop_range - 1:
305320
T.copy(
@@ -309,8 +324,7 @@ def flash_attn(
309324

310325
T.barrier_wait(p0_1_1_ready_barrier, k % 2)
311326
# Step 12.
312-
T.gemm(S_shared_, KV_shared_0_r, acc_o_r, policy=T.GemmWarpPolicy.FullCol)
313-
T.barrier_arrive(k_pe_shared_1_free_barrier)
327+
T.gemm(SP0_shared, KV_shared_0_r, acc_o_r)
314328

315329
if k < loop_range - 1:
316330

@@ -319,7 +333,6 @@ def flash_attn(
319333
h_dim:], KV_shared_0_r)
320334
T.barrier_arrive(kv_shared_0_r_is_ready)
321335

322-
T.barrier_wait(k_pe_shared_0_free_barrier, k % 2)
323336
T.copy(
324337
K_pe[bid, (2 * k + 2) * block_N:(2 * k + 3) * block_N, cur_kv_head, :],
325338
K_pe_shared_0)

0 commit comments

Comments
 (0)