-
Notifications
You must be signed in to change notification settings - Fork 583
fix(trtllm): reset negative strideBatch to 0 for ragged KV layout to … #2134
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix(trtllm): reset negative strideBatch to 0 for ragged KV layout to … #2134
Conversation
…avoid TMA descriptor error
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughClamps a negative batch stride to zero in KernelParams::makeStrideKv for ragged (neither paged nor contiguous) Q/K/V layouts; adds a CUDA regression test exercising very large total KV to trigger int32/stride overflow paths; changes one test to skip when requested world size exceeds available GPUs. Changes
Sequence Diagram(s)sequenceDiagram
participant Test as Integration Test
participant Prefill as flashinfer.prefill.trtllm_ragged_attention_deepseek
participant KP as KernelParams::makeStrideKv
participant TMA as TMA Descriptor
Test->>Prefill: call prefill with ragged Q/KV + cum_seq_lens
Prefill->>KP: compute strides for Q/K/V
KP-->>KP: compute strideBatch (may be negative)
alt ragged layout and strideBatch < 0
KP->>KP: clamp strideBatch = 0
end
KP->>TMA: populate descriptor with non-negative batch stride
Prefill->>Test: return prefilled tensor
Note over TMA: avoids invalid TMA descriptor when total_kv is very large
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @YAMY1234, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request provides a targeted fix for a critical issue within the TRT-LLM Flash Attention mechanism, specifically impacting ragged Key-Value (KV) layouts. It addresses an integer overflow in stride calculations that led to negative Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a targeted fix for an issue where strideBatch could overflow to a negative value for ragged KV layouts, leading to TMA descriptor failures. The change correctly identifies the ragged layout case and clamps the negative (overflowed) stride to zero, which is the correct behavior for this memory layout. The fix is clean and minimal. I have one minor suggestion to improve code conciseness.
yzh119
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the bugfix and it looks good to me overall, would you mind adding a unittest for the case
However, kStrideBatch/vStrideBatch can be set to a large numel()-based value that overflows into a negative int32 sentinel. This negative stride is then cast to in buildNdTmaDescriptor, producing an enormous strideInBytes and causing cuTensorMapEncodeTiled to fail with:
which used to fail but works after your fix?
Thanks for the suggestion! Unittest is added @yzh119 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tests/attention/test_trtllm_ragged_kv_stride.py (1)
91-113: Function call parameters match the API signature.The call to
trtllm_ragged_attention_deepseekincludes all required parameters and correctly omits optional ones. The test implicitly verifies that no TMA descriptor error is raised (pytest fails on uncaught exceptions).Consider adding a brief comment explaining the 128MB workspace buffer size choice, e.g., based on empirical testing or documented requirements.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
include/flashinfer/trtllm/fmha/kernelParams.h(1 hunks)tests/attention/test_trtllm_ragged_kv_stride.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- include/flashinfer/trtllm/fmha/kernelParams.h
🧰 Additional context used
🧬 Code graph analysis (1)
tests/attention/test_trtllm_ragged_kv_stride.py (2)
flashinfer/utils.py (1)
get_compute_capability(252-255)flashinfer/prefill.py (1)
trtllm_ragged_attention_deepseek(3194-3317)
🔇 Additional comments (5)
tests/attention/test_trtllm_ragged_kv_stride.py (5)
1-9: LGTM!The imports are appropriate and the test is correctly marked for CUDA execution.
10-29: Well-structured preflight checks.The comprehensive guards (CUDA availability, function existence, and SM 10.0+ requirement) ensure the test runs only in supported environments. The SM 10.0 requirement targets Blackwell and newer architectures, which is appropriately restrictive for this TRT-LLM feature.
31-60: LGTM!The ragged Q construction is correct. The head dimensions (192 for QK, 128 for VO) match the API requirements for DeepSeek R1, and the cumulative sequence length computation follows the expected pattern.
62-89: Overflow scenario correctly constructed.The configuration successfully creates a ragged KV scenario where
key.numel() = 3,221,225,472 > 2^31, which triggers the int32 overflow condition that the fix addresses. The cumulative sequence lengths are properly constructed for the ragged layout.
115-118: Shape assertions are appropriate for this regression test.The output shape checks confirm the function executes successfully and returns the expected tensor dimensions. For a test focused on preventing TMA descriptor errors (not output correctness), these assertions are sufficient.
|
/bot run |
|
[FAILED] Pipeline #39039563: 2/18 passed |
|
/bot run |
|
[FAILED] Pipeline #39062279: 9/18 passed |
|
/bot run |
Description
Fix TMA descriptor failures for ragged KV layouts in the TRT-LLM FMHA path.
When using
trtllm_ragged_attention_deepseekwith a non-paged, non-contiguous KV layout (ragged KV), the KV batch dimension is effectively collapsed (shape[head_dim, sum_seq_lens_kv, num_heads_kv, 1]). However,kStrideBatch/vStrideBatchcan be set to a largenumel()-based value that overflows a 32-bitintand becomes a negative sentinel. This negative stride is then interpreted as auint64_tinbuildNdTmaDescriptor, producing an enormousstrideInBytesand causingcuTensorMapEncodeTiledto fail with:This PR updates
makeStrideKvso that for ragged KV layouts (!isPagedKv && !isContiguousKv), any negativestrideBatchis treated as a sentinel and clamped to0. This matches the actual memory layout (no real batch stride for the collapsed batch dimension) and prevents overflow in the TMA descriptor.🔍 Related Issues
trtllm_ragged_attention_deepseekon SM100 with long KV (e.g., 128k total KV tokens) failing atbuildNdTmaDescriptorwith invalid TMA configuration.[Bug] [DeepSeek-R1] Error: Failed to initialize the TMA descriptor due to invalid argument on B200
Pull Request Checklist
Pre-commit Checks
pre-commit(e.g.,pip install pre-commit).pre-commit install.pre-commit run --all-filesand fixed any reported issues.Tests
Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests
✏️ Tip: You can customize this high-level summary in your review settings.