Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] refactor pynccl to hold multiple communicators #4591

Merged
merged 44 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
9c6130a
add cache for loading the same library multiple times
youkaichao May 3, 2024
1493243
refactor code
youkaichao May 3, 2024
cadcd02
fix import
youkaichao May 3, 2024
7918798
remove pynccl_utils.init_process_group
youkaichao May 3, 2024
5924038
remove pynccl_utils.is_initialized
youkaichao May 3, 2024
813b047
remove pynccl_utils.destroy_process_group
youkaichao May 3, 2024
b244e6c
remove pynccl_utils.get_world_size
youkaichao May 3, 2024
7e15c98
remove pynccl_utils.get_nccl_backend
youkaichao May 3, 2024
e610f64
remove is_pynccl_enabled_for_all_reduce
youkaichao May 3, 2024
8480995
remove _ENABLE_PYNCCL_FOR_ALL_REDUCE
youkaichao May 3, 2024
5ed6f07
remove set_pynccl_stream
youkaichao May 3, 2024
8134287
remove pynccl utils
youkaichao May 3, 2024
e65e9ef
fix state
youkaichao May 3, 2024
c8b6fc0
fix test
youkaichao May 3, 2024
c7a2f0c
fix import
youkaichao May 3, 2024
75a8d11
move warmup into pynccl
youkaichao May 4, 2024
59c064e
add device
youkaichao May 4, 2024
16aeef1
fix device for allreduce warmup
youkaichao May 4, 2024
4710fc3
improve ways of discovering default local rank
youkaichao May 4, 2024
c8542ec
make sure warmup happens in stream
youkaichao May 4, 2024
b2d2661
add disable
youkaichao May 4, 2024
67d1d9a
do not init when world size is 1
youkaichao May 4, 2024
c86199c
fix initial state of pynccl allreduce
youkaichao May 4, 2024
0030a31
add comments
youkaichao May 4, 2024
49f6d91
add context manager
youkaichao May 4, 2024
38b148b
refactor logic of available
youkaichao May 4, 2024
d241480
non-intrusive code
youkaichao May 4, 2024
d7209f1
clean up pynccl enable or disable
youkaichao May 4, 2024
7b55026
fix isort
youkaichao May 4, 2024
ee734b1
fix stream attribute
youkaichao May 4, 2024
0516956
fix import
youkaichao May 9, 2024
9f63bf8
rename to PyNcclCommunicator and pynccl_comm
youkaichao May 9, 2024
e9aa766
rename use_pynccl_allreduce
youkaichao May 9, 2024
0f64301
fix lint
youkaichao May 9, 2024
a64962e
fix lint
youkaichao May 9, 2024
d2f83ba
fix lint
youkaichao May 9, 2024
12f309b
fix dependency on custom_all_reduce
youkaichao May 9, 2024
68e448c
fix lint
youkaichao May 9, 2024
ad6f840
use _PP_DEVICE_GROUP
youkaichao May 9, 2024
e2153b2
use _PP_GLOBAL_RANKS
youkaichao May 9, 2024
80aca94
fix lint
youkaichao May 9, 2024
c1b1cdb
use change_state rather than enable
youkaichao May 9, 2024
c4e3b0f
Merge branch 'main' into bind_pynccl_to_group
youkaichao May 9, 2024
70a7e26
add get_tp_pynccl_communicator
youkaichao May 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions tests/distributed/test_pynccl.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import multiprocessing
import os

import pytest
import torch

import vllm.distributed.device_communicators.pynccl_utils as pynccl_utils
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.pynccl import (NCCLCommunicator,
ncclGetUniqueId)
from vllm.distributed.parallel_state import (
ensure_model_parallel_initialized, get_tensor_model_parallel_cpu_group,
init_distributed_environment, with_pynccl_for_all_reduce)
from vllm.distributed.communication_op import ( # noqa
graph_capture_mode, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.utils import update_environment_variables


Expand Down Expand Up @@ -41,6 +41,9 @@ def worker_fn_wrapper(fn):
# and update the environment variables in the function
def wrapped_fn(env):
update_environment_variables(env)
local_rank = os.environ['LOCAL_RANK']
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
init_distributed_environment()
fn()

Expand All @@ -49,11 +52,13 @@ def wrapped_fn(env):

@worker_fn_wrapper
def worker_fn():
comm = NCCLCommunicator()
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(comm.rank)
comm.all_reduce(tensor)
pynccl_comm = PyNcclCommunicator()
tensor = torch.ones(16, 1024, 1024,
dtype=torch.float32).cuda(pynccl_comm.rank)
with pynccl_comm.change_state(enable=True):
pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == comm.world_size
assert result == pynccl_comm.world_size


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -70,37 +75,35 @@ def multiple_tp_worker_fn():
torch.distributed.new_group(ranks=[2, 3], backend="gloo")
]
group = groups[0] if torch.distributed.get_rank() in [0, 1] else groups[1]
comm = NCCLCommunicator(group=group, device=device)
pynccl_comm = PyNcclCommunicator(group=group, device=device)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
comm.all_reduce(tensor)
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2
with pynccl_comm.change_state(enable=True):
# two groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
pynccl_comm.all_reduce(tensor)
pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 4
else:
pynccl_comm.all_reduce(tensor)
result = tensor.mean().cpu().item()
assert result == 2


@pytest.mark.skipif(torch.cuda.device_count() < 4,
reason="Need at least 4 GPUs to run the test.")
def test_pynccl_multiple_tp():
# this tests pynccl for multiple tp groups, in a standalone way
# i.e. call `comm.all_reduce` directly
# i.e. call `pynccl_comm.all_reduce` directly
distributed_run(multiple_tp_worker_fn, 4)


@worker_fn_wrapper
def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
torch.cuda.set_device(torch.distributed.get_rank())
ensure_model_parallel_initialized(2, 2)
pynccl_utils.init_process_group(
group=get_tensor_model_parallel_cpu_group())
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with with_pynccl_for_all_reduce():
with graph_capture_mode():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)
Expand All @@ -125,19 +128,21 @@ def test_pynccl_multiple_tp_with_vllm():
def worker_fn_with_cudagraph():
with torch.no_grad():
graph = torch.cuda.CUDAGraph()
comm = NCCLCommunicator()
pynccl_comm = PyNcclCommunicator()
# run something in the default stream to initialize torch engine
a = torch.ones((4, 4), device=f'cuda:{comm.rank}')
a = torch.ones((4, 4), device=f'cuda:{pynccl_comm.rank}')
torch.cuda.synchronize()
with torch.cuda.graph(graph, stream=comm.stream):
with torch.cuda.graph(
graph, stream=pynccl_comm.stream), pynccl_comm.change_state(
enable=True):
# operation during the graph capture is recorded but not executed
# see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa
comm.all_reduce(a)
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**0
pynccl_comm.all_reduce(a)
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**0
graph.replay()
comm.stream.synchronize()
assert a.mean().cpu().item() == comm.world_size**1
pynccl_comm.stream.synchronize()
assert a.mean().cpu().item() == pynccl_comm.world_size**1


@pytest.mark.skipif(torch.cuda.device_count() < 2,
Expand All @@ -147,7 +152,8 @@ def test_pynccl_with_cudagraph():


def test_ncclGetUniqueId():
unique_id = ncclGetUniqueId()
lib = NCCLLibrary()
unique_id = lib.ncclGetUniqueId()
# `list(unique_id.internal)` is something like this:
# [34, -16, 23, 83, 109, -19, 59, 95, 2, 0, -86, 55, 10, -128, 0, 29, 0,
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
Expand Down
28 changes: 24 additions & 4 deletions vllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import namedtuple
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -8,7 +9,26 @@
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce)
get_tp_pynccl_communicator)


@contextmanager
def graph_capture_mode():
# In graph capture, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
pynccl_comm = get_tp_pynccl_communicator()
assert pynccl_comm is not None
with pynccl_comm.change_state(enable=True,
stream=torch.cuda.current_stream()):
yield


def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
Expand All @@ -23,7 +43,6 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
custom_all_reduce)

Expand All @@ -33,8 +52,9 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
out = custom_all_reduce(input_)
if out is not None:
return out
if is_pynccl_enabled_for_all_reduce():
pynccl_utils.all_reduce(input_)
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
Expand Down
Loading
Loading