Skip to content

Commit 512a2c9

Browse files
committed
[2/N][Refactor] torchair model runner refactor
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent bc32acb commit 512a2c9

File tree

2 files changed

+53
-28
lines changed

2 files changed

+53
-28
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,12 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
1818
#
1919

20+
from typing import Optional
21+
2022
import torch
23+
import torch.distributed as dist
2124
from vllm.config import VllmConfig
25+
from vllm.distributed.parallel_state import get_dp_group
2226

2327
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2428

@@ -27,3 +31,47 @@ class NPUTorchairModelRunner(NPUModelRunner):
2731

2832
def __init__(self, vllm_config: VllmConfig, device: torch.device):
2933
super().__init__(vllm_config, device)
34+
35+
def _get_forward_metadata_across_dp(
36+
self,
37+
num_tokens: int,
38+
with_prefill: bool,
39+
enable_dbo: bool = False,
40+
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
41+
if with_prefill:
42+
maybe_padded_num_tokens = num_tokens
43+
else:
44+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
45+
num_tokens)
46+
if self.dp_size == 1:
47+
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
48+
49+
num_tokens_across_dp = [0] * self.dp_size * 2
50+
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
51+
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
52+
forward_metadata = torch.tensor(num_tokens_across_dp +
53+
[with_prefill, not enable_dbo],
54+
device="cpu",
55+
dtype=torch.int32)
56+
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
57+
with_prefill = bool(forward_metadata[-2])
58+
59+
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
60+
if with_prefill:
61+
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
62+
2]
63+
maybe_padded_num_tokens = num_tokens
64+
else:
65+
num_tokens_across_dp = forward_metadata[:self.dp_size]
66+
67+
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
68+
# `max_tokens_across_dp`, in other situation it is not necessary.
69+
if not with_prefill:
70+
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
71+
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
72+
self.dp_size,
73+
device="cpu",
74+
dtype=torch.int32)
75+
76+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
77+
forward_metadata[-1])

vllm_ascend/worker/model_runner_v1.py

Lines changed: 5 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -571,16 +571,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
571571

572572
def _get_forward_metadata_across_dp(
573573
self,
574-
maybe_padded_num_tokens: int,
575574
num_tokens: int,
576575
with_prefill: bool,
577576
enable_dbo: bool = False,
578577
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
579578
if self.dp_size == 1:
580-
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
579+
return num_tokens, None, with_prefill, enable_dbo
581580

582581
num_tokens_across_dp = [0] * self.dp_size * 2
583-
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
582+
num_tokens_across_dp[self.dp_rank] = num_tokens
584583
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
585584
forward_metadata = torch.tensor(num_tokens_across_dp +
586585
[with_prefill, not enable_dbo],
@@ -589,24 +588,13 @@ def _get_forward_metadata_across_dp(
589588
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
590589
with_prefill = bool(forward_metadata[-2])
591590

592-
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
593591
if with_prefill:
594592
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
595593
2]
596-
maybe_padded_num_tokens = num_tokens
597594
else:
598595
num_tokens_across_dp = forward_metadata[:self.dp_size]
599596

600-
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
601-
# `max_tokens_across_dp`, in other situation it is not necessary.
602-
if self.torchair_graph_enabled and not with_prefill:
603-
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
604-
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
605-
self.dp_size,
606-
device="cpu",
607-
dtype=torch.int32)
608-
609-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
597+
return num_tokens, num_tokens_across_dp, with_prefill, not bool(
610598
forward_metadata[-1])
611599

612600
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
@@ -1108,14 +1096,9 @@ def _process_reqs(
11081096
attn_state,
11091097
total_num_scheduled_tokens)
11101098

1111-
maybe_padded_num_tokens = total_num_scheduled_tokens
1112-
if self.torchair_graph_enabled and not with_prefill:
1113-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1114-
total_num_scheduled_tokens)
11151099
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
11161100
enable_dbo) = self._get_forward_metadata_across_dp(
1117-
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
1118-
enable_dbo)
1101+
total_num_scheduled_tokens, with_prefill, enable_dbo)
11191102
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
11201103

11211104
if self.torchair_graph_enabled and not with_prefill:
@@ -1791,15 +1774,9 @@ def _dummy_run(
17911774
with_prefill: bool = False,
17921775
is_torchair_compile: bool = False,
17931776
) -> torch.Tensor:
1794-
maybe_padded_num_tokens = num_tokens
1795-
if self.torchair_graph_enabled and not with_prefill:
1796-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
1797-
num_tokens)
1798-
17991777
# Padding for DP
18001778
(num_tokens, num_tokens_across_dp, with_prefill,
1801-
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
1802-
num_tokens, with_prefill,
1779+
_) = self._get_forward_metadata_across_dp(num_tokens, with_prefill,
18031780
False)
18041781

18051782
# Set num_scheduled_tokens based on num_tokens and max_num_seqs

0 commit comments

Comments
 (0)