Skip to content

Commit 4919b08

Browse files
committed
Code cleanup
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 5bc99e0 commit 4919b08

File tree

2 files changed

+10
-18
lines changed

2 files changed

+10
-18
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -611,13 +611,9 @@ class _MegatronTEGroupedMLP(_MegatronMLP):
611611
def _setup(self):
612612
if not hasattr(self, "parallel_state") or self.parallel_state is None:
613613
self.parallel_state = ParallelState(
614-
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
615-
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
616-
check_initialized=False
617-
),
618-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
619-
check_initialized=False
620-
),
614+
mcore_parallel.get_expert_data_parallel_group(),
615+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
616+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
621617
)
622618
# initialize parallel state for submodules linear_fc1 and linear_fc2
623619
self.linear_fc1.parallel_state = self.parallel_state
@@ -630,13 +626,9 @@ class _MegatronSequentialMLP(_MegatronMLP):
630626
def _setup(self):
631627
if not hasattr(self, "parallel_state") or self.parallel_state is None:
632628
self.parallel_state = ParallelState(
633-
mcore_parallel.get_expert_data_parallel_group(check_initialized=False),
634-
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(
635-
check_initialized=False
636-
),
637-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(
638-
check_initialized=False
639-
),
629+
mcore_parallel.get_expert_data_parallel_group(),
630+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
631+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
640632
)
641633

642634
# Initialize parallel state for submodules local_experts.*.linear_fc1 and local_experts.*.linear_fc2

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -517,16 +517,16 @@ def copy_weights_from_grouped_to_non_grouped(te_grouped_moe_model, sequential_mo
517517
weight_mapping = {}
518518
sequential_key_template = "decoder.layers.{}.mlp.experts.local_experts.{}.linear_fc{}.weight"
519519
for key, value in te_grouped_state.items():
520-
if "experts.linear_fc" in key and "weight" in key:
520+
if "experts.linear_fc" in key and any(param in key for param in ("weight", "bias")):
521521
# Extract expert index from grouped weight name
522522
# Format: decoder.layers.X.mlp.experts.linear_fcY.weightZ
523523
parts = key.split(".")
524524
layer_idx = parts[2] # X
525525
fc_idx = parts[5] # Y (linear_fc1 or linear_fc2)
526-
weight_idx = parts[6] # Z (weight0, weight1, etc.)
527-
526+
param_idx = parts[6] # weight0 / bias0 / etc.
527+
match = re.search(r"\d+", param_idx)
528+
expert_idx = match.group(0) if match else "0" # Z for expert index
528529
# Map to sequential format: decoder.layers.X.mlp.experts.local_experts.Y.linear_fcZ.weight
529-
expert_idx = weight_idx.replace("weight", "")
530530
sequential_key = sequential_key_template.format(layer_idx, expert_idx, fc_idx[-1])
531531
weight_mapping[sequential_key] = value
532532
elif isinstance(value, torch.Tensor):

0 commit comments

Comments
 (0)