From f64c8225c8c3ea0ba23bbdeb6fac129caba89a2d Mon Sep 17 00:00:00 2001 From: Yizhou Liu Date: Sat, 17 May 2025 16:18:25 +0800 Subject: [PATCH] [BugFix] Fix data parallel initialization and device management in NPUWorker and DPEngineCoreProc Co-authored-by: rjg-lyh Signed-off-by: Yizhou Liu --- .../patch/platform/patch_0_9_0/__init__.py | 1 + .../platform/patch_0_9_0/patch_distributed.py | 116 +++++++++++++++ .../patch_common/patch_distributed.py | 135 +++--------------- vllm_ascend/platform.py | 45 ++++++ vllm_ascend/worker/worker_v1.py | 9 +- 5 files changed, 191 insertions(+), 115 deletions(-) create mode 100644 vllm_ascend/patch/platform/patch_0_9_0/patch_distributed.py diff --git a/vllm_ascend/patch/platform/patch_0_9_0/__init__.py b/vllm_ascend/patch/platform/patch_0_9_0/__init__.py index 116c73c06c..f0ac16236a 100644 --- a/vllm_ascend/patch/platform/patch_0_9_0/__init__.py +++ b/vllm_ascend/patch/platform/patch_0_9_0/__init__.py @@ -14,3 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import vllm_ascend.patch.platform.patch_0_9_0.patch_distributed # noqa diff --git a/vllm_ascend/patch/platform/patch_0_9_0/patch_distributed.py b/vllm_ascend/patch/platform/patch_0_9_0/patch_distributed.py new file mode 100644 index 0000000000..d468326bd3 --- /dev/null +++ b/vllm_ascend/patch/platform/patch_0_9_0/patch_distributed.py @@ -0,0 +1,116 @@ +import torch +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import (Backend, PrefixStore, + _get_default_timeout, + is_nccl_available) +from torch.distributed.rendezvous import rendezvous +from vllm.distributed import utils + + +def stateless_init_torch_distributed_process_group( + host: str, port: int, rank: int, world_size: int, + backend: str) -> ProcessGroup: + """ + A replacement for `torch.distributed.init_process_group` that does not + pollute the global state. The created ProcessGroup object can be used for + some operations such as `allreduce`, because it does not depend on the + global rank. However, some operations such as `broadcast` cannot be used + because it depends on the global rank. + + # TODO: ask for help from PyTorch team if we need the `broadcast` operation. + + This function is useful when we are not sure about the total number of + processes in the process group. For example, we may have process + 1, 2, ..., 8 who want to communicate, and process 9 might be the same + process as process 1, or it might be a different process; process 10 + might be the same process as process 5, or it might be a different process. + In this case, how can we reliably form a communication channel within + process 9 and 10, without affecting the communication channel within + process 1, 2, ..., 8? + + One possible solution is to figure out if process 9 and 10 are the same + as process 1 and 5 beforehand, and then form a communication channel + based on the information, adjusting the ranks and world_size etc. However, + figuring out the information is not always easy, and it will interfere + with the main communication channel. + + Our solution is to always form a communication channel with process 1, 2, + ..., 8, and then use this function to form another communication channel + with process 9 and 10. This way, regardless of whether process 9 and 10 + are the same as process 1 and 5, the main communication channel is + always formed with process 1, 2, ..., 8, and the additional communication + channel is formed with process 9 and 10. + """ + init_method = f"tcp://{host}:{port}" + backend = Backend(backend) # it is basically string + timeout = _get_default_timeout(backend) + + store, rank, world_size = next( + rendezvous(init_method, rank, world_size, timeout=timeout)) + store.set_timeout(timeout) + + group_rank = rank + group_size = world_size + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + prefix_store = PrefixStore(init_method, store) + + # TODO(Yizhou): The reason we need to set options while vllm does not + # seems to be related to the version of PyTorch. In the latest version, + # there is no need to set options. While in the older version, 2.5.1 + # specifically, we need to set options. + options = ProcessGroup.Options(backend=backend) + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + options, + ) + if backend == "gloo": + from torch.distributed.distributed_c10d import ProcessGroupGloo + backend_class = ProcessGroupGloo(prefix_store, + group_rank, + group_size, + timeout=timeout) + backend_type = ProcessGroup.BackendType.GLOO + device = torch.device("cpu") + elif backend == "nccl": + assert is_nccl_available() + from torch.distributed.distributed_c10d import ProcessGroupNCCL + + backend_options = ProcessGroupNCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, + backend_options) + backend_type = ProcessGroup.BackendType.NCCL + device = torch.device("cuda") + elif backend == "hccl": + from torch.distributed import is_hccl_available + assert is_hccl_available() + from torch_npu._C._distributed_c10d import ProcessGroupHCCL + backend_options = ProcessGroupHCCL.Options() + backend_options._timeout = timeout + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, + backend_options) + device = torch.device("npu") + backend_class._set_sequence_number_for_group() + backend_type = ProcessGroup.BackendType.CUSTOM + pg._register_backend(device, backend_type, backend_class) + return pg + else: + raise RuntimeError(f"Unsupported torch distributed backend: {backend}") + + # TODO(Yizhou): Like we mentioned above, _set_default_backend is not + # implemented in the 2.5.1 version of PyTorch. But we need to set it + # after the latest version is released. + # pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + + pg._register_backend(device, backend_type, backend_class) + + return pg + + +utils.stateless_init_torch_distributed_process_group = stateless_init_torch_distributed_process_group diff --git a/vllm_ascend/patch/platform/patch_common/patch_distributed.py b/vllm_ascend/patch/platform/patch_common/patch_distributed.py index 0b88264b4d..21f846d319 100644 --- a/vllm_ascend/patch/platform/patch_common/patch_distributed.py +++ b/vllm_ascend/patch/platform/patch_common/patch_distributed.py @@ -17,16 +17,14 @@ # Adapted from vllm/model_executor/models/qwen2_vl.py # This file is a part of the vllm-ascend project. -import torch import vllm import vllm.distributed import vllm.envs as envs from torch.distributed import ProcessGroup -from torch.distributed.distributed_c10d import (Backend, PrefixStore, - _get_default_timeout, - is_nccl_available) -from torch.distributed.rendezvous import rendezvous -from vllm.config import ParallelConfig +from vllm.config import ParallelConfig, VllmConfig +from vllm.distributed.utils import \ + stateless_init_torch_distributed_process_group +from vllm.v1.engine.core import DPEngineCoreProc def ascend_destroy_model_parallel(): @@ -48,112 +46,6 @@ def ascend_destroy_model_parallel(): destory_ascend_model_parallel() -def stateless_init_torch_distributed_process_group( - host: str, port: int, rank: int, world_size: int, - backend: str) -> ProcessGroup: - """ - A replacement for `torch.distributed.init_process_group` that does not - pollute the global state. The created ProcessGroup object can be used for - some operations such as `allreduce`, because it does not depend on the - global rank. However, some operations such as `broadcast` cannot be used - because it depends on the global rank. - - # TODO: ask for help from PyTorch team if we need the `broadcast` operation. - - This function is useful when we are not sure about the total number of - processes in the process group. For example, we may have process - 1, 2, ..., 8 who want to communicate, and process 9 might be the same - process as process 1, or it might be a different process; process 10 - might be the same process as process 5, or it might be a different process. - In this case, how can we reliably form a communication channel within - process 9 and 10, without affecting the communication channel within - process 1, 2, ..., 8? - - One possible solution is to figure out if process 9 and 10 are the same - as process 1 and 5 beforehand, and then form a communication channel - based on the information, adjusting the ranks and world_size etc. However, - figuring out the information is not always easy, and it will interfere - with the main communication channel. - - Our solution is to always form a communication channel with process 1, 2, - ..., 8, and then use this function to form another communication channel - with process 9 and 10. This way, regardless of whether process 9 and 10 - are the same as process 1 and 5, the main communication channel is - always formed with process 1, 2, ..., 8, and the additional communication - channel is formed with process 9 and 10. - """ - init_method = f"tcp://{host}:{port}" - backend = Backend(backend) # it is basically string - timeout = _get_default_timeout(backend) - - store, rank, world_size = next( - rendezvous(init_method, rank, world_size, timeout=timeout)) - store.set_timeout(timeout) - - group_rank = rank - group_size = world_size - - # Use a PrefixStore to avoid accidental overrides of keys used by - # different systems (e.g. RPC) in case the store is multi-tenant. - prefix_store = PrefixStore(init_method, store) - - # TODO(Yizhou): The reason we need to set options while vllm does not - # seems to be related to the version of PyTorch. In the latest version, - # there is no need to set options. While in the older version, 2.5.1 - # specifically, we need to set options. - options = ProcessGroup.Options(backend=backend) - pg: ProcessGroup = ProcessGroup( - prefix_store, - group_rank, - group_size, - options, - ) - if backend == "gloo": - from torch.distributed.distributed_c10d import ProcessGroupGloo - backend_class = ProcessGroupGloo(prefix_store, - group_rank, - group_size, - timeout=timeout) - backend_type = ProcessGroup.BackendType.GLOO - device = torch.device("cpu") - elif backend == "nccl": - assert is_nccl_available() - from torch.distributed.distributed_c10d import ProcessGroupNCCL - - backend_options = ProcessGroupNCCL.Options() - backend_options._timeout = timeout - - backend_class = ProcessGroupNCCL(prefix_store, group_rank, group_size, - backend_options) - backend_type = ProcessGroup.BackendType.NCCL - device = torch.device("cuda") - elif backend == "hccl": - from torch.distributed import is_hccl_available - assert is_hccl_available() - from torch_npu._C._distributed_c10d import ProcessGroupHCCL - backend_options = ProcessGroupHCCL.Options() - backend_options._timeout = timeout - backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, - backend_options) - device = torch.device("npu") - backend_class._set_sequence_number_for_group() - backend_type = ProcessGroup.BackendType.CUSTOM - pg._register_backend(device, backend_type, backend_class) - return pg - else: - raise RuntimeError(f"Unsupported torch distributed backend: {backend}") - - # TODO(Yizhou): Like we mentioned above, _set_default_backend is not - # implemented in the 2.5.1 version of PyTorch. But we need to set it - # after the latest version is released. - # pg._set_default_backend(backend_type) - backend_class._set_sequence_number_for_group() - - pg._register_backend(device, backend_type, backend_class) - - return pg - - def parallel_config_get_dp_port(self) -> int: """ We might need to initialize process groups in multiple @@ -171,7 +63,7 @@ def parallel_config_get_dp_port(self) -> int: return port -def ascend_stateless_init_dp_group(self) -> "ProcessGroup": +def stateless_init_dp_group(self) -> "ProcessGroup": # TODO(Yizhou): Currently we have to set the backend to gloo # because in vllm.config.ParallelConfig.has_unfinished_dp the # device is set to cpu. We need to fix this in the future. @@ -187,6 +79,21 @@ def ascend_stateless_init_dp_group(self) -> "ProcessGroup": return dp_group +def _init_data_parallel(self, vllm_config: VllmConfig): + # Configure NPUs and stateless process group for data parallel. + dp_rank = vllm_config.parallel_config.data_parallel_rank + dp_size = vllm_config.parallel_config.data_parallel_size + local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local + + assert dp_size > 1 + assert 0 <= local_dp_rank <= dp_rank < dp_size + + self.local_dp_rank = local_dp_rank + self.dp_group = vllm_config.parallel_config.stateless_init_dp_group() + self.current_wave = 0 + + vllm.distributed.parallel_state.destroy_model_parallel = ascend_destroy_model_parallel +DPEngineCoreProc._init_data_parallel = _init_data_parallel ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port -ParallelConfig.stateless_init_dp_group = ascend_stateless_init_dp_group +ParallelConfig.stateless_init_dp_group = stateless_init_dp_group diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 647fefbe0e..0c29b03f61 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -17,10 +17,13 @@ import logging import os +from datetime import timedelta from typing import TYPE_CHECKING, Optional, Tuple import torch import vllm.envs as envs +from torch.distributed import ProcessGroup +from torch.distributed.distributed_c10d import PrefixStore from vllm.logger import logger from vllm.platforms import Platform, PlatformEnum @@ -249,3 +252,45 @@ def get_piecewise_backend_cls(cls) -> str: Get piecewise backend class for piecewise graph. """ return "vllm_ascend.compilation.piecewise_backend.NPUPiecewiseBackend" # noqa + + @classmethod + def stateless_init_device_torch_dist_pg( + cls, + backend: str, + prefix_store: PrefixStore, + group_rank: int, + group_size: int, + timeout: timedelta, + ) -> ProcessGroup: + from torch.distributed import is_hccl_available + from torch_npu._C._distributed_c10d import ProcessGroupHCCL + + assert is_hccl_available() + + # TODO(Yizhou): The reason we need to set options while vllm does not + # seems to be related to the version of PyTorch. In the latest version, + # there is no need to set options. While in the older version, 2.5.1 + # specifically, we need to set options. + options = ProcessGroup.Options(backend=backend) + pg: ProcessGroup = ProcessGroup( + prefix_store, + group_rank, + group_size, + options, + ) + + backend_options = ProcessGroupHCCL.Options() + backend_options._timeout = timeout + + backend_class = ProcessGroupHCCL(prefix_store, group_rank, group_size, + backend_options) + device = torch.device("npu") + # TODO(Yizhou): Like we mentioned above, _set_default_backend is not + # implemented in the 2.5.1 version of PyTorch. But we need to set it + # after the latest version is released. + # pg._set_default_backend(backend_type) + backend_class._set_sequence_number_for_group() + backend_type = ProcessGroup.BackendType.CUSTOM + + pg._register_backend(device, backend_type, backend_class) + return pg diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 4d01f84d38..89c6d4b0dc 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -76,6 +76,13 @@ def __init__( rank=rank, distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker) + + # NOTE(Yizhou): Since we do not set ASCEND_RT_VISIBLE_DEVICES in + # vllm_ascend, we need to set the device id manually. + local_dp_rank = self.vllm_config.parallel_config.data_parallel_rank_local + world_size = self.vllm_config.parallel_config.world_size + self.local_rank_across_dp = local_dp_rank * world_size + self.local_rank + # Try to import mindie_turbo to accelerate vLLM inference. try_register_lib( "mindie_turbo", @@ -102,7 +109,7 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None: def init_device(self): if self.device_config.device.type == "npu": - self.device = torch.device(f"npu:{self.local_rank}") + self.device = torch.device(f"npu:{self.local_rank_across_dp}") NPUPlatform.set_device(self.device) NPUPlatform.empty_cache() self.init_npu_memory = NPUPlatform.mem_get_info()[0]