Skip to content

Commit 5b8ab6d

Browse files
authored
[Enhancement] Add new examples for warp specialization and TMA integration (#448)
* [Refactor] Update KernelLaunch to clarify CPU and GPU kernel launch logic * Added comments to distinguish between CPU and GPU kernel launch sections for better code readability. * Changed the creation of empty blocks to use a consistent "root" identifier, enhancing clarity in frame management. * [Refactor] Rename operations for consistency in lower_hopper_intrin and related files * Updated function names from CamelCase to snake_case for better consistency across the codebase. * Refactored calls to `CreateTMADescriptorOp`, `CreateListofMBarrierOp`, and similar functions to their new names: `create_tma_descriptor`, `create_list_of_mbarrier`, etc. * Adjusted corresponding test cases to reflect these changes, ensuring compatibility with the new naming conventions. * [Refactor] Rename operations to snake_case for consistency * Updated function names from CamelCase to snake_case across various files, including `CreateTMADescriptorOp` to `create_tma_descriptor`, `GetMBarrierOp` to `get_mbarrier`, and others. * Adjusted corresponding calls and definitions in the codebase to reflect these naming changes, ensuring uniformity and improved readability. * Enhanced layout inference and loop partitioning logic to accommodate the new naming conventions. * [Feature] Introduce Warp Specialization and Eliminate Storage Sync for MBarrier * Added a new example `gemm_ws.py` demonstrating matrix multiplication with warp specialization using TileLang. * Implemented `WarpSpecializeFrame` and `WarpSpecialize` functionality to manage warp group indices in TIR frames. * Introduced `EliminateStorageSyncForMBarrier` transformation to optimize storage synchronization in mbarrier regions. * Enhanced the TileLang API with new methods for retrieving block and thread extents. * Updated the `LowerAndLegalize` and `OptimizeForTarget` functions to incorporate the new transformation. * Improved layout inference and kernel launch logic for better performance and clarity. * [Refactor] Clean up code formatting and improve readability * Added blank lines for better separation of code blocks in `gemm_ws.py`, `phase.py`, `kernel.py`, and `warpgroup.py`. * Reformatted the `tilelang.compile` call in `gemm_ws.py` for improved clarity. * Updated comments in `warpgroup.py` to clarify the availability of the `WarpSpecialize` function for NVIDIA GPUs. * Ensured consistent spacing and formatting across multiple files to enhance overall code readability. * lint fix * [Refactor] Update mbarrier functions for improved clarity and consistency * Refactored `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` to accept explicit parameters for better readability. * Updated calls in `gemm_ws.py` to use the new function signatures, enhancing code clarity. * Adjusted `warpgroup.py` to remove unused thread extent variable, streamlining the code. * Added detailed docstrings to clarify usage examples for memory barrier functions. * Added blank lines in `mbarrier_wait_parity` and `mbarrier_arrive` functions in `builtin.py` for improved code readability and separation of logical sections. * [Feature] Add examples for warp specialization and TMA barrier integration * Introduced three new example scripts: `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, and `example_warp_specialize_mla.py` demonstrating matrix multiplication with warp specialization and TMA barriers. * Implemented kernel functions with shared memory allocation and memory barrier synchronization for improved performance. * Enhanced the TileLang API with new methods for compiling and testing kernels in Python using PyTorch. * Updated the `phase.py` to include TMA barrier injection in the optimization process. * Improved documentation and comments for better clarity on usage and functionality. * [Feature] Add example for warp specialization in GEMM with TMA barriers * Introduced a new example script `example_warp_specialize_gemm_stage2.py` demonstrating matrix multiplication using warp specialization and TMA barriers. * Implemented a kernel function with shared memory allocation and memory barrier synchronization for enhanced performance. * Included functionality to compile the kernel into a PyTorch-compatible function and validate its correctness against PyTorch's reference implementation. * Enhanced documentation and comments for clarity on usage and functionality. * lint fix * [Feature] Implement WarpSpecializedDetector for TMA and MBarrier Detection * Added the `WarpSpecializedDetector` class to identify the presence of TMA operations and memory barrier operations within a given TIR statement. * Enhanced the `WarpSpecialized` pass to utilize the detector, allowing for conditional substitution based on the detection results. * Improved code organization by including necessary headers and utilizing the `IRVisitorWithAnalyzer` for analysis. * This addition aims to optimize warp specialization by ensuring that only relevant functions are transformed, enhancing performance and correctness. * lint fix * [Feature] Add new examples for warp specialization and TMA integration * Introduced multiple new example scripts demonstrating warp specialization techniques, including `example_warp_specialize_flashmla.py`, `example_warp_specialize_gemm_barrierpipe_stage2.py`, `example_warp_specialize_gemm_copy_0_gemm_1.py`, `example_warp_specialize_gemm_copy_1_gemm_0.py`, and `example_warp_specialize_gemm_softpipe_stage2.py`. * Each example showcases matrix multiplication with warp specialization and TMA barriers, implementing kernel functions with shared memory allocation and memory barrier synchronization for enhanced performance. * Added a test suite in `test_example_warp_specialize.py` to validate the functionality of the new examples. * Updated the TileLang API to support these examples and improve kernel compilation and testing processes. * Removed outdated example scripts to streamline the codebase and enhance clarity on available functionalities. * lint fix * Remove outdated example scripts for warp specialization and TMA integration to streamline the codebase. This includes `example_warp_specialize_gemm.py`, `example_warp_specialize_gemm_barrier4.py`, `example_warp_specialize_gemm_stage2.py`, and `example_warp_specialize_mla.py`, which are no longer needed following recent updates and improvements in the TileLang API.
1 parent f617c58 commit 5b8ab6d

29 files changed

+859
-549
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
# use default stage 1 template, not the optimal
4+
# schedule, please checkout examples/deepseek_mla
5+
import torch
6+
import torch.nn.functional as F
7+
import tilelang
8+
import tilelang.language as T
9+
from einops import rearrange, einsum
10+
import argparse
11+
12+
13+
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
14+
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
15+
dtype = "float16"
16+
accum_dtype = "float"
17+
kv_group_num = heads // kv_head_num
18+
VALID_BLOCK_H = min(block_H, kv_group_num)
19+
assert kv_head_num == 1, "kv_head_num must be 1"
20+
21+
@T.macro
22+
def flash_attn(
23+
Q: T.Tensor([batch, heads, dim], dtype),
24+
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
25+
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
26+
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
27+
Output: T.Tensor([batch, heads, dim], dtype),
28+
):
29+
with T.Kernel(batch, heads // min(block_H, kv_group_num), threads=384) as (bx, by):
30+
Q_shared = T.alloc_shared([block_H, dim], dtype)
31+
S_shared = T.alloc_shared([block_H, block_N], dtype)
32+
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
33+
KV_shared = T.alloc_shared([block_N, dim], dtype)
34+
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
35+
O_shared = T.alloc_shared([block_H, dim], dtype)
36+
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
37+
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
38+
scores_max = T.alloc_fragment([block_H], accum_dtype)
39+
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
40+
scores_scale = T.alloc_fragment([block_H], accum_dtype)
41+
scores_sum = T.alloc_fragment([block_H], accum_dtype)
42+
logsum = T.alloc_fragment([block_H], accum_dtype)
43+
44+
cur_kv_head = by // (kv_group_num // block_H)
45+
T.use_swizzle(10)
46+
T.annotate_layout({
47+
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
48+
})
49+
50+
T.create_list_of_mbarrier(128, 128, 256, 128)
51+
52+
loop_range = T.ceildiv(seqlen_kv, block_N)
53+
with T.ws(2):
54+
T.dec_max_nreg(24)
55+
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
56+
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
57+
T.mbarrier_arrive(T.get_mbarrier(3))
58+
for k in T.serial(loop_range):
59+
T.mbarrier_wait_parity(
60+
T.FloorMod(k, 1) + 2, T.bitwise_xor(T.FloorDiv(k, 1) % 2, 1))
61+
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
62+
T.mbarrier_arrive(T.FloorMod(k, 1))
63+
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
64+
T.mbarrier_arrive(T.FloorMod(k, 1) + 1)
65+
with T.ws(0, 1):
66+
T.inc_max_nreg(240)
67+
T.fill(acc_o, 0)
68+
T.fill(logsum, 0)
69+
T.fill(scores_max, -T.infinity(accum_dtype))
70+
T.mbarrier_wait_parity(T.get_mbarrier(3), 0)
71+
for k in T.serial(loop_range):
72+
T.clear(acc_s)
73+
T.mbarrier_wait_parity(T.get_mbarrier(T.FloorMod(k, 1)), T.FloorDiv(k, 1) % 2)
74+
T.gemm(
75+
Q_shared,
76+
KV_shared,
77+
acc_s,
78+
transpose_B=True,
79+
policy=T.GemmWarpPolicy.FullCol)
80+
T.mbarrier_wait_parity(
81+
T.get_mbarrier(T.FloorMod(k, 1) + 1),
82+
T.FloorDiv(k, 1) % 2)
83+
T.gemm(
84+
Q_pe_shared,
85+
K_pe_shared,
86+
acc_s,
87+
transpose_B=True,
88+
policy=T.GemmWarpPolicy.FullCol)
89+
T.copy(scores_max, scores_max_prev)
90+
T.fill(scores_max, -T.infinity(accum_dtype))
91+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
92+
for i in T.Parallel(block_H):
93+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
94+
for i, j in T.Parallel(block_H, block_N):
95+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
96+
T.reduce_sum(acc_s, scores_sum, dim=1)
97+
T.copy(acc_s, S_shared)
98+
for i in T.Parallel(block_H):
99+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
100+
for i, j in T.Parallel(block_H, dim):
101+
acc_o[i, j] *= scores_scale[i]
102+
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
103+
T.mbarrier_arrive(T.get_mbarrier(T.FloorMod(k, 1) + 2))
104+
for i, j in T.Parallel(block_H, dim):
105+
acc_o[i, j] /= logsum[i]
106+
T.copy(acc_o, O_shared)
107+
T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :])
108+
109+
@T.prim_func
110+
def main_no_split(
111+
Q: T.Tensor([batch, heads, dim], dtype),
112+
Q_pe: T.Tensor([batch, heads, pe_dim], dtype),
113+
KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype),
114+
K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
115+
glse: T.Tensor([batch, heads, num_split], dtype),
116+
Output_partial: T.Tensor([batch, heads, num_split, dim], dtype),
117+
Output: T.Tensor([batch, heads, dim], dtype),
118+
):
119+
flash_attn(Q, Q_pe, KV, K_pe, Output)
120+
121+
return main_no_split
122+
123+
124+
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
125+
# """
126+
# Inputs:
127+
# - q (Tensor): [batch, heads, dim]
128+
# - q_pe (Tensor): [batch, heads, pe_dim]
129+
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
130+
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
131+
# - glse (Tensor): [batch, heads, num_split]
132+
# - Output_partial (Tensor): [batch, heads, num_split, dim]
133+
# Outputs:
134+
# - output (Tensor): [batch, heads, dim]
135+
# """
136+
dim = q.shape[-1]
137+
pe_dim = q_pe.shape[-1]
138+
num_head_groups = q.shape[1] // kv.shape[2]
139+
scale = (dim + pe_dim)**0.5
140+
q = rearrange(
141+
q, 'b (h g) d -> b g h d', g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
142+
143+
q_pe = rearrange(
144+
q_pe, 'b (h g) d -> b g h d',
145+
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
146+
147+
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
148+
149+
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
150+
151+
query = torch.concat([q, q_pe], dim=-1)
152+
key = torch.concat([kv, k_pe], dim=-1)
153+
154+
scores = einsum(
155+
query, key,
156+
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
157+
158+
attention = F.softmax(
159+
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
160+
161+
out = einsum(attention, kv,
162+
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
163+
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
164+
return out
165+
166+
167+
def main():
168+
parser = argparse.ArgumentParser()
169+
parser.add_argument('--batch', type=int, default=128, help='batch size')
170+
parser.add_argument('--heads', type=int, default=128, help='q heads number')
171+
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
172+
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
173+
parser.add_argument('--dim', type=int, default=512, help='head dim')
174+
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
175+
args = parser.parse_args()
176+
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
177+
qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
178+
pv_flops = 2 * batch * heads * kv_ctx * dim
179+
total_flops = qk_flops + pv_flops
180+
BLOCK_N = 64
181+
BLOCK_H = 64
182+
num_split = 1
183+
184+
program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
185+
kernel = tilelang.compile(program, out_idx=[6])
186+
187+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
188+
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
189+
latency = profiler.do_bench(warmup=500)
190+
print(f"Latency: {latency} ms")
191+
print(f"TFlops: {total_flops / latency * 1e-9} TFlops")
192+
193+
194+
if __name__ == "__main__":
195+
main()

examples/warp_specialize/example_warp_specialize_gemm.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

0 commit comments

Comments
 (0)