3232from _test_utils .torch_quantization .quantize_common import (
3333 auto_quantize_helper ,
3434 tensor_parallel_test_helper ,
35+ data_parallel_test_helper ,
36+ context_parallel_test_helper ,
37+ data_tensor_context_parallel_test_helper ,
3538)
3639from packaging .version import Version
3740
4144from megatron .core .parallel_state import (
4245 destroy_model_parallel ,
4346 get_data_parallel_group ,
47+ get_context_parallel_group ,
4448 get_tensor_model_parallel_group ,
4549)
4650from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
@@ -91,13 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9195 # Clean up since this is not a spawned process
9296 destroy_model_parallel ()
9397
94-
98+ # 1. Tensor Parallel Test
9599def _test_tensor_parallel_helper (config , rank , size ):
96100 initialize_for_megatron (tensor_model_parallel_size = 2 , seed = SEED )
97- model = MegatronModel (size ).cuda ()
101+ model = MegatronModel (tp_size = size ).cuda ()
98102
99103 tensor_parallel_test_helper (
100- model , config , get_tensor_model_parallel_group (), get_data_parallel_group ()
104+ model , config , get_tensor_model_parallel_group ()
101105 )
102106
103107
@@ -118,6 +122,85 @@ def test_tensor_parallel(need_2_gpus, config):
118122 size = 2 , job = partial (_test_tensor_parallel_helper , config ), backend = "nccl"
119123 )
120124
125+ # 2. Data Parallel Test
126+ def _test_data_parallel_helper (config , rank , size ):
127+ # TODO does this model automatically get copied to both DP ranks?
128+ initialize_for_megatron (seed = SEED )
129+ model = MegatronModel ().cuda ()
130+
131+ data_parallel_test_helper (
132+ model , config , get_data_parallel_group ()
133+ )
134+
135+
136+ @pytest .mark .parametrize (
137+ "config" ,
138+ [
139+ mtq .INT8_DEFAULT_CFG ,
140+ mtq .FP8_DEFAULT_CFG ,
141+ mtq .W4A8_AWQ_BETA_CFG ,
142+ mtq .INT8_SMOOTHQUANT_CFG ,
143+ mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ,
144+ mtq .INT4_AWQ_CFG ,
145+ mtq .NVFP4_DEFAULT_CFG ,
146+ ],
147+ )
148+ def test_data_parallel (need_2_gpus , config ):
149+ spawn_multiprocess_job (
150+ size = 2 , job = partial (_test_data_parallel_helper , config ), backend = "nccl"
151+ )
152+
153+ # 3. Context Parallel Test
154+ def _test_context_parallel_helper (config , rank , size ):
155+ initialize_for_megatron (context_parallel_size = size , seed = SEED )
156+ model = MegatronModel (cp_size = size ).cuda ()
157+
158+ context_parallel_test_helper (
159+ model , config , get_context_parallel_group ()
160+ )
161+
162+ @pytest .mark .parametrize (
163+ "config" ,
164+ [
165+ mtq .INT8_DEFAULT_CFG ,
166+ mtq .FP8_DEFAULT_CFG ,
167+ mtq .W4A8_AWQ_BETA_CFG ,
168+ mtq .INT8_SMOOTHQUANT_CFG ,
169+ mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ,
170+ mtq .INT4_AWQ_CFG ,
171+ mtq .NVFP4_DEFAULT_CFG ,
172+ ],
173+ )
174+ def test_context_parallel (need_2_gpus , config ):
175+ spawn_multiprocess_job (
176+ size = 2 , job = partial (_test_context_parallel_helper , config ), backend = "nccl"
177+ )
178+
179+ # 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
180+ def _test_data_tensor_context_parallel_helper (config , rank , size ):
181+ initialize_for_megatron (tensor_model_parallel_size = 2 , context_parallel_size = 2 , seed = SEED )
182+ model = MegatronModel (tp_size = 2 , cp_size = 2 ).cuda ()
183+
184+ data_tensor_context_parallel_test_helper (
185+ model , config , get_data_parallel_group (), get_tensor_model_parallel_group (), get_context_parallel_group ()
186+ )
187+
188+ @pytest .mark .parametrize (
189+ "config" ,
190+ [
191+ mtq .INT8_DEFAULT_CFG ,
192+ mtq .FP8_DEFAULT_CFG ,
193+ mtq .W4A8_AWQ_BETA_CFG ,
194+ mtq .INT8_SMOOTHQUANT_CFG ,
195+ mtq .INT4_BLOCKWISE_WEIGHT_ONLY_CFG ,
196+ mtq .INT4_AWQ_CFG ,
197+ mtq .NVFP4_DEFAULT_CFG ,
198+ ],
199+ )
200+ def test_data_tensor_context_parallel (need_8_gpus , config ):
201+ spawn_multiprocess_job (
202+ size = 8 , job = partial (_test_data_tensor_context_parallel_helper , config ), backend = "nccl"
203+ )
121204
122205def _gpt_model_provider (tp_size : int , hidden_size = 256 , vocab_size = 64 , meta_device = False ):
123206 """Build the model."""
0 commit comments