From 977a6cd72ad686cd445aac9e985cea35edf9aedc Mon Sep 17 00:00:00 2001 From: youkaichao Date: Thu, 2 May 2024 17:32:33 -0700 Subject: [PATCH] [Core][Distributed] enable allreduce for multiple tp groups (#4566) --- tests/distributed/test_pynccl.py | 43 +++++++++++++++++++++++++--- vllm/distributed/communication_op.py | 1 - vllm/distributed/parallel_state.py | 36 ++++++++++++++++------- vllm/worker/worker.py | 13 +++++---- 4 files changed, 71 insertions(+), 22 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index e71d839648c83..b6f461b76ed03 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -3,9 +3,13 @@ import pytest import torch +import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils +from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator, ncclGetUniqueId) -from vllm.distributed.parallel_state import init_distributed_environment +from vllm.distributed.parallel_state import ( + ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group, + init_distributed_environment, with_pynccl_for_all_reduce) from vllm.utils import update_environment_variables @@ -67,7 +71,7 @@ def multiple_tp_worker_fn(): ] group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1] comm = NCCLCommunicator(group=group, device=device) - tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: comm.all_reduce(tensor) @@ -81,9 +85,40 @@ def multiple_tp_worker_fn(): @pytest.mark.skipif(torch.cuda.device_count() < 4, - reason="Need at least 2 GPUs to run the test.") + reason="Need at least 4 GPUs to run the test.") def test_pynccl_multiple_tp(): - distributed_run(worker_fn, 4) + # this tests pynccl for multiple tp groups, in a standalone way + # i.e. call `comm.all_reduce` directly + distributed_run(multiple_tp_worker_fn, 4) + + +@worker_fn_wrapper +def multiple_tp_with_vllm_worker_fn(): + device = torch.device(f"cuda:{torch.distributed.get_rank()}") + torch.cuda.set_device(torch.distributed.get_rank()) + ensure_model_parallel_initialized(2, 2) + pynccl_utils.init_process_group( + group=get_tensor_model_parallel_cpu_group()) + tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) + with with_pynccl_for_all_reduce(): + # two tp groups can communicate independently + if torch.distributed.get_rank() in [0, 1]: + tensor = tensor_model_parallel_all_reduce(tensor) + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 4 + else: + tensor = tensor_model_parallel_all_reduce(tensor) + result = tensor.mean().cpu().item() + assert result == 2 + + +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_multiple_tp_with_vllm(): + # this tests pynccl for multiple tp groups, together with vllm + # i.e. call `tensor_model_parallel_all_reduce` + distributed_run(multiple_tp_with_vllm_worker_fn, 4) @worker_fn_wrapper diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 8b2c26c3a8afb..b539a7beedbfe 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: if out is not None: return out if is_pynccl_enabled_for_all_reduce(): - # TODO: support multiple parallel groups. pynccl_utils.all_reduce(input_) else: torch.distributed.all_reduce(input_, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a82a1254693df..be5bb4e857caf 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -14,7 +14,8 @@ logger = init_logger(__name__) # Tensor model parallel group that the current rank belongs to. -_TENSOR_MODEL_PARALLEL_GROUP = None +_TP_DEVICE_GROUP = None +_TP_CPU_GROUP = None # Pipeline model parallel group that the current rank belongs to. _PIPELINE_MODEL_PARALLEL_GROUP = None @@ -132,15 +133,17 @@ def initialize_model_parallel( rank = torch.distributed.get_rank() # Build the tensor model-parallel groups. - global _TENSOR_MODEL_PARALLEL_GROUP - assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( + global _TP_DEVICE_GROUP, _TP_CPU_GROUP + assert _TP_DEVICE_GROUP is None, ( "tensor model parallel group is already initialized") for i in range(num_tensor_model_parallel_groups): ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size) group = torch.distributed.new_group(ranks, backend=backend) + cpu_group = torch.distributed.new_group(ranks, backend="gloo") if rank in ranks: - _TENSOR_MODEL_PARALLEL_GROUP = group + _TP_DEVICE_GROUP = group + _TP_CPU_GROUP = cpu_group # Build the pipeline model-parallel groups. global _PIPELINE_MODEL_PARALLEL_GROUP @@ -185,7 +188,7 @@ def ensure_model_parallel_initialized( def model_parallel_is_initialized(): """Check if tensor and pipeline parallel groups are initialized.""" - return (_TENSOR_MODEL_PARALLEL_GROUP is not None + return (_TP_DEVICE_GROUP is not None and _PIPELINE_MODEL_PARALLEL_GROUP is not None) @@ -197,9 +200,16 @@ def get_cpu_world_group(): def get_tensor_model_parallel_group(): """Get the tensor model parallel group the caller rank belongs to.""" - assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( + assert _TP_DEVICE_GROUP is not None, ( "tensor model parallel group is not initialized") - return _TENSOR_MODEL_PARALLEL_GROUP + return _TP_DEVICE_GROUP + + +def get_tensor_model_parallel_cpu_group(): + """Get the tensor model parallel cpu group the caller rank belongs to.""" + assert _TP_CPU_GROUP is not None, ( + "tensor model parallel cpu group is not initialized") + return _TP_CPU_GROUP def get_pipeline_model_parallel_group(): @@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank(): def destroy_model_parallel(): """Set the groups to none and destroy them.""" - global _TENSOR_MODEL_PARALLEL_GROUP - if _TENSOR_MODEL_PARALLEL_GROUP: - torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) - _TENSOR_MODEL_PARALLEL_GROUP = None + global _TP_DEVICE_GROUP + if _TP_DEVICE_GROUP: + torch.distributed.destroy_process_group(_TP_DEVICE_GROUP) + _TP_DEVICE_GROUP = None + global _TP_CPU_GROUP + if _TP_CPU_GROUP: + torch.distributed.destroy_process_group(_TP_CPU_GROUP) + _TP_CPU_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP if _PIPELINE_MODEL_PARALLEL_GROUP: torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 39ad428f16fe3..808261e47318b 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -11,6 +11,7 @@ VisionLanguageConfig) from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, + get_tensor_model_parallel_cpu_group, init_distributed_environment) from vllm.distributed.device_communicators import pynccl_utils from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -288,6 +289,9 @@ def init_worker_distributed_environment( init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + if pynccl_utils.is_initialized(): pynccl_world_size = pynccl_utils.get_world_size() if pynccl_world_size != parallel_config.world_size: @@ -298,12 +302,9 @@ def init_worker_distributed_environment( elif parallel_config.world_size > 1: # NOTE(woosuk): We don't initialize pynccl process group when world size # is 1. - # NOTE(kaichao): By default, pynccl will use information inside - # `parallel_state` for initialization. - pynccl_utils.init_process_group() - - ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + # NOTE(kaichao): By default, pynccl is initialized for tp group. + pynccl_utils.init_process_group( + group=get_tensor_model_parallel_cpu_group()) # Initialize a custom fast all-reduce implementation. if not parallel_config.disable_custom_all_reduce: