Skip to content

Commit f3176d0

Browse files
committed
[main][refrator] Refractor forward metadata retrieval across DP nodes to reduce redundant padding.
Signed-off-by: yx0716 <jinyx1007@foxmail.com>
1 parent ba3dfbd commit f3176d0

File tree

1 file changed

+51
-47
lines changed

1 file changed

+51
-47
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 51 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -569,45 +569,55 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
569569
self.input_batch.refresh_sampling_metadata()
570570

571571
def _get_forward_metadata_across_dp(
572-
self,
573-
maybe_padded_num_tokens: int,
574-
num_tokens: int,
575-
with_prefill: bool,
576-
enable_dbo: bool = False,
572+
self, num_tokens: int,
573+
with_prefill: bool, enable_dbo: bool
577574
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
578-
if self.dp_size == 1:
579-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
580-
581-
num_tokens_across_dp = [0] * self.dp_size * 2
582-
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
583-
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
584-
forward_metadata = torch.tensor(num_tokens_across_dp +
585-
[with_prefill, not enable_dbo],
586-
device="cpu",
587-
dtype=torch.int32)
588-
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
589-
with_prefill = bool(forward_metadata[-2])
590-
591-
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
592-
if with_prefill:
593-
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
594-
2]
595-
maybe_padded_num_tokens = num_tokens
596-
else:
597-
num_tokens_across_dp = forward_metadata[:self.dp_size]
598575

599-
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
600-
# `max_tokens_across_dp`, in other situation it is not necessary.
601-
if self.torchair_graph_enabled and not with_prefill:
602-
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
603-
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
576+
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
577+
num_tokens_across_dp = torch.zeros(self.dp_size + 2, dtype=torch.int32, device="cpu")
578+
num_tokens_across_dp[self.dp_rank] = num_tokens
579+
num_tokens_across_dp[-2] = int(with_prefill)
580+
num_tokens_across_dp[-1] = int(not enable_dbo)
581+
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
582+
with_prefill = bool(num_tokens_across_dp[-2])
583+
enable_dbo = not bool(num_tokens_across_dp[-1])
584+
num_tokens_across_dp = num_tokens_across_dp[:-2]
585+
return num_tokens_across_dp, with_prefill, enable_dbo
586+
587+
def _get_forward_metadata_across_dp_and_pad(
588+
self, num_tokens: int,
589+
with_prefill: bool, enable_dbo: bool
590+
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
591+
if self.dp_size == 1:
592+
return num_tokens, None, with_prefill, enable_dbo
593+
594+
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
595+
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
596+
device="cpu",
597+
dtype=torch.int32)
598+
return num_tokens, num_tokens_across_dp, True, enable_dbo
599+
600+
if self.is_kv_consumer and self.torchair_graph_enabled and len(
601+
self.torchair_graph_batch_sizes
602+
) == 1 and not self.in_profile_run:
603+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
604+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
604605
self.dp_size,
605606
device="cpu",
606607
dtype=torch.int32)
608+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
609+
610+
maybe_padded_num_tokens = num_tokens
611+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(num_tokens, with_prefill,enable_dbo)
607612

608-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
609-
forward_metadata[-1])
610-
613+
if self.torchair_graph_enabled and not with_prefill:
614+
max_num_token = num_tokens_across_dp.max().item()
615+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
616+
max_num_token )
617+
num_tokens_across_dp = torch.full((self.dp_size,), maybe_padded_num_tokens, dtype=torch.int32, device="cpu")
618+
619+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
620+
611621
def get_eagle_atten_dict(
612622
self,
613623
scheduler_output: "SchedulerOutput",
@@ -1083,13 +1093,12 @@ def _process_reqs(
10831093
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10841094
]
10851095

1086-
maybe_padded_num_tokens = total_num_scheduled_tokens
1087-
if self.torchair_graph_enabled and not with_prefill:
1088-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1089-
total_num_scheduled_tokens)
1096+
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
1097+
attn_state,
1098+
total_num_scheduled_tokens)
10901099
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
1091-
enable_dbo) = self._get_forward_metadata_across_dp(
1092-
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
1100+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1101+
total_num_scheduled_tokens, with_prefill, enable_dbo)
10931102

10941103
if self.torchair_graph_enabled and not with_prefill:
10951104
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1722,15 +1731,10 @@ def _dummy_run(
17221731
with_prefill: bool = False,
17231732
is_torchair_compile: bool = False,
17241733
) -> torch.Tensor:
1725-
maybe_padded_num_tokens = num_tokens
1726-
if self.torchair_graph_enabled and not with_prefill:
1727-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1728-
num_tokens)
1729-
17301734
# Padding for DP
17311735
(num_tokens, num_tokens_across_dp, with_prefill,
1732-
enable_dbo) = self._get_forward_metadata_across_dp(
1733-
maybe_padded_num_tokens, num_tokens, with_prefill, False)
1736+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1737+
num_tokens, with_prefill, False)
17341738

17351739
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
17361740
# for dummy run with LoRA so that the num_reqs collectively
@@ -2590,7 +2594,7 @@ def select_torchair_padded_batch_size(self, batch_size: int):
25902594
return selected_batch_size
25912595

25922596
def get_supported_pooling_tasks(self):
2593-
model = self.get_model()
2597+
model=self.get_model()
25942598
if not is_pooling_model(model):
25952599
return []
25962600

0 commit comments

Comments
 (0)