-
Notifications
You must be signed in to change notification settings - Fork 331
[Bugfix] Fix flops comp and softmax scale in mla #900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughRenames FlashInference entry points from run_flash_infer/flash_infer to run_flashinfer/flashinfer across benchmark dispatch, targets, and comparisons. Adds an optional softmax_scale parameter (default None) to mla_decode_tilelang, computing it when omitted, and updates total_flops calculation and initialization logic accordingly. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor Caller
participant Decode as mla_decode_tilelang
participant Scale as Softmax Scale Init
Caller->>Decode: call(..., softmax_scale=None)
alt softmax_scale is None
Decode->>Scale: compute softmax_scale (fallback)
Scale-->>Decode: softmax_scale value
else softmax_scale provided
Note over Decode: Use provided softmax_scale
end
Decode->>Decode: attention computation (uses softmax_scale)
Decode-->>Caller: outputs
sequenceDiagram
autonumber
actor User
participant Bench as benchmark_mla.py
participant Dispatch as FUNC_TABLE
User->>Bench: select target "flashinfer"
Bench->>Dispatch: lookup("flashinfer")
Dispatch-->>Bench: run_flashinfer
Bench->>Bench: execute selected path (unchanged behavior)
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
@Edenzzzz Thanks for your contribution! And PR ID 900 is a pretty cool number. We know the trick behind FlashMLA, but for now, we need to implement it using an identically warp‑specialized kernel — for example, as shown here: https://github.com/tile-ai/tilelang/blob/main/examples/warp_specialize/example_warp_specialize_flashmla.py with a carefully designed layout. We’ll work on improving performance in the future. For now, our focus will be on programmability and robustness. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/deepseek_mla/benchmark_mla.py (1)
99-109: Fix kv_indptr double-append — length should be b+1 (CSR row pointer)You build kv_indptr once in the first loop (yielding length b+1), then append again in the second loop, producing an invalid length (2b). This can break flashinfer plan inputs.
Apply this minimal fix:
- for seq_len in cache_seqlens[1:]: - kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + # kv_indptr is already built to length b+1 above.If you want a vectorized version (optional), I can provide it.
🧹 Nitpick comments (2)
examples/deepseek_mla/benchmark_mla.py (2)
131-137: Avoid shadowing the imported flashinfer module with a local function nameThe inner function named flashinfer masks the imported module, which is confusing and risks mistakes.
Apply this refactor:
- def flashinfer(): + def _run_once(): output, lse = mla_wrapper.run( q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d - dv), blocked_k_nope, blocked_k_pe, return_lse=True) return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) - out_flash, lse_flash = flashinfer() - t = triton.testing.do_bench(flashinfer) + out_flash, lse_flash = _run_once() + t = triton.testing.do_bench(_run_once)Also applies to: 140-142
499-501: Correctly skip LSE comparison for backends with different/absent LSEGuard is correct for flashinfer (different LSE semantics) and triton/tilelang (no LSE).
For readability/maintainability, consider:
- if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]: + LSE_INCOMPARABLE = {"flashinfer", "flash_mla_triton", "tilelang"} + if target not in LSE_INCOMPARABLE and baseline not in LSE_INCOMPARABLE:
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
examples/deepseek_mla/benchmark_mla.py(6 hunks)examples/deepseek_mla/example_mla_decode_paged.py(3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- examples/deepseek_mla/example_mla_decode_paged.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: build-test-amd
🔇 Additional comments (4)
examples/deepseek_mla/benchmark_mla.py (4)
124-128: Softmax scale uses full head dim d (correct)Passing 1 / sqrt(d) to flashinfer.plan matches d = dpe + dv and the intended scaling fix.
462-462: FUNC_TABLE mapping to "flashinfer" looks goodDispatch table entry updated consistently.
558-558: available_targets updated to "flashinfer"Target list matches the new name.
90-92: Approve rename to run_flashinferRename is consistent; no remaining references to flash_infer or run_flash_infer.
|
LGTM, Merge :) |
* fix flops comp and softmax scale * format
Fixes some issues.
d = dpe + dv. The softmax scale is adjusted accordingly. Seehttps://github.com/flashinfer-ai/flashinfer/blob/0f8ecab1121ad52dd22b47317662cbbd7f0d343c/benchmarks/bench_deepseek_mla.py#L73
I ran the bench against Flashinfer main branch on H100, and looks like it has packed in more optimizations than when the last comparision was made. Looking forward to learning more Tilelang tricks to speed it up!

Summary by CodeRabbit
New Features
Refactor
Improvements