diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 8b980458ddaf..59d3efb42449 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -68,6 +68,8 @@ class ParallelConfig: """Number of pipeline parallel groups.""" tensor_parallel_size: int = 1 """Number of tensor parallel groups.""" + context_parallel_size: int = 1 + """Number of context parallel groups.""" data_parallel_size: int = 1 """Number of data parallel groups. MoE layers will be sharded according to the product of the tensor parallel size and data parallel size.""" @@ -185,7 +187,7 @@ class is dynamically inherited by the worker class. This is used to inject calls.""" world_size: int = field(init=False) - """world_size is TPxPP, it affects the number of workers we create.""" + """world_size is TPxCPxPP, it affects the number of workers we create.""" rank: int = 0 """Global rank in distributed setup.""" @@ -335,6 +337,7 @@ def compute_hash(self): factors: list[Any] = [] factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) + factors.append(self.context_parallel_size) factors.append(self.enable_expert_parallel) factors.append(self.data_parallel_size) factors.append(envs.VLLM_ALL2ALL_BACKEND) @@ -374,7 +377,7 @@ def __post_init__(self) -> None: # Continue with the rest of the initialization self.world_size = self.pipeline_parallel_size * \ - self.tensor_parallel_size + self.tensor_parallel_size * self.context_parallel_size if self.distributed_executor_backend == "external_launcher": logger.info("Using external launcher for distributed inference.") diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 638170963e2b..dfd77e661c69 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -57,7 +57,7 @@ class GraphCaptureContext: def _split_tensor_dict( - tensor_dict: dict[str, Union[torch.Tensor, Any]] + tensor_dict: dict[str, Union[torch.Tensor, Any]], ) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]: """Split the tensor dictionary into two parts: 1. A list of (key, value) pairs. If the value is a tensor, it is replaced @@ -261,17 +261,19 @@ def __init__( from vllm.distributed.device_communicators.shm_broadcast import ( MessageQueue) + self.mq_broadcaster: Optional[MessageQueue] = None if use_message_queue_broadcaster and self.world_size > 1: self.mq_broadcaster = MessageQueue.create_from_process_group( self.cpu_group, 1 << 22, 6) from vllm.platforms import current_platform + self.use_custom_op_call = (current_platform.is_cuda_alike() or current_platform.is_tpu()) - self.use_cpu_custom_send_recv = (current_platform.is_cpu() and hasattr( - torch.ops._C, "init_shm_manager")) + self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr( + torch.ops._C, "init_shm_manager") @property def first_rank(self): @@ -321,6 +323,7 @@ def graph_capture( maybe_ca_context = nullcontext() from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator) + if self.device_communicator is not None: assert isinstance(self.device_communicator, CudaCommunicator) ca_comm = self.device_communicator.ca_comm @@ -371,8 +374,9 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if self.use_custom_op_call: return torch.ops.vllm.all_gather(input_, @@ -388,10 +392,12 @@ def _all_gather_out_place(self, input_: torch.Tensor, raise ValueError("No device communicator found") return self.device_communicator.all_gather(input_, dim) - def all_gatherv(self, - input_: Union[torch.Tensor, list[torch.Tensor]], - dim: int = 0, - sizes: Optional[list[int]] = None): + def all_gatherv( + self, + input_: Union[torch.Tensor, list[torch.Tensor]], + dim: int = 0, + sizes: Optional[list[int]] = None, + ): if self.device_communicator is None: raise ValueError("No device communicator found") return self.device_communicator.all_gatherv(input_, dim, sizes) @@ -403,8 +409,9 @@ def reduce_scatter(self, # Bypass the function if we are using only 1 GPU. if world_size == 1: return input_ - assert -input_.dim() <= dim < input_.dim(), ( - f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + assert ( + -input_.dim() <= dim < input_.dim() + ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}" if self.use_custom_op_call: return torch.ops.vllm.reduce_scatter(input_, @@ -538,9 +545,9 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" - assert src != self.rank_in_group, ( - "Invalid source rank. Source rank is the same as the current rank." - ) + assert ( + src != self.rank_in_group + ), "Invalid source rank. Source rank is the same as the current rank." size_tensor = torch.empty(1, dtype=torch.long, device="cpu") @@ -553,14 +560,16 @@ def recv_object(self, src: int) -> Any: object_tensor = torch.empty( # type: ignore[call-overload] size_tensor.item(), # type: ignore[arg-type] dtype=torch.uint8, - device="cpu") + device="cpu", + ) rank_object = torch.distributed.recv(object_tensor, src=self.ranks[src], group=self.cpu_group) - assert rank_object == rank_size, ( - "Received object sender rank does not match the size sender rank.") + assert ( + rank_object == rank_size + ), "Received object sender rank does not match the size sender rank." obj = pickle.loads(object_tensor.numpy().tobytes()) @@ -571,13 +580,13 @@ def broadcast_tensor_dict( tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, - metadata_group: Optional[ProcessGroup] = None + metadata_group: Optional[ProcessGroup] = None, ) -> Optional[dict[str, Union[torch.Tensor, Any]]]: """Broadcast the input tensor dictionary. NOTE: `src` is the local rank of the source rank. """ # Bypass the function if we are using only 1 GPU. - if (not torch.distributed.is_initialized() or self.world_size == 1): + if not torch.distributed.is_initialized() or self.world_size == 1: return tensor_dict group = self.device_group @@ -589,7 +598,7 @@ def broadcast_tensor_dict( metadata_list: list[tuple[Any, Any]] = [] assert isinstance( tensor_dict, - dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + dict), f"Expecting a dictionary, got {type(tensor_dict)}" metadata_list, tensor_list = _split_tensor_dict(tensor_dict) # `metadata_list` lives in CPU memory. # `broadcast_object_list` has serialization & deserialization, @@ -635,7 +644,8 @@ def broadcast_tensor_dict( tensor, src=self.ranks[src], group=metadata_group, - async_op=True) + async_op=True, + ) else: # use group for GPU tensors handle = torch.distributed.broadcast( @@ -694,8 +704,8 @@ def send_tensor_dict( if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") - self.device_communicator.send_tensor_dict( # type: ignore - tensor_dict, dst) + self.device_communicator.send_tensor_dict(tensor_dict, + dst) # type: ignore return None metadata_list: list[tuple[Any, Any]] = [] @@ -721,8 +731,8 @@ def send_tensor_dict( # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) - use_all_gather = all_gather_tensors.get(key, use_all_gather) \ - if all_gather_tensors else use_all_gather + use_all_gather = (all_gather_tensors.get(key, use_all_gather) + if all_gather_tensors else use_all_gather) if use_all_gather: tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] @@ -780,8 +790,8 @@ def recv_tensor_dict( if self.use_cpu_custom_send_recv: if self.device_communicator is None: raise ValueError("No device communicator found") - return self.device_communicator.recv_tensor_dict( # type: ignore - src) + return self.device_communicator.recv_tensor_dict( + src) # type: ignore recv_metadata_list = self.recv_object(src=src) tensor_dict: dict[str, Any] = {} @@ -798,8 +808,8 @@ def recv_tensor_dict( # send-allgather: send only a slice, then do allgather. use_all_gather = (all_gather_group is not None and tensor.numel() % all_gather_size == 0) - use_all_gather = all_gather_tensors.get(key, use_all_gather) \ - if all_gather_tensors else use_all_gather + use_all_gather = (all_gather_tensors.get(key, use_all_gather) + if all_gather_tensors else use_all_gather) if use_all_gather: orig_shape = tensor.shape @@ -874,7 +884,7 @@ def dispatch( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, - is_sequence_parallel: bool = False + is_sequence_parallel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: if self.device_communicator is not None: return self.device_communicator.dispatch(hidden_states, @@ -898,7 +908,7 @@ def combine(self, def get_world_group() -> GroupCoordinator: - assert _WORLD is not None, ("world group is not initialized") + assert _WORLD is not None, "world group is not initialized" return _WORLD @@ -935,7 +945,7 @@ def init_model_parallel_group( def get_tp_group() -> GroupCoordinator: - assert _TP is not None, ("tensor model parallel group is not initialized") + assert _TP is not None, "tensor model parallel group is not initialized" return _TP @@ -946,12 +956,20 @@ def get_tensor_model_parallel_group(): return get_tp_group() +_CP: Optional[GroupCoordinator] = None + + +def get_cp_group() -> GroupCoordinator: + assert _CP is not None, "context model parallel group is not initialized" + return _CP + + _DCP: Optional[GroupCoordinator] = None def get_dcp_group() -> GroupCoordinator: - assert _DCP is not None, ( - "decode context model parallel group is not initialized") + assert _DCP is not None, ("decode context model parallel group is not " + "initialized") return _DCP @@ -964,7 +982,7 @@ def get_dcp_group() -> GroupCoordinator: def get_dp_group() -> GroupCoordinator: - assert _DP is not None, ("data parallel group is not initialized") + assert _DP is not None, "data parallel group is not initialized" return _DP @@ -972,13 +990,12 @@ def get_dp_group() -> GroupCoordinator: def get_ep_group() -> GroupCoordinator: - assert _EP is not None, ("expert parallel group is not initialized") + assert _EP is not None, "expert parallel group is not initialized" return _EP def get_pp_group() -> GroupCoordinator: - assert _PP is not None, ( - "pipeline model parallel group is not initialized") + assert _PP is not None, "pipeline model parallel group is not initialized" return _PP @@ -1020,21 +1037,29 @@ def set_custom_all_reduce(enable: bool): _ENABLE_CUSTOM_ALL_REDUCE = enable -def init_distributed_environment(world_size: int = -1, - rank: int = -1, - distributed_init_method: str = "env://", - local_rank: int = -1, - backend: str = "nccl", - timeout: Optional[timedelta] = None): +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", + timeout: Optional[timedelta] = None, +): logger.debug( "world_size=%d rank=%d local_rank=%d " - "distributed_init_method=%s backend=%s", world_size, rank, local_rank, - distributed_init_method, backend) + "distributed_init_method=%s backend=%s", + world_size, + rank, + local_rank, + distributed_init_method, + backend, + ) from vllm.config import get_current_vllm_config + config = get_current_vllm_config() - if config is not None and config.parallel_config.data_parallel_size > 1 \ - and config.parallel_config.distributed_executor_backend \ - != "external_launcher": + if (config is not None and config.parallel_config.data_parallel_size > 1 + and config.parallel_config.distributed_executor_backend + != "external_launcher"): parallel_config = config.parallel_config # adjust to take into account data parallelism # offset the rank by the data parallel rank @@ -1046,7 +1071,10 @@ def init_distributed_environment(world_size: int = -1, distributed_init_method = get_distributed_init_method(ip, port) logger.info( "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP", - world_size, rank, distributed_init_method) + world_size, + rank, + distributed_init_method, + ) if not torch.distributed.is_initialized(): assert distributed_init_method is not None, ( "distributed_init_method must be provided when initializing " @@ -1054,9 +1082,11 @@ def init_distributed_environment(world_size: int = -1, if not torch.distributed.is_backend_available(backend): logger.warning( "Distributed backend %s is not available; " - "falling back to gloo.", backend) - assert torch.distributed.is_gloo_available(), ( - "Fallback Gloo backend is not available.") + "falling back to gloo.", + backend, + ) + assert (torch.distributed.is_gloo_available() + ), "Fallback Gloo backend is not available." backend = "gloo" # this backend is used for WORLD torch.distributed.init_process_group( @@ -1064,7 +1094,8 @@ def init_distributed_environment(world_size: int = -1, init_method=distributed_init_method, world_size=world_size, rank=rank, - timeout=timeout) + timeout=timeout, + ) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 @@ -1083,12 +1114,13 @@ def init_distributed_environment(world_size: int = -1, logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT) else: - assert _WORLD.world_size == torch.distributed.get_world_size(), ( - "world group already initialized with a different world size") + assert (_WORLD.world_size == torch.distributed.get_world_size( + )), "world group already initialized with a different world size" def initialize_model_parallel( tensor_model_parallel_size: int = 1, + context_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, @@ -1099,23 +1131,37 @@ def initialize_model_parallel( Arguments: tensor_model_parallel_size: number of GPUs used for tensor model parallelism. + context_parallel_size: number of GPUs used for context model + parallelism. pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism. backend: name of torch distributed communication backend. - Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we - use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize - the model pipeline. The present function will - create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: - 4 tensor model-parallel groups: - [g0, g1], [g2, g3], [g4, g5], [g6, g7] - 2 pipeline model-parallel groups: - [g0, g2, g4, g6], [g1, g3, g5, g7] + Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we + use 2 GPUs to parallelize the model tensor, 2 GPUs to parallelize the + context, and 4 GPUs to parallelize the model pipeline. The present + function will create 8 tensor model-parallel groups, 8 context-parallel + groups and 4 pipeline model-parallel groups: + 8 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], + [g12, g13], [g14, g15] + 8 context model-parallel groups: + [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], + [g12, g14], [g13, g15] + 4 pipeline model-parallel groups: + [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], + [g3, g7, g11, g15] Note that for efficiency, the caller should make sure adjacent ranks are on the same DGX box. For example if we are using 2 DGX-1 boxes with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ + assert not ( + decode_context_model_parallel_size is not None + and decode_context_model_parallel_size > 1 and context_parallel_size + > 1), ("decode_context_model_parallel_size and context_parallel_size " + "cannot be enabled together") + # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size: int = torch.distributed.get_world_size() @@ -1125,11 +1171,12 @@ def initialize_model_parallel( data_parallel_size = 1 from vllm.config import get_current_vllm_config + config = get_current_vllm_config() if config is not None: data_parallel_size = config.parallel_config.data_parallel_size - # the layout order is: ExternalDP x DP x PP x TP + # the layout order is: ExternalDP x DP x PP x CP x TP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). # DP is the data parallel group that is part of the model, @@ -1139,26 +1186,32 @@ def initialize_model_parallel( # to get group_ranks for each dimension, transpose that dimension to the # last dimension, then reshape to 2D, then unbind the last dimension all_ranks = torch.arange(world_size).reshape( - -1, data_parallel_size, pipeline_model_parallel_size, - tensor_model_parallel_size) # noqa + -1, + data_parallel_size, + pipeline_model_parallel_size, + context_parallel_size, + tensor_model_parallel_size, + ) # noqa # Build the tensor model-parallel groups. global _TP - assert _TP is None, ("tensor model parallel group is already initialized") + assert _TP is None, "tensor model parallel group is already initialized" group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] # message queue broadcaster is only used in tensor model parallel group - _TP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="tp") + _TP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp", + ) # Build the DCP model-parallel groups. global _DCP - assert _DCP is None, ( - "decode context model parallel group is already initialized") + assert _DCP is None, ("decode context model parallel group is already " + "initialized") # Note(hc): In the current implementation of decode context parallel, # dcp_size must not exceed tp_size, because the world size does not # change by DCP, it simply reuses the GPUs of TP group, and split one @@ -1166,18 +1219,30 @@ def initialize_model_parallel( group_ranks = all_ranks.reshape( -1, decode_context_model_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] - _DCP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, - backend, - use_message_queue_broadcaster=True, - group_name="dcp") + _DCP = init_model_parallel_group( + group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="dcp", + ) + + # Build the context parallel groups. + global _CP + assert _CP is None, "context model parallel group is already initialized" + group_ranks = all_ranks.transpose(3, 4).reshape( + -1, context_parallel_size).unbind(0) + group_ranks = [x.tolist() for x in group_ranks] + _CP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + group_name="cp") # Build the pipeline model-parallel groups. global _PP - assert _PP is None, ( - "pipeline model parallel group is already initialized") - group_ranks = all_ranks.transpose(2, 3).reshape( - -1, pipeline_model_parallel_size).unbind(0) + assert _PP is None, "pipeline model parallel group is already initialized" + group_ranks = (all_ranks.transpose(2, 4).reshape( + -1, pipeline_model_parallel_size).unbind(0)) group_ranks = [x.tolist() for x in group_ranks] _PP = init_model_parallel_group(group_ranks, get_world_group().local_rank, @@ -1185,9 +1250,9 @@ def initialize_model_parallel( group_name="pp") global _DP - assert _DP is None, ("data parallel group is already initialized") + assert _DP is None, "data parallel group is already initialized" group_ranks = all_ranks.transpose(1, - 3).reshape(-1, + 4).reshape(-1, data_parallel_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] _DP = init_model_parallel_group(group_ranks, @@ -1196,9 +1261,9 @@ def initialize_model_parallel( group_name="dp") global _EP - assert _EP is None, ("expert parallel group is already initialized") - group_ranks = all_ranks.transpose(1, 2).reshape( - -1, data_parallel_size * tensor_model_parallel_size).unbind(0) + assert _EP is None, "expert parallel group is already initialized" + group_ranks = (all_ranks.permute(0, 2, 3, 1, 4).reshape( + -1, data_parallel_size * tensor_model_parallel_size).unbind(0)) group_ranks = [x.tolist() for x in group_ranks] _EP = init_model_parallel_group(group_ranks, get_world_group().local_rank, @@ -1207,13 +1272,20 @@ def initialize_model_parallel( logger.info( "rank %s in world size %s is assigned as " - "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size, - _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group, - _EP.rank_in_group) + "DP rank %s, PP rank %s, CP rank %s, TP rank %s, EP rank %s", + rank, + world_size, + _DP.rank_in_group, + _PP.rank_in_group, + _CP.rank_in_group, + _TP.rank_in_group, + _EP.rank_in_group, + ) def ensure_model_parallel_initialized( tensor_model_parallel_size: int, + context_parallel_size: int, pipeline_model_parallel_size: int, decode_context_model_parallel_size: Optional[int] = 1, backend: Optional[str] = None, @@ -1225,18 +1297,28 @@ def ensure_model_parallel_initialized( backend = backend or torch.distributed.get_backend( get_world_group().device_group) if not model_parallel_is_initialized(): - initialize_model_parallel(tensor_model_parallel_size, - pipeline_model_parallel_size, - decode_context_model_parallel_size, backend) + initialize_model_parallel( + tensor_model_parallel_size, + context_parallel_size, + pipeline_model_parallel_size, + decode_context_model_parallel_size, + backend, + ) return - assert ( - get_tensor_model_parallel_world_size() == tensor_model_parallel_size - ), ("tensor parallel group already initialized, but of unexpected size. " + assert get_tensor_model_parallel_world_size( + ) == tensor_model_parallel_size, ( + "tensor parallel group already initialized, but of unexpected size. " f"got: {get_tensor_model_parallel_world_size()=} vs. " f"wanted: {tensor_model_parallel_size=}") + + assert get_context_parallel_world_size() == context_parallel_size, ( + "context parallel group already initialized, but of unexpected size. " + f"got: {get_context_parallel_world_size()=} vs. " + f"wanted: {context_parallel_size=}") + pp_world_size = get_pp_group().world_size - assert (pp_world_size == pipeline_model_parallel_size), ( + assert pp_world_size == pipeline_model_parallel_size, ( "pipeline parallel group already initialized, but of unexpected size. " f"got: {pp_world_size=} vs. " f"wanted: {pipeline_model_parallel_size=}") @@ -1251,6 +1333,8 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): """ if _TP is not None: _TP.prepare_communication_buffer_for_model(model) + if _CP is not None: + _CP.prepare_communication_buffer_for_model(model) if _PP is not None: _PP.prepare_communication_buffer_for_model(model) if _DP is not None: @@ -1261,7 +1345,7 @@ def prepare_communication_buffer_for_model(model: torch.nn.Module): def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TP is not None and _PP is not None) + return _TP is not None and _CP is not None and _PP is not None _TP_STATE_PATCHED = False @@ -1312,10 +1396,19 @@ def get_decode_context_model_parallel_rank(): return get_dcp_group().rank_in_group +def get_context_parallel_world_size(): + """Return world size for the context model parallel group.""" + return get_cp_group().world_size + + +def get_context_parallel_rank(): + """Return my rank for the context model parallel group.""" + return get_cp_group().rank_in_group + + def get_node_count() -> int: - """Return the total number of nodes in the distributed environment. """ - assert _NODE_COUNT is not None, ( - "distributed environment is not initialized") + """Return the total number of nodes in the distributed environment.""" + assert _NODE_COUNT is not None, "distributed environment is not initialized" return _NODE_COUNT @@ -1327,6 +1420,11 @@ def destroy_model_parallel(): _TP.destroy() _TP = None + global _CP + if _CP: + _CP.destroy() + _CP = None + global _PP if _PP: _PP.destroy() @@ -1363,9 +1461,11 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_distributed_environment() if shutdown_ray: import ray # Lazy import Ray + ray.shutdown() gc.collect() from vllm.platforms import current_platform + empty_cache = current_platform.empty_cache if empty_cache is not None: empty_cache() @@ -1385,9 +1485,9 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], memory system (shared access to shared memory). """ if isinstance(pg, ProcessGroup): - assert torch.distributed.get_backend( - pg) != torch.distributed.Backend.NCCL, ( - "in_the_same_node_as should be tested with a non-NCCL group.") + assert ( + torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL + ), "in_the_same_node_as should be tested with a non-NCCL group." # local rank inside the group rank = torch.distributed.get_rank(group=pg) world_size = torch.distributed.get_world_size(group=pg) @@ -1429,8 +1529,10 @@ def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup], # fix to https://stackoverflow.com/q/62748654/9191338 # Python incorrectly tracks shared memory even if it is not # created by the process. The following patch is a workaround. - with patch("multiprocessing.resource_tracker.register", - lambda *args, **kwargs: None): + with patch( + "multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None, + ): shm = shared_memory.SharedMemory(name=name) if shm.buf[:len(magic_message)] == magic_message: is_in_the_same_node[rank] = 1 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ec61fc4b9b06..85a01c27eaf9 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -318,6 +318,7 @@ class EngineArgs: # number of P/D disaggregation (or other disaggregation) workers pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size tensor_parallel_size: int = ParallelConfig.tensor_parallel_size + context_parallel_size: int = ParallelConfig.context_parallel_size decode_context_parallel_size: int = \ ParallelConfig.decode_context_parallel_size data_parallel_size: int = ParallelConfig.data_parallel_size @@ -657,6 +658,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parallel_group.add_argument( "--decode-context-parallel-size", "-dcp", **parallel_kwargs["decode_context_parallel_size"]) + parallel_group.add_argument("--context-parallel-size", "-cp", + **parallel_kwargs["context_parallel_size"]) parallel_group.add_argument("--data-parallel-size", "-dp", **parallel_kwargs["data_parallel_size"]) parallel_group.add_argument( @@ -1317,6 +1320,7 @@ def create_engine_config( parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, tensor_parallel_size=self.tensor_parallel_size, + context_parallel_size=self.context_parallel_size, data_parallel_size=self.data_parallel_size, data_parallel_rank=self.data_parallel_rank or 0, data_parallel_external_lb=data_parallel_external_lb, diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index eecdf8def6de..c13bf28c4e5e 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -61,11 +61,14 @@ def _init_executor(self) -> None: self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size + context_parallel_size = self.parallel_config.context_parallel_size pp_parallel_size = self.parallel_config.pipeline_parallel_size - assert self.world_size == tensor_parallel_size * pp_parallel_size, ( - f"world_size ({self.world_size}) must be equal to the " - f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" - f"_parallel_size ({pp_parallel_size}). ") + assert self.world_size == ( + tensor_parallel_size * context_parallel_size * pp_parallel_size), ( + f"world_size ({self.world_size}) must be equal to the " + f"tensor_parallel_size ({tensor_parallel_size}) x " + f"context_parallel_size ({context_parallel_size}) x " + f"pipeline_parallel_size ({pp_parallel_size}). ") # Set multiprocessing envs set_multiprocessing_worker_envs() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index a135a594ac6f..d218feb1a9ed 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -703,6 +703,7 @@ def init_worker_distributed_environment( ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, + parallel_config.context_parallel_size, parallel_config.pipeline_parallel_size, parallel_config.decode_context_parallel_size)