From f8fe2bfc3c082e4c1da2efe6cedb417ddfa0873d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 3 Jul 2024 13:04:30 -0700 Subject: [PATCH 1/2] allow ca in pp --- vllm/config.py | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 24f536c04ae65..1eb5e10452892 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -723,17 +723,11 @@ def _verify_args(self) -> None: if self.distributed_executor_backend == "ray": from vllm.executor import ray_utils ray_utils.assert_ray_available() - if not self.disable_custom_all_reduce and self.world_size > 1: - if is_hip(): - self.disable_custom_all_reduce = True - logger.info( - "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") - elif self.pipeline_parallel_size > 1: - self.disable_custom_all_reduce = True - logger.info( - "Disabled the custom all-reduce kernel because it is not " - "supported with pipeline parallelism.") + if is_hip(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on AMD GPUs.") if self.ray_workers_use_nsight and ( not self.distributed_executor_backend == "ray"): raise ValueError("Unable to use nsight profiling unless workers " From 87965c689810d53908f08144197c7e7eb4ca2b61 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 3 Jul 2024 13:05:43 -0700 Subject: [PATCH 2/2] relax constraint --- vllm/distributed/parallel_state.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index faf9177adc8d3..66ffe6e8a9fa9 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -719,14 +719,19 @@ def init_world_group(ranks: List[int], local_rank: int, ) -def init_model_parallel_group(group_ranks: List[List[int]], local_rank: int, - backend: str) -> GroupCoordinator: +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None) -> GroupCoordinator: + if use_custom_allreduce is None: + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE return GroupCoordinator( group_ranks=group_ranks, local_rank=local_rank, torch_distributed_backend=backend, use_pynccl=True, - use_custom_allreduce=_ENABLE_CUSTOM_ALL_REDUCE, + use_custom_allreduce=use_custom_allreduce, ) @@ -888,8 +893,11 @@ def initialize_model_parallel( for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce _PP = init_model_parallel_group(group_ranks, - get_world_group().local_rank, backend) + get_world_group().local_rank, + backend, + use_custom_allreduce=False) def ensure_model_parallel_initialized(