-
-
Couldn't load subscription status.
- Fork 10.8k
[NVIDIA] Support Cutlass MLA for Blackwell GPUs #16032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
fbcf237 to
b20ac92
Compare
eae2486 to
3c17c62
Compare
8d29c8c to
85049f8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks pretty good, left a couple nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates, left a few more nits (they can be punted to a future PR if you think thats more appropriate), overall though LGTM. Thanks for the contribution!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: is this wrapper required? can we just do:
template <typename T, bool PersistenceOption = true>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
vllm/_custom_ops.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this probably shouldn't be hard coded to 512, we should pass in latent size, also we should pass in the softmax scale so we can avoid hardcoding:
// the scale is based on the non-absorbed sizes, change as appropriate
// we can't determine this parameter from the info we have, it's an input
int D_non_latent = 128;
float scale = 1.0 / sqrt(1.0 * (D_non_latent + D_rope));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. PTAL.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we should pass in the scale (from Python) to avoid having to har code D_non_latent
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe for a future PR, if Q_ptr (q_nope) and Q_ptr + D_latent (q_pe) as seperate tensor (assuming the kernel is ok with these being arbitrary pointers and having separate strides (based on this interface it looks it could), then we can save the cat of q_nope and q_pe e.g.:
vllm/vllm/v1/attention/backends/mla/flashmla.py
Lines 133 to 134 in 7011645
| q = torch.cat([q_nope, q_pe], dim=-1)\ | |
| .unsqueeze(1) # Add seqlen dim of 1 (decode) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, currently we follow the cutlass example, which only supports the single query tensor. If needed or this is a common practice, we can ask for an improvement.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
im confused, it appears to support multiple: https://github.com/NVIDIA/cutlass/blob/e94e888df3551224738bfa505787b515eae8352f/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp#L246-L249
am I missing something here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried the separate tensors and it works. Updated the PR. PTAL.
4d28698 to
596c81a
Compare
Signed-off-by: kaixih <kaixih@nvidia.com>
596c81a to
985034c
Compare
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
|
Landing to help Blackwell perf but would like to follow up on: #16032 (comment) in a future PR potentially |
Signed-off-by: kaixih <kaixih@nvidia.com>
Head branch was pushed to by a user without write access
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com>
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: Agata Dobrzyniewicz <adobrzyniewicz@habana.ai>
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Signed-off-by: kaixih <kaixih@nvidia.com> Signed-off-by: Yuqi Zhang <yuqizhang@google.com>
The latest cutlass supports MLA for the blackwell GPUs. Examples can be found here. It should be available in the next release (v3.9).
This PR integrates this kernel as
ops.cutlass_mla_decode.cc. @kushanam