Skip to content

Commit 8bdd8b5

Browse files
ilmarkovmgoin
andauthored
Enable symmetric memory all reduce by default only enabling for TP (#25070)
Signed-off-by: ilmarkov <markovilya197@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent a8ffc4f commit 8bdd8b5

File tree

4 files changed

+56
-16
lines changed

4 files changed

+56
-16
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ steps:
164164
- tests/v1/test_internal_lb_dp.py
165165
- tests/v1/test_hybrid_lb_dp.py
166166
- tests/v1/engine/test_engine_core_client.py
167+
- tests/distributed/test_symm_mem_allreduce.py
167168
commands:
168169
# test with torchrun tp=2 and external_dp=2
169170
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
@@ -188,6 +189,7 @@ steps:
188189
- pytest -v -s compile/test_basic_correctness.py
189190
- pytest -v -s distributed/test_pynccl.py
190191
- pytest -v -s distributed/test_events.py
192+
- pytest -v -s distributed/test_symm_mem_allreduce.py
191193
# TODO: create a dedicated test section for multi-GPU example tests
192194
# when we have multiple distributed example tests
193195
- pushd ../examples/offline_inference

tests/distributed/test_symm_mem_allreduce.py

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4+
import queue
45
import random
56
import typing
67

@@ -10,26 +11,31 @@
1011
import torch.multiprocessing as mp
1112

1213
import vllm.envs as envs
14+
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
1315
from vllm.distributed import cleanup_dist_env_and_memory
1416
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
1517
from vllm.distributed.device_communicators.cuda_communicator import (
1618
CudaCommunicator)
17-
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
18-
get_tp_group,
19+
from vllm.distributed.parallel_state import (get_tp_group,
1920
init_distributed_environment,
2021
initialize_model_parallel)
22+
from vllm.engine.arg_utils import EngineArgs
23+
from vllm.engine.llm_engine import LLMEngine
2124
from vllm.platforms import current_platform
2225
from vllm.utils import update_environment_variables
2326

2427
torch.manual_seed(42)
2528
random.seed(44)
2629

27-
test_size_elements = 4 * 1024 * 1024
30+
test_size_elements = 1024 * 1024
2831

2932

30-
def symm_mem_allreduce_worker(local_rank: int, world_size: int):
33+
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
3134
monkeypatch = pytest.MonkeyPatch()
32-
with monkeypatch.context() as m:
35+
config = VllmConfig(parallel_config=ParallelConfig(
36+
tensor_parallel_size=world_size))
37+
38+
with monkeypatch.context() as m, set_current_vllm_config(config):
3339
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
3440
dtype = torch.bfloat16
3541
device = torch.device(f"cuda:{local_rank}")
@@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
5157
get_tp_group().device_communicator)
5258
symm_mem_comm = cuda_communicator.symm_mem_comm
5359
if symm_mem_comm is None or symm_mem_comm.disabled:
54-
pytest.skip("SymmMemCommunicator is not available or disabled.")
60+
# can't use skip under multiprocessing
61+
q.put("SymmMemCommunicator is not available or disabled.")
62+
return
5563

5664
inp_direct_symm_mem = torch.randint(1,
5765
23, (test_size_elements, ),
5866
dtype=dtype,
5967
device=device)
6068
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
61-
pytest.skip(
69+
# can't use skip under multiprocessing
70+
q.put(
6271
"SymmMemCommunicator isn't used for this world and input size."
6372
)
73+
return
6474

6575
original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
6676
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
6777
assert out_direct_symm_mem is not None
6878

69-
group = get_tensor_model_parallel_group().device_group
79+
group = get_tp_group().device_group
7080
dist.all_reduce(original_inp_direct_symm_mem, group=group)
7181
torch.testing.assert_close(out_direct_symm_mem,
7282
original_inp_direct_symm_mem,
@@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
100110
world_size = tp_size * pipeline_parallel_size
101111
if world_size > torch.cuda.device_count():
102112
pytest.skip("Not enough GPUs to run the test.")
113+
q = mp.get_context('spawn').Queue()
114+
mp.spawn(symm_mem_allreduce_worker,
115+
args=(world_size, q),
116+
nprocs=world_size)
117+
try:
118+
val = q.get(timeout=1)
119+
except queue.Empty:
120+
val = None
121+
finally:
122+
cleanup_dist_env_and_memory()
123+
if val is not None:
124+
pytest.skip(val)
103125

104-
# Enable SymmMemCommunicator
105-
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")
106126

107-
mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
108-
cleanup_dist_env_and_memory()
127+
@pytest.mark.skipif(
128+
not current_platform.is_cuda(),
129+
reason="SymmMemAllreduce is only available for CUDA platforms.")
130+
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
131+
reason="Only test on CUDA")
132+
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
133+
world_size = 4
134+
if world_size > torch.cuda.device_count():
135+
pytest.skip("Not enough GPUs to run the test.")
136+
# Verify that the DataParallel runs without error
137+
engine_args = EngineArgs(model="distilbert/distilgpt2",
138+
enforce_eager=True,
139+
enable_prefix_caching=True,
140+
data_parallel_size=2,
141+
tensor_parallel_size=2,
142+
data_parallel_backend="mp")
143+
LLMEngine.from_engine_args(engine_args)

vllm/distributed/device_communicators/cuda_communicator.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,21 @@ def __init__(self,
3030
unique_name: str = ""):
3131
super().__init__(cpu_group, device, device_group, unique_name)
3232
if "tp" not in unique_name:
33-
# only tp uses custom allreduce
33+
# custom allreduce or torch symm mem can be used only by tp
3434
use_custom_allreduce = False
35+
use_torch_symm_mem = False
3536
else:
3637
from vllm.distributed.parallel_state import (
3738
_ENABLE_CUSTOM_ALL_REDUCE)
3839
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
40+
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM
3941

4042
# ep does not use pynccl
4143
use_pynccl = "ep" not in unique_name
4244

4345
self.use_pynccl = use_pynccl
4446
self.use_custom_allreduce = use_custom_allreduce
47+
self.use_torch_symm_mem = use_torch_symm_mem
4548

4649
# lazy import to avoid documentation build error
4750
from vllm.distributed.device_communicators.custom_all_reduce import (
@@ -65,7 +68,7 @@ def __init__(self,
6568
self.ca_comm: Optional[CustomAllreduce] = None
6669
self.qr_comm: Optional[QuickAllReduce] = None
6770
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
68-
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
71+
if use_torch_symm_mem and current_platform.is_cuda():
6972
self.symm_mem_comm = SymmMemCommunicator(
7073
group=self.cpu_group,
7174
device=self.device,

vllm/envs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@
182182
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
183183
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
184184
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
185-
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
185+
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
186186
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
187187
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
188188
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
@@ -1370,7 +1370,7 @@ def get_vllm_port() -> Optional[int]:
13701370

13711371
# Whether to use pytorch symmetric memory for allreduce
13721372
"VLLM_ALLREDUCE_USE_SYMM_MEM":
1373-
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
1373+
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),
13741374

13751375
# Allows vllm to find tuned config under customized folder
13761376
"VLLM_TUNED_CONFIG_FOLDER":

0 commit comments

Comments
 (0)