Skip to content

Commit 1038da8

Browse files
authored
[Bugfix] Fix flops comp and softmax scale in mla (tile-ai#900)
* fix flops comp and softmax scale * format
1 parent 6c75403 commit 1038da8

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

examples/deepseek_mla/benchmark_mla.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def flash_mla():
8787

8888

8989
@torch.inference_mode()
90-
def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
91-
h_q, h_kv, d, dv, causal, dtype):
90+
def run_flashinfer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens,
91+
h_q, h_kv, d, dv, causal, dtype):
9292
# pip install flashinfer-python
9393
import flashinfer
9494
assert d > dv, "mla with rope dim should be larger than no rope dim"
@@ -128,7 +128,7 @@ def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_
128128
blocked_k.dtype,
129129
)
130130

131-
def flash_infer():
131+
def flashinfer():
132132
output, lse = mla_wrapper.run(
133133
q_nope.view(-1, h_q, dv),
134134
q_pe.view(-1, h_q, d - dv),
@@ -137,8 +137,8 @@ def flash_infer():
137137
return_lse=True)
138138
return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)
139139

140-
out_flash, lse_flash = flash_infer()
141-
t = triton.testing.do_bench(flash_infer)
140+
out_flash, lse_flash = flashinfer()
141+
t = triton.testing.do_bench(flashinfer)
142142
return out_flash, lse_flash, t
143143

144144

@@ -459,7 +459,7 @@ def flash_mla_tilelang():
459459
"torch": run_torch_mla,
460460
"tilelang": run_flash_mla_tilelang,
461461
"flash_mla": run_flash_mla,
462-
"flash_infer": run_flash_infer,
462+
"flashinfer": run_flashinfer,
463463
"flash_mla_triton": run_flash_mla_triton,
464464
}
465465

@@ -496,9 +496,9 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal
496496
s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype)
497497

498498
torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out"
499-
if target not in ["flash_infer", "flash_mla_triton", "tilelang"
500-
] and baseline not in ["flash_infer", "flash_mla_triton", "tilelang"]:
501-
# flash_infer has a different lse return value
499+
if target not in ["flashinfer", "flash_mla_triton", "tilelang"
500+
] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
501+
# flashinfer has a different lse return value
502502
# flash_mla_triton and flash_mla_tilelang doesn't return lse
503503
torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse"
504504

@@ -554,7 +554,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype):
554554
"torch",
555555
"tilelang",
556556
"flash_mla",
557-
"flash_infer",
557+
"flashinfer",
558558
"flash_mla_triton",
559559
]
560560

examples/deepseek_mla/example_mla_decode_paged.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,19 @@
1111
out_idx=[8], pass_configs={
1212
tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True,
1313
})
14-
def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split,
15-
block_size, softmax_scale):
14+
def mla_decode_tilelang(batch,
15+
h_q,
16+
h_kv,
17+
max_seqlen_pad,
18+
dv,
19+
dpe,
20+
block_N,
21+
block_H,
22+
num_split,
23+
block_size,
24+
softmax_scale=None):
25+
if softmax_scale is None:
26+
softmax_scale = (dv + dpe)**-0.5
1627
scale = float(softmax_scale * 1.44269504) # log2(e)
1728
dtype = "float16"
1829
accum_dtype = "float"
@@ -322,7 +333,7 @@ def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s
322333
num_kv_splits = 1
323334
BLOCK_N = 64
324335
BLOCK_H = min(64, h_q // h_kv)
325-
softmax_scale = (d + dv)**-0.5
336+
softmax_scale = d**-0.5
326337

327338
out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device)
328339
glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device)
@@ -379,7 +390,7 @@ def flash_mla_tilelang():
379390
max_seqlen = cache_seqlens.max().item()
380391
max_seqlen_pad = math.ceil(max_seqlen / 256) * 256
381392

382-
total_flops = s_q * total_seqlens * h_q * (d + dv) * 2
393+
total_flops = s_q * total_seqlens * h_q * d * 2
383394

384395
q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device)
385396
block_table = torch.arange(

0 commit comments

Comments
 (0)