Skip to content

Commit e055782

Browse files
authored
[Example] Add Split-K and Stream-K Examples and move MLA from fld to mla (#110)
* Add DeepSeek MLA decode example with Flash Attention implementation * Add GEMM SplitK and StreamK example implementations This commit introduces two new example scripts demonstrating advanced GEMM (matrix multiplication) techniques: - `example_tilelang_gemm_splitk.py`: Implements a Split-K GEMM kernel using TileLang - `example_tilelang_gemm_streamk.py`: Implements a Stream-K GEMM kernel using TileLang Both examples showcase different parallel computation strategies for matrix multiplication, with comprehensive testing using PyTorch reference implementations. * Refactor GEMM SplitK and StreamK example implementations Clean up and improve code formatting for the SplitK and StreamK GEMM example scripts: - Remove unused import (Profiler) in splitk example - Simplify line breaks and improve code readability - Standardize indentation and remove unnecessary whitespace - Optimize atomic add and copy operations for better clarity
1 parent f08efcc commit e055782

File tree

4 files changed

+537
-0
lines changed

4 files changed

+537
-0
lines changed

examples/deepseek_mla/.gitkeep

Whitespace-only changes.
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import torch
2+
import torch.nn.functional as F
3+
import tilelang
4+
from tilelang.autotuner import *
5+
import tilelang.language as T
6+
7+
num_split = 4
8+
9+
10+
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
11+
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
12+
shape_q = [batch, heads, (dim + pe_dim)]
13+
shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)]
14+
shape_v = [batch, seqlen_kv, kv_head_num, dim]
15+
shape_o = [batch, heads, dim]
16+
part_shape = [batch, heads, num_split, dim]
17+
dtype = "float16"
18+
accum_dtype = "float"
19+
kv_group_num = heads // kv_head_num
20+
VALID_BLOCK_H = min(block_H, kv_group_num)
21+
assert kv_head_num == 1, "kv_head_num must be 1"
22+
23+
@T.macro
24+
def flash_attn_split(
25+
Q: T.Buffer(shape_q, dtype),
26+
K: T.Buffer(shape_k, dtype),
27+
V: T.Buffer(shape_v, dtype),
28+
glse: T.Buffer([batch, heads, num_split], dtype),
29+
Output_partial: T.Buffer(part_shape, dtype),
30+
):
31+
with T.Kernel(
32+
batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz):
33+
Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype)
34+
K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype)
35+
V_shared = T.alloc_shared([block_N, dim], dtype)
36+
O_shared = T.alloc_shared([block_H, dim], dtype)
37+
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
38+
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
39+
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
40+
scores_max = T.alloc_fragment([block_H], accum_dtype)
41+
scores_max_prev = T.alloc_fragment([block_H], accum_dtype)
42+
scores_scale = T.alloc_fragment([block_H], accum_dtype)
43+
scores_sum = T.alloc_fragment([block_H], accum_dtype)
44+
logsum = T.alloc_fragment([block_H], accum_dtype)
45+
46+
bid = bx
47+
hid = by
48+
sid = bz
49+
cur_kv_head = hid // (kv_group_num // block_H)
50+
51+
T.annotate_layout({
52+
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
53+
})
54+
55+
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
56+
T.fill(acc_o, 0)
57+
T.fill(logsum, 0)
58+
T.fill(scores_max, -T.infinity(accum_dtype))
59+
60+
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
61+
for k in T.Pipelined(loop_range, num_stages=1):
62+
T.copy(
63+
K[bid, (seqlen_kv // num_split) * sid +
64+
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
65+
cur_kv_head, :], K_shared)
66+
T.clear(acc_s)
67+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
68+
T.copy(scores_max, scores_max_prev)
69+
T.fill(scores_max, -T.infinity(accum_dtype))
70+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
71+
for i in T.Parallel(block_H):
72+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
73+
for i, j in T.Parallel(block_H, block_N):
74+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
75+
T.reduce_sum(acc_s, scores_sum, dim=1)
76+
for i in T.Parallel(block_H):
77+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
78+
T.copy(acc_s, acc_s_cast)
79+
for i, j in T.Parallel(block_H, dim):
80+
acc_o[i, j] *= scores_scale[i]
81+
T.copy(
82+
V[bid, (seqlen_kv // num_split) * sid +
83+
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
84+
cur_kv_head, :], V_shared)
85+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
86+
for i, j in T.Parallel(block_H, dim):
87+
acc_o[i, j] /= logsum[i]
88+
for i in T.Parallel(block_H):
89+
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
90+
91+
T.copy(logsum, glse[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, sid])
92+
T.copy(acc_o, O_shared)
93+
T.copy(O_shared, Output_partial[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H,
94+
sid, :])
95+
96+
@T.macro
97+
def combine(
98+
glse: T.Buffer([batch, heads, num_split], dtype),
99+
Output_partial: T.Buffer(part_shape, dtype),
100+
Output: T.Buffer(shape_o, dtype),
101+
):
102+
with T.Kernel(heads, batch, threads=128) as (by, bz):
103+
po_local = T.alloc_fragment([dim], dtype)
104+
o_accum_local = T.alloc_fragment([dim], accum_dtype)
105+
lse_local = T.alloc_fragment([num_split, 1], dtype)
106+
lse_local_split = T.alloc_local([1], accum_dtype)
107+
lse_logsum_local = T.alloc_local([1], accum_dtype)
108+
lse_max_local = T.alloc_fragment([1], accum_dtype)
109+
scale_local = T.alloc_local([1], accum_dtype)
110+
111+
T.annotate_layout({
112+
lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i),
113+
})
114+
115+
T.clear(lse_logsum_local)
116+
T.clear(o_accum_local)
117+
for k in T.Parallel(num_split):
118+
lse_local[k, 0] = glse[bz, by, k]
119+
T.reduce_max(lse_local, lse_max_local, dim=0, clear=True)
120+
for k in T.Pipelined(num_split, num_stages=1):
121+
lse_local_split[0] = glse[bz, by, k]
122+
lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0])
123+
lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0]
124+
for k in T.serial(num_split):
125+
for i in T.Parallel(dim):
126+
po_local[i] = Output_partial[bz, by, k, i]
127+
lse_local_split[0] = glse[bz, by, k]
128+
scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0])
129+
for i in T.Parallel(dim):
130+
o_accum_local[i] += po_local[i] * scale_local[0]
131+
for i in T.Parallel(dim):
132+
Output[bz, by, i] = o_accum_local[i]
133+
134+
@T.prim_func
135+
def main(
136+
Q: T.Buffer(shape_q, dtype),
137+
K: T.Buffer(shape_k, dtype),
138+
V: T.Buffer(shape_v, dtype),
139+
glse: T.Buffer([batch, heads, num_split], dtype),
140+
Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim]
141+
Output: T.Buffer(shape_o, dtype),
142+
):
143+
flash_attn_split(Q, K, V, glse, Output_partial)
144+
combine(glse, Output_partial, Output)
145+
146+
return main
147+
148+
149+
def ref_program(query, key, value, glse, Output_partial):
150+
# """
151+
# Inputs:
152+
# - query (Tensor): [batch, heads, dim]
153+
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
154+
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
155+
156+
# Outputs:
157+
# - output (Tensor): [batch, heads, dim]
158+
# """
159+
from einops import rearrange
160+
batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim]
161+
_, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim]
162+
dim_v = value.shape[-1]
163+
assert kv_heads == 1, "kv_heads must be 1"
164+
165+
query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim]
166+
key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim]
167+
value_expanded = value.expand(-1, -1, query_heads,
168+
-1) # [batch_size, query_heads, seqlen_kv, dim]
169+
key_expanded = rearrange(key_expanded,
170+
'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim]
171+
value_expanded = rearrange(value_expanded,
172+
'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim]
173+
174+
scores = torch.matmul(query_expanded,
175+
key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv]
176+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
177+
attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv]
178+
output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim]
179+
return output.view(batch_size, query_heads, dim_v)
180+
181+
182+
def flash_split_ref(Q, K, V):
183+
dim = 512
184+
pe_dim = 64
185+
batch = Q.size(0)
186+
nheads = Q.size(1)
187+
assert Q.size(2) == dim + pe_dim, "dim must be 576=512+64"
188+
block_N = 32
189+
seqlen_kv = K.size(1)
190+
191+
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
192+
acc_s = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float)
193+
acc_s_cast = torch.empty((batch, nheads, block_N), device="cuda", dtype=torch.float16)
194+
acc_o = torch.empty((batch, nheads, dim), device="cuda", dtype=torch.float)
195+
scores_max = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
196+
scores_max_prev = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
197+
scores_scale = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
198+
scores_sum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
199+
logsum = torch.empty((batch, nheads), device="cuda", dtype=torch.float)
200+
gacc_o = torch.empty((num_split, batch, nheads, dim), device="cuda", dtype=torch.float)
201+
glogsum = torch.empty((num_split, batch, nheads), device="cuda", dtype=torch.float)
202+
203+
Q_ = Q * scale
204+
K_ = K.expand(-1, -1, nheads, -1)
205+
V_ = V.expand(-1, -1, nheads, -1)
206+
207+
for ks in range(num_split):
208+
acc_o.fill_(0)
209+
logsum.fill_(0)
210+
scores_max.fill_(float('-inf'))
211+
scores_max_prev.fill_(float('-inf'))
212+
for i in range(int((seqlen_kv // num_split) / block_N)):
213+
acc_s.fill_(0)
214+
acc_s = torch.einsum('bhd,bkhd->bhk', Q_,
215+
K_[:, (seqlen_kv // num_split) * ks +
216+
i * block_N:(seqlen_kv // num_split) * ks +
217+
(i + 1) * block_N, :, :]) # [batch, nheads, block_N]
218+
scores_max_prev = scores_max
219+
scores_max = acc_s.max(dim=-1, keepdim=False).values # [batch, nheads]
220+
scores_scale = torch.exp2(scores_max_prev - scores_max) # [batch, nheads]
221+
acc_o *= scores_scale[:, :, None]
222+
acc_s = torch.exp2(acc_s - scores_max[:, :, None])
223+
acc_s_cast = acc_s.to(torch.float16) # [batch, nheads, block_N]
224+
acc_o += torch.einsum(
225+
'bhk,bkhd->bhd', acc_s_cast,
226+
V_[:, (seqlen_kv // num_split) * ks + i * block_N:(seqlen_kv // num_split) * ks +
227+
(i + 1) * block_N, :, :])
228+
scores_sum = acc_s.sum(dim=-1, keepdim=False)
229+
logsum = logsum * scores_scale + scores_sum
230+
acc_o /= logsum[:, :, None]
231+
logsum = torch.log2(logsum) + scores_max
232+
gacc_o[ks, :, :, :] = acc_o
233+
glogsum[ks, :, :] = logsum
234+
235+
return glogsum.to(torch.float16).permute(1, 2, 0), gacc_o.to(torch.float16).permute(1, 2, 0, 3)
236+
237+
238+
def reduce_ref(Q, K, V, glse, Output_partial):
239+
o = torch.empty_like(Output_partial[:, :, 0, :]).fill_(0)
240+
lse_logsum = torch.empty_like(glse[:, :, 0]).fill_(0)
241+
lse_max = glse.max(dim=2, keepdim=False).values
242+
for ks in range(num_split):
243+
lse = glse[:, :, ks]
244+
lse_logsum += torch.exp2(lse - lse_max)
245+
lse_logsum = torch.log2(lse_logsum) + lse_max
246+
for ks in range(num_split):
247+
lse = glse[:, :, ks]
248+
scale = torch.exp2(lse - lse_logsum)
249+
o += Output_partial[:, :, ks, :] * scale[:, :, None]
250+
return o.to(torch.float16)
251+
252+
253+
if __name__ == "__main__":
254+
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64
255+
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
256+
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
257+
total_flops = qk_flops + pv_flops
258+
BLOCK_N = 32 # if D_HEAD <= 128 else 32
259+
BLOCK_H = 64
260+
261+
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
262+
mod, params = tilelang.lower(program)
263+
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
264+
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
265+
latency = mod.do_bench(mod.func, warmup=500)
266+
print("Tile-lang: {:.2f} ms".format(latency))
267+
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import tilelang
5+
import tilelang.language as T
6+
from tvm import DataType
7+
8+
9+
def matmul(M, N, K, block_M, block_N, block_K, split_k, dtype="float16", accum_dtype="float"):
10+
11+
splitK = K // split_k
12+
13+
@T.prim_func
14+
def main(
15+
A: T.Buffer((M, K), dtype),
16+
B: T.Buffer((N, K), dtype),
17+
C: T.Buffer((M, N), dtype),
18+
):
19+
with T.Kernel(
20+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), split_k, threads=128) as (bx, by, bz):
21+
A_shared = T.alloc_shared((block_M, block_K), dtype, "shared")
22+
B_shared = T.alloc_shared((block_K, block_N), dtype, "shared")
23+
C_shared = T.alloc_shared((block_M, block_N), dtype, "shared")
24+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
25+
26+
if bz == 0:
27+
# fuse the zero initialization kernel
28+
for i, j in T.Parallel(block_M, block_N):
29+
m, n = by * block_M + i, bx * block_N + j
30+
C[m, n] = T.cast(0, dtype)
31+
32+
T.clear(C_local)
33+
for ko in T.Pipelined(T.ceildiv(splitK, block_K), num_stages=0):
34+
T.copy(A[by * block_M, bz * splitK + ko * block_K], A_shared)
35+
T.copy(B[bz * splitK + ko * block_K, bx * block_N], B_shared)
36+
T.gemm(A_shared, B_shared, C_local)
37+
38+
T.copy(C_local, C_shared)
39+
40+
if DataType(dtype).bits == 16:
41+
for i, j in T.Parallel(block_M, block_N // 2):
42+
m, n = by * block_M + i, bx * block_N + j * 2
43+
# vectorized atomic
44+
T.atomic_addx2(C[m, n], C_shared[i, j * 2])
45+
else:
46+
for i, j in T.Parallel(block_M, block_N):
47+
T.atomic_add(C[by * block_M + i, bx * block_N + j], C_shared[i, j])
48+
49+
return main
50+
51+
52+
program = matmul(1024, 1024, 1024, 128, 128, 32, 4)
53+
54+
kernel = tilelang.compile(program)
55+
56+
print(kernel.get_kernel_source())
57+
58+
import torch
59+
60+
a = torch.randn(1024, 1024).cuda().half()
61+
b = torch.randn(1024, 1024).cuda().half()
62+
c = torch.zeros(1024, 1024).cuda().half()
63+
kernel(a, b, c)
64+
65+
ref_c = a @ b
66+
67+
print(c)
68+
print(ref_c)
69+
70+
torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2)

0 commit comments

Comments
 (0)