Skip to content

Commit d45d021

Browse files
committed
[BugFix] Fix data parallel initialization and device management in NPUWorker and DPEngineCoreProc
Co-authored-by: rjg-lyh <rjg-lyh@users.noreply.github.com> Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 17f05b1 commit d45d021

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

vllm_ascend/patch/platform/patch_common/patch_distributed.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
_get_default_timeout,
2626
is_nccl_available)
2727
from torch.distributed.rendezvous import rendezvous
28-
from vllm.config import ParallelConfig
28+
from vllm.config import ParallelConfig, VllmConfig
29+
from vllm.v1.engine.core import DPEngineCoreProc
2930

3031

3132
def ascend_destroy_model_parallel():
@@ -171,7 +172,7 @@ def parallel_config_get_dp_port(self) -> int:
171172
return port
172173

173174

174-
def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
175+
def stateless_init_dp_group(self) -> "ProcessGroup":
175176
# TODO(Yizhou): Currently we have to set the backend to gloo
176177
# because in vllm.config.ParallelConfig.has_unfinished_dp the
177178
# device is set to cpu. We need to fix this in the future.
@@ -187,6 +188,21 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup":
187188
return dp_group
188189

189190

191+
def _init_data_parallel(self, vllm_config: VllmConfig):
192+
# Configure NPUs and stateless process group for data parallel.
193+
dp_rank = vllm_config.parallel_config.data_parallel_rank
194+
dp_size = vllm_config.parallel_config.data_parallel_size
195+
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
196+
197+
assert dp_size > 1
198+
assert 0 <= local_dp_rank <= dp_rank < dp_size
199+
200+
self.local_dp_rank = local_dp_rank
201+
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
202+
self.current_wave = 0
203+
204+
190205
vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel
206+
DPEngineCoreProc._init_data_parallel = _init_data_parallel
191207
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
192-
ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group
208+
ParallelConfig.stateless_init_dp_group = stateless_init_dp_group

vllm_ascend/worker/worker_v1.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def __init__(
7272
rank=rank,
7373
distributed_init_method=distributed_init_method,
7474
is_driver_worker=is_driver_worker)
75+
76+
# NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in
77+
# vllm_ascend, we need to set the device id manually.
78+
local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local
79+
world_size = self.vllm_config.parallel_config.world_size
80+
self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank
81+
7582
# Try to import mindie_turbo to accelerate vLLM inference.
7683
try_register_lib(
7784
"mindie_turbo",
@@ -98,7 +105,7 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
98105

99106
def init_device(self):
100107
if self.device_config.device.type == "npu":
101-
self.device = torch.device(f"npu:{self.local_rank}")
108+
self.device = torch.device(f"npu:{self.local_rank_across_dp}")
102109
NPUPlatform.set_device(self.device)
103110
NPUPlatform.empty_cache()
104111
self.init_npu_memory = NPUPlatform.mem_get_info()[0]

0 commit comments

Comments
 (0)