Skip to content

Conversation

@PerkzZheng
Copy link
Contributor

@PerkzZheng PerkzZheng commented Nov 24, 2025

📌 Description

This MR adds trtllm-gen per-tensor sparseMla kernels.

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Added Sparse MLA mode to enable top-k sparse attention paths and configure sparse top-k behavior.
  • Performance

    • Improved kernel selection and runtime behavior to better support sparse MLA and varied head dimensions.
  • Tests

    • Expanded tests for multiple head dimensions and added comprehensive sparse MLA decoding tests and utilities.
  • Validation

    • Strengthened input/shape/runtime checks for sparse MLA configuration.
  • Chores

    • Updated public artifact references/checksums; tests now skip when insufficient GPUs are available.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 24, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

This PR adds sparse MLA (top-k) support across the FMHA stack by introducing runner/kernel params (mSparseMla, mSparseMlaTopK), threading a new sparse_mla_top_k argument through Python APIs, launcher, and kernel launch, and adapting kernel parameter layout, kernel selection, and reduction kernel behavior when sparse MLA is enabled.

Changes

Cohort / File(s) Summary
Kernel runtime & reduction
csrc/fmhaReduction.cu
Added bool sparseMla parameter to fmhaReductionKernel, early-exit on seqQ range, cap seqLenKv by mSparseMlaTopK when sparseMla, updated function-pointer type and kernel launch to pass sparse flag.
Launcher & public entry points
csrc/trtllm_fmha_kernel_launcher.cu
Added int64_t sparse_mla_top_k parameter to trtllm_paged_attention_launcher and trtllm_paged_attention_decode; set runner_params.mSparseMla/mSparseMlaTopK; added ICHECKs; propagated arg to callers.
Runner params & kernel traits
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h, include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Added fields mSparseMla and mSparseMlaTopK; kernel trait logic now considers sparseMla (caps maxAttentionWindow, restricts tile-size paths, adjusts numTokensPerPage, enforces head-dim constraints, includes mSparseMla in kernel hash).
Kernel parameter layout & validation
include/flashinfer/trtllm/fmha/kernelParams.h
Relaxed min dims for TMA descriptors, set 2D K/V layout and tile adjustments when sparseMla, validated SparseMlaTopK multiple-of-4, propagated mSparseMlaTopK into KernelParams.
Python API & decode plumbing
flashinfer/decode.py
Added sparse_mla_top_k argument to public decode API and internal shape checks (_check_trtllm_gen_mla_shape), branched page_table validation for sparse MLA, and threaded parameter through _paged_run → launcher.
Artifacts
flashinfer/artifacts.py
Updated ArtifactPath.TRTLLM_GEN_FMHA path and CheckSumHash.TRTLLM_GEN_FMHA checksum.
Tests — head-dim parametrization
tests/attention/test_trtllm_gen_attention.py
Parameterized tests over head_dim (128, 256); propagated head_dim through test helpers and removed prior xfail for 256.
Tests — sparse MLA
tests/attention/test_trtllm_gen_mla.py
Added generate_sparse_indices, sparse_mla_reference_torch, and test_trtllm_batch_decode_mla_sparse covering sparse MLA scenarios with reference comparisons.
Tests — skip behavior
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py
Replaced raising ValueError with pytest.skip when insufficient GPUs.

Sequence Diagram(s)

sequenceDiagram
    participant API as Public Decode API
    participant Launcher as trtllm_paged_attention_launcher
    participant Runner as RunnerParams
    participant Selector as Kernel Selector
    participant Kernel as FMHA Reduction Kernel

    API->>Launcher: call(..., sparse_mla_top_k)
    Launcher->>Runner: set mSparseMla=(sparse_mla_top_k>0), mSparseMlaTopK
    Launcher->>Launcher: ICHECK head-dim / MLA constraints
    Launcher->>Selector: request kernel (include mSparseMla)
    alt sparse MLA enabled
        Selector->>Selector: cap maxAttentionWindow, set numTokensPerPage=1, restrict tile paths
    else dense mode
        Selector->>Selector: normal kernel selection
    end
    Selector->>Kernel: launch fmhaReductionKernel(params, sparseMla, ...)
    rect rgba(150,200,255,0.12)
      Kernel->>Kernel: if sparseMla then cap seqLenKv by mSparseMlaTopK
      Kernel->>Kernel: use adjusted K/V layout and perform reduction
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Areas needing extra attention:
    • Correctness of updated kernel function pointer type and cudaLaunchKernel invocation argument ordering.
    • KernelParams K/V layout changes and TMA descriptor min-dims relaxation.
    • Kernel selection branches in fmhaKernels.cuh for regressions when sparseMla toggles.
    • Validation/ICHECK logic (SparseMlaTopK % 4, head-dim constraints).
    • New sparse-MLA tests and numeric tolerance/NaN handling in reference comparisons.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • djmmoss
  • cyx-6
  • wenscarl

Poem

🐰 I hop through kernels, top-K in paw,

Sparse MLA paths I gently draw,
From Python burrow to CUDA glade,
Params and tests in tidy parade,
A merry little hop — code freshly law! 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 44.44% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main feature addition: sparse MLA kernels for trtllm-gen, which aligns with the extensive changes across kernel files and supporting infrastructure throughout the PR.
Description check ✅ Passed The PR description follows the template with a brief description of the feature and checked-off pre-commit and test items, but lacks details on what sparse MLA does, why it's needed, and specific implementation notes.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between aacfabc and 388cc99.

📒 Files selected for processing (1)
  • flashinfer/artifacts.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • flashinfer/artifacts.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the trtllm-gen backend by adding support for per-tensor sparse Multi-Head Attention (MLA) kernels. This feature aims to optimize attention mechanisms by allowing the model to attend to a limited, 'top-k' set of key-value pairs, which can lead to performance improvements in scenarios where full attention is not strictly necessary. The changes span across kernel implementations, API definitions, and include robust testing to ensure the new functionality is correct and reliable.

Highlights

  • Sparse MLA Kernel Integration: Introduced new trtllm-gen per-tensor sparse MLA (Multi-Head Attention) kernels, enabling more efficient attention computations by focusing on a subset of key-value pairs.
  • API and Parameter Updates: Modified core attention functions and their Python wrappers to accept a sparse_mla_top_k parameter, controlling the number of top key-value pairs considered in sparse attention.
  • Kernel Logic Enhancements: Updated fmhaReductionKernel to incorporate sparse MLA logic, adjusted kernel selection in fmhaKernels.cuh based on sparse MLA flags, and modified KV cache handling in kernelParams.h for sparse layouts.
  • Comprehensive Testing: Added a new dedicated test suite (test_trtllm_batch_decode_mla_sparse) with a PyTorch reference implementation to rigorously validate the correctness and precision of the sparse MLA kernels across various configurations and data types.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for per-tensor sparse MLA kernels in trtllm-gen. The changes are well-structured, spanning from the CUDA kernels and their launchers to the Python bindings and tests. A comprehensive new test case with a PyTorch reference implementation has been added to validate the new sparse MLA functionality, which is excellent. I've identified a minor issue: a leftover debug print statement that should be removed. Additionally, there's a misleading comment in the new test code that should be corrected for clarity. Overall, this is a solid contribution.

"Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation"
)

print(f"query shape: {query.shape}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging purposes. Please remove it before merging.

Comment on lines +49 to +50
# Randomly sample topk positions from the sequence
if cur_seq_len > 0:
# cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk]
cur_abs_indices = torch.arange(0, topk, device="cpu")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 49 states that topk positions are randomly sampled, but the implementation on line 52 uses torch.arange, which deterministically selects the first topk indices. Please update the comment to reflect the actual behavior for clarity and maintainability. Using torch.randperm (which is commented out) would be another option if random sampling is desired.

Suggested change
# Randomly sample topk positions from the sequence
if cur_seq_len > 0:
# cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk]
cur_abs_indices = torch.arange(0, topk, device="cpu")
# Deterministically select the first topk positions from the sequence
if cur_seq_len > 0:
# cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk]
cur_abs_indices = torch.arange(0, topk, device="cpu")

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

466-490: Fix log2(numTokensPerPage) undefined behavior when value is 0

The issue is confirmed. At line 106-107 in hashID, the power-of-2 check (numTokensPerPage & (numTokensPerPage - 1)) == 0 accepts 0 as valid (since 0 & -1 = 0), but line 133 then calls log2(0), which is undefined behavior. This occurs when non-paged layouts set numTokensPerPage = 0 at line 547-549 and later call hashID at line 572.

Apply one of the proposed fixes:

Option 1 (guard log2 call, encode 0 as exponent 0):

  • Update line 106-107 check to explicitly allow only powers of 2 OR 0
  • Add int log2NumTokensPerPage = 0; if (numTokensPerPage > 0) { log2NumTokensPerPage = static_cast<int>(log2(numTokensPerPage)); }
  • Replace log2(numTokensPerPage) with log2NumTokensPerPage at line 133

Option 2 (normalize non-paged to numTokensPerPage = 1):

  • At line 549 in hashFromRunnerParams, set numTokensPerPage = 1 instead of 0
include/flashinfer/trtllm/fmha/kernelParams.h (1)

489-537: Fix dimension-aware debug logging to prevent out-of-bounds access in TMA descriptor error path

When sparse MLA is enabled, shapeK is explicitly reshaped to 2D {headDimQk, INT_MAX}, making dim = 2. The error-path logging unconditionally prints 5 elements from shapes, 4 from stridesInBytes, 5 from tileShapes, and 5 from tileStrides—causing out-of-bounds access and undefined behavior for arrays with size < 5.

Implement dimension-aware logging using loops to support all valid ranks (2–5):

      char const* err_str;
      cuGetErrorString(result, &err_str);
      std::cerr << "Error: Failed to initialize the TMA descriptor due to " << err_str << std::endl;
      std::cerr << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim
                << " gmem: " << gmemAddr << std::endl;
-      std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " "
-                << shapes[3] << " " << shapes[4] << std::endl;
-      std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " "
-                << stridesInBytes[2] << " " << stridesInBytes[3] << std::endl;
-      std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2]
-                << " " << tileShapes[3] << " " << tileShapes[4] << std::endl;
-      std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " "
-                << tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl;
+      std::cerr << "Shape:";
+      for (int ii = 0; ii < dim; ++ii) {
+        std::cerr << " " << shapes[ii];
+      }
+      std::cerr << std::endl;
+
+      std::cerr << "Stride (bytes):";
+      for (int ii = 0; ii < dim - 1; ++ii) {
+        std::cerr << " " << stridesInBytes[ii];
+      }
+      std::cerr << std::endl;
+
+      std::cerr << "tileShapes:";
+      for (size_t ii = 0; ii < tileShapes.size(); ++ii) {
+        std::cerr << " " << tileShapes[ii];
+      }
+      std::cerr << std::endl;
+
+      std::cerr << "tileStrides:";
+      for (int ii = 0; ii < dim; ++ii) {
+        std::cerr << " " << tileStrides[ii];
+      }
+      std::cerr << std::endl;
🧹 Nitpick comments (7)
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1)

290-293: Sparse MLA flags wiring looks consistent

mSparseMla/mSparseMlaTopK are added next to other scaling params and are zero‑initialized via the existing POD/memset constructor, which gives a sensible default (disabled / top‑k=0). Just ensure all call sites set both fields consistently (e.g., top‑k > 0 always implies mSparseMla == true) so downstream logic in kernelParams.h and fmhaKernels.cuh can rely on that invariant.

include/flashinfer/trtllm/fmha/kernelParams.h (1)

601-608: Sparse MLA K 2D layout and top‑k validation look good; consider clarifying constraints

  • The sparse MLA branch rewrites K’s TMA descriptor to a 2D layout [headDimQk, INT_MAX] with stride [1, headDimQk] and tileShapeKv[1] = 1. This is coherent with a per‑token sparse gather along the second dimension and matches the new dim >= 2 support.
  • The runtime check that SparseMlaTopK must be a multiple of 4 is a nice guardrail for the 16‑byte cp.async usage, and wiring options.mSparseMlaTopK into params.mSparseMlaTopK closes the loop.

Two small follow‑ups you might consider (non‑blocking):

  • If sparse MLA is not intended to support FP4 KV (DATA_TYPE_E2M1), an early FLASHINFER_CHECK(!options.mSparseMla || kernelMeta.mDataTypeKv != DATA_TYPE_E2M1, ...) would fail fast instead of relying on undocumented behavior of the 2D layout with packed nibbles.
  • Document in comments that mSparseMlaTopK == 0 is the “non‑sparse” mode and that callers should keep mSparseMla and mSparseMlaTopK consistent (mSparseMla true iff top‑k > 0) to avoid surprising combinations.

Also applies to: 734-739

tests/attention/test_trtllm_gen_mla.py (2)

15-86: Sparse index generation and reference MLA look structurally sound, with minor cleanups available

  • generate_sparse_indices correctly produces absolute indices and KV‑cache indices consistent with the block‑table scheme, and seq_lens sampling logic ensures cur_seq_len >= topk, so torch.arange(0, topk) is safe for the tests.
  • sparse_mla_reference_torch matches the layout assumptions of the kernel path (blocked KV, per‑batch gathering, optional sparse indices, causal masking, NaN handling).

Minor style/maintainability nits you may want to address:

  • random.seed(seed) is no longer needed now that randperm is commented out; either remove it or switch back to randomized sampling if you want to stress more index patterns.
  • The batch_idx parameter of scaled_dot_product_attention is unused; you can drop it (and its argument at the callsite) to satisfy Ruff:
-    def scaled_dot_product_attention(
-        batch_idx: int,
-        query: torch.Tensor,  # [h_q, s_q, d]
+    def scaled_dot_product_attention(
+        query: torch.Tensor,  # [h_q, s_q, d]
         key: torch.Tensor,  # [s_k, d]
...
-        cur_out, cur_lse = scaled_dot_product_attention(
-            i,
-            q[i].transpose(0, 1),  # [h_q, s_q, d]
+        cur_out, cur_lse = scaled_dot_product_attention(
+            q[i].transpose(0, 1),  # [h_q, s_q, d]
             cur_key,  # [s_k, d]

These are non‑functional improvements and won’t affect test behavior.

Also applies to: 89-211


657-710: Tighten Ruff‑flagged issues in sparse MLA test comparison

At the bottom of test_trtllm_batch_decode_mla_sparse:

  • The second return from sparse_mla_reference_torch is unpacked into lse_ref but never used.
  • The exception handlers re‑raise e, which drops the original traceback level.

Both are easy to clean up:

-    out_ref, lse_ref = sparse_mla_reference_torch(
+    out_ref, _lse_ref = sparse_mla_reference_torch(
         cache_seqlens=seq_lens_tensor,
...
-        except AssertionError as e:
+        except AssertionError:
...
-            raise e
+            raise
...
-        except AssertionError as e:
+        except AssertionError:
...
-            raise e
+            raise

This keeps the tests behavior identical while satisfying the Ruff hints and preserving full tracebacks on failures.

csrc/fmhaReduction.cu (1)

37-85: Sparse MLA top‑k and variable seq‑len handling look consistent; consider guarding degenerate seqLenKv

The new sparseMla flag and mSparseMlaTopK cap are threaded correctly and the early exit on ctaIdxQ >= seqLenQ avoids doing useless work on padded tokens. One thing to double‑check: after speculative‑decoding adjustment and the top‑k cap, seqLenKv can become very small (or even non‑positive in edge cases), which makes numCtasKv zero and leaves sumVal at 0 before the normalization 1.0f / sumVal. Adding a simple clamp such as seqLenKv = max(seqLenKv, 1); (or an early return for empty KV) before computing numCtasKv would harden this against divide‑by‑zero without changing the typical path.

flashinfer/decode.py (1)

2535-2703: Tighten MLA decode API: remove debug print and guard unsupported sparse_mla_top_k backends

The new sparse_mla_top_k argument is threaded correctly into the shape check and trtllm‑gen C++ call, but two small issues are worth fixing:

  1. The unconditional print(f"query shape: {query.shape}") will spam stdout in normal use and should be removed or gated behind an explicit debug flag.
  2. When backend == "xqa", sparse_mla_top_k is effectively ignored; if a user passes a positive value here, they’ll silently get dense behavior. It would be safer to raise if backend == "xqa" and sparse_mla_top_k > 0.

You might also want to extend the block_tables docstring to mention the 3D [B, Q_len, top_k] layout used when sparse_mla_top_k > 0.

 def trtllm_batch_decode_with_kv_cache_mla(
@@
-    if backend == "xqa":
+    if backend == "xqa":
+        if sparse_mla_top_k > 0:
+            raise ValueError(
+                "sparse_mla_top_k > 0 is only supported for trtllm-gen backend"
+            )
@@
-        if out is None:
+        if out is None:
             out_shape = query.shape[:-1] + (kv_lora_rank,)
             out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device)
@@
-        if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
+        if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None:
             # dynamic scale factors
@@
-        print(f"query shape: {query.shape}")
         run_func(
             out,
             None,  # fp4 output not supported in wrapper api yet.
csrc/trtllm_fmha_kernel_launcher.cu (1)

73-140: Sparse MLA flags on runner_params are wired correctly; consider explicit defaults in other launchers

The extended trtllm_paged_attention_launcher signature, the assignments:

runner_params.mSparseMla = sparse_mla_top_k > 0;
runner_params.mSparseMlaTopK = sparse_mla_top_k;
TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0)
    << "Only decode MLA supports sparse MLA";

cleanly gate sparse MLA to the intended (576, 512) MLA configuration and leave the dense path unchanged when sparse_mla_top_k <= 0.

To be fully robust, it would be good to also set mSparseMla = false and mSparseMlaTopK = 0 in other launchers that construct TllmGenFmhaRunnerParams (e.g., trtllm_ragged_attention_launcher), so those fields are never left unintentionally uninitialized if future kernels start consulting them in more places.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cf2df82 and 934f093.

📒 Files selected for processing (9)
  • csrc/fmhaReduction.cu (4 hunks)
  • csrc/trtllm_fmha_kernel_launcher.cu (5 hunks)
  • flashinfer/artifacts.py (2 hunks)
  • flashinfer/decode.py (10 hunks)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh (7 hunks)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (3 hunks)
  • tests/attention/test_trtllm_gen_attention.py (6 hunks)
  • tests/attention/test_trtllm_gen_mla.py (3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_mla.py (2)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
flashinfer/decode.py (1)
  • trtllm_batch_decode_with_kv_cache_mla (2535-2705)
🪛 Ruff (0.14.5)
tests/attention/test_trtllm_gen_mla.py

130-130: Unused function argument: batch_idx

(ARG001)


657-657: Unpacked variable lse_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


690-690: Use raise without specifying exception name

Remove exception name

(TRY201)


710-710: Use raise without specifying exception name

Remove exception name

(TRY201)

flashinfer/decode.py

2519-2521: Avoid specifying long messages outside the exception class

(TRY003)


2526-2528: Avoid specifying long messages outside the exception class

(TRY003)


2530-2532: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (10)
flashinfer/artifacts.py (1)

90-111: TRTLLM‑GEN FMHA artifact path/hash update is coherent

ArtifactPath.TRTLLM_GEN_FMHA and CheckSumHash.TRTLLM_GEN_FMHA are updated together, and map_checksums references the updated path constant, so the mapping remains self‑consistent. Just make sure the new hash matches the published checksums.txt for that directory in the cubin repository.

include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

335-339: Sparse MLA max‑attention window and tileSizeKv heuristic are consistent

Reducing maxAttentionWindow to min(mMaxSeqLenKv, mSparseMlaTopK) under mSparseMla correctly reflects that only top‑k keys are ever attended, which tightens multi‑CTAs‑per‑Kv heuristics. Likewise, gating the tileSizeKv 128→64 downshift with !params.mSparseMla avoids picking kernels that don’t exist for sparse MLA.

No functional issues spotted in this part.

Also applies to: 371-378

csrc/fmhaReduction.cu (1)

353-365: Kernel function pointer and launch signature update are correct

The function pointer type and cudaLaunchKernelEx invocation now match the new bool sparseMla parameter, and wiring kernelMeta.mSparseMla through here looks consistent with the template instantiations selected by SELECT_FMHA_REDUCTION_KERNEL.

tests/attention/test_trtllm_gen_attention.py (3)

374-390: Head‑dim parametrization for prefill tests is wired correctly

head_dim is cleanly added to parametrization, test signatures, and is used consistently in create_query_tensor and create_kv_cache, so both 128 and 256 head dimensions are now exercised without altering other test logic.


590-622: BS1 prefill wrapper correctly forwards head_dim

test_trtllm_batch_prefill_bs1 now simply forwards the new head_dim argument into test_trtllm_batch_prefill, maintaining behavior while expanding coverage to both supported head dimensions.


948-983: Decode tests now cover head_dim 128 and 256 through shared helper

The additional head_dim parametrization and its forwarding into _test_trtllm_batch_decode look consistent; this should give good coverage of the new 256‑dim decode path without duplicating logic.

flashinfer/decode.py (3)

1874-1924: Defaulting sparse MLA off for generic trtllm‑gen decode is correct

Passing a hardcoded 0 for sparse_mla_top_k in _paged_run keeps the existing non‑sparse behavior for the generic TrtllmGenDecodeModule path while matching the updated C++ launcher signature. This is a sensible default until/if sparse MLA is plumbed through higher‑level wrappers.


2486-2533: Sparse MLA shape checks clearly distinguish dense vs per‑token top‑k layouts

The updated _check_trtllm_gen_mla_shape correctly enforces:

  • page_table.shape == (B_q, Q_len, sparse_mla_top_k) when sparse_mla_top_k > 0 (sparse MLA), and
  • the original 2D [B, num_blocks] plus block_num % (128 / block_size) == 0 constraint for the dense case.

This matches the new semantics without breaking existing dense MLA callers.


2778-2787: xqa MLA correctly reuses common shape checks with sparse_mla_top_k disabled

Calling _check_trtllm_gen_mla_shape(..., sparse_mla_top_k=0, ...) here keeps the xqa MLA path on the dense 2D block‑table contract while still benefiting from the shared dimension validations.

csrc/trtllm_fmha_kernel_launcher.cu (1)

203-271: Decode/context wrappers forward sparse_mla_top_k consistently

trtllm_paged_attention_decode and both its callers have been updated consistently:

  • Decode path passes through the user’s sparse_mla_top_k down to trtllm_paged_attention_launcher.
  • Context path hardcodes /*sparse_mla_top_k=*/0, ensuring no change to existing behavior.

This matches the intent that sparse MLA is only enabled for the decode MLA API.

Also applies to: 321-332

Comment on lines +580 to +633
# Mask unused KV cache entries with NaN for correctness checking
kv_cache_ref = kv_cache.clone()
if dtype == torch.float8_e4m3fn:
kv_cache_ref = kv_cache_ref.to(torch.bfloat16)

# Mark all positions as NaN initially
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)

# Only used indices should be valid
kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim)
used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
kv_cache_flat[~used_mask] = float("0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

Fix device mismatch when masking unused KV cache entries

Here:

kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim)
used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
kv_cache_flat[~used_mask] = float("0")

kv_cache_flat lives on the CUDA device, but used_mask (and the index tensor) are on CPU. Boolean / index masks in PyTorch must be on the same device as the tensor being indexed, so this is likely to raise a runtime error.

You can keep everything on the CUDA device instead:

-    used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
-    used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
-    kv_cache_flat[~used_mask] = float("0")
+    used_mask = torch.zeros(
+        kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device
+    )
+    used_indices = torch.tensor(
+        all_indices, dtype=torch.int64, device=kv_cache_flat.device
+    )
+    used_mask[used_indices] = True
+    kv_cache_flat[~used_mask] = float("0")

If you want to exercise the NaN‑clean‑up path in the reference implementation more directly, you could also set unused entries to float("nan") instead of 0.0.


</parameter>
</invoke>

---

🏁 Script executed:

```shell
#!/bin/bash
# Extract the full function/test context to understand kv_cache_ref device placement
cat tests/attention/test_trtllm_gen_mla.py | sed -n '560,600p'

Length of output: 1417


🌐 Web query:

PyTorch boolean indexing device mismatch error CUDA CPU

💡 Result:

Cause

  • PyTorch requires all tensors involved in an operation to live on the same device. A boolean mask on CPU cannot be used to index a tensor on CUDA (and vice‑versa), which raises a device-mismatch RuntimeError. [1][2]

Quick fixes

  • Move the mask to the tensor's device:
    mask = mask.to(x.device)
    x = x[mask]
  • Or move the tensor to the mask's device:
    x = x.to(mask.device)
    x = x[mask]
  • If converting to NumPy, first move to CPU: arr = x.cpu().numpy() (NumPy doesn't accept CUDA tensors). [3]

Example

  • Preferred one-liner:
    x = x[mask.to(x.device)]

Notes

  • For assignment/indexed writes consider masked_fill / masked_scatter for performance (GPU vs CPU behavior differs). [2]

References

  • PyTorch indexing / maskedSelect docs. [1]
  • PyTorch GitHub discussion about boolean mask/indexing and performance. [2]
  • NumPy conversion / CPU vs CUDA reminder. [3]

Fix device mismatch when masking unused KV cache entries

kv_cache_flat is on the CUDA device (inherited from kv_cache_ref at line 580), but used_mask and the index tensor are explicitly created on CPU at lines 591–592. PyTorch requires all tensors involved in indexing operations to be on the same device—attempting to index a CUDA tensor with a CPU boolean mask raises a RuntimeError.

Correct this by moving the mask and indices to the same device as the tensor:

-    used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
-    used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
-    kv_cache_flat[~used_mask] = float("0")
+    used_mask = torch.zeros(
+        kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device
+    )
+    used_indices = torch.tensor(
+        all_indices, dtype=torch.int64, device=kv_cache_flat.device
+    )
+    used_mask[used_indices] = True
+    kv_cache_flat[~used_mask] = float("0")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Mask unused KV cache entries with NaN for correctness checking
kv_cache_ref = kv_cache.clone()
if dtype == torch.float8_e4m3fn:
kv_cache_ref = kv_cache_ref.to(torch.bfloat16)
# Mark all positions as NaN initially
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
# Only used indices should be valid
kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim)
used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
kv_cache_flat[~used_mask] = float("0")
# Mask unused KV cache entries with NaN for correctness checking
kv_cache_ref = kv_cache.clone()
if dtype == torch.float8_e4m3fn:
kv_cache_ref = kv_cache_ref.to(torch.bfloat16)
# Mark all positions as NaN initially
all_indices = indices_in_kvcache.flatten().tolist()
all_indices = list(set(all_indices))
if -1 in all_indices:
all_indices.remove(-1)
# Only used indices should be valid
kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim)
used_mask = torch.zeros(
kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device
)
used_indices = torch.tensor(
all_indices, dtype=torch.int64, device=kv_cache_flat.device
)
used_mask[used_indices] = True
kv_cache_flat[~used_mask] = float("0")
🤖 Prompt for AI Agents
In tests/attention/test_trtllm_gen_mla.py around lines 580 to 595 the test
creates kv_cache_ref on CUDA but constructs used_mask and the index tensor on
CPU, causing a device mismatch when indexing kv_cache_flat; move the mask and
the index tensor to the same device as kv_cache_flat (i.e., create them on
kv_cache_ref.device or call .to(kv_cache_flat.device)) before using them to
index and assign, ensuring the boolean mask and index tensor share device with
kv_cache_flat.

@PerkzZheng
Copy link
Contributor Author

let me rebase this.

@PerkzZheng PerkzZheng force-pushed the user/perkzz/trtllm-gen-sparse-mla branch from 934f093 to ba867e6 Compare November 24, 2025 06:26
@PerkzZheng
Copy link
Contributor Author

/bot run

@flashinfer-bot
Copy link
Collaborator

@PerkzZheng is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww

@yzh119 yzh119 requested a review from IwakuraRein as a code owner November 24, 2025 08:35
@yzh119
Copy link
Collaborator

yzh119 commented Nov 24, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !161 has been created, and the CI pipeline #39062444 is currently running. I'll report back once the pipeline job completes.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
tests/attention/test_trtllm_gen_mla.py (2)

47-50: Update comment to match deterministic implementation.

The comment at line 49 states "Randomly sample topk positions from the sequence", but the implementation uses torch.arange(0, topk) which deterministically selects the first topk indices.

Based on past review, please update the comment to reflect the actual deterministic behavior:

         # Generate indices for each query position
         for j in range(q_len_per_request):
-            # Randomly sample topk positions from the sequence
+            # Deterministically select the first topk positions from the sequence
             if cur_seq_len > 0:
                 # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk]
                 cur_abs_indices = torch.arange(0, topk, device="cpu")

630-633: Fix device mismatch in KV cache masking.

The code creates used_mask and index tensors on CPU (lines 631-632), but attempts to use them to index kv_cache_flat which is on CUDA (line 630). This will raise a device mismatch RuntimeError.

Move the mask and indices to the same device as kv_cache_flat:

     # Only used indices should be valid
     kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim)
-    used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
-    used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
-    kv_cache_flat[~used_mask] = float("0")
+    used_mask = torch.zeros(
+        kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device
+    )
+    used_indices = torch.tensor(
+        all_indices, dtype=torch.int64, device=kv_cache_flat.device
+    )
+    used_mask[used_indices] = True
+    kv_cache_flat[~used_mask] = float("0")
🧹 Nitpick comments (6)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

476-489: Document the numHeadsQPerKv==128 constraint for sparse MLA.

The code enforces that KeepsMmaAbForGeneration sparse MLA kernels only support numHeadsQPerKv == 128 (line 488), while SwapsMmaAbForGeneration allows numHeadsQPerKv < 128 (line 476).

This asymmetry should be documented to explain why the high-throughput kernel path has this restriction. Consider adding a comment explaining the architectural reason for this constraint.

         kernelType = FmhaKernelType::KeepsMmaAbForGeneration;
         // Always use the separate reduction kernel.
         if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) {
           selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel;
         }
+        // The keepsMmaAbForGeneration sparseMla kernels are optimized for numHeadsQPerKv=128
+        // due to 2-CTA MMA layout requirements and register pressure constraints.
         // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128.
         FLASHINFER_CHECK(
tests/attention/test_trtllm_gen_mla.py (3)

128-128: Remove unused batch_idx parameter.

The batch_idx parameter in the scaled_dot_product_attention helper function is never used in the function body.

Apply this diff:

     def scaled_dot_product_attention(
-        batch_idx: int,
         query: torch.Tensor,  # [h_q, s_q, d]
         key: torch.Tensor,  # [s_k, d]
         value: torch.Tensor,  # [s_k, dv]

And update the call site at line 196:

         cur_out, cur_lse = scaled_dot_product_attention(
-            i,
             q[i].transpose(0, 1),  # [h_q, s_q, d]

695-695: Consider using lse_ref for additional validation.

The lse_ref return value from sparse_mla_reference_torch is computed but never used. While the current output comparison is sufficient, you could optionally add LSE validation if the kernel also returns LSE values.

If LSE validation would be valuable:

out_ref, lse_ref = sparse_mla_reference_torch(...)
# Add LSE comparison if kernel returns it
# torch.testing.assert_close(lse_kernel, lse_ref, ...)

Otherwise, suppress the warning by prefixing with underscore:

-    out_ref, lse_ref = sparse_mla_reference_torch(
+    out_ref, _lse_ref = sparse_mla_reference_torch(

728-728: Use bare raise to re-raise exceptions.

When re-raising exceptions after debugging output, use bare raise instead of raise e to preserve the original traceback.

Apply this diff at both locations:

             print(f"Output sample: {output[0, 0, 0, :8]}")
             print(f"Reference sample: {out_ref[0, 0, 0, :8]}")
-            raise e
+            raise

Also applies to: 748-748

flashinfer/decode.py (2)

2505-2546: Sparse MLA page_table shape check is sound; consider tightening docs and (optionally) lint style

The extended _check_trtllm_gen_mla_shape:

  • Correctly distinguishes sparse vs dense modes by sparse_mla_top_k > 0, enforcing a 3‑D page_table shape (B_q, Q_len, sparse_mla_top_k) for sparse MLA while preserving the original 2‑D [B_block_table, block_num] contract when sparse_mla_top_k == 0.
  • Is used consistently: trtllm_batch_decode_with_kv_cache_mla passes the runtime sparse_mla_top_k for the trtllm‑gen path, and xqa_batch_decode_with_kv_cache_mla pins it to 0 to keep the dense [batch_size, num_pages] layout.

Two follow‑ups you may want to consider:

  • The public docstring for trtllm_batch_decode_with_kv_cache_mla still documents block_tables only as [batch_size, num_pages]. It would be clearer to state the sparse case explicitly (e.g., [batch_size, q_len_per_request, sparse_mla_top_k] when sparse_mla_top_k > 0).
  • Ruff’s TRY003 warning about long error messages here is purely stylistic. If you want to silence it, you could shorten these messages or funnel them through a small helper, but functionally the current checks are fine.

Also applies to: 2779-2788


2560-2561: Make sparse_mla_top_k semantics explicit for backends (trtllm‑gen vs xqa)

trtllm_batch_decode_with_kv_cache_mla now plumbs sparse_mla_top_k through to:

  • _check_trtllm_gen_mla_shape and trtllm_paged_attention_decode in the backend == "trtllm-gen" branch, which is great for enabling per‑tensor sparse MLA.
  • But the backend == "xqa" branch neither validates nor uses sparse_mla_top_k, effectively ignoring it while still expecting a 2‑D block_tables.

To avoid surprising users or accidental misuse (e.g., passing a non‑zero sparse_mla_top_k with an XQA backend expecting dense layout), consider:

  • Either raising ValueError when backend == "xqa" and sparse_mla_top_k != 0, or
  • Explicitly documenting that sparse MLA (sparse_mla_top_k > 0) is only supported for the trtllm‑gen backend and must be left as 0 for XQA.

This keeps the API contract sharp without changing current intended behavior.

Also applies to: 2576-2577, 2651-2653, 2669-2704

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 934f093 and aacfabc.

📒 Files selected for processing (10)
  • csrc/fmhaReduction.cu (4 hunks)
  • csrc/trtllm_fmha_kernel_launcher.cu (5 hunks)
  • flashinfer/artifacts.py (2 hunks)
  • flashinfer/decode.py (9 hunks)
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh (7 hunks)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1 hunks)
  • include/flashinfer/trtllm/fmha/kernelParams.h (3 hunks)
  • tests/attention/test_trtllm_gen_attention.py (7 hunks)
  • tests/attention/test_trtllm_gen_mla.py (3 hunks)
  • tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
🧰 Additional context used
🧬 Code graph analysis (3)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
  • std (229-583)
  • std (239-244)
  • std (279-300)
  • std (290-296)
csrc/trtllm_fmha_kernel_launcher.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
  • enable_pdl (220-220)
tests/attention/test_trtllm_gen_mla.py (2)
flashinfer/utils.py (1)
  • get_compute_capability (253-256)
flashinfer/decode.py (1)
  • trtllm_batch_decode_with_kv_cache_mla (2550-2712)
🪛 Ruff (0.14.5)
tests/attention/test_trtllm_gen_mla.py

128-128: Unused function argument: batch_idx

(ARG001)


695-695: Unpacked variable lse_ref is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


728-728: Use raise without specifying exception name

Remove exception name

(TRY201)


748-748: Use raise without specifying exception name

Remove exception name

(TRY201)

flashinfer/decode.py

2533-2535: Avoid specifying long messages outside the exception class

(TRY003)


2540-2542: Avoid specifying long messages outside the exception class

(TRY003)


2544-2546: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/artifacts.py (1)

90-90: LGTM: Artifact version updates for sparse MLA kernels.

The artifact path and checksum updates align with the new sparse MLA kernel generation. These changes are consistent with the broader PR objective of adding sparse MLA support.

Also applies to: 110-110

tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)

253-255: LGTM: Proper test skipping behavior.

Using pytest.skip instead of raising a ValueError is the correct approach when test requirements (sufficient GPUs) aren't met. This allows test suites to gracefully skip rather than fail.

include/flashinfer/trtllm/fmha/kernelParams.h (3)

489-490: LGTM: Dimension check relaxation for sparse MLA.

Relaxing the minimum TMA descriptor dimension from 3 to 2 is necessary to support the 2D K/V layout used in sparse MLA paths (as configured at lines 601-608).


733-737: LGTM: Proper alignment validation for sparse MLA top-k.

The validation that SparseMlaTopK must be a multiple of 4 is correct, as the implementation uses 16B cpAsync instructions for loading indices. The error message clearly explains the constraint.


601-608: Based on my verification of the codebase, the concerns raised in the review have been addressed by the existing code design.

Key findings:

The TMA descriptor validation at lines 487-495 checks that each shape dimension is within the valid range [1, 2^32]. INT_MAX (2,147,483,647 = 2^31 - 1) is well within this valid range, so the descriptor creation will succeed.

More importantly, the actual kernel access pattern is bounded: maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK), with a comment explicitly stating "The sparseMla only selects topK tokensKv". This means the INT_MAX dimension is not actually accessed beyond the mSparseMlaTopK boundary—it's a design pattern to indicate an unbounded logical dimension while the kernel controls actual accesses through the attention window calculation.

Additionally, mSparseMlaTopK is validated at construction time to be a multiple of 4, and stride validation requires strides[0] == 1, which is satisfied by strideK = [1, mHeadDimQk]. The second stride component (mHeadDimQk) is bounded by head dimension constraints, eliminating overflow risks.

include/flashinfer/trtllm/fmha/fmhaKernels.cuh (3)

336-339: LGTM: Proper attention window clamping for sparse MLA.

When sparse MLA is enabled, the maximum attention window is correctly clamped to the minimum of the KV sequence length and the sparse MLA top-k value, as sparse MLA only selects a subset of KV pairs.


372-373: LGTM: Correct kernel selection for sparse MLA.

Disabling the tileSizeKv=64 optimization when sparse MLA is active is appropriate, as the sparse path requires different kernel characteristics.


539-547: LGTM: Correct numTokensPerPage handling for sparse MLA.

Setting numTokensPerPage = 1 for sparse MLA is correct, as the sparse indexing treats each token independently. For non-paged layouts, setting it to 0 is also appropriate.

tests/attention/test_trtllm_gen_attention.py (1)

417-417: LGTM: Proper head dimension parameterization.

Adding head_dim as a test parameter with values [128, 256] extends test coverage to include both standard and MLA-specific head dimensions. The parameter is correctly threaded through all test functions.

Also applies to: 642-642, 657-657, 674-674, 696-696, 711-711, 728-728

csrc/fmhaReduction.cu (3)

67-77: LGTM: Proper early exit for sparse MLA.

Adding the early exit when ctaIdxQ >= seqLenQ is correct for sparse MLA, where each CTA processes one query token and we may launch more CTAs than actual query tokens.


82-85: LGTM: Correct seqLenKv clamping for sparse MLA.

When sparse MLA is enabled, clamping seqLenKv to min(seqLenKv, params.mSparseMlaTopK) correctly restricts the reduction to only the top-k selected KV pairs.


354-354: LGTM: Kernel signature and launch updated consistently.

The function pointer type and kernel launch call are correctly updated to include the bool sparseMla parameter, maintaining consistency with the kernel signature changes.

Also applies to: 364-365

csrc/trtllm_fmha_kernel_launcher.cu (3)

142-146: LGTM: Proper sparse MLA validation.

The validation ensures sparse MLA is only enabled for decode-MLA configurations (head_dim_qk==576 and head_dim_vo==512). Setting mSparseMla = sparse_mla_top_k > 0 is a clean way to enable the feature.


85-86: LGTM: Consistent sparse_mla_top_k parameter threading.

The sparse_mla_top_k parameter is correctly:

  1. Added to the launcher signature after o_sf_start_index
  2. Threaded through the decode function
  3. Passed to the launcher with the correct value

The parameter positioning maintains consistency with the existing signature.

Also applies to: 210-216, 294-295


374-375: LGTM: Correct default for context path.

Passing sparse_mla_top_k=0 for the context path correctly disables sparse MLA for prefill/context operations, which only applies to decode.

flashinfer/decode.py (1)

1925-1930: sparse_mla_top_k explicitly zeroed for non‑MLA decode preserves prior behavior

Both TrtllmGenDecodeModule._paged_run and the trtllm_batch_decode_with_kv_cache trtllm‑gen path pass sparse_mla_top_k=0 into trtllm_paged_attention_decode, which cleanly gates the new kernel parameter without changing existing dense decode semantics. This looks correct and keeps the non‑MLA paths backwards compatible.

Also applies to: 2332-2337

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39062444: 5/18 passed

@yzh119
Copy link
Collaborator

yzh119 commented Nov 25, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !161 has been updated with latest changes, and the CI pipeline #39124916 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants