1414# limitations under the License.
1515import copy
1616import re
17+ from collections import defaultdict
1718from warnings import warn
1819
1920import torch
4142from megatron .core .parallel_state import (
4243 get_expert_model_parallel_group ,
4344 get_expert_tensor_parallel_group ,
45+ get_expert_tensor_parallel_rank ,
4446 initialize_model_parallel ,
4547 is_pipeline_first_stage ,
4648 is_pipeline_last_stage ,
@@ -190,7 +192,7 @@ def squared_relu(x):
190192 pipeline_model_parallel_size = pipeline_model_parallel_size ,
191193 expert_model_parallel_size = expert_model_parallel_size ,
192194 expert_tensor_parallel_size = expert_tensor_parallel_size ,
193- sequence_parallel = expert_model_parallel_size > 1 ,
195+ sequence_parallel = False ,
194196 moe_grouped_gemm = moe_grouped_gemm ,
195197 num_layers = num_layers ,
196198 num_layers_in_first_pipeline_stage = num_layers_in_first_pipeline_stage ,
@@ -221,7 +223,12 @@ def squared_relu(x):
221223 else :
222224 assert HAS_TE , "Transformer Engine not installed"
223225 transformer_layer_spec = (
224- get_gpt_modelopt_spec (config , remap_te_layernorm = True )
226+ get_gpt_modelopt_spec (
227+ config ,
228+ remap_te_layernorm = True ,
229+ # TODO: uncomment this when TEGroupedMLP is enabled in Megatron-LM
230+ # moe_grouped_gemm=moe_grouped_gemm
231+ )
225232 if transformer_impl == "modelopt"
226233 else get_gpt_layer_with_transformer_engine_spec ()
227234 )
@@ -565,8 +572,7 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
565572 # Check for both TEGrouped and sequential MoE patterns
566573 if "local_experts" in name or ("experts" in name and "linear_fc" in name ):
567574 # Convert to scalar only if tensor has a single element
568- amax_val = module .amax .detach ().clone ().cpu ()
569- expert_amax_values [name ] = amax_val
575+ expert_amax_values [name ] = module .amax .detach ().clone ().cpu ()
570576
571577 # Early return if no expert quantizers found
572578 assert expert_amax_values , "No expert quantizers found"
@@ -577,19 +583,16 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
577583 torch .distributed .all_gather_object (all_amax_values , expert_amax_values )
578584
579585 # Group quantizers by type (ignoring specific expert indices) and check sync
580- expert_quantizers = {}
586+ expert_quantizers = defaultdict ( dict )
581587 for rank_idx , rank_amax in enumerate (all_amax_values ):
582588 for name , amax_val in rank_amax .items ():
583589 # Create quantizer type key by normalizing the name
584- if "local_experts" in name :
585- # sequential MoE: replace expert index with wildcard
586- quantizer_type = re .sub (r"local_experts\.\d+" , "local_experts.*" , name )
587- else :
588- # TEGrouped MoE: use the name as-is since experts are grouped
589- quantizer_type = name
590-
591- if quantizer_type not in expert_quantizers :
592- expert_quantizers [quantizer_type ] = {}
590+ quantizer_type = (
591+ re .sub (r"local_experts\.\d+" , "local_experts.*" , name )
592+ if "local_experts" in name
593+ else name
594+ )
595+
593596 if (
594597 quantizer_type in expert_quantizers
595598 and rank_idx in expert_quantizers [quantizer_type ]
@@ -608,21 +611,53 @@ def compare_amax_sync_across_expert_parallel(model, compare_across_experts=True)
608611 )
609612 expert_quantizers [quantizer_type ][rank_idx ] = amax_val
610613
611- # Check synchronization - fail fast on first inconsistency
614+ rank_info = {
615+ "global_rank" : torch .distributed .get_rank (),
616+ "etp_rank" : get_expert_tensor_parallel_rank (),
617+ }
618+
619+ all_rank_info = [None ] * world_size
620+ torch .distributed .all_gather_object (all_rank_info , rank_info )
621+
622+ # Group ranks by ETP rank for fc1 (ColumnParallel: same output channels should match)
623+ etp_groups = defaultdict (list )
624+ for info in all_rank_info :
625+ etp_groups [info ["etp_rank" ] if info ["etp_rank" ] else 0 ].append (info ["global_rank" ])
626+
612627 for quantizer_type , rank_values in expert_quantizers .items ():
613- if len (rank_values ) > 1 : # Only check if we have multiple ranks
614- values = list (rank_values .values ())
615- # Handle both scalar and tensor comparisons
616- first_val = values [0 ]
617- if isinstance (first_val , torch .Tensor ):
618- # For tensors, check if all values are close to the first one
619- for val in values [1 :]:
620- if not torch .allclose (first_val , val , rtol = 1e-6 , atol = 1e-6 ):
621- return False , quantizer_type , rank_values
622- else :
623- # For scalars, use numeric comparison
624- max_diff = max (values ) - min (values )
625- if max_diff > 1e-6 : # Allow for small floating point differences
626- return False , quantizer_type , rank_values
628+ # Determine which ranks should have same amax
629+ # Find which rank should have same amax
630+ #
631+ # fc1: ColumnParallel: X @ [A_1, A_2] (weights split along Cout)
632+ # so amax should be the same across same ETP rank
633+ # if EP is 2, ETP is 2, we have 4 ranks, EP1, ETP1: 0, EP1, ETP2: 1, EP2, ETP1: 2, EP2, ETP2: 3
634+ # so we need to compare amax across same ETP rank [0, 2] [1, 3] for per-channel quantization
635+ #
636+ # fc2: RowParallel: [X_1, X_2] @ [A_1
637+ # A_2] (weights split along Cin)
638+ # amax should be the same across all ranks
639+
640+ rank_groups = (
641+ list (etp_groups .values ())
642+ if "linear_fc1" in quantizer_type and rank_values [0 ].ndim > 0
643+ else [list (range (world_size ))]
644+ )
645+
646+ # Check each group independently
647+ for group in rank_groups :
648+ group_values = [rank_values [r ] for r in group if r in rank_values ]
649+ if len (group_values ) > 1 :
650+ # All values in this group should be identical
651+ first_val = group_values [0 ]
652+ for val in group_values [1 :]:
653+ if isinstance (first_val , torch .Tensor ):
654+ if not torch .allclose (first_val , val , rtol = 1e-6 , atol = 1e-6 ):
655+ group_rank_values = {
656+ r : rank_values [r ] for r in group if r in rank_values
657+ }
658+ return False , f"{ quantizer_type } (group { group } )" , group_rank_values
659+ elif abs (first_val - val ) > 1e-6 :
660+ group_rank_values = {r : rank_values [r ] for r in group if r in rank_values }
661+ return False , f"{ quantizer_type } (group { group } )" , group_rank_values
627662
628663 return True , None , None
0 commit comments