Skip to content

Commit 73e7a33

Browse files
committed
[main][refrator] Refractor forward metadata retrieval across DP nodes to reduce redundant padding.
Signed-off-by: yx0716 <jinyx1007@foxmail.com>
1 parent 0190b68 commit 73e7a33

File tree

1 file changed

+51
-46
lines changed

1 file changed

+51
-46
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 51 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -559,45 +559,56 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
559559
self.input_batch.refresh_sampling_metadata()
560560

561561
def _get_forward_metadata_across_dp(
562-
self,
563-
maybe_padded_num_tokens: int,
564-
num_tokens: int,
565-
with_prefill: bool,
566-
enable_dbo: bool = False,
562+
self,
563+
num_tokens: int,
564+
with_prefill: bool, enable_dbo: bool
565+
) -> tuple[Optional[torch.Tensor], bool, bool]:
566+
567+
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
568+
num_tokens_across_dp = torch.zeros(self.dp_size + 2, dtype=torch.int32, device="cpu")
569+
num_tokens_across_dp[self.dp_rank] = num_tokens
570+
num_tokens_across_dp[-2] = int(with_prefill)
571+
num_tokens_across_dp[-1] = int(not enable_dbo)
572+
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
573+
with_prefill = bool(num_tokens_across_dp[-2])
574+
enable_dbo = not bool(num_tokens_across_dp[-1])
575+
num_tokens_across_dp = num_tokens_across_dp[:-2]
576+
return num_tokens_across_dp, with_prefill, enable_dbo
577+
578+
def _get_forward_metadata_across_dp_and_pad(
579+
self, num_tokens: int,
580+
with_prefill: bool, enable_dbo: bool
567581
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
568582
if self.dp_size == 1:
569-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
570-
571-
num_tokens_across_dp = [0] * self.dp_size * 2
572-
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
573-
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
574-
forward_metadata = torch.tensor(num_tokens_across_dp +
575-
[with_prefill, not enable_dbo],
576-
device="cpu",
577-
dtype=torch.int32)
578-
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
579-
with_prefill = bool(forward_metadata[-2])
580-
581-
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
582-
if with_prefill:
583-
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
584-
2]
585-
maybe_padded_num_tokens = num_tokens
586-
else:
587-
num_tokens_across_dp = forward_metadata[:self.dp_size]
588-
589-
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
590-
# `max_tokens_across_dp`, in other situation it is not necessary.
591-
if self.torchair_graph_enabled and not with_prefill:
592-
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
593-
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
583+
return num_tokens, None, with_prefill, enable_dbo
584+
585+
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
586+
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
587+
device="cpu",
588+
dtype=torch.int32)
589+
return num_tokens, num_tokens_across_dp, True, enable_dbo
590+
591+
if self.is_kv_consumer and self.torchair_graph_enabled and len(
592+
self.torchair_graph_batch_sizes
593+
) == 1 and not self.in_profile_run:
594+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
595+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
594596
self.dp_size,
595597
device="cpu",
596598
dtype=torch.int32)
599+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
600+
601+
maybe_padded_num_tokens = num_tokens
602+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(num_tokens, with_prefill,enable_dbo)
597603

598-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
599-
forward_metadata[-1])
600-
604+
if self.torchair_graph_enabled and not with_prefill:
605+
max_num_token = num_tokens_across_dp.max().item()
606+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
607+
max_num_token )
608+
num_tokens_across_dp = torch.full((self.dp_size,), maybe_padded_num_tokens, dtype=torch.int32, device="cpu")
609+
610+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
611+
601612
def get_eagle_atten_dict(
602613
self,
603614
scheduler_output: "SchedulerOutput",
@@ -1073,13 +1084,12 @@ def _process_reqs(
10731084
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10741085
]
10751086

1076-
maybe_padded_num_tokens = total_num_scheduled_tokens
1077-
if self.torchair_graph_enabled and not with_prefill:
1078-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1079-
total_num_scheduled_tokens)
1087+
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
1088+
attn_state,
1089+
total_num_scheduled_tokens)
10801090
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
1081-
enable_dbo) = self._get_forward_metadata_across_dp(
1082-
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
1091+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1092+
total_num_scheduled_tokens, with_prefill, enable_dbo)
10831093

10841094
if self.torchair_graph_enabled and not with_prefill:
10851095
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1712,15 +1722,10 @@ def _dummy_run(
17121722
with_prefill: bool = False,
17131723
is_torchair_compile: bool = False,
17141724
) -> torch.Tensor:
1715-
maybe_padded_num_tokens = num_tokens
1716-
if self.torchair_graph_enabled and not with_prefill:
1717-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1718-
num_tokens)
1719-
17201725
# Padding for DP
17211726
(num_tokens, num_tokens_across_dp, with_prefill,
1722-
enable_dbo) = self._get_forward_metadata_across_dp(
1723-
maybe_padded_num_tokens, num_tokens, with_prefill, False)
1727+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1728+
num_tokens, with_prefill, False)
17241729

17251730
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
17261731
# for dummy run with LoRA so that the num_reqs collectively

0 commit comments

Comments
 (0)