-
Notifications
You must be signed in to change notification settings - Fork 581
feat: add trtllm-gen per-tensor sparseMla kernels. #2138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: add trtllm-gen per-tensor sparseMla kernels. #2138
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR adds sparse MLA (top-k) support across the FMHA stack by introducing runner/kernel params (mSparseMla, mSparseMlaTopK), threading a new sparse_mla_top_k argument through Python APIs, launcher, and kernel launch, and adapting kernel parameter layout, kernel selection, and reduction kernel behavior when sparse MLA is enabled. Changes
Sequence Diagram(s)sequenceDiagram
participant API as Public Decode API
participant Launcher as trtllm_paged_attention_launcher
participant Runner as RunnerParams
participant Selector as Kernel Selector
participant Kernel as FMHA Reduction Kernel
API->>Launcher: call(..., sparse_mla_top_k)
Launcher->>Runner: set mSparseMla=(sparse_mla_top_k>0), mSparseMlaTopK
Launcher->>Launcher: ICHECK head-dim / MLA constraints
Launcher->>Selector: request kernel (include mSparseMla)
alt sparse MLA enabled
Selector->>Selector: cap maxAttentionWindow, set numTokensPerPage=1, restrict tile paths
else dense mode
Selector->>Selector: normal kernel selection
end
Selector->>Kernel: launch fmhaReductionKernel(params, sparseMla, ...)
rect rgba(150,200,255,0.12)
Kernel->>Kernel: if sparseMla then cap seqLenKv by mSparseMlaTopK
Kernel->>Kernel: use adjusted K/V layout and perform reduction
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @PerkzZheng, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for per-tensor sparse MLA kernels in trtllm-gen. The changes are well-structured, spanning from the CUDA kernels and their launchers to the Python bindings and tests. A comprehensive new test case with a PyTorch reference implementation has been added to validate the new sparse MLA functionality, which is excellent. I've identified a minor issue: a leftover debug print statement that should be removed. Additionally, there's a misleading comment in the new test code that should be corrected for clarity. Overall, this is a solid contribution.
flashinfer/decode.py
Outdated
| "Dynamic scale factors bmm1_scale_tensor and bmm2_scale_tensor are only supported for fp8 tensor core operation" | ||
| ) | ||
|
|
||
| print(f"query shape: {query.shape}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # Randomly sample topk positions from the sequence | ||
| if cur_seq_len > 0: | ||
| # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk] | ||
| cur_abs_indices = torch.arange(0, topk, device="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment on line 49 states that topk positions are randomly sampled, but the implementation on line 52 uses torch.arange, which deterministically selects the first topk indices. Please update the comment to reflect the actual behavior for clarity and maintainability. Using torch.randperm (which is commented out) would be another option if random sampling is desired.
| # Randomly sample topk positions from the sequence | |
| if cur_seq_len > 0: | |
| # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk] | |
| cur_abs_indices = torch.arange(0, topk, device="cpu") | |
| # Deterministically select the first topk positions from the sequence | |
| if cur_seq_len > 0: | |
| # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk] | |
| cur_abs_indices = torch.arange(0, topk, device="cpu") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
466-490: Fix log2(numTokensPerPage) undefined behavior when value is 0The issue is confirmed. At line 106-107 in
hashID, the power-of-2 check(numTokensPerPage & (numTokensPerPage - 1)) == 0accepts 0 as valid (since0 & -1 = 0), but line 133 then callslog2(0), which is undefined behavior. This occurs when non-paged layouts setnumTokensPerPage = 0at line 547-549 and later callhashIDat line 572.Apply one of the proposed fixes:
Option 1 (guard log2 call, encode 0 as exponent 0):
- Update line 106-107 check to explicitly allow only powers of 2 OR 0
- Add
int log2NumTokensPerPage = 0; if (numTokensPerPage > 0) { log2NumTokensPerPage = static_cast<int>(log2(numTokensPerPage)); }- Replace
log2(numTokensPerPage)withlog2NumTokensPerPageat line 133Option 2 (normalize non-paged to
numTokensPerPage = 1):
- At line 549 in
hashFromRunnerParams, setnumTokensPerPage = 1instead of 0include/flashinfer/trtllm/fmha/kernelParams.h (1)
489-537: Fix dimension-aware debug logging to prevent out-of-bounds access in TMA descriptor error pathWhen sparse MLA is enabled,
shapeKis explicitly reshaped to 2D{headDimQk, INT_MAX}, makingdim = 2. The error-path logging unconditionally prints 5 elements fromshapes, 4 fromstridesInBytes, 5 fromtileShapes, and 5 fromtileStrides—causing out-of-bounds access and undefined behavior for arrays with size < 5.Implement dimension-aware logging using loops to support all valid ranks (2–5):
char const* err_str; cuGetErrorString(result, &err_str); std::cerr << "Error: Failed to initialize the TMA descriptor due to " << err_str << std::endl; std::cerr << "tmaFormat: " << static_cast<int>(tmaDataFormat) << " dim: " << dim << " gmem: " << gmemAddr << std::endl; - std::cerr << "Shape: " << shapes[0] << " " << shapes[1] << " " << shapes[2] << " " - << shapes[3] << " " << shapes[4] << std::endl; - std::cerr << "Stride: " << stridesInBytes[0] << " " << stridesInBytes[1] << " " - << stridesInBytes[2] << " " << stridesInBytes[3] << std::endl; - std::cerr << "tileShapes: " << tileShapes[0] << " " << tileShapes[1] << " " << tileShapes[2] - << " " << tileShapes[3] << " " << tileShapes[4] << std::endl; - std::cerr << "tileStrides: " << tileStrides[0] << " " << tileStrides[1] << " " - << tileStrides[2] << " " << tileStrides[3] << " " << tileStrides[4] << std::endl; + std::cerr << "Shape:"; + for (int ii = 0; ii < dim; ++ii) { + std::cerr << " " << shapes[ii]; + } + std::cerr << std::endl; + + std::cerr << "Stride (bytes):"; + for (int ii = 0; ii < dim - 1; ++ii) { + std::cerr << " " << stridesInBytes[ii]; + } + std::cerr << std::endl; + + std::cerr << "tileShapes:"; + for (size_t ii = 0; ii < tileShapes.size(); ++ii) { + std::cerr << " " << tileShapes[ii]; + } + std::cerr << std::endl; + + std::cerr << "tileStrides:"; + for (int ii = 0; ii < dim; ++ii) { + std::cerr << " " << tileStrides[ii]; + } + std::cerr << std::endl;
🧹 Nitpick comments (7)
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1)
290-293: Sparse MLA flags wiring looks consistent
mSparseMla/mSparseMlaTopKare added next to other scaling params and are zero‑initialized via the existing POD/memset constructor, which gives a sensible default (disabled / top‑k=0). Just ensure all call sites set both fields consistently (e.g., top‑k > 0 always impliesmSparseMla == true) so downstream logic inkernelParams.handfmhaKernels.cuhcan rely on that invariant.include/flashinfer/trtllm/fmha/kernelParams.h (1)
601-608: Sparse MLA K 2D layout and top‑k validation look good; consider clarifying constraints
- The sparse MLA branch rewrites K’s TMA descriptor to a 2D layout
[headDimQk, INT_MAX]with stride[1, headDimQk]andtileShapeKv[1] = 1. This is coherent with a per‑token sparse gather along the second dimension and matches the newdim >= 2support.- The runtime check that
SparseMlaTopKmust be a multiple of 4 is a nice guardrail for the 16‑bytecp.asyncusage, and wiringoptions.mSparseMlaTopKintoparams.mSparseMlaTopKcloses the loop.Two small follow‑ups you might consider (non‑blocking):
- If sparse MLA is not intended to support FP4 KV (
DATA_TYPE_E2M1), an earlyFLASHINFER_CHECK(!options.mSparseMla || kernelMeta.mDataTypeKv != DATA_TYPE_E2M1, ...)would fail fast instead of relying on undocumented behavior of the 2D layout with packed nibbles.- Document in comments that
mSparseMlaTopK == 0is the “non‑sparse” mode and that callers should keepmSparseMlaandmSparseMlaTopKconsistent (mSparseMlatrue iff top‑k > 0) to avoid surprising combinations.Also applies to: 734-739
tests/attention/test_trtllm_gen_mla.py (2)
15-86: Sparse index generation and reference MLA look structurally sound, with minor cleanups available
generate_sparse_indicescorrectly produces absolute indices and KV‑cache indices consistent with the block‑table scheme, andseq_lenssampling logic ensurescur_seq_len >= topk, sotorch.arange(0, topk)is safe for the tests.sparse_mla_reference_torchmatches the layout assumptions of the kernel path (blocked KV, per‑batch gathering, optional sparse indices, causal masking, NaN handling).Minor style/maintainability nits you may want to address:
random.seed(seed)is no longer needed now thatrandpermis commented out; either remove it or switch back to randomized sampling if you want to stress more index patterns.- The
batch_idxparameter ofscaled_dot_product_attentionis unused; you can drop it (and its argument at the callsite) to satisfy Ruff:- def scaled_dot_product_attention( - batch_idx: int, - query: torch.Tensor, # [h_q, s_q, d] + def scaled_dot_product_attention( + query: torch.Tensor, # [h_q, s_q, d] key: torch.Tensor, # [s_k, d] ... - cur_out, cur_lse = scaled_dot_product_attention( - i, - q[i].transpose(0, 1), # [h_q, s_q, d] + cur_out, cur_lse = scaled_dot_product_attention( + q[i].transpose(0, 1), # [h_q, s_q, d] cur_key, # [s_k, d]These are non‑functional improvements and won’t affect test behavior.
Also applies to: 89-211
657-710: Tighten Ruff‑flagged issues in sparse MLA test comparisonAt the bottom of
test_trtllm_batch_decode_mla_sparse:
- The second return from
sparse_mla_reference_torchis unpacked intolse_refbut never used.- The exception handlers re‑
raise e, which drops the original traceback level.Both are easy to clean up:
- out_ref, lse_ref = sparse_mla_reference_torch( + out_ref, _lse_ref = sparse_mla_reference_torch( cache_seqlens=seq_lens_tensor, ... - except AssertionError as e: + except AssertionError: ... - raise e + raise ... - except AssertionError as e: + except AssertionError: ... - raise e + raiseThis keeps the tests behavior identical while satisfying the Ruff hints and preserving full tracebacks on failures.
csrc/fmhaReduction.cu (1)
37-85: Sparse MLA top‑k and variable seq‑len handling look consistent; consider guarding degenerate seqLenKvThe new
sparseMlaflag andmSparseMlaTopKcap are threaded correctly and the early exit onctaIdxQ >= seqLenQavoids doing useless work on padded tokens. One thing to double‑check: after speculative‑decoding adjustment and the top‑k cap,seqLenKvcan become very small (or even non‑positive in edge cases), which makesnumCtasKvzero and leavessumValat 0 before the normalization1.0f / sumVal. Adding a simple clamp such asseqLenKv = max(seqLenKv, 1);(or an early return for empty KV) before computingnumCtasKvwould harden this against divide‑by‑zero without changing the typical path.flashinfer/decode.py (1)
2535-2703: Tighten MLA decode API: remove debug print and guard unsupported sparse_mla_top_k backendsThe new
sparse_mla_top_kargument is threaded correctly into the shape check and trtllm‑gen C++ call, but two small issues are worth fixing:
- The unconditional
print(f"query shape: {query.shape}")will spam stdout in normal use and should be removed or gated behind an explicit debug flag.- When
backend == "xqa",sparse_mla_top_kis effectively ignored; if a user passes a positive value here, they’ll silently get dense behavior. It would be safer to raise ifbackend == "xqa" and sparse_mla_top_k > 0.You might also want to extend the
block_tablesdocstring to mention the 3D[B, Q_len, top_k]layout used whensparse_mla_top_k > 0.def trtllm_batch_decode_with_kv_cache_mla( @@ - if backend == "xqa": + if backend == "xqa": + if sparse_mla_top_k > 0: + raise ValueError( + "sparse_mla_top_k > 0 is only supported for trtllm-gen backend" + ) @@ - if out is None: + if out is None: out_shape = query.shape[:-1] + (kv_lora_rank,) out = torch.empty(out_shape, dtype=torch.bfloat16, device=query.device) @@ - if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: + if bmm1_scale_log2_tensor is not None and bmm2_scale_tensor is not None: # dynamic scale factors @@ - print(f"query shape: {query.shape}") run_func( out, None, # fp4 output not supported in wrapper api yet.csrc/trtllm_fmha_kernel_launcher.cu (1)
73-140: Sparse MLA flags on runner_params are wired correctly; consider explicit defaults in other launchersThe extended
trtllm_paged_attention_launchersignature, the assignments:runner_params.mSparseMla = sparse_mla_top_k > 0; runner_params.mSparseMlaTopK = sparse_mla_top_k; TVM_FFI_ICHECK((head_dim_qk == 576 && head_dim_vo == 512) || sparse_mla_top_k <= 0) << "Only decode MLA supports sparse MLA";cleanly gate sparse MLA to the intended
(576, 512)MLA configuration and leave the dense path unchanged whensparse_mla_top_k <= 0.To be fully robust, it would be good to also set
mSparseMla = falseandmSparseMlaTopK = 0in other launchers that constructTllmGenFmhaRunnerParams(e.g.,trtllm_ragged_attention_launcher), so those fields are never left unintentionally uninitialized if future kernels start consulting them in more places.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (9)
csrc/fmhaReduction.cu(4 hunks)csrc/trtllm_fmha_kernel_launcher.cu(5 hunks)flashinfer/artifacts.py(2 hunks)flashinfer/decode.py(10 hunks)include/flashinfer/trtllm/fmha/fmhaKernels.cuh(7 hunks)include/flashinfer/trtllm/fmha/fmhaRunnerParams.h(1 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(3 hunks)tests/attention/test_trtllm_gen_attention.py(6 hunks)tests/attention/test_trtllm_gen_mla.py(3 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_gen_mla.py (2)
flashinfer/utils.py (1)
get_compute_capability(252-255)flashinfer/decode.py (1)
trtllm_batch_decode_with_kv_cache_mla(2535-2705)
🪛 Ruff (0.14.5)
tests/attention/test_trtllm_gen_mla.py
130-130: Unused function argument: batch_idx
(ARG001)
657-657: Unpacked variable lse_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
690-690: Use raise without specifying exception name
Remove exception name
(TRY201)
710-710: Use raise without specifying exception name
Remove exception name
(TRY201)
flashinfer/decode.py
2519-2521: Avoid specifying long messages outside the exception class
(TRY003)
2526-2528: Avoid specifying long messages outside the exception class
(TRY003)
2530-2532: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (10)
flashinfer/artifacts.py (1)
90-111: TRTLLM‑GEN FMHA artifact path/hash update is coherent
ArtifactPath.TRTLLM_GEN_FMHAandCheckSumHash.TRTLLM_GEN_FMHAare updated together, andmap_checksumsreferences the updated path constant, so the mapping remains self‑consistent. Just make sure the new hash matches the publishedchecksums.txtfor that directory in the cubin repository.include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
335-339: Sparse MLA max‑attention window and tileSizeKv heuristic are consistentReducing
maxAttentionWindowtomin(mMaxSeqLenKv, mSparseMlaTopK)undermSparseMlacorrectly reflects that only top‑k keys are ever attended, which tightens multi‑CTAs‑per‑Kv heuristics. Likewise, gating the tileSizeKv 128→64 downshift with!params.mSparseMlaavoids picking kernels that don’t exist for sparse MLA.No functional issues spotted in this part.
Also applies to: 371-378
csrc/fmhaReduction.cu (1)
353-365: Kernel function pointer and launch signature update are correctThe function pointer type and
cudaLaunchKernelExinvocation now match the newbool sparseMlaparameter, and wiringkernelMeta.mSparseMlathrough here looks consistent with the template instantiations selected bySELECT_FMHA_REDUCTION_KERNEL.tests/attention/test_trtllm_gen_attention.py (3)
374-390: Head‑dim parametrization for prefill tests is wired correctly
head_dimis cleanly added to parametrization, test signatures, and is used consistently increate_query_tensorandcreate_kv_cache, so both 128 and 256 head dimensions are now exercised without altering other test logic.
590-622: BS1 prefill wrapper correctly forwards head_dim
test_trtllm_batch_prefill_bs1now simply forwards the newhead_dimargument intotest_trtllm_batch_prefill, maintaining behavior while expanding coverage to both supported head dimensions.
948-983: Decode tests now cover head_dim 128 and 256 through shared helperThe additional
head_dimparametrization and its forwarding into_test_trtllm_batch_decodelook consistent; this should give good coverage of the new 256‑dim decode path without duplicating logic.flashinfer/decode.py (3)
1874-1924: Defaulting sparse MLA off for generic trtllm‑gen decode is correctPassing a hardcoded
0forsparse_mla_top_kin_paged_runkeeps the existing non‑sparse behavior for the genericTrtllmGenDecodeModulepath while matching the updated C++ launcher signature. This is a sensible default until/if sparse MLA is plumbed through higher‑level wrappers.
2486-2533: Sparse MLA shape checks clearly distinguish dense vs per‑token top‑k layoutsThe updated
_check_trtllm_gen_mla_shapecorrectly enforces:
page_table.shape == (B_q, Q_len, sparse_mla_top_k)whensparse_mla_top_k > 0(sparse MLA), and- the original 2D
[B, num_blocks]plusblock_num % (128 / block_size) == 0constraint for the dense case.This matches the new semantics without breaking existing dense MLA callers.
2778-2787: xqa MLA correctly reuses common shape checks with sparse_mla_top_k disabledCalling
_check_trtllm_gen_mla_shape(..., sparse_mla_top_k=0, ...)here keeps the xqa MLA path on the dense 2D block‑table contract while still benefiting from the shared dimension validations.csrc/trtllm_fmha_kernel_launcher.cu (1)
203-271: Decode/context wrappers forward sparse_mla_top_k consistently
trtllm_paged_attention_decodeand both its callers have been updated consistently:
- Decode path passes through the user’s
sparse_mla_top_kdown totrtllm_paged_attention_launcher.- Context path hardcodes
/*sparse_mla_top_k=*/0, ensuring no change to existing behavior.This matches the intent that sparse MLA is only enabled for the decode MLA API.
Also applies to: 321-332
| # Mask unused KV cache entries with NaN for correctness checking | ||
| kv_cache_ref = kv_cache.clone() | ||
| if dtype == torch.float8_e4m3fn: | ||
| kv_cache_ref = kv_cache_ref.to(torch.bfloat16) | ||
|
|
||
| # Mark all positions as NaN initially | ||
| all_indices = indices_in_kvcache.flatten().tolist() | ||
| all_indices = list(set(all_indices)) | ||
| if -1 in all_indices: | ||
| all_indices.remove(-1) | ||
|
|
||
| # Only used indices should be valid | ||
| kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim) | ||
| used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu") | ||
| used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True | ||
| kv_cache_flat[~used_mask] = float("0") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Fix device mismatch when masking unused KV cache entries
Here:
kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim)
used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
kv_cache_flat[~used_mask] = float("0")kv_cache_flat lives on the CUDA device, but used_mask (and the index tensor) are on CPU. Boolean / index masks in PyTorch must be on the same device as the tensor being indexed, so this is likely to raise a runtime error.
You can keep everything on the CUDA device instead:
- used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
- used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
- kv_cache_flat[~used_mask] = float("0")
+ used_mask = torch.zeros(
+ kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device
+ )
+ used_indices = torch.tensor(
+ all_indices, dtype=torch.int64, device=kv_cache_flat.device
+ )
+ used_mask[used_indices] = True
+ kv_cache_flat[~used_mask] = float("0")If you want to exercise the NaN‑clean‑up path in the reference implementation more directly, you could also set unused entries to float("nan") instead of 0.0.
</parameter>
</invoke>
---
🏁 Script executed:
```shell
#!/bin/bash
# Extract the full function/test context to understand kv_cache_ref device placement
cat tests/attention/test_trtllm_gen_mla.py | sed -n '560,600p'
Length of output: 1417
🌐 Web query:
PyTorch boolean indexing device mismatch error CUDA CPU
💡 Result:
Cause
- PyTorch requires all tensors involved in an operation to live on the same device. A boolean mask on CPU cannot be used to index a tensor on CUDA (and vice‑versa), which raises a device-mismatch RuntimeError. [1][2]
Quick fixes
- Move the mask to the tensor's device:
mask = mask.to(x.device)
x = x[mask] - Or move the tensor to the mask's device:
x = x.to(mask.device)
x = x[mask] - If converting to NumPy, first move to CPU: arr = x.cpu().numpy() (NumPy doesn't accept CUDA tensors). [3]
Example
- Preferred one-liner:
x = x[mask.to(x.device)]
Notes
- For assignment/indexed writes consider masked_fill / masked_scatter for performance (GPU vs CPU behavior differs). [2]
References
- PyTorch indexing / maskedSelect docs. [1]
- PyTorch GitHub discussion about boolean mask/indexing and performance. [2]
- NumPy conversion / CPU vs CUDA reminder. [3]
Fix device mismatch when masking unused KV cache entries
kv_cache_flat is on the CUDA device (inherited from kv_cache_ref at line 580), but used_mask and the index tensor are explicitly created on CPU at lines 591–592. PyTorch requires all tensors involved in indexing operations to be on the same device—attempting to index a CUDA tensor with a CPU boolean mask raises a RuntimeError.
Correct this by moving the mask and indices to the same device as the tensor:
- used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu")
- used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True
- kv_cache_flat[~used_mask] = float("0")
+ used_mask = torch.zeros(
+ kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device
+ )
+ used_indices = torch.tensor(
+ all_indices, dtype=torch.int64, device=kv_cache_flat.device
+ )
+ used_mask[used_indices] = True
+ kv_cache_flat[~used_mask] = float("0")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Mask unused KV cache entries with NaN for correctness checking | |
| kv_cache_ref = kv_cache.clone() | |
| if dtype == torch.float8_e4m3fn: | |
| kv_cache_ref = kv_cache_ref.to(torch.bfloat16) | |
| # Mark all positions as NaN initially | |
| all_indices = indices_in_kvcache.flatten().tolist() | |
| all_indices = list(set(all_indices)) | |
| if -1 in all_indices: | |
| all_indices.remove(-1) | |
| # Only used indices should be valid | |
| kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim) | |
| used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu") | |
| used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True | |
| kv_cache_flat[~used_mask] = float("0") | |
| # Mask unused KV cache entries with NaN for correctness checking | |
| kv_cache_ref = kv_cache.clone() | |
| if dtype == torch.float8_e4m3fn: | |
| kv_cache_ref = kv_cache_ref.to(torch.bfloat16) | |
| # Mark all positions as NaN initially | |
| all_indices = indices_in_kvcache.flatten().tolist() | |
| all_indices = list(set(all_indices)) | |
| if -1 in all_indices: | |
| all_indices.remove(-1) | |
| # Only used indices should be valid | |
| kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim) | |
| used_mask = torch.zeros( | |
| kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device | |
| ) | |
| used_indices = torch.tensor( | |
| all_indices, dtype=torch.int64, device=kv_cache_flat.device | |
| ) | |
| used_mask[used_indices] = True | |
| kv_cache_flat[~used_mask] = float("0") |
🤖 Prompt for AI Agents
In tests/attention/test_trtllm_gen_mla.py around lines 580 to 595 the test
creates kv_cache_ref on CUDA but constructs used_mask and the index tensor on
CPU, causing a device mismatch when indexing kv_cache_flat; move the mask and
the index tensor to the same device as kv_cache_flat (i.e., create them on
kv_cache_ref.device or call .to(kv_cache_flat.device)) before using them to
index and assign, ensuring the boolean mask and index tensor share device with
kv_cache_flat.
|
let me rebase this. |
Signed-off-by: Perkz Zheng <67892460+PerkzZheng@users.noreply.github.com>
934f093 to
ba867e6
Compare
|
/bot run |
|
@PerkzZheng is not authorized to trigger this CI job. cc: @yzh119, @sricketts, @yongwww |
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
tests/attention/test_trtllm_gen_mla.py (2)
47-50: Update comment to match deterministic implementation.The comment at line 49 states "Randomly sample topk positions from the sequence", but the implementation uses
torch.arange(0, topk)which deterministically selects the first topk indices.Based on past review, please update the comment to reflect the actual deterministic behavior:
# Generate indices for each query position for j in range(q_len_per_request): - # Randomly sample topk positions from the sequence + # Deterministically select the first topk positions from the sequence if cur_seq_len > 0: # cur_abs_indices = torch.randperm(cur_seq_len, device="cpu")[:topk] cur_abs_indices = torch.arange(0, topk, device="cpu")
630-633: Fix device mismatch in KV cache masking.The code creates
used_maskand index tensors on CPU (lines 631-632), but attempts to use them to indexkv_cache_flatwhich is on CUDA (line 630). This will raise a device mismatch RuntimeError.Move the mask and indices to the same device as
kv_cache_flat:# Only used indices should be valid kv_cache_flat = kv_cache_ref.view(-1, kv_lora_rank + qk_rope_head_dim) - used_mask = torch.zeros(kv_cache_flat.size(0), dtype=torch.bool, device="cpu") - used_mask[torch.tensor(all_indices, dtype=torch.int64, device="cpu")] = True - kv_cache_flat[~used_mask] = float("0") + used_mask = torch.zeros( + kv_cache_flat.size(0), dtype=torch.bool, device=kv_cache_flat.device + ) + used_indices = torch.tensor( + all_indices, dtype=torch.int64, device=kv_cache_flat.device + ) + used_mask[used_indices] = True + kv_cache_flat[~used_mask] = float("0")
🧹 Nitpick comments (6)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
476-489: Document the numHeadsQPerKv==128 constraint for sparse MLA.The code enforces that
KeepsMmaAbForGenerationsparse MLA kernels only supportnumHeadsQPerKv == 128(line 488), whileSwapsMmaAbForGenerationallowsnumHeadsQPerKv < 128(line 476).This asymmetry should be documented to explain why the high-throughput kernel path has this restriction. Consider adding a comment explaining the architectural reason for this constraint.
kernelType = FmhaKernelType::KeepsMmaAbForGeneration; // Always use the separate reduction kernel. if (isMultiCtasKvEnabled(selectKernelParams.mMultiCtasKvMode)) { selectKernelParams.mMultiCtasKvMode = MultiCtasKvMode::GmemReductionWithSeparateKernel; } + // The keepsMmaAbForGeneration sparseMla kernels are optimized for numHeadsQPerKv=128 + // due to 2-CTA MMA layout requirements and register pressure constraints. // The keepsMmaAbForGeneration sparseMla kernels only support numHeadsQPerKv = 128. FLASHINFER_CHECK(tests/attention/test_trtllm_gen_mla.py (3)
128-128: Remove unused batch_idx parameter.The
batch_idxparameter in thescaled_dot_product_attentionhelper function is never used in the function body.Apply this diff:
def scaled_dot_product_attention( - batch_idx: int, query: torch.Tensor, # [h_q, s_q, d] key: torch.Tensor, # [s_k, d] value: torch.Tensor, # [s_k, dv]And update the call site at line 196:
cur_out, cur_lse = scaled_dot_product_attention( - i, q[i].transpose(0, 1), # [h_q, s_q, d]
695-695: Consider using lse_ref for additional validation.The
lse_refreturn value fromsparse_mla_reference_torchis computed but never used. While the current output comparison is sufficient, you could optionally add LSE validation if the kernel also returns LSE values.If LSE validation would be valuable:
out_ref, lse_ref = sparse_mla_reference_torch(...) # Add LSE comparison if kernel returns it # torch.testing.assert_close(lse_kernel, lse_ref, ...)Otherwise, suppress the warning by prefixing with underscore:
- out_ref, lse_ref = sparse_mla_reference_torch( + out_ref, _lse_ref = sparse_mla_reference_torch(
728-728: Use bare raise to re-raise exceptions.When re-raising exceptions after debugging output, use bare
raiseinstead ofraise eto preserve the original traceback.Apply this diff at both locations:
print(f"Output sample: {output[0, 0, 0, :8]}") print(f"Reference sample: {out_ref[0, 0, 0, :8]}") - raise e + raiseAlso applies to: 748-748
flashinfer/decode.py (2)
2505-2546: Sparse MLA page_table shape check is sound; consider tightening docs and (optionally) lint styleThe extended
_check_trtllm_gen_mla_shape:
- Correctly distinguishes sparse vs dense modes by
sparse_mla_top_k > 0, enforcing a 3‑Dpage_tableshape(B_q, Q_len, sparse_mla_top_k)for sparse MLA while preserving the original 2‑D[B_block_table, block_num]contract whensparse_mla_top_k == 0.- Is used consistently:
trtllm_batch_decode_with_kv_cache_mlapasses the runtimesparse_mla_top_kfor the trtllm‑gen path, andxqa_batch_decode_with_kv_cache_mlapins it to0to keep the dense[batch_size, num_pages]layout.Two follow‑ups you may want to consider:
- The public docstring for
trtllm_batch_decode_with_kv_cache_mlastill documentsblock_tablesonly as[batch_size, num_pages]. It would be clearer to state the sparse case explicitly (e.g.,[batch_size, q_len_per_request, sparse_mla_top_k]whensparse_mla_top_k > 0).- Ruff’s TRY003 warning about long error messages here is purely stylistic. If you want to silence it, you could shorten these messages or funnel them through a small helper, but functionally the current checks are fine.
Also applies to: 2779-2788
2560-2561: Make sparse_mla_top_k semantics explicit for backends (trtllm‑gen vs xqa)
trtllm_batch_decode_with_kv_cache_mlanow plumbssparse_mla_top_kthrough to:
_check_trtllm_gen_mla_shapeandtrtllm_paged_attention_decodein thebackend == "trtllm-gen"branch, which is great for enabling per‑tensor sparse MLA.- But the
backend == "xqa"branch neither validates nor usessparse_mla_top_k, effectively ignoring it while still expecting a 2‑Dblock_tables.To avoid surprising users or accidental misuse (e.g., passing a non‑zero
sparse_mla_top_kwith an XQA backend expecting dense layout), consider:
- Either raising
ValueErrorwhenbackend == "xqa"andsparse_mla_top_k != 0, or- Explicitly documenting that sparse MLA (
sparse_mla_top_k > 0) is only supported for the trtllm‑gen backend and must be left as 0 for XQA.This keeps the API contract sharp without changing current intended behavior.
Also applies to: 2576-2577, 2651-2653, 2669-2704
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (10)
csrc/fmhaReduction.cu(4 hunks)csrc/trtllm_fmha_kernel_launcher.cu(5 hunks)flashinfer/artifacts.py(2 hunks)flashinfer/decode.py(9 hunks)include/flashinfer/trtllm/fmha/fmhaKernels.cuh(7 hunks)include/flashinfer/trtllm/fmha/fmhaRunnerParams.h(1 hunks)include/flashinfer/trtllm/fmha/kernelParams.h(3 hunks)tests/attention/test_trtllm_gen_attention.py(7 hunks)tests/attention/test_trtllm_gen_mla.py(3 hunks)tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
🧰 Additional context used
🧬 Code graph analysis (3)
include/flashinfer/trtllm/fmha/kernelParams.h (1)
include/flashinfer/trtllm/batched_gemm/trtllmGen_bmm_export/KernelParams.h (4)
std(229-583)std(239-244)std(279-300)std(290-296)
csrc/trtllm_fmha_kernel_launcher.cu (1)
csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h (1)
enable_pdl(220-220)
tests/attention/test_trtllm_gen_mla.py (2)
flashinfer/utils.py (1)
get_compute_capability(253-256)flashinfer/decode.py (1)
trtllm_batch_decode_with_kv_cache_mla(2550-2712)
🪛 Ruff (0.14.5)
tests/attention/test_trtllm_gen_mla.py
128-128: Unused function argument: batch_idx
(ARG001)
695-695: Unpacked variable lse_ref is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
728-728: Use raise without specifying exception name
Remove exception name
(TRY201)
748-748: Use raise without specifying exception name
Remove exception name
(TRY201)
flashinfer/decode.py
2533-2535: Avoid specifying long messages outside the exception class
(TRY003)
2540-2542: Avoid specifying long messages outside the exception class
(TRY003)
2544-2546: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
🔇 Additional comments (16)
flashinfer/artifacts.py (1)
90-90: LGTM: Artifact version updates for sparse MLA kernels.The artifact path and checksum updates align with the new sparse MLA kernel generation. These changes are consistent with the broader PR objective of adding sparse MLA support.
Also applies to: 110-110
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py (1)
253-255: LGTM: Proper test skipping behavior.Using
pytest.skipinstead of raising aValueErroris the correct approach when test requirements (sufficient GPUs) aren't met. This allows test suites to gracefully skip rather than fail.include/flashinfer/trtllm/fmha/kernelParams.h (3)
489-490: LGTM: Dimension check relaxation for sparse MLA.Relaxing the minimum TMA descriptor dimension from 3 to 2 is necessary to support the 2D K/V layout used in sparse MLA paths (as configured at lines 601-608).
733-737: LGTM: Proper alignment validation for sparse MLA top-k.The validation that
SparseMlaTopKmust be a multiple of 4 is correct, as the implementation uses 16B cpAsync instructions for loading indices. The error message clearly explains the constraint.
601-608: Based on my verification of the codebase, the concerns raised in the review have been addressed by the existing code design.Key findings:
The TMA descriptor validation at lines 487-495 checks that each shape dimension is within the valid range [1, 2^32]. INT_MAX (2,147,483,647 = 2^31 - 1) is well within this valid range, so the descriptor creation will succeed.
More importantly, the actual kernel access pattern is bounded:
maxAttentionWindow = std::min(params.mMaxSeqLenKv, params.mSparseMlaTopK), with a comment explicitly stating "The sparseMla only selects topK tokensKv". This means the INT_MAX dimension is not actually accessed beyond the mSparseMlaTopK boundary—it's a design pattern to indicate an unbounded logical dimension while the kernel controls actual accesses through the attention window calculation.Additionally, mSparseMlaTopK is validated at construction time to be a multiple of 4, and stride validation requires strides[0] == 1, which is satisfied by
strideK = [1, mHeadDimQk]. The second stride component (mHeadDimQk) is bounded by head dimension constraints, eliminating overflow risks.include/flashinfer/trtllm/fmha/fmhaKernels.cuh (3)
336-339: LGTM: Proper attention window clamping for sparse MLA.When sparse MLA is enabled, the maximum attention window is correctly clamped to the minimum of the KV sequence length and the sparse MLA top-k value, as sparse MLA only selects a subset of KV pairs.
372-373: LGTM: Correct kernel selection for sparse MLA.Disabling the tileSizeKv=64 optimization when sparse MLA is active is appropriate, as the sparse path requires different kernel characteristics.
539-547: LGTM: Correct numTokensPerPage handling for sparse MLA.Setting
numTokensPerPage = 1for sparse MLA is correct, as the sparse indexing treats each token independently. For non-paged layouts, setting it to 0 is also appropriate.tests/attention/test_trtllm_gen_attention.py (1)
417-417: LGTM: Proper head dimension parameterization.Adding
head_dimas a test parameter with values [128, 256] extends test coverage to include both standard and MLA-specific head dimensions. The parameter is correctly threaded through all test functions.Also applies to: 642-642, 657-657, 674-674, 696-696, 711-711, 728-728
csrc/fmhaReduction.cu (3)
67-77: LGTM: Proper early exit for sparse MLA.Adding the early exit when
ctaIdxQ >= seqLenQis correct for sparse MLA, where each CTA processes one query token and we may launch more CTAs than actual query tokens.
82-85: LGTM: Correct seqLenKv clamping for sparse MLA.When sparse MLA is enabled, clamping
seqLenKvtomin(seqLenKv, params.mSparseMlaTopK)correctly restricts the reduction to only the top-k selected KV pairs.
354-354: LGTM: Kernel signature and launch updated consistently.The function pointer type and kernel launch call are correctly updated to include the
bool sparseMlaparameter, maintaining consistency with the kernel signature changes.Also applies to: 364-365
csrc/trtllm_fmha_kernel_launcher.cu (3)
142-146: LGTM: Proper sparse MLA validation.The validation ensures sparse MLA is only enabled for decode-MLA configurations (head_dim_qk==576 and head_dim_vo==512). Setting
mSparseMla = sparse_mla_top_k > 0is a clean way to enable the feature.
85-86: LGTM: Consistent sparse_mla_top_k parameter threading.The
sparse_mla_top_kparameter is correctly:
- Added to the launcher signature after
o_sf_start_index- Threaded through the decode function
- Passed to the launcher with the correct value
The parameter positioning maintains consistency with the existing signature.
Also applies to: 210-216, 294-295
374-375: LGTM: Correct default for context path.Passing
sparse_mla_top_k=0for the context path correctly disables sparse MLA for prefill/context operations, which only applies to decode.flashinfer/decode.py (1)
1925-1930: sparse_mla_top_k explicitly zeroed for non‑MLA decode preserves prior behaviorBoth
TrtllmGenDecodeModule._paged_runand thetrtllm_batch_decode_with_kv_cachetrtllm‑gen path passsparse_mla_top_k=0intotrtllm_paged_attention_decode, which cleanly gates the new kernel parameter without changing existing dense decode semantics. This looks correct and keeps the non‑MLA paths backwards compatible.Also applies to: 2332-2337
|
[FAILED] Pipeline #39062444: 5/18 passed |
|
/bot run |
📌 Description
This MR adds trtllm-gen per-tensor sparseMla kernels.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
New Features
Performance
Tests
Validation
Chores
✏️ Tip: You can customize this high-level summary in your review settings.