-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[Feature] Support Decode Context Parallel (DCP) for MLA #23734
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support Decode Context Parallel (DCP) for MLA #23734
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Context Parallelism (CP) support for MLA inference, which is a significant feature enhancement. The changes are extensive, touching configuration, parallel state management, scheduling, KV cache, and attention backends. The implementation seems well-thought-out, with new CUDA kernels for CP-specific operations and corresponding Python wrappers and tests. The end-to-end tests comparing CP with TP are a good validation strategy.
My review found one critical bug fix in the cuda_communicator.py file, where a reduce_scatter operation was using a potentially non-contiguous tensor, which could lead to incorrect results. The provided patch correctly fixes this issue. The rest of the implementation for context parallelism appears solid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the great work!
as discussed, there can be two types of cp, cp for prefill (where the world size is enlarged by cp) and cp for decode (where the world size does not change by cp). if possible, let's denote the current pr as decode-context-parallel-size and dcp_size to leave room for prefill cp in the future.
|
@youzhedian to accelerate the review and merge (especially ci testing), maybe we can split the kernel side changes to a separate PR and get it merged first. then follow-up PRs can use pre-compiled wheels from that PR, with much faster ci testing. |
|
I've just come across this PR adding |
#23791 as suggested |
|
Cool thanks for taking this on! I think this can be done without any GPU model runner changes; I was working on a prototype but got unfortunately it got backburned for few months 😞 anyways just sharing here for an alternative solution that doesn't require as much more of the core code but potentially more susceptible to imbalance (its not fully functional yet) |
|
@wushidonguc has imported this pull request. If you are a Meta employee, you can view this in D81728831. |
…#23734) Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
…#23734) Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
…#23734) Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
…#23734) Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
|
Is this PR compatible with #22668 ? @youzhedian |
…#23734) Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This PR adds Decode Context Parallel (DCP) support for MLA inference, fully compatible with chunked prefill and APC.
You can enable DCP with
--decode-context-parallel-size/-dcp xxx(only support flashmla backend now), and tp_size needs to be divisible by dcp_size, because the world size does not change by dcp, it simply reuse the GPUs of TP group, and split one TP group into tp_size//dcp_size DCP groups. e.g.This DCP implement store kvcache with an interleave style, the kvcache for the token whose
token_idxisiis always stored on the GPU whosedcp_rankequals toi % dcp_world_size:deepseek-ai/DeepSeek-V2-Lite-Chat gsm8k eval:
more info pls ref introduce Doc
Future work (These items will be tackled in follow-up PRs; community contributions are warmly welcomed.):