Skip to content

Commit ca23f38

Browse files
yx0716yx0716
authored andcommitted
[Main][Refractor] Refractor forward metadata retrieval across DP nodes to reduce redundant padding.
Signed-off-by: yx0716 <jinyx1007@foxmail.com>
1 parent 0190b68 commit ca23f38

File tree

1 file changed

+80
-43
lines changed

1 file changed

+80
-43
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 80 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
337337
torch._logging.set_logs(
338338
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
339339

340+
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
341+
self.check_batch_sizes_consistency()
340342
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
341343
self.in_profile_run = False
342344

@@ -345,6 +347,33 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
345347
if vllm_config.kv_transfer_config is not None:
346348
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
347349

350+
def _check_dbo_is_valid(self, query_lens, other_args=None):
351+
"""Reserved DBO validation interface, no actual implemented yet"""
352+
# TODO: Add real DBO validation later
353+
return False
354+
355+
def check_batch_sizes_consistency(self) -> None:
356+
if not dist.is_initialized():
357+
return
358+
359+
local = torch.tensor(self.torchair_graph_batch_sizes,
360+
device="cpu",
361+
dtype=torch.int32)
362+
gathered_graph_batch_size = local.clone()
363+
dist.all_reduce(gathered_graph_batch_size,
364+
group=get_dp_group().cpu_group)
365+
expected = local * self.dp_size
366+
367+
if not torch.equal(gathered_graph_batch_size, expected):
368+
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
369+
as_tuple=False).flatten().tolist()
370+
raise AssertionError(
371+
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
372+
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
373+
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
374+
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
375+
)
376+
348377
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
349378
"""Update the cached states and the persistent batch with the scheduler
350379
output.
@@ -559,44 +588,58 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
559588
self.input_batch.refresh_sampling_metadata()
560589

561590
def _get_forward_metadata_across_dp(
562-
self,
563-
maybe_padded_num_tokens: int,
564-
num_tokens: int,
565-
with_prefill: bool,
566-
enable_dbo: bool = False,
591+
self, num_tokens: int, with_prefill: bool,
592+
enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:
593+
594+
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
595+
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
596+
dtype=torch.int32,
597+
device="cpu")
598+
num_tokens_across_dp[self.dp_rank] = num_tokens
599+
num_tokens_across_dp[-2] = int(with_prefill)
600+
num_tokens_across_dp[-1] = int(not enable_dbo)
601+
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
602+
with_prefill = bool(num_tokens_across_dp[-2])
603+
enable_dbo = not bool(num_tokens_across_dp[-1])
604+
num_tokens_across_dp = num_tokens_across_dp[:-2]
605+
return num_tokens_across_dp, with_prefill, enable_dbo
606+
607+
def _get_forward_metadata_across_dp_and_pad(
608+
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
567609
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
568610
if self.dp_size == 1:
569-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
611+
return num_tokens, None, with_prefill, enable_dbo
570612

571-
num_tokens_across_dp = [0] * self.dp_size * 2
572-
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
573-
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
574-
forward_metadata = torch.tensor(num_tokens_across_dp +
575-
[with_prefill, not enable_dbo],
576-
device="cpu",
577-
dtype=torch.int32)
578-
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
579-
with_prefill = bool(forward_metadata[-2])
580-
581-
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
582-
if with_prefill:
583-
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
584-
2]
585-
maybe_padded_num_tokens = num_tokens
586-
else:
587-
num_tokens_across_dp = forward_metadata[:self.dp_size]
613+
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
614+
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
615+
device="cpu",
616+
dtype=torch.int32)
617+
return num_tokens, num_tokens_across_dp, True, enable_dbo
588618

589-
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
590-
# `max_tokens_across_dp`, in other situation it is not necessary.
591-
if self.torchair_graph_enabled and not with_prefill:
592-
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
593-
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
619+
if self.is_kv_consumer and self.torchair_graph_enabled and len(
620+
self.torchair_graph_batch_sizes
621+
) == 1 and not self.in_profile_run:
622+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
623+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
594624
self.dp_size,
595625
device="cpu",
596626
dtype=torch.int32)
627+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
597628

598-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
599-
forward_metadata[-1])
629+
maybe_padded_num_tokens = num_tokens
630+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
631+
num_tokens, with_prefill, enable_dbo)
632+
633+
if self.torchair_graph_enabled and not with_prefill:
634+
max_num_token = num_tokens_across_dp.max().item()
635+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
636+
max_num_token)
637+
num_tokens_across_dp = torch.full((self.dp_size, ),
638+
maybe_padded_num_tokens,
639+
dtype=torch.int32,
640+
device="cpu")
641+
642+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
600643

601644
def get_eagle_atten_dict(
602645
self,
@@ -1073,13 +1116,12 @@ def _process_reqs(
10731116
AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding
10741117
]
10751118

1076-
maybe_padded_num_tokens = total_num_scheduled_tokens
1077-
if self.torchair_graph_enabled and not with_prefill:
1078-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1079-
total_num_scheduled_tokens)
1119+
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
1120+
attn_state,
1121+
total_num_scheduled_tokens)
10801122
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
1081-
enable_dbo) = self._get_forward_metadata_across_dp(
1082-
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill)
1123+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1124+
total_num_scheduled_tokens, with_prefill, enable_dbo)
10831125

10841126
if self.torchair_graph_enabled and not with_prefill:
10851127
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
@@ -1712,15 +1754,10 @@ def _dummy_run(
17121754
with_prefill: bool = False,
17131755
is_torchair_compile: bool = False,
17141756
) -> torch.Tensor:
1715-
maybe_padded_num_tokens = num_tokens
1716-
if self.torchair_graph_enabled and not with_prefill:
1717-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1718-
num_tokens)
1719-
17201757
# Padding for DP
17211758
(num_tokens, num_tokens_across_dp, with_prefill,
1722-
enable_dbo) = self._get_forward_metadata_across_dp(
1723-
maybe_padded_num_tokens, num_tokens, with_prefill, False)
1759+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1760+
num_tokens, with_prefill, False)
17241761

17251762
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
17261763
# for dummy run with LoRA so that the num_reqs collectively

0 commit comments

Comments
 (0)