Skip to content

Commit fc3899e

Browse files
committed
feat: Implement DP padding logic in NPUModelRunner
Adds data parallelism (DP) padding to ensure token tensors have a uniform shape across all DP ranks. This change mirrors the padding logic from the GPU model runner. This alignment is necessary for features like ACL graphs that require consistent tensor shapes in distributed environments. The padding is calculated and applied before the model forward pass. Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 6078002 commit fc3899e

File tree

1 file changed

+35
-2
lines changed

1 file changed

+35
-2
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
4444
from vllm.distributed.parallel_state import (get_dp_group, get_pp_group,
4545
get_tp_group)
46-
from vllm.forward_context import get_forward_context
46+
from vllm.forward_context import DPMetadata, get_forward_context
4747
from vllm.logger import logger
4848
from vllm.model_executor.layers.fused_moe import FusedMoE
4949
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
@@ -80,7 +80,6 @@
8080
from vllm_ascend.attention.attention_v1_torchair import AscendTorchairMetadata
8181
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
8282
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
83-
NativeAllGatherCommImpl,
8483
DummyCommImpl,
8584
MoECommMethod)
8685
from vllm_ascend.multistream.ms_split import compute_split_seq_index
@@ -1029,6 +1028,32 @@ def _gather_mm_embeddings(
10291028
mm_embeds.append(mm_embeds_item)
10301029
return mm_embeds
10311030

1031+
def get_dp_padding(self,
1032+
num_tokens: int) -> tuple[int, Optional[torch.Tensor]]:
1033+
"""This implementation is derived from vLLM's `GPUModelRunner.get_dp_padding`.
1034+
Please note that vLLM may refactor or modify this function over time,
1035+
at present, we are using the version introduced in PR #18935.
1036+
"""
1037+
dp_size = self.vllm_config.parallel_config.data_parallel_size
1038+
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
1039+
1040+
# For DP: Don't pad when setting enforce_eager.
1041+
# This lets us set enforce_eager on the prefiller in a P/D setup and
1042+
# still use ACL graphs (enabled by this padding) on the decoder.
1043+
1044+
if dp_size == 1 or self.vllm_config.model_config.enforce_eager:
1045+
# Early exit.
1046+
return 0, None
1047+
1048+
num_tokens_across_dp = DPMetadata.num_tokens_across_dp(
1049+
num_tokens, dp_size, dp_rank)
1050+
max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp).item()
1051+
num_tokens_after_padding = torch.tensor([max_tokens_across_dp_cpu] *
1052+
dp_size,
1053+
device="cpu",
1054+
dtype=torch.int32)
1055+
return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding
1056+
10321057
def _process_reqs(
10331058
self,
10341059
scheduler_output: "SchedulerOutput",
@@ -1051,6 +1076,11 @@ def _process_reqs(
10511076
# Eager mode.
10521077
num_input_tokens = total_num_scheduled_tokens
10531078

1079+
# Padding for DP
1080+
num_pad, num_tokens_across_dp_native = self.get_dp_padding(
1081+
num_input_tokens)
1082+
num_input_tokens += num_pad
1083+
10541084
modified_batch = self.attn_metadata_builder.reorder_batch(
10551085
self.input_batch, scheduler_output)
10561086
if modified_batch:
@@ -1280,8 +1310,11 @@ def _process_reqs(
12801310

12811311
# NOTE: Currently this padding logic is really messy,
12821312
# MC2 may not be available in eager mode
1313+
# TODO: Unify the padding logic between TorchAir and ACL Graph ASAP
12831314
if not self.use_aclgraph or self.torchair_graph_enabled:
12841315
num_input_tokens = padded_num_tokens_across_dp
1316+
else:
1317+
num_tokens_across_dp = num_tokens_across_dp_native
12851318

12861319
# Run forward pass
12871320
with set_ascend_forward_context(

0 commit comments

Comments
 (0)