1111
1212from vllm .distributed import (broadcast_tensor_dict , get_pp_group ,
1313 tensor_model_parallel_all_gather ,
14- tensor_model_parallel_all_reduce )
14+ tensor_model_parallel_all_reduce ,
15+ tensor_model_parallel_reduce_scatter )
1516
1617from ..utils import init_test_distributed_environment , multi_process_parallel
1718
@@ -38,6 +39,33 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
3839 torch .testing .assert_close (t , expected )
3940
4041
42+ @ray .remote (num_gpus = 1 , max_calls = 1 )
43+ def reduce_scatter_test_worker (tp_size : int , pp_size : int , rank : int ,
44+ distributed_init_port : str ):
45+ # it is important to delete the CUDA_VISIBLE_DEVICES environment variable
46+ # so that each worker can see all the GPUs
47+ # they will be able to set the device to the correct GPU
48+ os .environ .pop ("CUDA_VISIBLE_DEVICES" , None )
49+ device = torch .device (f"cuda:{ rank } " )
50+ torch .cuda .set_device (device )
51+ init_test_distributed_environment (tp_size , pp_size , rank ,
52+ distributed_init_port )
53+
54+ num_elements = 8
55+ all_tensors = [
56+ torch .arange (num_elements , dtype = torch .float32 , device = "cuda" ) *
57+ (r + 1 ) for r in range (tp_size )
58+ ]
59+
60+ index = rank % tp_size
61+ partition_size = num_elements // tp_size
62+ all_reduce = torch .sum (torch .stack (all_tensors , dim = 0 ), dim = 0 )
63+ expected = all_reduce [index * partition_size :(index + 1 ) * partition_size ]
64+ t = all_tensors [index ]
65+ t = tensor_model_parallel_reduce_scatter (t )
66+ torch .testing .assert_close (t , expected )
67+
68+
4169@ray .remote (num_gpus = 1 , max_calls = 1 )
4270def all_gather_test_worker (tp_size : int , pp_size : int , rank : int ,
4371 distributed_init_port : str ):
@@ -178,6 +206,17 @@ def test_multi_process_tensor_parallel(tp_size, test_target):
178206 multi_process_parallel (tp_size , 1 , test_target )
179207
180208
209+ @pytest .mark .skipif (torch .cuda .device_count () < 2 ,
210+ reason = "Need at least 2 GPUs to run the test." )
211+ @pytest .mark .parametrize ("tp_size" , [2 ])
212+ @pytest .mark .parametrize ("test_target" , [
213+ all_reduce_test_worker , all_gather_test_worker , reduce_scatter_test_worker ,
214+ broadcast_tensor_dict_test_worker
215+ ])
216+ def test_multi_process_tesor_parallel_sequence_parallel (tp_size , test_target ):
217+ multi_process_parallel (tp_size , 1 , test_target )
218+
219+
181220@pytest .mark .skipif (torch .cuda .device_count () < 2 ,
182221 reason = "Need at least 2 GPUs to run the test." )
183222@pytest .mark .parametrize ("pp_size" , [2 ])
@@ -199,3 +238,17 @@ def test_multi_process_pipeline_parallel(pp_size, test_target):
199238def test_multi_process_tensor_parallel_pipeline_parallel (
200239 tp_size , pp_size , test_target ):
201240 multi_process_parallel (tp_size , pp_size , test_target )
241+
242+
243+ @pytest .mark .skipif (torch .cuda .device_count () < 4 ,
244+ reason = "Need at least 4 GPUs to run the test." )
245+ @pytest .mark .parametrize ("tp_size" , [2 ])
246+ @pytest .mark .parametrize ("pp_size" , [2 ])
247+ @pytest .mark .parametrize ("test_target" , [
248+ send_recv_test_worker , send_recv_tensor_dict_test_worker ,
249+ all_reduce_test_worker , all_gather_test_worker , reduce_scatter_test_worker ,
250+ broadcast_tensor_dict_test_worker
251+ ])
252+ def test_multi_process_tensor_parallel_sequence_parallel_pipeline_parallel (
253+ tp_size , pp_size , test_target ):
254+ multi_process_parallel (tp_size , pp_size , test_target )
0 commit comments