Skip to content

Commit 7b133f6

Browse files
committed
[Main][Refractor] Refractor forward metadata retrieval across DP nodes to reduce redundant padding.
Signed-off-by: yx0716 <jinyx1007@foxmail.com>
1 parent 8cf97d8 commit 7b133f6

File tree

1 file changed

+70
-46
lines changed

1 file changed

+70
-46
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 70 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,35 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
346346
torch._logging.set_logs(
347347
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
348348

349+
self.check_batch_sizes_consistency()
349350
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
350351
self.in_profile_run = False
351352

352353
# kv role
353354
self.is_kv_producer = False
355+
self.is_kv_consumer = False
354356
if vllm_config.kv_transfer_config is not None:
355357
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
358+
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
359+
360+
def check_batch_sizes_consistency(self) -> None:
361+
if not dist.is_initialized():
362+
return
363+
364+
local = torch.tensor(self.torchair_graph_batch_sizes, device="cpu", dtype=torch.int32)
365+
gathered_graph_batch_size = local.clone()
366+
dist.all_reduce(gathered_graph_batch_size, group=get_dp_group().cpu_group)
367+
expected = local * self.dp_size
368+
369+
if not torch.equal(gathered_graph_batch_size, expected):
370+
diff_idxs = (gathered_graph_batch_size != expected).nonzero(as_tuple=False).flatten().tolist()
371+
raise AssertionError(
372+
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
373+
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
374+
f"Sum over ranks: {sum_ts.tolist()}\n"
375+
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
376+
)
377+
356378

357379
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
358380
"""Update the cached states and the persistent batch with the scheduler
@@ -568,44 +590,55 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
568590
self.input_batch.refresh_sampling_metadata()
569591

570592
def _get_forward_metadata_across_dp(
571-
self,
572-
maybe_padded_num_tokens: int,
573-
num_tokens: int,
574-
with_prefill: bool,
575-
enable_dbo: bool = False,
593+
self,
594+
num_tokens: int,
595+
with_prefill: bool, enable_dbo: bool
596+
) -> tuple[torch.Tensor, bool, bool]:
597+
598+
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
599+
num_tokens_across_dp = torch.zeros(self.dp_size + 2, dtype=torch.int32, device="cpu")
600+
num_tokens_across_dp[self.dp_rank] = num_tokens
601+
num_tokens_across_dp[-2] = int(with_prefill)
602+
num_tokens_across_dp[-1] = int(not enable_dbo)
603+
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
604+
with_prefill = bool(num_tokens_across_dp[-2])
605+
enable_dbo = not bool(num_tokens_across_dp[-1])
606+
num_tokens_across_dp = num_tokens_across_dp[:-2]
607+
return num_tokens_across_dp, with_prefill, enable_dbo
608+
609+
def _get_forward_metadata_across_dp_and_pad(
610+
self, num_tokens: int,
611+
with_prefill: bool, enable_dbo: bool
576612
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
577613
if self.dp_size == 1:
578-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
614+
return num_tokens, None, with_prefill, enable_dbo
579615

580-
num_tokens_across_dp = [0] * self.dp_size * 2
581-
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
582-
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
583-
forward_metadata = torch.tensor(num_tokens_across_dp +
584-
[with_prefill, not enable_dbo],
585-
device="cpu",
586-
dtype=torch.int32)
587-
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
588-
with_prefill = bool(forward_metadata[-2])
589-
590-
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
591-
if with_prefill:
592-
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
593-
2]
594-
maybe_padded_num_tokens = num_tokens
595-
else:
596-
num_tokens_across_dp = forward_metadata[:self.dp_size]
616+
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
617+
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
618+
device="cpu",
619+
dtype=torch.int32)
620+
return num_tokens, num_tokens_across_dp, True, enable_dbo
597621

598-
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
599-
# `max_tokens_across_dp`, in other situation it is not necessary.
600-
if self.torchair_graph_enabled and not with_prefill:
601-
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
602-
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
622+
if self.is_kv_consumer and self.torchair_graph_enabled and len(
623+
self.torchair_graph_batch_sizes
624+
) == 1 and not self.in_profile_run:
625+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
626+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
603627
self.dp_size,
604628
device="cpu",
605629
dtype=torch.int32)
630+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
606631

607-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
608-
forward_metadata[-1])
632+
maybe_padded_num_tokens = num_tokens
633+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(num_tokens, with_prefill,enable_dbo)
634+
635+
if self.torchair_graph_enabled and not with_prefill:
636+
max_num_token = num_tokens_across_dp.max().item()
637+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
638+
max_num_token )
639+
num_tokens_across_dp = torch.full((self.dp_size,), maybe_padded_num_tokens, dtype=torch.int32, device="cpu")
640+
641+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
609642

610643
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
611644
attn_state: AscendAttentionState,
@@ -1106,16 +1139,13 @@ def _process_reqs(
11061139
attn_state,
11071140
total_num_scheduled_tokens)
11081141

1109-
maybe_padded_num_tokens = total_num_scheduled_tokens
1110-
if self.torchair_graph_enabled and not with_prefill:
1111-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1112-
total_num_scheduled_tokens)
1142+
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
1143+
attn_state,
1144+
total_num_scheduled_tokens)
11131145
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
1114-
enable_dbo) = self._get_forward_metadata_across_dp(
1115-
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
1116-
enable_dbo)
1146+
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
1147+
total_num_scheduled_tokens, with_prefill, enable_dbo)
11171148
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
1118-
11191149
if self.torchair_graph_enabled and not with_prefill:
11201150
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
11211151

@@ -1759,16 +1789,10 @@ def _dummy_run(
17591789
with_prefill: bool = False,
17601790
is_torchair_compile: bool = False,
17611791
) -> torch.Tensor:
1762-
maybe_padded_num_tokens = num_tokens
1763-
if self.torchair_graph_enabled and not with_prefill:
1764-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1765-
num_tokens)
1766-
17671792
# Padding for DP
17681793
(num_tokens, num_tokens_across_dp, with_prefill,
1769-
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
1770-
num_tokens, with_prefill,
1771-
False)
1794+
_) = self._get_forward_metadata_across_dp_and_pad(
1795+
num_tokens, with_prefill, False)
17721796

17731797
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
17741798
# for dummy run with LoRA so that the num_reqs collectively

0 commit comments

Comments
 (0)