@@ -611,13 +611,9 @@ class _MegatronTEGroupedMLP(_MegatronMLP):
611611 def _setup (self ):
612612 if not hasattr (self , "parallel_state" ) or self .parallel_state is None :
613613 self .parallel_state = ParallelState (
614- mcore_parallel .get_expert_data_parallel_group (check_initialized = False ),
615- tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (
616- check_initialized = False
617- ),
618- expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (
619- check_initialized = False
620- ),
614+ mcore_parallel .get_expert_data_parallel_group (),
615+ tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (),
616+ expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (),
621617 )
622618 # initialize parallel state for submodules linear_fc1 and linear_fc2
623619 self .linear_fc1 .parallel_state = self .parallel_state
@@ -630,13 +626,9 @@ class _MegatronSequentialMLP(_MegatronMLP):
630626 def _setup (self ):
631627 if not hasattr (self , "parallel_state" ) or self .parallel_state is None :
632628 self .parallel_state = ParallelState (
633- mcore_parallel .get_expert_data_parallel_group (check_initialized = False ),
634- tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (
635- check_initialized = False
636- ),
637- expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (
638- check_initialized = False
639- ),
629+ mcore_parallel .get_expert_data_parallel_group (),
630+ tensor_parallel_group = mcore_parallel .get_expert_tensor_parallel_group (),
631+ expert_model_parallel_group = mcore_parallel .get_expert_model_parallel_group (),
640632 )
641633
642634 # Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2
0 commit comments