Skip to content

Commit

Permalink
[Core][Distributed] enable allreduce for multiple tp groups (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored and joerunde committed May 6, 2024
1 parent 3e9f425 commit 977a6cd
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 22 deletions.
43 changes: 39 additions & 4 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import pytest
import torch

import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
init_distributed_environment, with_pynccl_for_all_reduce)
from vllm.utils import update_environment_variables


Expand Down Expand Up @@ -67,7 +71,7 @@ def multiple_tp_worker_fn():
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
comm = NCCLCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
Expand All @@ -81,9 +85,40 @@ def multiple_tp_worker_fn():


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 2 GPUs to run the test.")
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp():
distributed_run(worker_fn, 4)
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `comm.all_reduce` directly
distributed_run(multiple_tp_worker_fn, 4)


@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(torch.distributed.get_rank())
ensure_model_parallel_initialized(2, 2)
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with with_pynccl_for_all_reduce():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
tensor = tensor_model_parallel_all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
tensor = tensor_model_parallel_all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp_with_vllm():
# this tests pynccl for multiple tp groups, together with vllm
# i.e. call `tensor_model_parallel_all_reduce`
distributed_run(multiple_tp_with_vllm_worker_fn, 4)


@worker_fn_wrapper
Expand Down
1 change: 0 additions & 1 deletion vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
if out is not None:
return out
if is_pynccl_enabled_for_all_reduce():
# TODO: support multiple parallel groups.
pynccl_utils.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
Expand Down
36 changes: 25 additions & 11 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
logger = init_logger(__name__)

# Tensor model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
_TP_DEVICE_GROUP = None
_TP_CPU_GROUP = None
# Pipeline model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None

Expand Down Expand Up @@ -132,15 +133,17 @@ def initialize_model_parallel(
rank = torch.distributed.get_rank()

# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
global _TP_DEVICE_GROUP, _TP_CPU_GROUP
assert _TP_DEVICE_GROUP is None, (
"tensor model parallel group is already initialized")
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size,
(i + 1) * tensor_model_parallel_size)
group = torch.distributed.new_group(ranks, backend=backend)
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
_TP_DEVICE_GROUP = group
_TP_CPU_GROUP = cpu_group

# Build the pipeline model-parallel groups.
global _PIPELINE_MODEL_PARALLEL_GROUP
Expand Down Expand Up @@ -185,7 +188,7 @@ def ensure_model_parallel_initialized(

def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
return (_TP_DEVICE_GROUP is not None
and _PIPELINE_MODEL_PARALLEL_GROUP is not None)


Expand All @@ -197,9 +200,16 @@ def get_cpu_world_group():

def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
assert _TP_DEVICE_GROUP is not None, (
"tensor model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP
return _TP_DEVICE_GROUP


def get_tensor_model_parallel_cpu_group():
"""Get the tensor model parallel cpu group the caller rank belongs to."""
assert _TP_CPU_GROUP is not None, (
"tensor model parallel cpu group is not initialized")
return _TP_CPU_GROUP


def get_pipeline_model_parallel_group():
Expand Down Expand Up @@ -277,10 +287,14 @@ def get_pipeline_model_parallel_prev_rank():

def destroy_model_parallel():
"""Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP
if _TENSOR_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
_TENSOR_MODEL_PARALLEL_GROUP = None
global _TP_DEVICE_GROUP
if _TP_DEVICE_GROUP:
torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
_TP_DEVICE_GROUP = None
global _TP_CPU_GROUP
if _TP_CPU_GROUP:
torch.distributed.destroy_process_group(_TP_CPU_GROUP)
_TP_CPU_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
Expand Down
13 changes: 7 additions & 6 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
VisionLanguageConfig)
from vllm.distributed import (broadcast_tensor_dict,
ensure_model_parallel_initialized,
get_tensor_model_parallel_cpu_group,
init_distributed_environment)
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
Expand Down Expand Up @@ -288,6 +289,9 @@ def init_worker_distributed_environment(
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank)

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)

if pynccl_utils.is_initialized():
pynccl_world_size = pynccl_utils.get_world_size()
if pynccl_world_size != parallel_config.world_size:
Expand All @@ -298,12 +302,9 @@ def init_worker_distributed_environment(
elif parallel_config.world_size > 1:
# NOTE(woosuk): We don't initialize pynccl process group when world size
# is 1.
# NOTE(kaichao): By default, pynccl will use information inside
# `parallel_state` for initialization.
pynccl_utils.init_process_group()

ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
# NOTE(kaichao): By default, pynccl is initialized for tp group.
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())

# Initialize a custom fast all-reduce implementation.
if not parallel_config.disable_custom_all_reduce:
Expand Down

0 comments on commit 977a6cd

Please sign in to comment.