- 
          
- 
                Notifications
    You must be signed in to change notification settings 
- Fork 10.9k
[BugFix] Fix UB in per_token_group_quant.cu #24913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix UB in per_token_group_quant.cu #24913
Conversation
| 👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run  You ask your reviewers to trigger select CI tests on top of  Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add  If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses an undefined behavior bug in the GroupReduceMax CUDA function. The original implementation used a fixed mask for __shfl_xor_sync that was only valid for threads in the lower half of a warp, causing UB for threads in the upper half. The fix correctly computes the mask at runtime based on the thread's lane ID within its warp, ensuring that the shuffle operation is always valid. The change also removes an unused parameter from the function signature. The fix is correct and effectively resolves the reported issue. The changes are well-targeted and I have no further suggestions.
af95a41    to
    c5075c1      
    Compare
  
    c5075c1    to
    6f79ba1      
    Compare
  
    Head branch was pushed to by a user without write access
9e57436    to
    408a074      
    Compare
  
    The GroupReduceMax function causes Undefined Behavior. Existing mask specifies a half-warp causing UB for the upper 16 half of the threads in a warp. Use threadIdx.x to select the correct mask at runtime. Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
408a074    to
    dd69539      
    Compare
  
    Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com>
Signed-off-by: Shreeasish Kumar <shreeasish@rivosinc.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
The GroupReduceMax function causes Undefined Behavior. Existing mask specifies a half-warp causing UB for the upper 16 half of the threads in a warp.
Use threadIdx.x to select the correct mask at runtime.
This change follows a patch to SGLang for the same kernel.
Test Plan
Existing tests cover this kernel.
Test Result
Existing tests including CI should continue passing.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.