diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index c4ea4b675649..955592588506 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -164,6 +164,7 @@ steps: - tests/v1/test_internal_lb_dp.py - tests/v1/test_hybrid_lb_dp.py - tests/v1/engine/test_engine_core_client.py + - tests/distributed/test_symm_mem_allreduce.py commands: # test with torchrun tp=2 and external_dp=2 - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py @@ -188,6 +189,7 @@ steps: - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s distributed/test_events.py + - pytest -v -s distributed/test_symm_mem_allreduce.py # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference diff --git a/tests/distributed/test_symm_mem_allreduce.py b/tests/distributed/test_symm_mem_allreduce.py index 5a804a389123..83e1fe47aeec 100644 --- a/tests/distributed/test_symm_mem_allreduce.py +++ b/tests/distributed/test_symm_mem_allreduce.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import queue import random import typing @@ -10,26 +11,31 @@ import torch.multiprocessing as mp import vllm.envs as envs +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.distributed import cleanup_dist_env_and_memory from vllm.distributed.communication_op import tensor_model_parallel_all_reduce from vllm.distributed.device_communicators.cuda_communicator import ( CudaCommunicator) -from vllm.distributed.parallel_state import (get_tensor_model_parallel_group, - get_tp_group, +from vllm.distributed.parallel_state import (get_tp_group, init_distributed_environment, initialize_model_parallel) +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine from vllm.platforms import current_platform from vllm.utils import update_environment_variables torch.manual_seed(42) random.seed(44) -test_size_elements = 4 * 1024 * 1024 +test_size_elements = 1024 * 1024 -def symm_mem_allreduce_worker(local_rank: int, world_size: int): +def symm_mem_allreduce_worker(local_rank: int, world_size: int, q: mp.Queue): monkeypatch = pytest.MonkeyPatch() - with monkeypatch.context() as m: + config = VllmConfig(parallel_config=ParallelConfig( + tensor_parallel_size=world_size)) + + with monkeypatch.context() as m, set_current_vllm_config(config): m.delenv("CUDA_VISIBLE_DEVICES", raising=False) dtype = torch.bfloat16 device = torch.device(f"cuda:{local_rank}") @@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int): get_tp_group().device_communicator) symm_mem_comm = cuda_communicator.symm_mem_comm if symm_mem_comm is None or symm_mem_comm.disabled: - pytest.skip("SymmMemCommunicator is not available or disabled.") + # can't use skip under multiprocessing + q.put("SymmMemCommunicator is not available or disabled.") + return inp_direct_symm_mem = torch.randint(1, 23, (test_size_elements, ), dtype=dtype, device=device) if not symm_mem_comm.should_use_symm_mem(inp_direct_symm_mem): - pytest.skip( + # can't use skip under multiprocessing + q.put( "SymmMemCommunicator isn't used for this world and input size." ) + return original_inp_direct_symm_mem = inp_direct_symm_mem.clone() out_direct_symm_mem = symm_mem_comm.all_reduce(inp_direct_symm_mem) assert out_direct_symm_mem is not None - group = get_tensor_model_parallel_group().device_group + group = get_tp_group().device_group dist.all_reduce(original_inp_direct_symm_mem, group=group) torch.testing.assert_close(out_direct_symm_mem, original_inp_direct_symm_mem, @@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size, world_size = tp_size * pipeline_parallel_size if world_size > torch.cuda.device_count(): pytest.skip("Not enough GPUs to run the test.") + q = mp.get_context('spawn').Queue() + mp.spawn(symm_mem_allreduce_worker, + args=(world_size, q), + nprocs=world_size) + try: + val = q.get(timeout=1) + except queue.Empty: + val = None + finally: + cleanup_dist_env_and_memory() + if val is not None: + pytest.skip(val) - # Enable SymmMemCommunicator - monkeypatch.setenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1") - mp.spawn(symm_mem_allreduce_worker, args=(world_size, ), nprocs=world_size) - cleanup_dist_env_and_memory() +@pytest.mark.skipif( + not current_platform.is_cuda(), + reason="SymmMemAllreduce is only available for CUDA platforms.") +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], + reason="Only test on CUDA") +def test_dp_with_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch): + world_size = 4 + if world_size > torch.cuda.device_count(): + pytest.skip("Not enough GPUs to run the test.") + # Verify that the DataParallel runs without error + engine_args = EngineArgs(model="distilbert/distilgpt2", + enforce_eager=True, + enable_prefix_caching=True, + data_parallel_size=2, + tensor_parallel_size=2, + data_parallel_backend="mp") + LLMEngine.from_engine_args(engine_args) diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index b2bf3bc3cc2e..177eaecdbd6c 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -24,18 +24,21 @@ def __init__(self, unique_name: str = ""): super().__init__(cpu_group, device, device_group, unique_name) if "tp" not in unique_name: - # only tp uses custom allreduce + # custom allreduce or torch symm mem can be used only by tp use_custom_allreduce = False + use_torch_symm_mem = False else: from vllm.distributed.parallel_state import ( _ENABLE_CUSTOM_ALL_REDUCE) use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + use_torch_symm_mem = envs.VLLM_ALLREDUCE_USE_SYMM_MEM # ep does not use pynccl use_pynccl = "ep" not in unique_name self.use_pynccl = use_pynccl self.use_custom_allreduce = use_custom_allreduce + self.use_torch_symm_mem = use_torch_symm_mem # lazy import to avoid documentation build error from vllm.distributed.device_communicators.custom_all_reduce import ( @@ -57,7 +60,7 @@ def __init__(self, self.ca_comm: Optional[CustomAllreduce] = None self.qr_comm: Optional[QuickAllReduce] = None self.symm_mem_comm: Optional[SymmMemCommunicator] = None - if envs.VLLM_ALLREDUCE_USE_SYMM_MEM and current_platform.is_cuda(): + if use_torch_symm_mem and current_platform.is_cuda(): self.symm_mem_comm = SymmMemCommunicator( group=self.cpu_group, device=self.device, diff --git a/vllm/envs.py b/vllm/envs.py index f6eafe892ef2..5f6242f51b11 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -182,7 +182,7 @@ VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False - VLLM_ALLREDUCE_USE_SYMM_MEM: bool = False + VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True VLLM_TUNED_CONFIG_FOLDER: Optional[str] = None VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False @@ -1368,7 +1368,7 @@ def get_vllm_port() -> Optional[int]: # Whether to use pytorch symmetric memory for allreduce "VLLM_ALLREDUCE_USE_SYMM_MEM": - lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "0"))), + lambda: bool(int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))), # Allows vllm to find tuned config under customized folder "VLLM_TUNED_CONFIG_FOLDER":