3131from _test_utils .torch_quantization .quant_utils import get_model_size
3232from _test_utils .torch_quantization .quantize_common import (
3333 auto_quantize_helper ,
34- tensor_parallel_test_helper ,
35- data_parallel_test_helper ,
3634 context_parallel_test_helper ,
35+ data_parallel_test_helper ,
3736 data_tensor_context_parallel_test_helper ,
37+ tensor_parallel_test_helper ,
3838)
3939from packaging .version import Version
4040
4343import megatron .core
4444from megatron .core .parallel_state import (
4545 destroy_model_parallel ,
46- get_data_parallel_group ,
4746 get_context_parallel_group ,
47+ get_data_parallel_group ,
4848 get_tensor_model_parallel_group ,
4949)
5050from megatron .core .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
@@ -95,14 +95,13 @@ def test_convert_megatron_parallel_linear(distributed_setup_size_1):
9595 # Clean up since this is not a spawned process
9696 destroy_model_parallel ()
9797
98+
9899# 1. Tensor Parallel Test
99100def _test_tensor_parallel_helper (config , rank , size ):
100101 initialize_for_megatron (tensor_model_parallel_size = 2 , seed = SEED )
101102 model = MegatronModel (tp_size = size ).cuda ()
102103
103- tensor_parallel_test_helper (
104- model , config , get_tensor_model_parallel_group ()
105- )
104+ tensor_parallel_test_helper (model , config , get_tensor_model_parallel_group ())
106105
107106
108107@pytest .mark .parametrize (
@@ -122,15 +121,14 @@ def test_tensor_parallel(need_2_gpus, config):
122121 size = 2 , job = partial (_test_tensor_parallel_helper , config ), backend = "nccl"
123122 )
124123
124+
125125# 2. Data Parallel Test
126126def _test_data_parallel_helper (config , rank , size ):
127127 # TODO does this model automatically get copied to both DP ranks?
128128 initialize_for_megatron (seed = SEED )
129129 model = MegatronModel ().cuda ()
130130
131- data_parallel_test_helper (
132- model , config , get_data_parallel_group ()
133- )
131+ data_parallel_test_helper (model , config , get_data_parallel_group ())
134132
135133
136134@pytest .mark .parametrize (
@@ -146,18 +144,16 @@ def _test_data_parallel_helper(config, rank, size):
146144 ],
147145)
148146def test_data_parallel (need_2_gpus , config ):
149- spawn_multiprocess_job (
150- size = 2 , job = partial (_test_data_parallel_helper , config ), backend = "nccl"
151- )
147+ spawn_multiprocess_job (size = 2 , job = partial (_test_data_parallel_helper , config ), backend = "nccl" )
148+
152149
153150# 3. Context Parallel Test
154151def _test_context_parallel_helper (config , rank , size ):
155152 initialize_for_megatron (context_parallel_size = size , seed = SEED )
156153 model = MegatronModel (cp_size = size ).cuda ()
157154
158- context_parallel_test_helper (
159- model , config , get_context_parallel_group ()
160- )
155+ context_parallel_test_helper (model , config , get_context_parallel_group ())
156+
161157
162158@pytest .mark .parametrize (
163159 "config" ,
@@ -176,15 +172,21 @@ def test_context_parallel(need_2_gpus, config):
176172 size = 2 , job = partial (_test_context_parallel_helper , config ), backend = "nccl"
177173 )
178174
175+
179176# 4. DP=2 + TP=2 + CP=2 Test (on 2*2*2=8 GPUs)
180177def _test_data_tensor_context_parallel_helper (config , rank , size ):
181178 initialize_for_megatron (tensor_model_parallel_size = 2 , context_parallel_size = 2 , seed = SEED )
182179 model = MegatronModel (tp_size = 2 , cp_size = 2 ).cuda ()
183180
184181 data_tensor_context_parallel_test_helper (
185- model , config , get_data_parallel_group (), get_tensor_model_parallel_group (), get_context_parallel_group ()
182+ model ,
183+ config ,
184+ get_data_parallel_group (),
185+ get_tensor_model_parallel_group (),
186+ get_context_parallel_group (),
186187 )
187188
189+
188190@pytest .mark .parametrize (
189191 "config" ,
190192 [
@@ -199,9 +201,10 @@ def _test_data_tensor_context_parallel_helper(config, rank, size):
199201)
200202def test_data_tensor_context_parallel (need_8_gpus , config ):
201203 spawn_multiprocess_job (
202- size = 8 , job = partial (_test_data_tensor_context_parallel_helper , config ), backend = "nccl"
204+ size = 8 , job = partial (_test_data_tensor_context_parallel_helper , config ), backend = "nccl"
203205 )
204206
207+
205208def _gpt_model_provider (tp_size : int , hidden_size = 256 , vocab_size = 64 , meta_device = False ):
206209 """Build the model."""
207210
0 commit comments