Skip to content

Conversation

@Edenzzzz
Copy link
Contributor

@Edenzzzz Edenzzzz commented Sep 29, 2025

Fixes some issues.

  1. The benchmark cannot run due to missing arg
image
  1. The FLOPs computation should use the full head dim d = dpe + dv. The softmax scale is adjusted accordingly. See
    https://github.com/flashinfer-ai/flashinfer/blob/0f8ecab1121ad52dd22b47317662cbbd7f0d343c/benchmarks/bench_deepseek_mla.py#L73

I ran the bench against Flashinfer main branch on H100, and looks like it has packed in more optimizations than when the last comparision was made. Looking forward to learning more Tilelang tricks to speed it up!
image

Summary by CodeRabbit

  • New Features

    • softmax_scale is now optional and auto-computed at runtime when omitted; callers remain backward compatible.
  • Refactor

    • Renamed Flash Inference backend from "flash_infer" to "flashinfer" across targets and dispatch.
  • Improvements

    • FLOPs reporting updated to use d only (was d + dv), aligning metrics with the new attention scaling.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 29, 2025

Walkthrough

Renames FlashInference entry points from run_flash_infer/flash_infer to run_flashinfer/flashinfer across benchmark dispatch, targets, and comparisons. Adds an optional softmax_scale parameter (default None) to mla_decode_tilelang, computing it when omitted, and updates total_flops calculation and initialization logic accordingly.

Changes

Cohort / File(s) Summary
FlashInfer naming updates
examples/deepseek_mla/benchmark_mla.py
Renamed public API and internal helper: run_flash_inferrun_flashinfer, flash_inferflashinfer. Updated FUNC_TABLE, available targets, and comparison references to the new name. No functional logic changes.
Softmax scale default and FLOPs adjustment
examples/deepseek_mla/example_mla_decode_paged.py
Made softmax_scale optional with default None; compute fallback when not provided. Adjusted internal initialization and usage. Updated total_flops from (d + dv) to d. Aligned run_tilelang_mla with new default behavior. Backward compatible.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor Caller
  participant Decode as mla_decode_tilelang
  participant Scale as Softmax Scale Init

  Caller->>Decode: call(..., softmax_scale=None)
  alt softmax_scale is None
    Decode->>Scale: compute softmax_scale (fallback)
    Scale-->>Decode: softmax_scale value
  else softmax_scale provided
    Note over Decode: Use provided softmax_scale
  end
  Decode->>Decode: attention computation (uses softmax_scale)
  Decode-->>Caller: outputs
Loading
sequenceDiagram
  autonumber
  actor User
  participant Bench as benchmark_mla.py
  participant Dispatch as FUNC_TABLE

  User->>Bench: select target "flashinfer"
  Bench->>Dispatch: lookup("flashinfer")
  Dispatch-->>Bench: run_flashinfer
  Bench->>Bench: execute selected path (unchanged behavior)
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Poem

A bunny renames with nimble flair,
From flash_infer to flashinfer—now fair.
Softmax scales, when none are there,
Hop in and compute with gentle care.
FLOPs slim down, d takes the chair—
Code fields bloom; I nibble air. 🐇✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly and concisely highlights the primary bugfix—correcting the FLOPs computation and softmax scaling in the MLA example—which directly aligns with the pull request’s stated objectives. It avoids unnecessary detail and noise, providing enough context for team members to understand the core change at a glance.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@LeiWang1999
Copy link
Member

@Edenzzzz Thanks for your contribution! And PR ID 900 is a pretty cool number.

We know the trick behind FlashMLA, but for now, we need to implement it using an identically warp‑specialized kernel — for example, as shown here: https://github.com/tile-ai/tilelang/blob/main/examples/warp_specialize/example_warp_specialize_flashmla.py with a carefully designed layout.

We’ll work on improving performance in the future. For now, our focus will be on programmability and robustness.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
examples/deepseek_mla/benchmark_mla.py (1)

99-109: Fix kv_indptr double-append — length should be b+1 (CSR row pointer)

You build kv_indptr once in the first loop (yielding length b+1), then append again in the second loop, producing an invalid length (2b). This can break flashinfer plan inputs.

Apply this minimal fix:

-    for seq_len in cache_seqlens[1:]:
-        kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1])
+    # kv_indptr is already built to length b+1 above.

If you want a vectorized version (optional), I can provide it.

🧹 Nitpick comments (2)
examples/deepseek_mla/benchmark_mla.py (2)

131-137: Avoid shadowing the imported flashinfer module with a local function name

The inner function named flashinfer masks the imported module, which is confusing and risks mistakes.

Apply this refactor:

-    def flashinfer():
+    def _run_once():
         output, lse = mla_wrapper.run(
             q_nope.view(-1, h_q, dv),
             q_pe.view(-1, h_q, d - dv),
             blocked_k_nope,
             blocked_k_pe,
             return_lse=True)
         return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1)

-    out_flash, lse_flash = flashinfer()
-    t = triton.testing.do_bench(flashinfer)
+    out_flash, lse_flash = _run_once()
+    t = triton.testing.do_bench(_run_once)

Also applies to: 140-142


499-501: Correctly skip LSE comparison for backends with different/absent LSE

Guard is correct for flashinfer (different LSE semantics) and triton/tilelang (no LSE).

For readability/maintainability, consider:

-    if target not in ["flashinfer", "flash_mla_triton", "tilelang"] and baseline not in ["flashinfer", "flash_mla_triton", "tilelang"]:
+    LSE_INCOMPARABLE = {"flashinfer", "flash_mla_triton", "tilelang"}
+    if target not in LSE_INCOMPARABLE and baseline not in LSE_INCOMPARABLE:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9c0209c and 2c5e0a5.

📒 Files selected for processing (2)
  • examples/deepseek_mla/benchmark_mla.py (6 hunks)
  • examples/deepseek_mla/example_mla_decode_paged.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/deepseek_mla/example_mla_decode_paged.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: build-test-amd
🔇 Additional comments (4)
examples/deepseek_mla/benchmark_mla.py (4)

124-128: Softmax scale uses full head dim d (correct)

Passing 1 / sqrt(d) to flashinfer.plan matches d = dpe + dv and the intended scaling fix.


462-462: FUNC_TABLE mapping to "flashinfer" looks good

Dispatch table entry updated consistently.


558-558: available_targets updated to "flashinfer"

Target list matches the new name.


90-92: Approve rename to run_flashinfer

Rename is consistent; no remaining references to flash_infer or run_flash_infer.

@LeiWang1999
Copy link
Member

LGTM, Merge :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants