Skip to content

Commit 1d5d653

Browse files
committed
[2/N][Refactor] torchair model runner refactor
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 292fb8f commit 1d5d653

File tree

2 files changed

+41
-22
lines changed

2 files changed

+41
-22
lines changed

vllm_ascend/torchair/torchair_model_runner.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,53 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
1818
#
1919

20+
from typing import Optional
21+
2022
import torch
2123
from vllm.config import VllmConfig
2224

25+
import vllm_ascend.envs as envs_ascend
2326
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
2427

2528

2629
class NPUTorchairModelRunner(NPUModelRunner):
2730

2831
def __init__(self, vllm_config: VllmConfig, device: torch.device):
2932
super().__init__(vllm_config, device)
33+
34+
def _get_forward_metadata_across_dp_and_pad(
35+
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
36+
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
37+
if self.dp_size == 1:
38+
return num_tokens, None, with_prefill, enable_dbo
39+
40+
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
41+
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
42+
device="cpu",
43+
dtype=torch.int32)
44+
return num_tokens, num_tokens_across_dp, True, enable_dbo
45+
46+
if self.is_kv_consumer and len(self.torchair_graph_batch_sizes
47+
) == 1 and not self.in_profile_run:
48+
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
49+
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
50+
self.dp_size,
51+
device="cpu",
52+
dtype=torch.int32)
53+
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
54+
55+
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
56+
num_tokens, with_prefill, enable_dbo)
57+
58+
if not with_prefill:
59+
max_num_token = num_tokens_across_dp.max().item()
60+
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
61+
max_num_token)
62+
num_tokens_across_dp = torch.full((self.dp_size, ),
63+
maybe_padded_num_tokens,
64+
dtype=torch.int32,
65+
device="cpu")
66+
else:
67+
maybe_padded_num_tokens = num_tokens
68+
69+
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo

vllm_ascend/worker/model_runner_v1.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -623,30 +623,9 @@ def _get_forward_metadata_across_dp_and_pad(
623623
dtype=torch.int32)
624624
return num_tokens, num_tokens_across_dp, True, enable_dbo
625625

626-
if self.is_kv_consumer and self.torchair_graph_enabled and len(
627-
self.torchair_graph_batch_sizes
628-
) == 1 and not self.in_profile_run:
629-
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
630-
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
631-
self.dp_size,
632-
device="cpu",
633-
dtype=torch.int32)
634-
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
635-
636-
maybe_padded_num_tokens = num_tokens
637626
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
638627
num_tokens, with_prefill, enable_dbo)
639-
640-
if self.torchair_graph_enabled and not with_prefill:
641-
max_num_token = num_tokens_across_dp.max().item()
642-
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
643-
max_num_token)
644-
num_tokens_across_dp = torch.full((self.dp_size, ),
645-
maybe_padded_num_tokens,
646-
dtype=torch.int32,
647-
device="cpu")
648-
649-
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
628+
return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
650629

651630
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
652631
attn_state: AscendAttentionState,

0 commit comments

Comments
 (0)