Skip to content

Conversation

@YAMY1234
Copy link
Contributor

@YAMY1234 YAMY1234 commented Nov 23, 2025

Description

Fix TMA descriptor failures for ragged KV layouts in the TRT-LLM FMHA path.

When using trtllm_ragged_attention_deepseek with a non-paged, non-contiguous KV layout (ragged KV), the KV batch dimension is effectively collapsed (shape [head_dim, sum_seq_lens_kv, num_heads_kv, 1]). However, kStrideBatch / vStrideBatch can be set to a large numel()-based value that overflows a 32-bit int and becomes a negative sentinel. This negative stride is then interpreted as a uint64_t in buildNdTmaDescriptor, producing an enormous strideInBytes and causing cuTensorMapEncodeTiled to fail with:

Error: Failed to initialize the TMA descriptor due to invalid argument

This PR updates makeStrideKv so that for ragged KV layouts (!isPagedKv && !isContiguousKv), any negative strideBatch is treated as a sentinel and clamped to 0. This matches the actual memory layout (no real batch stride for the collapsed batch dimension) and prevents overflow in the TMA descriptor.

🔍 Related Issues

  • SGLang: DeepSeek-R1 + trtllm_ragged_attention_deepseek on SM100 with long KV (e.g., 128k total KV tokens) failing at buildNdTmaDescriptor with invalid TMA configuration.

[Bug] [DeepSeek-R1] Error: Failed to initialize the TMA descriptor due to invalid argument on B200

Pull Request Checklist

Pre-commit Checks

  • I have installed pre-commit (e.g., pip install pre-commit).
  • I have installed the hooks with pre-commit install.
  • I have run pre-commit run --all-files and fixed any reported issues.

Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

  • The change is intentionally minimal and only affects ragged KV layouts.
  • The goal is to keep the existing device kernels intact while fixing the host-side TMA descriptor construction for this layout.

Summary by CodeRabbit

  • Bug Fixes

    • Prevented negative batch stride for ragged key/value cache layouts, avoiding invalid descriptor errors in large or irregular attention workloads.
  • Tests

    • Added an integration test that reproduces large ragged KV scenarios to verify the clamped stride behavior on CUDA systems.
    • Updated a communication test to skip when required GPUs aren't available instead of failing, improving test robustness.

✏️ Tip: You can customize this high-level summary in your review settings.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 23, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Clamps a negative batch stride to zero in KernelParams::makeStrideKv for ragged (neither paged nor contiguous) Q/K/V layouts; adds a CUDA regression test exercising very large total KV to trigger int32/stride overflow paths; changes one test to skip when requested world size exceeds available GPUs.

Changes

Cohort / File(s) Summary
TMA Descriptor Batch Stride Clamping
include/flashinfer/trtllm/fmha/kernelParams.h
Added conditional logic in KernelParams::makeStrideKv to set strideBatch = 0 when the Q/K/V layout is neither paged nor contiguous and strideBatch is negative (clamps negative batch stride for ragged KV layouts).
Ragged KV Integration Test
tests/attention/test_trtllm_ragged_kv_stride.py
New CUDA-focused test test_trtllm_ragged_kv_large_stride_overflow that constructs ragged Q/KV tensors with very large total_kv to exercise int32/stride overflow handling; includes CUDA/SM checks, seeded randomness, workspace allocation, call to trtllm_ragged_attention_deepseek, and output shape assertions.
Test skip instead of error
tests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py
Replaced raising ValueError when world_size exceeds available GPUs with pytest.skip to skip the test in that scenario.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Integration Test
    participant Prefill as flashinfer.prefill.trtllm_ragged_attention_deepseek
    participant KP as KernelParams::makeStrideKv
    participant TMA as TMA Descriptor

    Test->>Prefill: call prefill with ragged Q/KV + cum_seq_lens
    Prefill->>KP: compute strides for Q/K/V
    KP-->>KP: compute strideBatch (may be negative)
    alt ragged layout and strideBatch < 0
        KP->>KP: clamp strideBatch = 0
    end
    KP->>TMA: populate descriptor with non-negative batch stride
    Prefill->>Test: return prefilled tensor
    Note over TMA: avoids invalid TMA descriptor when total_kv is very large
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

  • Check layout-detection logic (paged vs contiguous vs ragged) and boundary conditions.
  • Validate no behavioral change for paged/contiguous layouts.
  • Run or inspect the new CUDA test and its SM/device guards.
  • Confirm skip change in comm test preserves intent and messaging.

Possibly related PRs

Suggested reviewers

  • joker-eph
  • aleozlx
  • wenscarl
  • cyx-6
  • nvmbreughe

Poem

🐇
I hopped through buffers, late at night,
Found a stride that slipped from sight.
One gentle clamp, I set it right,
Ragged keys now sleep polite. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title correctly summarizes the main fix: resetting negative strideBatch to 0 for ragged KV layouts in the TRT-LLM FMHA path.
Description check ✅ Passed The description is comprehensive and complete. It clearly explains the problem (int32 overflow leading to TMA descriptor failure), the root cause, the solution, and references related issues. All checklist items are marked as completed.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 96f3713 and bc5e69c.

📒 Files selected for processing (1)
  • tests/attention/test_trtllm_ragged_kv_stride.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/attention/test_trtllm_ragged_kv_stride.py

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request provides a targeted fix for a critical issue within the TRT-LLM Flash Attention mechanism, specifically impacting ragged Key-Value (KV) layouts. It addresses an integer overflow in stride calculations that led to negative strideBatch values, subsequently causing Tensor Memory Access (TMA) descriptor initialization to fail. By clamping these erroneous negative strides to zero for ragged layouts, the change ensures the stability and correct operation of the FMHA path for advanced memory configurations, preventing runtime errors.

Highlights

  • Fixes TMA Descriptor Failures: Resolves an issue where Tensor Memory Access (TMA) descriptor initialization failed for ragged Key-Value (KV) layouts in the TRT-LLM Flash Attention (FMHA) path.
  • Addresses Negative Stride Overflow: Specifically targets kStrideBatch/vStrideBatch values that would overflow into a negative int32 sentinel, causing cuTensorMapEncodeTiled to fail with an "invalid argument" error.
  • Clamps StrideBatch to Zero: Implements a change in makeStrideKv to reset any negative strideBatch to 0 when dealing with non-paged, non-contiguous (ragged) KV layouts, aligning with the actual memory layout and preventing overflow.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a targeted fix for an issue where strideBatch could overflow to a negative value for ragged KV layouts, leading to TMA descriptor failures. The change correctly identifies the ragged layout case and clamps the negative (overflowed) stride to zero, which is the correct behavior for this memory layout. The fix is clean and minimal. I have one minor suggestion to improve code conciseness.

Copy link
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the bugfix and it looks good to me overall, would you mind adding a unittest for the case

However, kStrideBatch/vStrideBatch can be set to a large numel()-based value that overflows into a negative int32 sentinel. This negative stride is then cast to in buildNdTmaDescriptor, producing an enormous strideInBytes and causing cuTensorMapEncodeTiled to fail with:

which used to fail but works after your fix?

@YAMY1234
Copy link
Contributor Author

Thanks for the bugfix and it looks good to me overall, would you mind adding a unittest for the case

However, kStrideBatch/vStrideBatch can be set to a large numel()-based value that overflows into a negative int32 sentinel. This negative stride is then cast to in buildNdTmaDescriptor, producing an enormous strideInBytes and causing cuTensorMapEncodeTiled to fail with:

which used to fail but works after your fix?

Thanks for the suggestion! Unittest is added @yzh119

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tests/attention/test_trtllm_ragged_kv_stride.py (1)

91-113: Function call parameters match the API signature.

The call to trtllm_ragged_attention_deepseek includes all required parameters and correctly omits optional ones. The test implicitly verifies that no TMA descriptor error is raised (pytest fails on uncaught exceptions).

Consider adding a brief comment explaining the 128MB workspace buffer size choice, e.g., based on empirical testing or documented requirements.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0c88758 and efb8990.

📒 Files selected for processing (2)
  • include/flashinfer/trtllm/fmha/kernelParams.h (1 hunks)
  • tests/attention/test_trtllm_ragged_kv_stride.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_ragged_kv_stride.py (2)
flashinfer/utils.py (1)
  • get_compute_capability (252-255)
flashinfer/prefill.py (1)
  • trtllm_ragged_attention_deepseek (3194-3317)
🔇 Additional comments (5)
tests/attention/test_trtllm_ragged_kv_stride.py (5)

1-9: LGTM!

The imports are appropriate and the test is correctly marked for CUDA execution.


10-29: Well-structured preflight checks.

The comprehensive guards (CUDA availability, function existence, and SM 10.0+ requirement) ensure the test runs only in supported environments. The SM 10.0 requirement targets Blackwell and newer architectures, which is appropriately restrictive for this TRT-LLM feature.


31-60: LGTM!

The ragged Q construction is correct. The head dimensions (192 for QK, 128 for VO) match the API requirements for DeepSeek R1, and the cumulative sequence length computation follows the expected pattern.


62-89: Overflow scenario correctly constructed.

The configuration successfully creates a ragged KV scenario where key.numel() = 3,221,225,472 > 2^31, which triggers the int32 overflow condition that the fix addresses. The cumulative sequence lengths are properly constructed for the ragged layout.


115-118: Shape assertions are appropriate for this regression test.

The output shape checks confirm the function executes successfully and returns the expected tensor dimensions. For a test focused on preventing TMA descriptor errors (not output correctness), these assertions are sufficient.

@yzh119
Copy link
Collaborator

yzh119 commented Nov 23, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !159 has been created, and the CI pipeline #39039563 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39039563: 2/18 passed

@yzh119 yzh119 requested a review from IwakuraRein as a code owner November 24, 2025 08:34
@yzh119
Copy link
Collaborator

yzh119 commented Nov 24, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !159 has been updated with latest changes, and the CI pipeline #39062279 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Collaborator

[FAILED] Pipeline #39062279: 9/18 passed

@yzh119
Copy link
Collaborator

yzh119 commented Nov 25, 2025

/bot run

@flashinfer-bot
Copy link
Collaborator

GitLab MR !159 has been updated with latest changes, and the CI pipeline #39125055 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 merged commit aeeccac into flashinfer-ai:main Nov 25, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants