Skip to content

Commit ba56e06

Browse files
authored
[CI][Test] Add test cases for tilelang kernel convolution (#51)
* [CI][Test] Add test cases for tilelang kernel convolution
1 parent da65817 commit ba56e06

File tree

2 files changed

+658
-0
lines changed

2 files changed

+658
-0
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import tilelang
4+
from tilelang.autotuner import *
5+
import tilelang.language as T
6+
from functools import partial
7+
8+
num_split = 4
9+
10+
11+
def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_casual, block_M, block_N):
12+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
13+
shape_q = [batch, seqlen_q, heads, dim]
14+
shape_kv = [batch, seqlen_kv, heads, dim]
15+
part_shape = [batch, seqlen_q, heads, num_split, dim]
16+
dtype = "float16"
17+
accum_dtype = "float"
18+
19+
@T.macro
20+
def MMA0(
21+
K: T.Buffer(shape_kv, dtype),
22+
Q_shared: T.Buffer([block_M, dim], dtype),
23+
K_shared: T.Buffer([block_N, dim], dtype),
24+
acc_s: T.Buffer([block_M, block_N], accum_dtype),
25+
k: T.int32,
26+
mid: T.int32,
27+
hid: T.int32,
28+
bid: T.int32,
29+
sid: T.int32,
30+
):
31+
T.copy(
32+
K[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
33+
(k + 1) * block_N, hid, :], K_shared)
34+
# TODO: Handle casual split case
35+
if is_casual:
36+
for i, j in T.Parallel(block_M, block_N):
37+
acc_s[i, j] = T.if_then_else(mid * block_M + i >= k * block_N + j, 0,
38+
-T.infinity(acc_s.dtype))
39+
else:
40+
T.clear(acc_s)
41+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
42+
43+
@T.macro
44+
def MMA1(
45+
V: T.Buffer(shape_kv, dtype),
46+
V_shared: T.Buffer([block_M, dim], dtype),
47+
acc_s_cast: T.Buffer([block_M, block_N], dtype),
48+
acc_o: T.Buffer([block_M, dim], accum_dtype),
49+
k: T.int32,
50+
hid: T.int32,
51+
bid: T.int32,
52+
sid: T.int32,
53+
):
54+
T.copy(
55+
V[bid, (seqlen_kv // num_split) * sid + k * block_N:(seqlen_kv // num_split) * sid +
56+
(k + 1) * block_N, hid, :], V_shared)
57+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
58+
59+
@T.macro
60+
def Softmax(
61+
acc_s: T.Buffer([block_M, block_N], accum_dtype),
62+
acc_s_cast: T.Buffer([block_M, block_N], dtype),
63+
scores_max: T.Buffer([block_M], accum_dtype),
64+
scores_max_prev: T.Buffer([block_M], accum_dtype),
65+
scores_scale: T.Buffer([block_M], accum_dtype),
66+
scores_sum: T.Buffer([block_M], accum_dtype),
67+
logsum: T.Buffer([block_M], accum_dtype),
68+
):
69+
T.copy(scores_max, scores_max_prev)
70+
T.fill(scores_max, -T.infinity(accum_dtype))
71+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
72+
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
73+
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
74+
# in the first ceil_div(kBlockM, kBlockN) steps.
75+
# for i in T.Parallel(block_M):
76+
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
77+
for i in T.Parallel(block_M):
78+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
79+
for i, j in T.Parallel(block_M, block_N):
80+
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
81+
# max * log_2(e)) This allows the compiler to use the ffma
82+
# instruction instead of fadd and fmul separately.
83+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
84+
T.reduce_sum(acc_s, scores_sum, dim=1)
85+
for i in T.Parallel(block_M):
86+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
87+
T.copy(acc_s, acc_s_cast)
88+
89+
@T.macro
90+
def Rescale(
91+
acc_o: T.Buffer([block_M, dim], accum_dtype),
92+
scores_scale: T.Buffer([block_M], accum_dtype),
93+
):
94+
for i, j in T.Parallel(block_M, dim):
95+
acc_o[i, j] *= scores_scale[i]
96+
97+
@T.macro
98+
def flash_attn_split(
99+
Q: T.Buffer(shape_q, dtype),
100+
K: T.Buffer(shape_kv, dtype),
101+
V: T.Buffer(shape_kv, dtype),
102+
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
103+
Output_partial: T.Buffer(part_shape, dtype),
104+
):
105+
with T.Kernel(
106+
T.ceildiv(seqlen_q, block_M), heads * batch, num_split,
107+
threads=128) as (bx, by, bz):
108+
Q_shared = T.alloc_shared([block_M, dim], dtype)
109+
K_shared = T.alloc_shared([block_N, dim], dtype)
110+
V_shared = T.alloc_shared([block_N, dim], dtype)
111+
O_shared = T.alloc_shared([block_M, dim], dtype)
112+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
113+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
114+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
115+
scores_max = T.alloc_fragment([block_M], accum_dtype)
116+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
117+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
118+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
119+
logsum = T.alloc_fragment([block_M], accum_dtype)
120+
121+
mid = bx
122+
hid = by % heads
123+
bid = by // heads
124+
sid = bz
125+
126+
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
127+
T.copy(Q[bid, mid * block_M:(mid + 1) * block_M, hid, :], Q_shared)
128+
T.fill(acc_o, 0)
129+
T.fill(logsum, 0)
130+
T.fill(scores_max, -T.infinity(accum_dtype))
131+
132+
# TODO: Handle casual split case
133+
loop_range = (
134+
T.min(T.ceildiv(seqlen_kv, block_N), T.ceildiv(
135+
(mid + 1) * block_M, block_N)) if is_casual else T.ceildiv(
136+
(seqlen_kv // num_split), block_N))
137+
138+
for k in T.Pipelined(loop_range, num_stages=2):
139+
MMA0(K, Q_shared, K_shared, acc_s, k, mid, hid, bid, sid)
140+
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum,
141+
logsum)
142+
Rescale(acc_o, scores_scale)
143+
MMA1(V, V_shared, acc_s_cast, acc_o, k, hid, bid, sid)
144+
for i, j in T.Parallel(block_M, dim):
145+
acc_o[i, j] /= logsum[i]
146+
for i in T.Parallel(block_M):
147+
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
148+
T.copy(logsum, glse[bid, hid, sid, mid * block_M:(mid + 1) * block_M])
149+
T.copy(acc_o, O_shared)
150+
T.copy(O_shared, Output_partial[bid, mid * block_M:(mid + 1) * block_M, hid, sid, :])
151+
152+
@T.macro
153+
def combine(
154+
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
155+
Output_partial: T.Buffer(part_shape, dtype),
156+
Output: T.Buffer(shape_q, dtype),
157+
):
158+
with T.Kernel(T.ceildiv(seqlen_q, block_M), heads, batch, threads=128) as (bx, by, bz):
159+
po_local = T.alloc_fragment([block_M, dim], dtype)
160+
po_shared = T.alloc_shared([block_M, dim], dtype)
161+
o_accum_local = T.alloc_fragment([block_M, dim], accum_dtype)
162+
o_shared = T.alloc_shared([block_M, dim], dtype)
163+
lse_local = T.alloc_fragment([num_split, block_M], dtype)
164+
lse_local_split = T.alloc_fragment([block_M], accum_dtype)
165+
lse_logsum_local = T.alloc_fragment([block_M], accum_dtype)
166+
lse_max_local = T.alloc_fragment([block_M], accum_dtype)
167+
scale_local = T.alloc_fragment([block_M], accum_dtype)
168+
169+
T.annotate_layout({
170+
o_accum_local: T.Fragment(o_accum_local.shape, forward_thread_fn=lambda i, j: i),
171+
lse_local_split: T.Fragment(lse_local_split.shape, forward_thread_fn=lambda i: i),
172+
o_shared: tilelang.layout.make_swizzled_layout(o_shared),
173+
po_shared: tilelang.layout.make_swizzled_layout(po_shared),
174+
})
175+
176+
T.clear(lse_logsum_local)
177+
T.clear(o_accum_local)
178+
T.copy(glse[
179+
bz,
180+
by,
181+
:,
182+
bx * block_M:(bx + 1) * block_M,
183+
], lse_local)
184+
T.reduce_max(lse_local, lse_max_local, dim=0, clear=False)
185+
for k in T.Pipelined(num_split):
186+
T.copy(lse_local[k, :], lse_local_split)
187+
for i in T.Parallel(block_M):
188+
lse_logsum_local[i] += T.exp2(lse_local_split[i] - lse_max_local[i])
189+
for i in T.Parallel(block_M):
190+
lse_logsum_local[i] = T.log2(lse_logsum_local[i]) + lse_max_local[i]
191+
for k in T.Pipelined(num_split, num_stages=2):
192+
T.copy(Output_partial[bz, bx * block_M:(bx + 1) * block_M, by, k, :], po_shared)
193+
T.copy(po_shared, po_local)
194+
T.copy(lse_local[k, :], lse_local_split)
195+
for i in T.Parallel(block_M):
196+
scale_local[i] = T.exp2(lse_local_split[i] - lse_logsum_local[i])
197+
for i, j in T.Parallel(block_M, dim):
198+
o_accum_local[i, j] += po_local[i, j] * scale_local[i]
199+
T.copy(o_accum_local, o_shared)
200+
T.copy(o_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
201+
202+
@T.prim_func
203+
def main(
204+
Q: T.Buffer(shape_q, dtype),
205+
K: T.Buffer(shape_kv, dtype),
206+
V: T.Buffer(shape_kv, dtype),
207+
glse: T.Buffer([batch, heads, num_split, seqlen_q], dtype),
208+
Output_partial: T.Buffer(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim]
209+
Output: T.Buffer(shape_q, dtype),
210+
):
211+
flash_attn_split(Q, K, V, glse, Output_partial)
212+
combine(glse, Output_partial, Output)
213+
214+
return main
215+
216+
217+
def ref_program(Q, K, V, glse, Output_partial, casual):
218+
assert casual is False
219+
dim = Q.size(-1)
220+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
221+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
222+
attention_weights = F.softmax(scores, dim=-1)
223+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
224+
return output
225+
226+
227+
def reduce_ref(Q, K, V, glse, Output_partial, casual):
228+
o = torch.empty_like(Output_partial[:, :, :, 0, :]).fill_(0)
229+
lse_logsum = torch.empty_like(glse[:, :, 0, :]).fill_(0) # [batch, seqlen_q, heads]
230+
lse_max = glse.max(dim=2, keepdim=False).values
231+
for ks in range(num_split):
232+
lse = glse[:, :, ks, :]
233+
lse_logsum += torch.exp2(lse - lse_max)
234+
lse_logsum = torch.log2(lse_logsum) + lse_max
235+
for ks in range(num_split):
236+
lse = glse[:, :, ks, :]
237+
scale = torch.exp2(lse - lse_logsum) # [batch, heads, seqlen_q]
238+
o += Output_partial[:, :, :, ks, :] * scale[:, :, :, None].transpose(1, 2)
239+
return o.to(torch.float16)
240+
241+
242+
def flash_split_ref(Q, K, V, casual):
243+
# [batch, seqlen_q, heads, dim]
244+
batch = Q.size(0)
245+
block_M = Q.size(1)
246+
nheads = Q.size(2)
247+
dim = Q.size(3)
248+
block_N = 128
249+
seqlen_kv = K.size(1)
250+
251+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
252+
acc_s = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float)
253+
acc_s_cast = torch.empty((batch, nheads, block_M, block_N), device="cuda", dtype=torch.float16)
254+
acc_o = torch.empty((batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
255+
scores_max = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
256+
scores_max_prev = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
257+
scores_scale = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
258+
scores_sum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
259+
logsum = torch.empty((batch, nheads, block_M), device="cuda", dtype=torch.float)
260+
gacc_o = torch.empty((num_split, batch, block_M, nheads, dim), device="cuda", dtype=torch.float)
261+
glogsum = torch.empty((num_split, batch, nheads, block_M), device="cuda", dtype=torch.float)
262+
263+
Q_ = Q * scale
264+
265+
for ks in range(num_split):
266+
acc_o.fill_(0)
267+
logsum.fill_(0)
268+
scores_max.fill_(float('-inf'))
269+
scores_max_prev.fill_(float('-inf'))
270+
for i in range(int((seqlen_kv // num_split) / block_N)):
271+
acc_s.fill_(0)
272+
acc_s = torch.einsum('bqhd,bkhd->bhqk', Q_,
273+
K[:, (seqlen_kv // num_split) * ks +
274+
i * block_N:(seqlen_kv // num_split) * ks +
275+
(i + 1) * block_N, :, :]) # [batch, seqlen, nheads, block_N]
276+
scores_max_prev = scores_max
277+
scores_max = acc_s.max(dim=-1, keepdim=False).values # [blockM]
278+
scores_scale = torch.exp2(scores_max_prev - scores_max)
279+
acc_o *= scores_scale[:, :, :, None].transpose(1, 2)
280+
acc_s = torch.exp2(acc_s - scores_max[:, :, :, None])
281+
acc_s_cast = acc_s.to(torch.float16)
282+
acc_o += torch.einsum(
283+
'bhqk,bkhd->bqhd', acc_s_cast,
284+
V[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
285+
(i + 1) * block_N, :, :])
286+
scores_sum = acc_s.sum(dim=-1, keepdim=False)
287+
logsum = logsum * scores_scale + scores_sum
288+
acc_o /= logsum[:, :, :, None].transpose(1, 2)
289+
logsum = torch.log2(logsum) + scores_max
290+
gacc_o[ks, :, :, :, :] = acc_o
291+
glogsum[ks, :, :, :] = logsum
292+
293+
return glogsum.to(torch.float16).permute(1, 2, 0,
294+
3), gacc_o.to(torch.float16).permute(1, 2, 3, 0, 4)
295+
296+
297+
if __name__ == "__main__":
298+
BATCH, H, Q_CTX, KV_CTX, D_HEAD = 1, 32, 128, 8192, 128
299+
casual = False
300+
flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD
301+
total_flops = 2 * flops_per_matmul
302+
if casual:
303+
total_flops *= 0.5
304+
BLOCK_M = 128
305+
BLOCK_N = 64 # if D_HEAD <= 128 else 32
306+
program = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, casual, BLOCK_M, BLOCK_N)
307+
ref_program = partial(ref_program, casual=casual)
308+
mod, params = tilelang.lower(program)
309+
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
310+
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
311+
print("All checks passed!")
312+
313+
latency = mod.do_bench(ref_program, warmup=500)
314+
print("{:.2f} ms".format(latency))
315+
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))
316+
latency = mod.do_bench(mod, n_warmup=10, n_repeat=10, profiler="tvm")
317+
print("{:.4f} ms".format(latency))
318+
print("{:.2f} TFlops".format(total_flops / latency * 1e-9))

0 commit comments

Comments
 (0)