Skip to content

Commit 8f00a9e

Browse files
committed
sparse mla kernels
1 parent d4fb9b7 commit 8f00a9e

File tree

4 files changed

+277
-259
lines changed

4 files changed

+277
-259
lines changed

examples/deepseek_v32/fp8_mqa_logits.py

Lines changed: 79 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,26 @@
11
import itertools
2-
import math
3-
from einops import rearrange
42
import tilelang
53
from tilelang import language as T
64
import torch
7-
from tilelang.autotuner import autotune
85
from tilelang import tvm
9-
from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q
10-
6+
from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar
117
from typing import Tuple
128

139

1410
def ceil_to_ue8m0(x: torch.Tensor):
1511
assert x.view(-1).amax().item() > 0
1612
return torch.pow(2.0, torch.ceil(torch.log2(x.abs())))
1713

18-
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
14+
15+
def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int],
16+
use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]:
1917
excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)])
2018
x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4)
2119
sf = x_amax / 448.0
2220
sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf
2321
x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn)
2422
return x_scaled, sf.squeeze()
2523

26-
def print_red_warning(message):
27-
print(f"\033[31mWARNING: {message}\033[0m")
28-
29-
30-
def calc_sim(x, y, name="tensor"):
31-
x, y = x.data.double(), y.data.double()
32-
denominator = (x * x + y * y).sum()
33-
if denominator == 0:
34-
print_red_warning(f"{name} all zero")
35-
return 1
36-
sim = 2 * (x * y).sum() / denominator
37-
return sim
38-
39-
40-
def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True):
41-
x_mask = torch.isfinite(x)
42-
y_mask = torch.isfinite(y)
43-
if not torch.all(x_mask == y_mask):
44-
print_red_warning(f"{name} Error: isfinite mask mismatch")
45-
if raise_assert:
46-
assert False
47-
if not torch.isclose(
48-
x.masked_fill(x_mask, 0),
49-
y.masked_fill(y_mask, 0),
50-
rtol=0,
51-
atol=0,
52-
equal_nan=True,
53-
).all():
54-
print_red_warning(f"{name} Error: nonfinite value mismatch")
55-
if raise_assert:
56-
assert False
57-
x = x.masked_fill(~x_mask, 0)
58-
y = y.masked_fill(~y_mask, 0)
59-
sim = calc_sim(x, y, name)
60-
diff = 1.0 - sim
61-
if not (0 <= diff <= eps):
62-
print_red_warning(f"{name} Error: {diff}")
63-
if raise_assert:
64-
assert False
65-
return diff
66-
6724

6825
def get_configs():
6926
iter_params = dict(
@@ -72,13 +29,13 @@ def get_configs():
7229
threads=[128, 256],
7330
block_Q=[1, 2, 4],
7431
)
75-
return [
76-
{k: v for k, v in zip(iter_params, values)}
77-
for values in itertools.product(*iter_params.values())
78-
]
32+
return [{
33+
k: v for k, v in zip(iter_params, values)
34+
} for values in itertools.product(*iter_params.values())]
7935

8036

8137
class SupplyProg:
38+
8239
def __init__(self):
8340
self.tensors_dict = {}
8441

@@ -127,13 +84,13 @@ def mqa_attn_return_logits(
12784

12885
@T.prim_func
12986
def mqa_attn_return_logits_kernel(
130-
IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore
131-
IndexK: T.Tensor(index_k_shape, dtype), # type: ignore
132-
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore
133-
Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore
134-
Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore
135-
CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore
136-
CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore
87+
IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore
88+
IndexK: T.Tensor(index_k_shape, dtype), # type: ignore
89+
IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore
90+
Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore
91+
Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore
92+
CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore
93+
CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore
13794
):
13895
with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx:
13996

@@ -156,20 +113,17 @@ def mqa_attn_return_logits_kernel(
156113
cu_k_e_max[0] = -2147483648
157114

158115
for bq_i in T.serial(block_Q):
159-
cu_k_s_min[0] = T.min(
160-
cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv)
161-
)
116+
cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i],
117+
seq_len_kv))
162118
for bq_i in T.serial(block_Q):
163-
cu_k_e_max[0] = T.max(
164-
cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv)
165-
)
119+
cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i],
120+
seq_len_kv))
166121

167122
T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared)
168123
T.copy(Weights[seq_len_i, 0], weights)
169124

170125
for nbn_i in T.Pipelined(
171-
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages
172-
):
126+
T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages):
173127
T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared)
174128
T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment)
175129

@@ -183,16 +137,16 @@ def mqa_attn_return_logits_kernel(
183137
)
184138

185139
for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads):
186-
s_reshaped[bn_i, bq_i, h_i] = (
187-
T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i]
188-
) * index_k_scale_fragment[bn_i]
140+
s_reshaped[bn_i, bq_i,
141+
h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) *
142+
weights[bq_i, h_i]) * index_k_scale_fragment[bn_i]
189143

190144
T.reduce_sum(s_reshaped, logits, dim=-1, clear=True)
191145

192146
for bq_i, bn_i in T.Parallel(block_Q, block_N):
193147
Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = (
194-
logits[bn_i, bq_i]
195-
)
148+
logits[bn_i, bq_i])
149+
196150
return mqa_attn_return_logits_kernel
197151

198152

@@ -209,9 +163,9 @@ def clean_logits_(
209163

210164
@T.prim_func
211165
def clean_logits_kernel(
212-
Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore
213-
CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore
214-
CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore
166+
Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore
167+
CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore
168+
CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore
215169
):
216170
with T.Kernel(seq_len, threads=threads) as bx:
217171
tx = T.thread_binding(0, threads, thread="threadIdx.x")
@@ -229,17 +183,19 @@ def clean_logits_kernel(
229183
return clean_logits_kernel
230184

231185

232-
def mqa_attn_return_logits_interface(
233-
q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True
234-
):
186+
def mqa_attn_return_logits_interface(q,
187+
kv,
188+
kv_scales,
189+
weights,
190+
cu_seqlen_ks,
191+
cu_seqlen_ke,
192+
clean_logits=True):
235193
seq_len, heads, index_dim = q.shape
236194
seq_len_kv = kv.shape[0]
237195

238196
clean_logits_kernel = clean_logits_()
239197

240-
mqa_attn_return_logits_kernel = mqa_attn_return_logits(
241-
heads=heads, index_dim=index_dim
242-
)
198+
mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim)
243199
logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32)
244200
mqa_attn_return_logits_kernel(
245201
q.view(seq_len * heads, index_dim),
@@ -273,33 +229,30 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
273229
cost = mask.sum()
274230
return logits, cost
275231

232+
276233
if __name__ == "__main__":
277234
torch.manual_seed(0)
278235
S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1
279-
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(
280-
torch.bfloat16
281-
)
282-
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(
283-
torch.bfloat16
284-
)
236+
q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
237+
kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16)
285238
weights = torch.randn(S, H, device="cuda", dtype=torch.float32)
286239
p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1)
287240

288-
def generate_random_cu_seqlens(
289-
per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512
290-
):
241+
def generate_random_cu_seqlens(per_cp_seqlen,
242+
cp_size=4,
243+
cp_rank=3,
244+
kv_stride=1,
245+
average_q_len=512):
291246
total_seqlen = per_cp_seqlen * cp_size
292247

293-
cu_seqlens = torch.randint(
294-
0, average_q_len * 2, (total_seqlen // average_q_len * 2,)
295-
).cuda()
248+
cu_seqlens = torch.randint(0, average_q_len * 2,
249+
(total_seqlen // average_q_len * 2,)).cuda()
296250
last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0]
297251
cu_seqlens = cu_seqlens[:last_seq_id]
298252

299253
if cu_seqlens.sum() < total_seqlen:
300254
cu_seqlens = torch.cat(
301-
[cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]
302-
)
255+
[cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()])
303256

304257
total_seqlen_k = (cu_seqlens // kv_stride).sum()
305258

@@ -328,75 +281,65 @@ def generate_random_cu_seqlens(
328281

329282
assert per_cp_seqlen % 2 == 0
330283
per_chunk_seqlen = per_cp_seqlen // 2
331-
slice_short = slice(
332-
cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen
333-
)
284+
slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen)
334285
slice_long = slice(
335286
total_seqlen - (cp_rank + 1) * per_chunk_seqlen,
336287
total_seqlen - cp_rank * per_chunk_seqlen,
337288
)
338-
ks = torch.cat(
339-
[
340-
cu_seqlens_ks_for_each_q[slice_short],
341-
cu_seqlens_ks_for_each_q[slice_long],
342-
]
343-
)
344-
ke = torch.cat(
345-
[
346-
cu_seqlens_ke_for_each_q[slice_short],
347-
cu_seqlens_ke_for_each_q[slice_long],
348-
]
349-
)
289+
ks = torch.cat([
290+
cu_seqlens_ks_for_each_q[slice_short],
291+
cu_seqlens_ks_for_each_q[slice_long],
292+
])
293+
ke = torch.cat([
294+
cu_seqlens_ke_for_each_q[slice_short],
295+
cu_seqlens_ke_for_each_q[slice_long],
296+
])
350297
assert len(ks) == len(ke) == per_cp_seqlen
351298
return ks, ke
352299

353300
ks, ke = generate_random_cu_seqlens(
354-
per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048
355-
)
301+
per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048)
356302

357303
logits_ref, cost_ref = ref_fp8_mqa_logits(
358-
q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke
359-
)
360-
304+
q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
305+
361306
q_fp8 = q.to(torch.float8_e4m3fn)
362-
kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0, ), False)
307+
kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False)
363308

364309
logits_tl = mqa_attn_return_logits_interface(
365-
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke
366-
)
367-
diff = assert_similar(
368-
logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False
369-
)
310+
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke)
311+
diff = assert_similar(logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False)
370312

371313
original_diff = None
372314
for i in range(10):
373315
logits_tl = mqa_attn_return_logits_interface(
374-
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke
375-
)
376-
diff = assert_similar(
377-
logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False
378-
)
316+
q=q_fp8,
317+
kv=kv_fp8,
318+
kv_scales=kv_scales,
319+
weights=weights,
320+
cu_seqlen_ks=ks,
321+
cu_seqlen_ke=ke)
322+
diff = assert_similar(logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False)
379323
if original_diff is None:
380324
original_diff = diff
381325
else:
382326
assert original_diff == diff
383327

384-
from tilelang.profiler import do_bench
385-
328+
from tilelang.profiler import do_bench
386329

387330
def logits_fn():
388331
return mqa_attn_return_logits_interface(
389-
q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke
390-
)
391-
392-
with torch.profiler.profile(
393-
activities=[torch.profiler.ProfilerActivity.CUDA]
394-
) as prof:
332+
q=q_fp8,
333+
kv=kv_fp8,
334+
kv_scales=kv_scales,
335+
weights=weights,
336+
cu_seqlen_ks=ks,
337+
cu_seqlen_ke=ke)
338+
339+
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
395340
logits_fn()
396341

397-
print(
398-
prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)
399-
)
342+
print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50))
400343

401344
logits_ms = do_bench(logits_fn, warmup=100, rep=100)
402345
logits_flops = 2 * cost_ref * H * D

0 commit comments

Comments
 (0)