Skip to content

Conversation

@anon189Ty
Copy link
Contributor

What this PR does / why we need it?

Currently, the MTP model still runs in eager in full graph mode. This PR adapts the MTP with the full graph capture and execution. When the graph mode is set to "FULL_DECODE_ONLY", the MTP will run in full-graph to improve the performance.

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for running MTP models in full graph mode on Ascend hardware, which should improve performance. The changes adapt the MTP model with full graph capture and execution. The implementation introduces a new set of graph parameters and functions for MTP, which are duplicates of existing ones. My main feedback is to refactor this duplicated code to improve maintainability.

Comment on lines +344 to +367
@dataclass
class MTPGraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]


_mtp_graph_params: Optional[MTPGraphParams] = None


def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
global _mtp_graph_params
if _mtp_graph_params is not None:
raise ValueError("MTPGraph parameters have already been set!")
_mtp_graph_params = MTPGraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)


def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any):
global _mtp_graph_params
if _mtp_graph_params is not None:
_mtp_graph_params.workspaces[num_tokens] = workspace


def get_mtp_graph_params():
return _mtp_graph_params
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The MTPGraphParams class and its associated functions (set_mtp_graph_params, update_mtp_graph_params_workspaces, get_mtp_graph_params) are duplicates of GraphParams and its functions. This introduces significant code duplication, which can lead to maintenance issues and potential bugs if one version is updated and the other is forgotten.

Consider refactoring to avoid this duplication. You could, for example, use a single set of functions and a dictionary to manage parameters for different graph types (e.g., DEFAULT and MTP). This would involve:

  1. Removing MTPGraphParams and its related functions. GraphParams can be used for both.
  2. Using a dictionary to store GraphParams instances for different graph types, keyed by an enum.
  3. Modifying the set/update/get functions to accept a graph_type parameter to operate on the correct GraphParams instance.

@github-actions
Copy link

github-actions bot commented Nov 3, 2025

This pull request has conflicts, please resolve those before we can evaluate the pull request.

Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Signed-off-by: anon189Ty <Stari_Falcon@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant