-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA #24385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Kernel] Support decode context parallelism on Blackwell with CUTLASS MLA #24385
Conversation
Signed-off-by: Ming Yang <minos.future@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for returning log-sum-exp (LSE) values from the CUTLASS MLA decode kernel, which is a key requirement for enabling decode context parallelism on GB200. The changes are well-contained and correctly plumb the new lse tensor through the C++ kernel, PyTorch bindings, and the Python attention backend. I have identified one critical issue related to tensor shape consistency that needs to be addressed.
Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Ming Yang <minos.future@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Ming Yang <minos.future@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
locally verified that tests/distributed/test_context_parallel.py can pass on B200 now. thanks for the great job!
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for doing this!
… MLA (vllm-project#24385) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
… MLA (vllm-project#24385) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
… MLA (vllm-project#24385) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
… MLA (vllm-project#24385) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
… MLA (vllm-project#24385) Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
This PR supports decode context parallelism with CUTLASS MLA kernels on GB200
credits to #22789 from @LucasWilkinson, and @youkaichao
Test Plan
pytest -v -s tests/distributed/test_context_parallel.py
note: on GB200, needs to modify this UT to use only two GPUs. This can be added later or in a follow-up PR.
Test Result
both
-tp 2 -dcp 2and-tp 2workEssential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.