4343from vllm .distributed .kv_transfer .kv_connector .v1 import KVConnectorBase_V1
4444from vllm .distributed .parallel_state import (get_dp_group , get_pp_group ,
4545 get_tp_group )
46- from vllm .forward_context import get_forward_context
46+ from vllm .forward_context import DPMetadata , get_forward_context
4747from vllm .logger import logger
4848from vllm .model_executor .layers .fused_moe import FusedMoE
4949from vllm .model_executor .layers .rotary_embedding import MRotaryEmbedding
8080from vllm_ascend .attention .attention_v1_torchair import AscendTorchairMetadata
8181from vllm_ascend .attention .mla_v1 import AscendMLAMetadata
8282from vllm_ascend .distributed .moe_comm_method import (AllGatherCommImpl ,
83- NativeAllGatherCommImpl ,
8483 DummyCommImpl ,
8584 MoECommMethod )
8685from vllm_ascend .multistream .ms_split import compute_split_seq_index
@@ -1029,6 +1028,32 @@ def _gather_mm_embeddings(
10291028 mm_embeds .append (mm_embeds_item )
10301029 return mm_embeds
10311030
1031+ def get_dp_padding (self ,
1032+ num_tokens : int ) -> tuple [int , Optional [torch .Tensor ]]:
1033+ """This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
1034+ Please note that vLLM may refactor or modify this function over time,
1035+ at present, we are using the version introduced in PR #18935.
1036+ """
1037+ dp_size = self .vllm_config .parallel_config .data_parallel_size
1038+ dp_rank = self .vllm_config .parallel_config .data_parallel_rank
1039+
1040+ # For DP: Don't pad when setting enforce_eager.
1041+ # This lets us set enforce_eager on the prefiller in a P/D setup and
1042+ # still use ACL graphs (enabled by this padding) on the decoder.
1043+
1044+ if dp_size == 1 or self .vllm_config .model_config .enforce_eager :
1045+ # Early exit.
1046+ return 0 , None
1047+
1048+ num_tokens_across_dp = DPMetadata .num_tokens_across_dp (
1049+ num_tokens , dp_size , dp_rank )
1050+ max_tokens_across_dp_cpu = torch .max (num_tokens_across_dp ).item ()
1051+ num_tokens_after_padding = torch .tensor ([max_tokens_across_dp_cpu ] *
1052+ dp_size ,
1053+ device = "cpu" ,
1054+ dtype = torch .int32 )
1055+ return max_tokens_across_dp_cpu - num_tokens , num_tokens_after_padding
1056+
10321057 def _process_reqs (
10331058 self ,
10341059 scheduler_output : "SchedulerOutput" ,
@@ -1051,6 +1076,11 @@ def _process_reqs(
10511076 # Eager mode.
10521077 num_input_tokens = total_num_scheduled_tokens
10531078
1079+ # Padding for DP
1080+ num_pad , num_tokens_across_dp_native = self .get_dp_padding (
1081+ num_input_tokens )
1082+ num_input_tokens += num_pad
1083+
10541084 modified_batch = self .attn_metadata_builder .reorder_batch (
10551085 self .input_batch , scheduler_output )
10561086 if modified_batch :
@@ -1280,8 +1310,11 @@ def _process_reqs(
12801310
12811311 # NOTE: Currently this padding logic is really messy,
12821312 # MC2 may not be available in eager mode
1313+ # TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
12831314 if not self .use_aclgraph or self .torchair_graph_enabled :
12841315 num_input_tokens = padded_num_tokens_across_dp
1316+ else :
1317+ num_tokens_across_dp = num_tokens_across_dp_native
12851318
12861319 # Run forward pass
12871320 with set_ascend_forward_context (
0 commit comments