1717# Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py
1818#
1919
20+ from typing import Optional
21+
2022import torch
23+ import torch .distributed as dist
2124from vllm .config import VllmConfig
25+ from vllm .distributed .parallel_state import get_dp_group
2226
2327from vllm_ascend .worker .model_runner_v1 import NPUModelRunner
2428
@@ -27,3 +31,47 @@ class NPUTorchairModelRunner(NPUModelRunner):
2731
2832 def __init__ (self , vllm_config : VllmConfig , device : torch .device ):
2933 super ().__init__ (vllm_config , device )
34+
35+ def _get_forward_metadata_across_dp (
36+ self ,
37+ num_tokens : int ,
38+ with_prefill : bool ,
39+ enable_dbo : bool = False ,
40+ ) -> tuple [int , Optional [torch .Tensor ], bool , bool ]:
41+ if with_prefill :
42+ maybe_padded_num_tokens = num_tokens
43+ else :
44+ maybe_padded_num_tokens = self .select_torchair_padded_batch_size (
45+ num_tokens )
46+ if self .dp_size == 1 :
47+ return maybe_padded_num_tokens , None , with_prefill , enable_dbo
48+
49+ num_tokens_across_dp = [0 ] * self .dp_size * 2
50+ num_tokens_across_dp [self .dp_rank ] = maybe_padded_num_tokens
51+ num_tokens_across_dp [self .dp_size + self .dp_rank ] = num_tokens
52+ forward_metadata = torch .tensor (num_tokens_across_dp +
53+ [with_prefill , not enable_dbo ],
54+ device = "cpu" ,
55+ dtype = torch .int32 )
56+ dist .all_reduce (forward_metadata , group = get_dp_group ().cpu_group )
57+ with_prefill = bool (forward_metadata [- 2 ])
58+
59+ # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
60+ if with_prefill :
61+ num_tokens_across_dp = forward_metadata [self .dp_size :self .dp_size *
62+ 2 ]
63+ maybe_padded_num_tokens = num_tokens
64+ else :
65+ num_tokens_across_dp = forward_metadata [:self .dp_size ]
66+
67+ # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
68+ # `max_tokens_across_dp`, in other situation it is not necessary.
69+ if not with_prefill :
70+ maybe_padded_num_tokens = torch .max (num_tokens_across_dp ).item ()
71+ num_tokens_across_dp = torch .tensor ([maybe_padded_num_tokens ] *
72+ self .dp_size ,
73+ device = "cpu" ,
74+ dtype = torch .int32 )
75+
76+ return maybe_padded_num_tokens , num_tokens_across_dp , with_prefill , not bool (
77+ forward_metadata [- 1 ])
0 commit comments