Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def set_ascend_forward_context(
batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
weight_prefetch_method: Optional[WeightPrefetchMethod] = None,
is_mtp_model=False):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
We add some additional param into forward_context.
Expand Down Expand Up @@ -153,6 +154,7 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
forward_context.model_instance = model_instance
forward_context.weight_prefetch_method = weight_prefetch_method
forward_context.is_mtp_model = is_mtp_model

# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph.
Expand Down
6 changes: 5 additions & 1 deletion vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
get_mtp_graph_params,
update_graph_params_workspaces)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
Expand Down Expand Up @@ -1022,8 +1023,11 @@ def _forward_decode(
"actual_seq_lengths": actual_seq_lengths,
"actual_seq_lengths_kv": decode_meta.seq_lens_list,
}
graph_params = get_graph_params()
forward_context: ForwardContext = get_forward_context()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
if forward_context.capturing:
stream = torch_npu.npu.current_stream()

Expand Down
48 changes: 46 additions & 2 deletions vllm_ascend/compilation/acl_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,10 @@ def update_attn_params(update_stream, forward_context, runtime_shape):

def update_mla_attn_params(update_stream, forward_context, runtime_shape,
speculative_config):
graph_params = get_graph_params()
if forward_context.is_mtp_model:
graph_params = get_mtp_graph_params()
else:
graph_params = get_graph_params()
# FIXME: Behold! We are using a temporary hack here to update the args
# for each layer's attention op in the graph.
for key, param, handle, event in zip(
Expand All @@ -245,7 +248,8 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
spec_attn_mask, sparse_mode, scale, block_table, block_size,
seq_lens_list, actual_seq_lengths, attn_output, softmax_lse) = param
seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list
if speculative_config and speculative_config.method == "deepseek_mtp":
if speculative_config and speculative_config.method == "deepseek_mtp" \
and not forward_context.is_mtp_model:
actual_seq_lengths = forward_context.attn_metadata[
key].decode.actual_seq_lengths_q
spec_multiple = speculative_config.num_speculative_tokens + 1
Expand All @@ -255,6 +259,9 @@ def update_mla_attn_params(update_stream, forward_context, runtime_shape,
spec_multiple * (i + 1)
for i in range(runtime_shape // spec_multiple)
]
elif forward_context.is_mtp_model:
seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) -
len(seq_lens_list))
else:
seq_lens_list = seq_lens_list + [0] * (runtime_shape -
len(seq_lens_list))
Expand Down Expand Up @@ -321,3 +328,40 @@ def update_graph_params_workspaces(num_tokens: int, workspace: Any):

def get_graph_params():
return _graph_params


@dataclass
class MTPGraphParams:
events: dict[int, list[torch.npu.ExternalEvent]]
workspaces: dict[int, torch.Tensor]
handles: dict[int, list[torch_npu._C._NPUTaskGroupHandle]]
attn_params: dict[int, list[tuple]]


_mtp_graph_params: Optional[MTPGraphParams] = None


def set_mtp_graph_params(aclgraph_capture_sizes: set[int]):
global _mtp_graph_params
if _mtp_graph_params is not None:
raise ValueError("MTPGraph parameters have already been set!")
_mtp_graph_params = MTPGraphParams(
{size: []
for size in aclgraph_capture_sizes},
{size: None
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
{size: []
for size in aclgraph_capture_sizes},
)


def update_mtp_graph_params_workspaces(num_tokens: int, workspace: Any):
global _mtp_graph_params
if _mtp_graph_params is not None:
_mtp_graph_params.workspaces[num_tokens] = workspace


def get_mtp_graph_params():
return _mtp_graph_params
Comment on lines +333 to +367
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The MTPGraphParams class and its associated functions (set_mtp_graph_params, update_mtp_graph_params_workspaces, get_mtp_graph_params) are duplicates of GraphParams and its functions. This introduces significant code duplication, which can lead to maintenance issues and potential bugs if one version is updated and the other is forgotten.

Consider refactoring to avoid this duplication. You could, for example, use a single set of functions and a dictionary to manage parameters for different graph types (e.g., DEFAULT and MTP). This would involve:

  1. Removing MTPGraphParams and its related functions. GraphParams can be used for both.
  2. Using a dictionary to store GraphParams instances for different graph types, keyed by an enum.
  3. Modifying the set/update/get functions to accept a graph_type parameter to operate on the correct GraphParams instance.

Loading