-
Notifications
You must be signed in to change notification settings - Fork 536
Support mtp run in full graph mode #3903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: v0.11.0-dev
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for running MTP models in full graph mode on Ascend hardware, which should improve performance. The changes adapt the MTP model with full graph capture and execution. The implementation introduces a new set of graph parameters and functions for MTP, which are duplicates of existing ones. My main feedback is to refactor this duplicated code to improve maintainability.
| @dataclass | ||
| class MTPGraphParams: | ||
| events: dict[int, list[torch.npu.ExternalEvent]] | ||
| workspaces: dict[int, torch.Tensor] | ||
| handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]] | ||
| attn_params: dict[int, list[tuple]] | ||
|
|
||
|
|
||
| _mtp_graph_params: Optional[MTPGraphParams] = None | ||
|
|
||
|
|
||
| def set_mtp_graph_params(aclgraph_capture_sizes: set[int]): | ||
| global _mtp_graph_params | ||
| if _mtp_graph_params is not None: | ||
| raise ValueError("MTPGraph parameters have already been set!") | ||
| _mtp_graph_params = MTPGraphParams( | ||
| {size: [] | ||
| for size in aclgraph_capture_sizes}, | ||
| {size: None | ||
| for size in aclgraph_capture_sizes}, | ||
| {size: [] | ||
| for size in aclgraph_capture_sizes}, | ||
| {size: [] | ||
| for size in aclgraph_capture_sizes}, | ||
| ) | ||
|
|
||
|
|
||
| def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any): | ||
| global _mtp_graph_params | ||
| if _mtp_graph_params is not None: | ||
| _mtp_graph_params.workspaces[num_tokens] = workspace | ||
|
|
||
|
|
||
| def get_mtp_graph_params(): | ||
| return _mtp_graph_params |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MTPGraphParams class and its associated functions (set_mtp_graph_params, update_mtp_graph_params_workspaces, get_mtp_graph_params) are duplicates of GraphParams and its functions. This introduces significant code duplication, which can lead to maintenance issues and potential bugs if one version is updated and the other is forgotten.
Consider refactoring to avoid this duplication. You could, for example, use a single set of functions and a dictionary to manage parameters for different graph types (e.g., DEFAULT and MTP). This would involve:
- Removing
MTPGraphParamsand its related functions.GraphParamscan be used for both. - Using a dictionary to store
GraphParamsinstances for different graph types, keyed by an enum. - Modifying the
set/update/getfunctions to accept agraph_typeparameter to operate on the correctGraphParamsinstance.
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
14f3211 to
59b458f
Compare
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
What this PR does / why we need it?
Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance.
Does this PR introduce any user-facing change?
How was this patch tested?