Skip to content

Commit d5d4247

Browse files
authored
[Dev][Doc] Enhance Flash Attention Implementation in GQA Decoding Example and Fix Typo (#139)
- Add non-split flash attention macro for more flexible kernel generation - Implement `main_no_split` function to handle single-split scenarios - Modify kernel selection logic to dynamically choose between split and non-split implementations
1 parent 1cc5c69 commit d5d4247

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

examples/deepseek_mla/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ Here, `T.annotate_layout` allows users to specify any desired layout for a buffe
116116

117117
### Warp-Specialization
118118

119-
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Access), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
119+
The Hopper architecture commonly employs warp specialization for performance optimization. A typical approach is to designate one warpgroup as a producer that handles data movement using TMA (Tensor Memory Accelerator), while the remaining warpgroups serve as consumers performing computations. However, this programming pattern is complex, requiring developers to manually manage the execution logic for producers and consumers, including synchronization through the `mbarrier` objects.
120120

121121
In TileLang, users are completely shielded from these implementation details. The frontend script is automatically transformed into a warp-specialized form, where TileLang handles all producer-consumer synchronization automatically, enabling efficient computation.
122122

examples/flash_decoding/example_gqa_decode.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,75 @@ def kernel_func(block_N, block_H, num_split, num_stages, threads):
4040
part_shape = [batch, heads, num_split, dim]
4141
valid_block_H = min(block_H, kv_group_num)
4242

43+
@T.macro
44+
def flash_attn(
45+
Q: T.Buffer(shape_q, dtype),
46+
K: T.Buffer(shape_k, dtype),
47+
V: T.Buffer(shape_v, dtype),
48+
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"),
49+
Output: T.Buffer([batch, heads, dim], dtype),
50+
):
51+
with T.Kernel(
52+
batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz):
53+
Q_shared = T.alloc_shared([block_H, dim], dtype)
54+
K_shared = T.alloc_shared([block_N, dim], dtype)
55+
V_shared = T.alloc_shared([block_N, dim], dtype)
56+
O_shared = T.alloc_shared([valid_block_H, dim], dtype)
57+
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
58+
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
59+
mask_local = T.alloc_fragment([block_N], "uint8")
60+
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
61+
scores_max = T.alloc_fragment([block_H], accum_dtype)
62+
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
63+
scores_scale = T.alloc_fragment([block_H], accum_dtype)
64+
scores_sum = T.alloc_fragment([block_H], accum_dtype)
65+
logsum = T.alloc_fragment([block_H], accum_dtype)
66+
67+
bid = bx
68+
hid = by
69+
cur_kv_head = hid // (kv_group_num // valid_block_H)
70+
71+
T.copy(Q[bid, hid * valid_block_H:hid * valid_block_H + block_H, :], Q_shared)
72+
T.fill(acc_o, 0)
73+
T.fill(logsum, 0)
74+
T.fill(scores_max, -T.infinity(accum_dtype))
75+
76+
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
77+
for k in T.Pipelined(loop_range, num_stages=num_stages):
78+
T.copy(K[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_shared)
79+
T.copy(mask[bid, k * block_N:(k + 1) * block_N, cur_kv_head], mask_local)
80+
T.clear(acc_s)
81+
T.gemm(
82+
Q_shared,
83+
K_shared,
84+
acc_s,
85+
transpose_B=True,
86+
policy=T.GemmWarpPolicy.FullRow)
87+
for i, j in T.Parallel(block_H, block_N):
88+
acc_s[i, j] = T.if_then_else(mask_local[j] != 0, acc_s[i, j],
89+
-T.infinity(accum_dtype))
90+
T.copy(scores_max, scores_max_prev)
91+
T.fill(scores_max, -T.infinity(accum_dtype))
92+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
93+
for i in T.Parallel(block_H):
94+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
95+
for i, j in T.Parallel(block_H, block_N):
96+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
97+
T.reduce_sum(acc_s, scores_sum, dim=1)
98+
for i in T.Parallel(block_H):
99+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
100+
T.copy(acc_s, acc_s_cast)
101+
for i, j in T.Parallel(block_H, dim):
102+
acc_o[i, j] *= scores_scale[i]
103+
T.copy(V[bid, k * block_N:(k + 1) * block_N, cur_kv_head, :], V_shared)
104+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
105+
for i, j in T.Parallel(block_H, dim):
106+
acc_o[i, j] /= logsum[i]
107+
for i in T.Parallel(block_H):
108+
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
109+
T.copy(acc_o[:valid_block_H, :], O_shared)
110+
T.copy(O_shared, Output[bid, hid * valid_block_H:(hid + 1) * valid_block_H, :])
111+
43112
@T.macro
44113
def flash_attn_split(
45114
Q: T.Buffer(shape_q, dtype),
@@ -168,7 +237,7 @@ def combine(
168237
Output[bz, by, i] = o_accum_local[i]
169238

170239
@T.prim_func
171-
def main(
240+
def main_split(
172241
Q: T.Buffer(shape_q, dtype),
173242
K: T.Buffer(shape_k, dtype),
174243
V: T.Buffer(shape_v, dtype),
@@ -180,7 +249,22 @@ def main(
180249
flash_attn_split(Q, K, V, mask, glse, Output_partial)
181250
combine(glse, Output_partial, Output)
182251

183-
return main
252+
@T.prim_func
253+
def main_no_split(
254+
Q: T.Buffer(shape_q, dtype),
255+
K: T.Buffer(shape_k, dtype),
256+
V: T.Buffer(shape_v, dtype),
257+
mask: T.Buffer([batch, seqlen_kv, groups], "uint8"),
258+
glse: T.Buffer([batch, heads, num_split], dtype),
259+
Output_partial: T.Buffer(part_shape, dtype),
260+
Output: T.Buffer(shape_o, dtype),
261+
):
262+
flash_attn(Q, K, V, mask, Output)
263+
264+
if num_split > 1:
265+
return main_split
266+
else:
267+
return main_no_split
184268

185269
if tune:
186270

@@ -349,6 +433,7 @@ def reduce_ref(Q, K, V, mask, glse, Output_partial):
349433
block_N=128, block_H=64, num_split=8, num_stages=2, threads=128)
350434
kernel = tilelang.compile(program, out_idx=[6])
351435
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
436+
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
352437
print("All checks pass.")
353438
latency = profiler.do_bench(ref_program, warmup=500)
354439
print("Ref: {:.2f} ms".format(latency))

0 commit comments

Comments
 (0)