1717# Adapted from vllm/model_executor/models/qwen2_vl.py
1818# This file is a part of the vllm-ascend project.
1919
20+ import torch
2021import vllm
2122import vllm .distributed
2223import vllm .envs as envs
2324from torch .distributed import ProcessGroup
24- from vllm .config import ParallelConfig , VllmConfig
25- from vllm .distributed .utils import \
26- stateless_init_torch_distributed_process_group
27- from vllm .v1 .engine .core import DPEngineCoreProc
25+ from torch .distributed .distributed_c10d import (Backend , PrefixStore ,
26+ _get_default_timeout ,
27+ is_nccl_available )
28+ from torch .distributed .rendezvous import rendezvous
29+ from vllm .config import ParallelConfig
2830
2931
3032def ascend_destroy_model_parallel ():
@@ -46,6 +48,112 @@ def ascend_destroy_model_parallel():
4648 destory_ascend_model_parallel ()
4749
4850
51+ def stateless_init_torch_distributed_process_group (
52+ host : str , port : int , rank : int , world_size : int ,
53+ backend : str ) -> ProcessGroup :
54+ """
55+ A replacement for `torch.distributed.init_process_group` that does not
56+ pollute the global state. The created ProcessGroup object can be used for
57+ some operations such as `allreduce`, because it does not depend on the
58+ global rank. However, some operations such as `broadcast` cannot be used
59+ because it depends on the global rank.
60+
61+ # TODO: ask for help from PyTorch team if we need the `broadcast` operation.
62+
63+ This function is useful when we are not sure about the total number of
64+ processes in the process group. For example, we may have process
65+ 1, 2, ..., 8 who want to communicate, and process 9 might be the same
66+ process as process 1, or it might be a different process; process 10
67+ might be the same process as process 5, or it might be a different process.
68+ In this case, how can we reliably form a communication channel within
69+ process 9 and 10, without affecting the communication channel within
70+ process 1, 2, ..., 8?
71+
72+ One possible solution is to figure out if process 9 and 10 are the same
73+ as process 1 and 5 beforehand, and then form a communication channel
74+ based on the information, adjusting the ranks and world_size etc. However,
75+ figuring out the information is not always easy, and it will interfere
76+ with the main communication channel.
77+
78+ Our solution is to always form a communication channel with process 1, 2,
79+ ..., 8, and then use this function to form another communication channel
80+ with process 9 and 10. This way, regardless of whether process 9 and 10
81+ are the same as process 1 and 5, the main communication channel is
82+ always formed with process 1, 2, ..., 8, and the additional communication
83+ channel is formed with process 9 and 10.
84+ """
85+ init_method = f"tcp://{ host } :{ port } "
86+ backend = Backend (backend ) # it is basically string
87+ timeout = _get_default_timeout (backend )
88+
89+ store , rank , world_size = next (
90+ rendezvous (init_method , rank , world_size , timeout = timeout ))
91+ store .set_timeout (timeout )
92+
93+ group_rank = rank
94+ group_size = world_size
95+
96+ # Use a PrefixStore to avoid accidental overrides of keys used by
97+ # different systems (e.g. RPC) in case the store is multi-tenant.
98+ prefix_store = PrefixStore (init_method , store )
99+
100+ # TODO(Yizhou): The reason we need to set options while vllm does not
101+ # seems to be related to the version of PyTorch. In the latest version,
102+ # there is no need to set options. While in the older version, 2.5.1
103+ # specifically, we need to set options.
104+ options = ProcessGroup .Options (backend = backend )
105+ pg : ProcessGroup = ProcessGroup (
106+ prefix_store ,
107+ group_rank ,
108+ group_size ,
109+ options ,
110+ )
111+ if backend == "gloo" :
112+ from torch .distributed .distributed_c10d import ProcessGroupGloo
113+ backend_class = ProcessGroupGloo (prefix_store ,
114+ group_rank ,
115+ group_size ,
116+ timeout = timeout )
117+ backend_type = ProcessGroup .BackendType .GLOO
118+ device = torch .device ("cpu" )
119+ elif backend == "nccl" :
120+ assert is_nccl_available ()
121+ from torch .distributed .distributed_c10d import ProcessGroupNCCL
122+
123+ backend_options = ProcessGroupNCCL .Options ()
124+ backend_options ._timeout = timeout
125+
126+ backend_class = ProcessGroupNCCL (prefix_store , group_rank , group_size ,
127+ backend_options )
128+ backend_type = ProcessGroup .BackendType .NCCL
129+ device = torch .device ("cuda" )
130+ elif backend == "hccl" :
131+ from torch .distributed import is_hccl_available
132+ assert is_hccl_available ()
133+ from torch_npu ._C ._distributed_c10d import ProcessGroupHCCL
134+ backend_options = ProcessGroupHCCL .Options ()
135+ backend_options ._timeout = timeout
136+ backend_class = ProcessGroupHCCL (prefix_store , group_rank , group_size ,
137+ backend_options )
138+ device = torch .device ("npu" )
139+ backend_class ._set_sequence_number_for_group ()
140+ backend_type = ProcessGroup .BackendType .CUSTOM
141+ pg ._register_backend (device , backend_type , backend_class )
142+ return pg
143+ else :
144+ raise RuntimeError (f"Unsupported torch distributed backend: { backend } " )
145+
146+ # TODO(Yizhou): Like we mentioned above, _set_default_backend is not
147+ # implemented in the 2.5.1 version of PyTorch. But we need to set it
148+ # after the latest version is released.
149+ # pg._set_default_backend(backend_type)
150+ backend_class ._set_sequence_number_for_group ()
151+
152+ pg ._register_backend (device , backend_type , backend_class )
153+
154+ return pg
155+
156+
49157def parallel_config_get_dp_port (self ) -> int :
50158 """
51159 We might need to initialize process groups in multiple
@@ -63,7 +171,7 @@ def parallel_config_get_dp_port(self) -> int:
63171 return port
64172
65173
66- def stateless_init_dp_group (self ) -> "ProcessGroup" :
174+ def ascend_stateless_init_dp_group (self ) -> "ProcessGroup" :
67175 # TODO(Yizhou): Currently we have to set the backend to gloo
68176 # because in vllm.config.ParallelConfig.has_unfinished_dp the
69177 # device is set to cpu. We need to fix this in the future.
@@ -79,21 +187,6 @@ def stateless_init_dp_group(self) -> "ProcessGroup":
79187 return dp_group
80188
81189
82- def _init_data_parallel (self , vllm_config : VllmConfig ):
83- # Configure NPUs and stateless process group for data parallel.
84- dp_rank = vllm_config .parallel_config .data_parallel_rank
85- dp_size = vllm_config .parallel_config .data_parallel_size
86- local_dp_rank = vllm_config .parallel_config .data_parallel_rank_local
87-
88- assert dp_size > 1
89- assert 0 <= local_dp_rank <= dp_rank < dp_size
90-
91- self .local_dp_rank = local_dp_rank
92- self .dp_group = vllm_config .parallel_config .stateless_init_dp_group ()
93- self .current_wave = 0
94-
95-
96190vllm .distributed .parallel_state .destroy_model_parallel = ascend_destroy_model_parallel
97- DPEngineCoreProc ._init_data_parallel = _init_data_parallel
98191ParallelConfig .get_next_dp_init_port = parallel_config_get_dp_port
99- ParallelConfig .stateless_init_dp_group = stateless_init_dp_group
192+ ParallelConfig .stateless_init_dp_group = ascend_stateless_init_dp_group
0 commit comments