11# SPDX-License-Identifier: Apache-2.0
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
4+ import queue
45import random
56import typing
67
1011import torch .multiprocessing as mp
1112
1213import vllm .envs as envs
14+ from vllm .config import ParallelConfig , VllmConfig , set_current_vllm_config
1315from vllm .distributed import cleanup_dist_env_and_memory
1416from vllm .distributed .communication_op import tensor_model_parallel_all_reduce
1517from vllm .distributed .device_communicators .cuda_communicator import (
1618 CudaCommunicator )
17- from vllm .distributed .parallel_state import (get_tensor_model_parallel_group ,
18- get_tp_group ,
19+ from vllm .distributed .parallel_state import (get_tp_group ,
1920 init_distributed_environment ,
2021 initialize_model_parallel )
22+ from vllm .engine .arg_utils import EngineArgs
23+ from vllm .engine .llm_engine import LLMEngine
2124from vllm .platforms import current_platform
2225from vllm .utils import update_environment_variables
2326
2427torch .manual_seed (42 )
2528random .seed (44 )
2629
27- test_size_elements = 4 * 1024 * 1024
30+ test_size_elements = 1024 * 1024
2831
2932
30- def symm_mem_allreduce_worker (local_rank : int , world_size : int ):
33+ def symm_mem_allreduce_worker (local_rank : int , world_size : int , q : mp . Queue ):
3134 monkeypatch = pytest .MonkeyPatch ()
32- with monkeypatch .context () as m :
35+ config = VllmConfig (parallel_config = ParallelConfig (
36+ tensor_parallel_size = world_size ))
37+
38+ with monkeypatch .context () as m , set_current_vllm_config (config ):
3339 m .delenv ("CUDA_VISIBLE_DEVICES" , raising = False )
3440 dtype = torch .bfloat16
3541 device = torch .device (f"cuda:{ local_rank } " )
@@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
5157 get_tp_group ().device_communicator )
5258 symm_mem_comm = cuda_communicator .symm_mem_comm
5359 if symm_mem_comm is None or symm_mem_comm .disabled :
54- pytest .skip ("SymmMemCommunicator is not available or disabled." )
60+ # can't use skip under multiprocessing
61+ q .put ("SymmMemCommunicator is not available or disabled." )
62+ return
5563
5664 inp_direct_symm_mem = torch .randint (1 ,
5765 23 , (test_size_elements , ),
5866 dtype = dtype ,
5967 device = device )
6068 if not symm_mem_comm .should_use_symm_mem (inp_direct_symm_mem ):
61- pytest .skip (
69+ # can't use skip under multiprocessing
70+ q .put (
6271 "SymmMemCommunicator isn't used for this world and input size."
6372 )
73+ return
6474
6575 original_inp_direct_symm_mem = inp_direct_symm_mem .clone ()
6676 out_direct_symm_mem = symm_mem_comm .all_reduce (inp_direct_symm_mem )
6777 assert out_direct_symm_mem is not None
6878
69- group = get_tensor_model_parallel_group ().device_group
79+ group = get_tp_group ().device_group
7080 dist .all_reduce (original_inp_direct_symm_mem , group = group )
7181 torch .testing .assert_close (out_direct_symm_mem ,
7282 original_inp_direct_symm_mem ,
@@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
100110 world_size = tp_size * pipeline_parallel_size
101111 if world_size > torch .cuda .device_count ():
102112 pytest .skip ("Not enough GPUs to run the test." )
113+ q = mp .get_context ('spawn' ).Queue ()
114+ mp .spawn (symm_mem_allreduce_worker ,
115+ args = (world_size , q ),
116+ nprocs = world_size )
117+ try :
118+ val = q .get (timeout = 1 )
119+ except queue .Empty :
120+ val = None
121+ finally :
122+ cleanup_dist_env_and_memory ()
123+ if val is not None :
124+ pytest .skip (val )
103125
104- # Enable SymmMemCommunicator
105- monkeypatch .setenv ("VLLM_ALLREDUCE_USE_SYMM_MEM" , "1" )
106126
107- mp .spawn (symm_mem_allreduce_worker , args = (world_size , ), nprocs = world_size )
108- cleanup_dist_env_and_memory ()
127+ @pytest .mark .skipif (
128+ not current_platform .is_cuda (),
129+ reason = "SymmMemAllreduce is only available for CUDA platforms." )
130+ @pytest .mark .skipif (envs .VLLM_TARGET_DEVICE not in ["cuda" ],
131+ reason = "Only test on CUDA" )
132+ def test_dp_with_symm_mem_allreduce (monkeypatch : pytest .MonkeyPatch ):
133+ world_size = 4
134+ if world_size > torch .cuda .device_count ():
135+ pytest .skip ("Not enough GPUs to run the test." )
136+ # Verify that the DataParallel runs without error
137+ engine_args = EngineArgs (model = "distilbert/distilgpt2" ,
138+ enforce_eager = True ,
139+ enable_prefix_caching = True ,
140+ data_parallel_size = 2 ,
141+ tensor_parallel_size = 2 ,
142+ data_parallel_backend = "mp" )
143+ LLMEngine .from_engine_args (engine_args )
0 commit comments