11# SPDX-License-Identifier: Apache-2.0 
22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
33
4+ import  queue 
45import  random 
56import  typing 
67
1011import  torch .multiprocessing  as  mp 
1112
1213import  vllm .envs  as  envs 
14+ from  vllm .config  import  ParallelConfig , VllmConfig , set_current_vllm_config 
1315from  vllm .distributed  import  cleanup_dist_env_and_memory 
1416from  vllm .distributed .communication_op  import  tensor_model_parallel_all_reduce 
1517from  vllm .distributed .device_communicators .cuda_communicator  import  (
1618    CudaCommunicator )
17- from  vllm .distributed .parallel_state  import  (get_tensor_model_parallel_group ,
18-                                              get_tp_group ,
19+ from  vllm .distributed .parallel_state  import  (get_tp_group ,
1920                                             init_distributed_environment ,
2021                                             initialize_model_parallel )
22+ from  vllm .engine .arg_utils  import  EngineArgs 
23+ from  vllm .engine .llm_engine  import  LLMEngine 
2124from  vllm .platforms  import  current_platform 
2225from  vllm .utils  import  update_environment_variables 
2326
2427torch .manual_seed (42 )
2528random .seed (44 )
2629
27- test_size_elements  =  4   *   1024  *  1024 
30+ test_size_elements  =  1024  *  1024 
2831
2932
30- def  symm_mem_allreduce_worker (local_rank : int , world_size : int ):
33+ def  symm_mem_allreduce_worker (local_rank : int , world_size : int ,  q :  mp . Queue ):
3134    monkeypatch  =  pytest .MonkeyPatch ()
32-     with  monkeypatch .context () as  m :
35+     config  =  VllmConfig (parallel_config = ParallelConfig (
36+         tensor_parallel_size = world_size ))
37+ 
38+     with  monkeypatch .context () as  m , set_current_vllm_config (config ):
3339        m .delenv ("CUDA_VISIBLE_DEVICES" , raising = False )
3440        dtype  =  torch .bfloat16 
3541        device  =  torch .device (f"cuda:{ local_rank }  )
@@ -51,22 +57,26 @@ def symm_mem_allreduce_worker(local_rank: int, world_size: int):
5157                                        get_tp_group ().device_communicator )
5258        symm_mem_comm  =  cuda_communicator .symm_mem_comm 
5359        if  symm_mem_comm  is  None  or  symm_mem_comm .disabled :
54-             pytest .skip ("SymmMemCommunicator is not available or disabled." )
60+             # can't use skip under multiprocessing 
61+             q .put ("SymmMemCommunicator is not available or disabled." )
62+             return 
5563
5664        inp_direct_symm_mem  =  torch .randint (1 ,
5765                                            23 , (test_size_elements , ),
5866                                            dtype = dtype ,
5967                                            device = device )
6068        if  not  symm_mem_comm .should_use_symm_mem (inp_direct_symm_mem ):
61-             pytest .skip (
69+             # can't use skip under multiprocessing 
70+             q .put (
6271                "SymmMemCommunicator isn't used for this world and input size." 
6372            )
73+             return 
6474
6575        original_inp_direct_symm_mem  =  inp_direct_symm_mem .clone ()
6676        out_direct_symm_mem  =  symm_mem_comm .all_reduce (inp_direct_symm_mem )
6777        assert  out_direct_symm_mem  is  not None 
6878
69-         group  =  get_tensor_model_parallel_group ().device_group 
79+         group  =  get_tp_group ().device_group 
7080        dist .all_reduce (original_inp_direct_symm_mem , group = group )
7181        torch .testing .assert_close (out_direct_symm_mem ,
7282                                   original_inp_direct_symm_mem ,
@@ -100,9 +110,34 @@ def test_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, tp_size,
100110    world_size  =  tp_size  *  pipeline_parallel_size 
101111    if  world_size  >  torch .cuda .device_count ():
102112        pytest .skip ("Not enough GPUs to run the test." )
113+     q  =  mp .get_context ('spawn' ).Queue ()
114+     mp .spawn (symm_mem_allreduce_worker ,
115+              args = (world_size , q ),
116+              nprocs = world_size )
117+     try :
118+         val  =  q .get (timeout = 1 )
119+     except  queue .Empty :
120+         val  =  None 
121+     finally :
122+         cleanup_dist_env_and_memory ()
123+         if  val  is  not None :
124+             pytest .skip (val )
103125
104-     # Enable SymmMemCommunicator 
105-     monkeypatch .setenv ("VLLM_ALLREDUCE_USE_SYMM_MEM" , "1" )
106126
107-     mp .spawn (symm_mem_allreduce_worker , args = (world_size , ), nprocs = world_size )
108-     cleanup_dist_env_and_memory ()
127+ @pytest .mark .skipif ( 
128+     not  current_platform .is_cuda (), 
129+     reason = "SymmMemAllreduce is only available for CUDA platforms." ) 
130+ @pytest .mark .skipif (envs .VLLM_TARGET_DEVICE  not  in "cuda" ], 
131+                     reason = "Only test on CUDA" ) 
132+ def  test_dp_with_symm_mem_allreduce (monkeypatch : pytest .MonkeyPatch ):
133+     world_size  =  4 
134+     if  world_size  >  torch .cuda .device_count ():
135+         pytest .skip ("Not enough GPUs to run the test." )
136+     # Verify that the DataParallel runs without error 
137+     engine_args  =  EngineArgs (model = "distilbert/distilgpt2" ,
138+                              enforce_eager = True ,
139+                              enable_prefix_caching = True ,
140+                              data_parallel_size = 2 ,
141+                              tensor_parallel_size = 2 ,
142+                              data_parallel_backend = "mp" )
143+     LLMEngine .from_engine_args (engine_args )
0 commit comments