-
-
Notifications
You must be signed in to change notification settings - Fork 11.3k
[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel #23280
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Perf] Use upstream CUTLASS for SM90 Block FP8 kernel #23280
Conversation
Signed-off-by: mgoin <mgoin64@gmail.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly enables the SM90 CUTLASS Block FP8 kernel for weight shapes that are not divisible by 128, addressing a previous limitation and a TODO in the code. The logic in vllm/model_executor/layers/quantization/utils/fp8_utils.py has been simplified and corrected to use a more general condition for SM90+ GPUs, which will improve performance for models like DeepSeekV3. The newly added test case in tests/kernels/quantization/test_block_fp8.py effectively validates that the CUTLASS kernel works as expected with non-aligned shapes. The changes are well-implemented and look good.
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: mgoin <mgoin64@gmail.com>
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the work!
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
Show resolved
Hide resolved
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
Show resolved
Hide resolved
yewentao256
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the work!
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh
Show resolved
Hide resolved
LucasWilkinson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing; thank you for doing this!
…3280) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…3280) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…3280) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
…3280) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…3280) Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
The
weight.shape % 128 != 0limitation of the SM90 CUTLASS implementation was a product of our custom implementation and works now on modern CUTLASS. We can completely removecsrc/cutlass_extensions/gemm/and migrate to using CUTLASS directly.Caveats with the current support in CUTLASS is that there is a restriction of
M%4==0and weight block scales layout when usingKernelTmaWarpSpecializedCooperativeFP8BlockScaledAccumfor theMainloopScheduler. We work around this by:fp8.py::process_weights_after_loadinghttps://github.com/vllm-project/vllm/pull/23280/files#diff-5511bfcc9c53f7d96517ad43e4087f6777bef21302da983f42cafae40a866644R466cutlass::detail::Sm90BlockwiseScaleConfigto support thisThis means we can replace the triton kernel when running the DeepSeekV3 layers
kv_a_proj_with_mqawith shape[576, 7168]andfused_qkv_a_projwith shape[2112, 7168]for better performance on Hopper.Additional tuning should come next to help with smaller M
Test Plan
Added kernel test case that works. Will manually test DeepSeek and profile
Test Result
GSM8k for DSV3:
Profiling result with dummy deepseek using
llm = LLM(model="deepseek-ai/DeepSeek-R1", hf_overrides={'num_hidden_layers': 4}, load_format="dummy"). You can see the_w8a8_block_fp8_matmultriton kernel in before and none of that in the after, with e2e perf being better.(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.