Skip to content

Commit 8c1c81a

Browse files
Amir-19mgoin
andauthored
[core] add nccl symmetric memory for all reduce (#24532)
Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent a3a7828 commit 8c1c81a

File tree

12 files changed

+489
-6
lines changed

12 files changed

+489
-6
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1039,3 +1039,4 @@ steps:
10391039
num_gpus: 2
10401040
commands:
10411041
- pytest -v -s tests/distributed/test_context_parallel.py
1042+
- pytest -v -s tests/distributed/test_nccl_symm_mem_allreduce.py

benchmarks/kernels/benchmark_device_communicators.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
CustomAllreduce (oneshot, twoshot), PyNcclCommunicator,
88
and SymmMemCommunicator (multimem, two-shot).
99
10+
for NCCL symmetric memory you need to set the environment variables
11+
NCCL_NVLS_ENABLE=1 NCCL_CUMEM_ENABLE=1 VLLM_USE_NCCL_SYMM_MEM=1, otherwise NCCL does
12+
not use fast NVLS implementation for all reduce.
13+
1014
Usage:
1115
torchrun --nproc_per_node=<N> benchmark_device_communicators.py [options]
1216
@@ -26,7 +30,13 @@
2630
from torch.distributed import ProcessGroup
2731

2832
from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce
29-
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
33+
from vllm.distributed.device_communicators.pynccl import (
34+
PyNcclCommunicator,
35+
register_nccl_symmetric_ops,
36+
)
37+
from vllm.distributed.device_communicators.pynccl_allocator import (
38+
set_graph_pool_id,
39+
)
3040
from vllm.distributed.device_communicators.symm_mem import SymmMemCommunicator
3141
from vllm.logger import init_logger
3242
from vllm.utils import FlexibleArgumentParser
@@ -98,6 +108,7 @@ def _init_communicators(self):
98108
)
99109
if not self.pynccl_comm.disabled:
100110
logger.info("Rank %s: PyNcclCommunicator initialized", self.rank)
111+
register_nccl_symmetric_ops(self.pynccl_comm)
101112
else:
102113
logger.info("Rank %s: PyNcclCommunicator disabled", self.rank)
103114
self.pynccl_comm = None
@@ -194,6 +205,15 @@ def benchmark_allreduce(
194205
None, # no env variable needed
195206
)
196207
)
208+
communicators.append(
209+
(
210+
"pynccl-symm",
211+
lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t),
212+
lambda t: True, # Always available if initialized
213+
nullcontext(),
214+
None, # no env variable needed
215+
)
216+
)
197217

198218
if self.symm_mem_comm_multimem is not None:
199219
comm = self.symm_mem_comm_multimem
@@ -271,7 +291,9 @@ def benchmark_allreduce_single(
271291
# Capture the graph using context manager
272292
with context:
273293
graph = torch.cuda.CUDAGraph()
274-
with torch.cuda.graph(graph):
294+
graph_pool = torch.cuda.graph_pool_handle()
295+
set_graph_pool_id(graph_pool)
296+
with torch.cuda.graph(graph, pool=graph_pool):
275297
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
276298
allreduce_fn(graph_input)
277299

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import random
5+
import typing
6+
7+
import pytest
8+
import torch
9+
import torch.distributed as dist
10+
import torch.multiprocessing as mp
11+
12+
import vllm.envs as envs
13+
from vllm.distributed import cleanup_dist_env_and_memory
14+
from vllm.distributed.device_communicators.cuda_communicator import (
15+
CudaCommunicator)
16+
from vllm.distributed.device_communicators.pynccl import (
17+
register_nccl_symmetric_ops)
18+
from vllm.distributed.device_communicators.pynccl_allocator import (
19+
get_nccl_mem_pool, is_symmetric_memory_enabled)
20+
from vllm.distributed.parallel_state import (get_tp_group,
21+
init_distributed_environment,
22+
initialize_model_parallel)
23+
from vllm.platforms import current_platform
24+
from vllm.utils import update_environment_variables
25+
26+
torch.manual_seed(42)
27+
random.seed(44)
28+
29+
test_size_elements = 4 * 1024 * 1024
30+
31+
32+
def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
33+
monkeypatch = pytest.MonkeyPatch()
34+
with monkeypatch.context() as m:
35+
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
36+
dtype = torch.bfloat16
37+
device = torch.device(f"cuda:{local_rank}")
38+
torch.cuda.set_device(device)
39+
torch.set_default_device(device)
40+
torch.set_default_dtype(dtype)
41+
update_environment_variables({
42+
"RANK": str(local_rank),
43+
"LOCAL_RANK": str(local_rank),
44+
"WORLD_SIZE": str(world_size),
45+
"MASTER_ADDR": "localhost",
46+
"MASTER_PORT": "12345",
47+
})
48+
49+
init_distributed_environment()
50+
initialize_model_parallel(tensor_model_parallel_size=world_size)
51+
52+
cuda_communicator = typing.cast(CudaCommunicator,
53+
get_tp_group().device_communicator)
54+
pynccl_comm = cuda_communicator.pynccl_comm
55+
if get_nccl_mem_pool() is None:
56+
pytest.skip("NCCL allocator compilation failed "
57+
"(probably missing NCCL headers).")
58+
if not is_symmetric_memory_enabled():
59+
pytest.skip("NCCL symmetric memory allreduce is disabled.")
60+
61+
register_nccl_symmetric_ops(pynccl_comm)
62+
input = torch.randint(1,
63+
23, (test_size_elements, ),
64+
dtype=dtype,
65+
device=device)
66+
input_clone = input.clone()
67+
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
68+
assert output is not None
69+
70+
group = get_tp_group().device_group
71+
dist.all_reduce(input_clone, group=group)
72+
torch.testing.assert_close(output, input_clone, atol=2.5, rtol=0.1)
73+
74+
75+
@pytest.mark.skipif(
76+
not current_platform.is_cuda(),
77+
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
78+
)
79+
@pytest.mark.parametrize("world_size", [2])
80+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
81+
reason="Only test on CUDA")
82+
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
83+
if world_size > torch.cuda.device_count():
84+
pytest.skip("Not enough GPUs to run the test.")
85+
86+
# Enable SymmMemCommunicator
87+
monkeypatch.setenv("VLLM_USE_NCCL_SYMM_MEM", "1")
88+
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
89+
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
90+
91+
mp.spawn(nccl_symm_mem_allreduce_worker,
92+
args=(world_size, ),
93+
nprocs=world_size)
94+
cleanup_dist_env_and_memory()

vllm/compilation/cuda_graph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from vllm.compilation.counter import compilation_counter
1313
from vllm.compilation.monitor import validate_cudagraph_capturing_enabled
1414
from vllm.config import CUDAGraphMode, VllmConfig
15+
from vllm.distributed.device_communicators.pynccl_allocator import (
16+
set_graph_pool_id)
1517
from vllm.forward_context import BatchDescriptor, get_forward_context
1618
from vllm.logger import init_logger
1719
from vllm.platforms import current_platform
@@ -154,6 +156,10 @@ def __call__(self, *args, **kwargs):
154156
stack.enter_context(
155157
patch("torch.cuda.empty_cache", lambda: None))
156158

159+
if self.graph_pool is not None:
160+
set_graph_pool_id(self.graph_pool)
161+
else:
162+
set_graph_pool_id(current_platform.graph_pool_handle())
157163
# mind-exploding: carefully manage the reference and memory.
158164
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
159165
# `output` is managed by pytorch's cudagraph pool

vllm/distributed/device_communicators/all_reduce_utils.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,9 @@
1010
import tempfile
1111
from collections.abc import Sequence
1212
from itertools import product
13-
from typing import Optional
13+
from typing import Any, Optional
1414

15+
import torch
1516
import torch.distributed as dist
1617
import torch.multiprocessing as mp
1718

@@ -56,6 +57,30 @@
5657
}
5758
}
5859

60+
NCCL_SYMM_MEM_ALL_REDUCE_CONFIG: dict[str, Any] = {
61+
"min_world_size": 4,
62+
"thresholds": {
63+
4: 2 * MiB, # 2 MB
64+
8: 1 * MiB, # 1 MB
65+
},
66+
"always_use_above_world_size": 8 # Always use symm mem for world_size > 8
67+
}
68+
69+
70+
def should_nccl_symm_mem_allreduce(world_size: int,
71+
input_tensor: torch.Tensor) -> bool:
72+
from vllm.distributed.device_communicators.pynccl_allocator import (
73+
is_symmetric_memory_enabled)
74+
if not is_symmetric_memory_enabled():
75+
return False
76+
if world_size < NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["min_world_size"]:
77+
return False
78+
threshold = NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["thresholds"].get(world_size)
79+
if threshold is not None and input_tensor.nbytes >= threshold:
80+
return True
81+
return (world_size
82+
> NCCL_SYMM_MEM_ALL_REDUCE_CONFIG["always_use_above_world_size"])
83+
5984

6085
def producer(batch_src: Sequence[int],
6186
producer_queue,

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
from torch.distributed import ProcessGroup
88

99
import vllm.envs as envs
10+
from vllm.distributed.device_communicators.all_reduce_utils import (
11+
should_nccl_symm_mem_allreduce)
12+
from vllm.distributed.device_communicators.pynccl import (
13+
register_nccl_symmetric_ops)
14+
from vllm.distributed.device_communicators.pynccl_allocator import (
15+
is_symmetric_memory_enabled)
1016
from vllm.logger import init_logger
1117
from vllm.platforms import current_platform
1218

@@ -53,6 +59,8 @@ def __init__(self,
5359
group=self.cpu_group,
5460
device=self.device,
5561
)
62+
if is_symmetric_memory_enabled():
63+
register_nccl_symmetric_ops(self.pynccl_comm)
5664

5765
self.ca_comm: Optional[CustomAllreduce] = None
5866
self.qr_comm: Optional[QuickAllReduce] = None
@@ -107,6 +115,13 @@ def __init__(self,
107115
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
108116

109117
def all_reduce(self, input_):
118+
# since currently we perform copy input -> symm_input -> out-of-place AR
119+
# return symm_output, we don't need to check if input is symmetric
120+
if self.pynccl_comm is not None and \
121+
should_nccl_symm_mem_allreduce(self.pynccl_comm.world_size,input_):
122+
out = torch.ops.vllm.all_reduce_symmetric_with_copy(input_)
123+
if out is not None:
124+
return out
110125
# always try quick reduce first, then custom allreduce,
111126
# and then pynccl. (quick reduce just for ROCM MI3*)
112127
qr_comm = self.qr_comm

vllm/distributed/device_communicators/pynccl.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,39 @@
1717

1818
logger = init_logger(__name__)
1919

20+
_NCCL_SYMM_OPS_REGISTERED = False
21+
22+
23+
def register_nccl_symmetric_ops(pynccl_comm):
24+
from vllm.distributed.device_communicators.pynccl_allocator import (
25+
nccl_symm_mem_context)
26+
from vllm.utils import direct_register_custom_op
27+
28+
global _NCCL_SYMM_OPS_REGISTERED
29+
if _NCCL_SYMM_OPS_REGISTERED:
30+
return
31+
_NCCL_SYMM_OPS_REGISTERED = True
32+
33+
def all_reduce_symmetric_with_copy_impl(
34+
input_tensor: torch.Tensor) -> torch.Tensor:
35+
with nccl_symm_mem_context(pynccl_comm):
36+
symm_input = torch.empty_like(input_tensor)
37+
symm_output = torch.empty_like(input_tensor)
38+
symm_input.copy_(input_tensor)
39+
symm_output = pynccl_comm.all_reduce(symm_input, symm_output)
40+
return symm_output
41+
42+
def all_reduce_symmetric_with_copy_fake(
43+
input_tensor: torch.Tensor) -> torch.Tensor:
44+
return torch.empty_like(input_tensor)
45+
46+
direct_register_custom_op(
47+
op_name="all_reduce_symmetric_with_copy",
48+
op_func=all_reduce_symmetric_with_copy_impl,
49+
mutates_args=[],
50+
fake_impl=all_reduce_symmetric_with_copy_fake,
51+
)
52+
2053

2154
class PyNcclCommunicator:
2255

@@ -67,6 +100,7 @@ def __init__(
67100
self.available = True
68101
self.disabled = False
69102

103+
self.nccl_version = self.nccl.ncclGetRawVersion()
70104
logger.info("vLLM is using nccl==%s", self.nccl.ncclGetVersion())
71105

72106
if self.rank == 0:
@@ -109,6 +143,7 @@ def __init__(
109143

110144
def all_reduce(self,
111145
in_tensor: torch.Tensor,
146+
out_tensor: torch.Tensor = None,
112147
op: ReduceOp = ReduceOp.SUM,
113148
stream=None) -> torch.Tensor:
114149
if self.disabled:
@@ -120,7 +155,8 @@ def all_reduce(self,
120155
f"this nccl communicator is created to work on {self.device}, "
121156
f"but the input tensor is on {in_tensor.device}")
122157

123-
out_tensor = torch.empty_like(in_tensor)
158+
if out_tensor is None:
159+
out_tensor = torch.empty_like(in_tensor)
124160

125161
if stream is None:
126162
stream = current_stream()
@@ -288,3 +324,18 @@ def group_start(self):
288324

289325
def group_end(self):
290326
self.nccl.ncclGroupEnd()
327+
328+
def register_comm_window(self, tensor: torch.Tensor):
329+
return self.nccl.ncclCommWindowRegister(
330+
self.comm,
331+
buffer_type(tensor.data_ptr()),
332+
tensor.numel() * tensor.element_size(),
333+
1,
334+
)
335+
336+
def register_comm_window_raw(self, ptr: int, size: int):
337+
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr),
338+
size, 1)
339+
340+
def deregister_comm_window(self, window):
341+
return self.nccl.ncclCommWindowDeregister(self.comm, window)

0 commit comments

Comments
 (0)