Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ steps:
- tests/v1/test_internal_lb_dp.py
- tests/v1/test_hybrid_lb_dp.py
- tests/v1/engine/test_engine_core_client.py
- tests/distributed/test_symm_mem_allreduce.py
commands:
# test with torchrun tp=2 and external_dp=2
- torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
Expand All @@ -188,6 +189,7 @@ steps:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s distributed/test_events.py
- pytest -v -s distributed/test_symm_mem_allreduce.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
- pushd ../examples/offline_inference
Expand Down
59 changes: 47 additions & 12 deletions tests/distributed/test_symm_mem_allreduce.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import queue
import random
import typing

Expand All @@ -10,26 +11,31 @@
import torch.multiprocessing as mp

import vllm.envs as envs
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.device_communicators.cuda_communicator import (
CudaCommunicator)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group,
from vllm.distributed.parallel_state import (get_tp_group,
init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

torch.manual_seed(42)
random.seed(44)

test_size_elements = 4 * 1024 * 1024
test_size_elements = 1024 * 1024


def symm_mem_allreduce_worker(local_rank: int, world_size: int):
def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue):
monkeypatch = pytest.MonkeyPatch()
with monkeypatch.context() as m:
config = VllmConfig(parallel_config=ParallelConfig(
tensor_parallel_size=world_size))

with monkeypatch.context() as m, set_current_vllm_config(config):
m.delenv("CUDA_VISIBLE_DEVICES", raising=False)
dtype = torch.bfloat16
device = torch.device(f"cuda:{local_rank}")
Expand All @@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
get_tp_group().device_communicator)
symm_mem_comm = cuda_communicator.symm_mem_comm
if symm_mem_comm is None or symm_mem_comm.disabled:
pytest.skip("SymmMemCommunicator is not available or disabled.")
# can't use skip under multiprocessing
q.put("SymmMemCommunicator is not available or disabled.")
return

inp_direct_symm_mem = torch.randint(1,
23, (test_size_elements, ),
dtype=dtype,
device=device)
if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem):
pytest.skip(
# can't use skip under multiprocessing
q.put(
"SymmMemCommunicator isn't used for this world and input size."
)
return

original_inp_direct_symm_mem = inp_direct_symm_mem.clone()
out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem)
assert out_direct_symm_mem is not None

group = get_tensor_model_parallel_group().device_group
group = get_tp_group().device_group
dist.all_reduce(original_inp_direct_symm_mem, group=group)
torch.testing.assert_close(out_direct_symm_mem,
original_inp_direct_symm_mem,
Expand Down Expand Up @@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
world_size = tp_size * pipeline_parallel_size
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
q = mp.get_context('spawn').Queue()
mp.spawn(symm_mem_allreduce_worker,
args=(world_size, q),
nprocs=world_size)
try:
val = q.get(timeout=1)
except queue.Empty:
val = None
finally:
cleanup_dist_env_and_memory()
if val is not None:
pytest.skip(val)

# Enable SymmMemCommunicator
monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1")

mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size)
cleanup_dist_env_and_memory()
@pytest.mark.skipif(
not current_platform.is_cuda(),
reason="SymmMemAllreduce is only available for CUDA platforms.")
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch):
world_size = 4
if world_size > torch.cuda.device_count():
pytest.skip("Not enough GPUs to run the test.")
# Verify that the DataParallel runs without error
engine_args = EngineArgs(model="distilbert/distilgpt2",
enforce_eager=True,
enable_prefix_caching=True,
data_parallel_size=2,
tensor_parallel_size=2,
data_parallel_backend="mp")
LLMEngine.from_engine_args(engine_args)
7 changes: 5 additions & 2 deletions vllm/distributed/device_communicators/cuda_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ def __init__(self,
unique_name: str = ""):
super().__init__(cpu_group, device, device_group, unique_name)
if "tp" not in unique_name:
# only tp uses custom allreduce
# custom allreduce or torch symm mem can be used only by tp
use_custom_allreduce = False
use_torch_symm_mem = False
else:
from vllm.distributed.parallel_state import (
_ENABLE_CUSTOM_ALL_REDUCE)
use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM

# ep does not use pynccl
use_pynccl = "ep" not in unique_name

self.use_pynccl = use_pynccl
self.use_custom_allreduce = use_custom_allreduce
self.use_torch_symm_mem = use_torch_symm_mem

# lazy import to avoid documentation build error
from vllm.distributed.device_communicators.custom_all_reduce import (
Expand All @@ -57,7 +60,7 @@ def __init__(self,
self.ca_comm: Optional[CustomAllreduce] = None
self.qr_comm: Optional[QuickAllReduce] = None
self.symm_mem_comm: Optional[SymmMemCommunicator] = None
if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda():
if use_torch_symm_mem and current_platform.is_cuda():
self.symm_mem_comm = SymmMemCommunicator(
group=self.cpu_group,
device=self.device,
Expand Down
4 changes: 2 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
Expand Down Expand Up @@ -1368,7 +1368,7 @@ def get_vllm_port() -> Optional[int]:

# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM":
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))),
lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))),

# Allows vllm to find tuned config under customized folder
"VLLM_TUNED_CONFIG_FOLDER":
Expand Down