Skip to content

Commit 0a2b781

Browse files
authored
[Dev][Bugfix] Fix bug in ThreadTagChecker; Add WgmmaSync rewriter and add MHA WGMMA pipelined example (#128)
* [Dev] Add RetNet Linear Attention example * [Dev] Add WgmmaSync rewriter for pipelined WGMMA operations and add MHA WGMMA pipelined example (FA3-like scheduling) This commit introduces a new transformation pass `RewriteWgmmaSync` to optimize warp group matrix multiply accumulate (WGMMA) operations in the TileLang compiler: - Implemented `WgmmaSyncRewriter` in `src/transform/wgmma_sync_rewriter.cc` - Added pass registration for `RewriteWgmmaSync` - Updated `tilelang/engine/phase.py` to include the new transformation pass - Updated `tilelang/transform/__init__.py` to expose the new pass The rewriter intelligently manages synchronization and dependencies between WGMMA operations, improving pipeline efficiency for complex matrix multiplication kernels. * [Bugfix] Fix bug in ThreadTagChecker for warp specialization Improve thread tag validation in warp specialized rewriter to prevent unintended transformations: - Add more precise checks for threadIdx.y and threadIdx.z - Validate thread extent to ensure only single-extent thread bindings are allowed - Prevent warp specialization for multi-extent thread bindings in y and z dimensions * lint * [CI] Add TMA descriptor attribute to transformed module in test case
1 parent aacb91d commit 0a2b781

File tree

6 files changed

+547
-5
lines changed

6 files changed

+547
-5
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import torch.nn.functional as F
6+
import tilelang
7+
from tilelang import Profiler
8+
from tilelang.autotuner import *
9+
import tilelang.language as T
10+
import itertools
11+
import argparse
12+
from functools import partial
13+
14+
15+
def get_configs():
16+
block_M = [128]
17+
block_N = [128]
18+
num_stages = [2]
19+
threads = [256]
20+
_configs = list(itertools.product(block_M, block_N, num_stages, threads))
21+
22+
configs = [{
23+
'block_M': c[0],
24+
'block_N': c[1],
25+
'num_stages': c[2],
26+
'threads': c[3]
27+
} for c in _configs]
28+
return configs
29+
30+
31+
def flashattn(batch, heads, seq_len, dim, is_causal, tune=False):
32+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
33+
shape = [batch, seq_len, heads, dim]
34+
dtype = "float16"
35+
accum_dtype = "float"
36+
37+
def kernel_func(block_M, block_N, num_stages, threads):
38+
39+
@T.macro
40+
def MMA0(
41+
K: T.Buffer(shape, dtype),
42+
Q_shared: T.Buffer([block_M, dim], dtype),
43+
K_shared: T.Buffer([block_N, dim], dtype),
44+
acc_s: T.Buffer([block_M, block_N], accum_dtype),
45+
k: T.int32,
46+
bx: T.int32,
47+
by: T.int32,
48+
bz: T.int32,
49+
):
50+
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
51+
if is_causal:
52+
for i, j in T.Parallel(block_M, block_N):
53+
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
54+
-T.infinity(acc_s.dtype))
55+
else:
56+
T.clear(acc_s)
57+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
58+
59+
@T.macro
60+
def MMA1(
61+
V: T.Buffer(shape, dtype),
62+
V_shared: T.Buffer([block_M, dim], dtype),
63+
acc_s_cast: T.Buffer([block_M, block_N], dtype),
64+
acc_o: T.Buffer([block_M, dim], accum_dtype),
65+
k: T.int32,
66+
by: T.int32,
67+
bz: T.int32,
68+
):
69+
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
70+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
71+
72+
@T.macro
73+
def Softmax(
74+
acc_s: T.Buffer([block_M, block_N], accum_dtype),
75+
acc_s_cast: T.Buffer([block_M, block_N], dtype),
76+
scores_max: T.Buffer([block_M], accum_dtype),
77+
scores_max_prev: T.Buffer([block_M], accum_dtype),
78+
scores_scale: T.Buffer([block_M], accum_dtype),
79+
scores_sum: T.Buffer([block_M], accum_dtype),
80+
logsum: T.Buffer([block_M], accum_dtype),
81+
):
82+
T.copy(scores_max, scores_max_prev)
83+
T.fill(scores_max, -T.infinity(accum_dtype))
84+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
85+
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
86+
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
87+
# in the first ceil_div(kBlockM, kBlockN) steps.
88+
# for i in T.Parallel(block_M):
89+
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
90+
for i in T.Parallel(block_M):
91+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
92+
for i, j in T.Parallel(block_M, block_N):
93+
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
94+
# max * log_2(e)) This allows the compiler to use the ffma
95+
# instruction instead of fadd and fmul separately.
96+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
97+
T.reduce_sum(acc_s, scores_sum, dim=1)
98+
for i in T.Parallel(block_M):
99+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
100+
T.copy(acc_s, acc_s_cast)
101+
102+
@T.macro
103+
def Rescale(
104+
acc_o: T.Buffer([block_M, dim], accum_dtype),
105+
scores_scale: T.Buffer([block_M], accum_dtype),
106+
):
107+
for i, j in T.Parallel(block_M, dim):
108+
acc_o[i, j] *= scores_scale[i]
109+
110+
@T.prim_func
111+
def main(
112+
Q: T.Buffer(shape, dtype),
113+
K: T.Buffer(shape, dtype),
114+
V: T.Buffer(shape, dtype),
115+
Output: T.Buffer(shape, dtype),
116+
):
117+
with T.Kernel(
118+
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
119+
Q_shared = T.alloc_shared([block_M, dim], dtype)
120+
K_shared = T.alloc_shared([block_N, dim], dtype)
121+
V_shared = T.alloc_shared([block_N, dim], dtype)
122+
O_shared = T.alloc_shared([block_M, dim], dtype)
123+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
124+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
125+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
126+
scores_max = T.alloc_fragment([block_M], accum_dtype)
127+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
128+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
129+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
130+
logsum = T.alloc_fragment([block_M], accum_dtype)
131+
132+
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
133+
T.fill(acc_o, 0)
134+
T.fill(logsum, 0)
135+
T.fill(scores_max, -T.infinity(accum_dtype))
136+
137+
loop_range = (
138+
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
139+
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))
140+
141+
for k in T.Pipelined(
142+
loop_range,
143+
num_stages=num_stages,
144+
order=[-1, 0, 3, 1, -1, 2],
145+
stage=[-1, 0, 0, 1, -1, 1],
146+
group=[[0], [1, 2], [3, 4, 5, 6, 7, 8, 9, 10], [11], [12], [13]]):
147+
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
148+
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
149+
scores_sum, logsum)
150+
Rescale(acc_o, scores_scale)
151+
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
152+
for i, j in T.Parallel(block_M, dim):
153+
acc_o[i, j] /= logsum[i]
154+
T.copy(acc_o, O_shared)
155+
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
156+
157+
return main
158+
159+
if tune:
160+
161+
@autotune(
162+
configs=get_configs(),
163+
keys=["block_M", "block_N", "num_stages", "threads"],
164+
warmup=10,
165+
rep=10)
166+
@jit(
167+
out_idx=[3],
168+
supply_type=tilelang.TensorSupplyType.Integer,
169+
ref_prog=None,
170+
profiler="auto")
171+
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
172+
return kernel_func(block_M, block_N, num_stages, threads)
173+
174+
return kernel()
175+
else:
176+
177+
def kernel(block_M, block_N, num_stages, threads):
178+
return kernel_func(block_M, block_N, num_stages, threads)
179+
180+
return kernel
181+
182+
183+
def ref_program(Q, K, V, is_causal):
184+
dim = Q.size(-1)
185+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
186+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
187+
if is_causal:
188+
seq_len = Q.size(1)
189+
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
190+
mask = mask.unsqueeze(0).unsqueeze(0)
191+
scores = scores.masked_fill(mask == 0, float('-inf'))
192+
attention_weights = F.softmax(scores, dim=-1)
193+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
194+
return output
195+
196+
197+
if __name__ == "__main__":
198+
parser = argparse.ArgumentParser()
199+
parser.add_argument('--batch', type=int, default=8, help='batch size')
200+
parser.add_argument('--heads', type=int, default=32, help='heads')
201+
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
202+
parser.add_argument('--dim', type=int, default=128, help='dim')
203+
parser.add_argument('--is_causal', action='store_true', help='causal')
204+
parser.add_argument('--tune', action='store_true', help='tune configs')
205+
args = parser.parse_args()
206+
batch, heads, seq_len, dim, is_causal = args.batch, args.heads, args.seq_len, args.dim, args.is_causal
207+
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
208+
total_flops = 2 * flops_per_matmul
209+
if is_causal:
210+
total_flops *= 0.5
211+
212+
if (not args.tune):
213+
program = flashattn(
214+
batch, heads, seq_len, dim, is_causal, tune=args.tune)(
215+
block_M=128, block_N=128, num_stages=2, threads=256)
216+
ref_program = partial(ref_program, is_causal=is_causal)
217+
mod, params = tilelang.lower(program)
218+
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
219+
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
220+
print("All checks pass.")
221+
latency = mod.do_bench(ref_program, warmup=500)
222+
print("Ref: {:.2f} ms".format(latency))
223+
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
224+
latency = mod.do_bench(mod.func, warmup=500)
225+
print("Tile-lang: {:.2f} ms".format(latency))
226+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
227+
else:
228+
best_latency, best_config, _ = flashattn(
229+
batch, heads, seq_len, dim, is_causal, tune=args.tune)
230+
print(f"Best latency: {best_latency}")
231+
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
232+
print(f"Best config: {best_config}")

src/transform/warp_specialized_rewriter.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -878,9 +878,12 @@ class ThreadTagChecker : public StmtExprVisitor {
878878
private:
879879
void VisitStmt_(const AttrStmtNode *op) final {
880880
if (op->attr_key == tir::attr::thread_extent) {
881-
auto iter_var = Downcast<IterVar>(op->node);
882-
if (iter_var->thread_tag.length() > 0 &&
883-
iter_var->thread_tag != "threadIdx.x") {
881+
IterVar iter_var = Downcast<IterVar>(op->node);
882+
String thread_tag = iter_var->thread_tag;
883+
bool is_y_or_z =
884+
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
885+
886+
if (!thread_tag.empty() && is_y_or_z && !is_one(iter_var->dom->extent)) {
884887
is_valid_ = false;
885888
}
886889
}
@@ -891,8 +894,14 @@ class ThreadTagChecker : public StmtExprVisitor {
891894
if (op->kind == ForKind::kThreadBinding) {
892895
ICHECK(op->thread_binding.defined());
893896
String thread_tag = op->thread_binding.value()->thread_tag;
894-
if (thread_tag.length() > 0 && thread_tag != "threadIdx.x") {
895-
is_valid_ = false;
897+
bool is_y_or_z =
898+
thread_tag == "threadIdx.y" || thread_tag == "threadIdx.z";
899+
if (!thread_tag.empty() && is_y_or_z) {
900+
auto iter_var = Downcast<IterVar>(op->thread_binding);
901+
if (iter_var.defined() && iter_var->dom.defined() &&
902+
!is_one(iter_var->dom->extent)) {
903+
is_valid_ = false;
904+
}
896905
}
897906
}
898907
StmtExprVisitor::VisitStmt_(op);

0 commit comments

Comments
 (0)