Skip to content

Commit 540aef4

Browse files
authored
[Dev] Update MLA decode kernel (#120)
1 parent 4cd8a9b commit 540aef4

File tree

1 file changed

+84
-62
lines changed

1 file changed

+84
-62
lines changed

examples/flash_decoding/example_mla_decode.py

Lines changed: 84 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@
33
import tilelang
44
from tilelang.autotuner import *
55
import tilelang.language as T
6+
from einops import rearrange, einsum
67

7-
num_split = 4
8+
num_split = 1
89

910

1011
def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H):
1112
scale = (1.0 / (dim + pe_dim))**0.5 * 1.44269504 # log2(e)
12-
shape_q = [batch, heads, (dim + pe_dim)]
13-
shape_k = [batch, seqlen_kv, kv_head_num, (dim + pe_dim)]
14-
shape_v = [batch, seqlen_kv, kv_head_num, dim]
15-
shape_o = [batch, heads, dim]
16-
part_shape = [batch, heads, num_split, dim]
1713
dtype = "float16"
1814
accum_dtype = "float"
1915
kv_group_num = heads // kv_head_num
@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
2218

2319
@T.macro
2420
def flash_attn_split(
25-
Q: T.Buffer(shape_q, dtype),
26-
K: T.Buffer(shape_k, dtype),
27-
V: T.Buffer(shape_v, dtype),
21+
Q: T.Buffer([batch, heads, dim], dtype),
22+
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
23+
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
24+
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
2825
glse: T.Buffer([batch, heads, num_split], dtype),
29-
Output_partial: T.Buffer(part_shape, dtype),
26+
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
3027
):
3128
with T.Kernel(
32-
batch, heads // min(block_H, kv_group_num), num_split, threads=128) as (bx, by, bz):
33-
Q_shared = T.alloc_shared([block_H, (dim + pe_dim)], dtype)
34-
K_shared = T.alloc_shared([block_N, (dim + pe_dim)], dtype)
35-
V_shared = T.alloc_shared([block_N, dim], dtype)
29+
batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz):
30+
Q_shared = T.alloc_shared([block_H, dim], dtype)
31+
S_shared = T.alloc_shared([block_H, block_N], dtype)
32+
Q_pe_shared = T.alloc_shared([block_H, pe_dim], dtype)
33+
KV_shared = T.alloc_shared([block_N, dim], dtype)
34+
K_pe_shared = T.alloc_shared([block_N, pe_dim], dtype)
3635
O_shared = T.alloc_shared([block_H, dim], dtype)
3736
acc_s = T.alloc_fragment([block_H, block_N], accum_dtype)
37+
acc_s_0 = T.alloc_fragment([block_H, block_N], accum_dtype)
3838
acc_s_cast = T.alloc_fragment([block_H, block_N], dtype)
3939
acc_o = T.alloc_fragment([block_H, dim], accum_dtype)
4040
scores_max = T.alloc_fragment([block_H], accum_dtype)
@@ -53,20 +53,32 @@ def flash_attn_split(
5353
})
5454

5555
T.copy(Q[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_shared)
56+
T.copy(Q_pe[bid, hid * VALID_BLOCK_H:(hid + 1) * VALID_BLOCK_H, :], Q_pe_shared)
5657
T.fill(acc_o, 0)
5758
T.fill(logsum, 0)
5859
T.fill(scores_max, -T.infinity(accum_dtype))
5960

6061
loop_range = T.ceildiv((seqlen_kv // num_split), block_N)
61-
for k in T.Pipelined(loop_range, num_stages=1):
62+
for k in T.Pipelined(loop_range, num_stages=2):
63+
kv_start = (seqlen_kv // num_split) * sid + k * block_N
64+
kv_end = (seqlen_kv // num_split) * sid + (k + 1) * block_N
65+
6266
T.copy(
63-
K[bid, (seqlen_kv // num_split) * sid +
64-
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
65-
cur_kv_head, :], K_shared)
66-
T.clear(acc_s)
67-
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
67+
KV[bid, kv_start:kv_end, cur_kv_head, :],
68+
KV_shared
69+
)
70+
T.copy(
71+
K_pe[bid, kv_start:kv_end, cur_kv_head, :],
72+
K_pe_shared
73+
)
74+
75+
T.clear(acc_s_0)
76+
T.gemm(Q_shared, KV_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
77+
T.gemm(Q_pe_shared, K_pe_shared, acc_s_0, transpose_B=True, policy=T.GemmWarpPolicy.FullCol)
6878
T.copy(scores_max, scores_max_prev)
6979
T.fill(scores_max, -T.infinity(accum_dtype))
80+
T.copy(acc_s_0, S_shared)
81+
T.copy(S_shared, acc_s)
7082
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
7183
for i in T.Parallel(block_H):
7284
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
@@ -78,11 +90,7 @@ def flash_attn_split(
7890
T.copy(acc_s, acc_s_cast)
7991
for i, j in T.Parallel(block_H, dim):
8092
acc_o[i, j] *= scores_scale[i]
81-
T.copy(
82-
V[bid, (seqlen_kv // num_split) * sid +
83-
k * block_N:(seqlen_kv // num_split) * sid + (k + 1) * block_N,
84-
cur_kv_head, :], V_shared)
85-
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
93+
T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
8694
for i, j in T.Parallel(block_H, dim):
8795
acc_o[i, j] /= logsum[i]
8896
for i in T.Parallel(block_H):
@@ -96,8 +104,8 @@ def flash_attn_split(
96104
@T.macro
97105
def combine(
98106
glse: T.Buffer([batch, heads, num_split], dtype),
99-
Output_partial: T.Buffer(part_shape, dtype),
100-
Output: T.Buffer(shape_o, dtype),
107+
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
108+
Output: T.Buffer([batch, heads, dim], dtype),
101109
):
102110
with T.Kernel(heads, batch, threads=128) as (by, bz):
103111
po_local = T.alloc_fragment([dim], dtype)
@@ -133,50 +141,63 @@ def combine(
133141

134142
@T.prim_func
135143
def main(
136-
Q: T.Buffer(shape_q, dtype),
137-
K: T.Buffer(shape_k, dtype),
138-
V: T.Buffer(shape_v, dtype),
144+
Q: T.Buffer([batch, heads, dim], dtype),
145+
Q_pe: T.Buffer([batch, heads, pe_dim], dtype),
146+
KV: T.Buffer([batch, seqlen_kv, kv_head_num, dim], dtype),
147+
K_pe: T.Buffer([batch, seqlen_kv, kv_head_num, pe_dim], dtype),
139148
glse: T.Buffer([batch, heads, num_split], dtype),
140-
Output_partial: T.Buffer(part_shape, dtype), # [batch, heads, num_split, dim]
141-
Output: T.Buffer(shape_o, dtype),
149+
Output_partial: T.Buffer([batch, heads, num_split, dim], dtype),
150+
Output: T.Buffer([batch, heads, dim], dtype),
142151
):
143-
flash_attn_split(Q, K, V, glse, Output_partial)
152+
flash_attn_split(Q, Q_pe, KV, K_pe, glse, Output_partial)
144153
combine(glse, Output_partial, Output)
145154

146155
return main
147156

148157

149-
def ref_program(query, key, value, glse, Output_partial):
158+
159+
def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
150160
# """
151161
# Inputs:
152-
# - query (Tensor): [batch, heads, dim]
153-
# - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
154-
# - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
155-
162+
# - q (Tensor): [batch, heads, dim]
163+
# - q_pe (Tensor): [batch, heads, pe_dim]
164+
# - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
165+
# - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
166+
# - glse (Tensor): [batch, heads, num_split]
167+
# - Output_partial (Tensor): [batch, heads, num_split, dim]
156168
# Outputs:
157169
# - output (Tensor): [batch, heads, dim]
158170
# """
159-
from einops import rearrange
160-
batch_size, query_heads, dim = query.shape # [batch_size, query_heads, dim]
161-
_, seqlen_kv, kv_heads, _ = key.shape # [batch_size, seqlen_kv, kv_heads, kv_dim]
162-
dim_v = value.shape[-1]
163-
assert kv_heads == 1, "kv_heads must be 1"
164-
165-
query_expanded = rearrange(query, 'b h d -> b h 1 d') # [batch_size, query_heads, 1, dim]
166-
key_expanded = key.expand(-1, -1, query_heads, -1) # [batch_size, query_heads, seqlen_kv, dim]
167-
value_expanded = value.expand(-1, -1, query_heads,
168-
-1) # [batch_size, query_heads, seqlen_kv, dim]
169-
key_expanded = rearrange(key_expanded,
170-
'b n h d -> b h n d') # [batch_size, kv_head_num, seqlen_kv, dim]
171-
value_expanded = rearrange(value_expanded,
172-
'b n h d -> b h n d') # [batch_size, query_heads, seqlen_kv, dim]
173-
174-
scores = torch.matmul(query_expanded,
175-
key_expanded.transpose(-1, -2)) # [batch_size, query_heads, 1, seqlen_kv]
176-
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
177-
attention_weights = F.softmax(scores, dim=-1) # [batch_size, query_heads, 1, seqlen_kv]
178-
output = torch.matmul(attention_weights, value_expanded) # [batch_size, query_heads, 1, dim]
179-
return output.view(batch_size, query_heads, dim_v)
171+
dim = q.shape[-1]
172+
pe_dim = q_pe.shape[-1]
173+
num_head_groups = q.shape[1] // kv.shape[2]
174+
scale = (dim + pe_dim) ** 0.5
175+
q = rearrange(
176+
q, 'b (h g) d -> b g h d',
177+
g=num_head_groups) # [batch_size, num_head_groups, groups, dim]
178+
179+
q_pe = rearrange(
180+
q_pe, 'b (h g) d -> b g h d',
181+
g=num_head_groups) # [batch_size, num_head_groups, groups, pe_dim]
182+
183+
kv = rearrange(kv, 'b n h d -> b h n d') # [batch_size, groups, seqlen_kv, dim]
184+
185+
k_pe = rearrange(k_pe, 'b n h d -> b h n d') # [batch_size, num_head_groups, groups, pe_dim]
186+
187+
query = torch.concat([q, q_pe], dim=-1)
188+
key = torch.concat([kv, k_pe], dim=-1)
189+
190+
scores = einsum(
191+
query, key,
192+
'b g h d, b h s d -> b g h s') # [batch_size, num_head_groups, groups, seqlen_kv]
193+
194+
attention = F.softmax(
195+
scores / scale, dim=-1) # [batch_size, num_head_groups, groups, seqlen_kv]
196+
197+
out = einsum(attention, kv,
198+
'b g h s, b h s d -> b g h d') # [batch_size, num_head_groups, groups, dim]
199+
out = rearrange(out, 'b g h d -> b (h g) d') # [batch_size, heads, dim]
200+
return out
180201

181202

182203
def flash_split_ref(Q, K, V):
@@ -251,7 +272,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
251272

252273

253274
if __name__ == "__main__":
254-
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 64, 128, 1, 8192, 512, 64
275+
BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE = 128, 128, 1, 8192, 512, 64
255276
qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE)
256277
pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
257278
total_flops = qk_flops + pv_flops
@@ -260,8 +281,9 @@ def reduce_ref(Q, K, V, glse, Output_partial):
260281

261282
program = flashattn(BATCH, H_Q, KV_H, KV_CTX, D_HEAD, DPE, BLOCK_N, BLOCK_H)
262283
mod, params = tilelang.lower(program)
263-
mod = tilelang.Profiler(mod, params, [5], tilelang.TensorSupplyType.Normal)
284+
mod = tilelang.Profiler(mod, params, [6], tilelang.TensorSupplyType.Normal)
264285
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
265-
latency = mod.do_bench(mod.func, warmup=500)
286+
print("All close")
287+
latency = mod.do_bench(mod.func, n_warmup=10, n_repeat=10, profiler="torch")
266288
print("Tile-lang: {:.2f} ms".format(latency))
267289
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))

0 commit comments

Comments
 (0)