Skip to content

Commit 2df77b1

Browse files
committed
fixing test and comments
Signed-off-by: Kinjal Patel <kinjalpravin@nvidia.com>
1 parent 153e376 commit 2df77b1

File tree

3 files changed

+14
-17
lines changed

3 files changed

+14
-17
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
from megatron.core.extensions import transformer_engine as megatron_te
2929
from megatron.core.parallel_state import get_data_parallel_group
3030
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
31-
from megatron.core.parallel_state import get_data_parallel_group
32-
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
3331
from megatron.core.transformer import MegatronModule
3432
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
3533
from megatron.core.utils import get_tensor_model_parallel_group_if_none
@@ -63,23 +61,24 @@ def get_sequential_mlp_expert_names(name: str, module: torch.nn.Module):
6361
expert_name, local_expert_name = name.split(".local_experts.")
6462
# extract quantizer name by removing local_expert number from the name
6563
local_expert_name = ".".join(local_expert_name.split(".")[1:])
66-
return expert_name, local_expert_name
67-
return None, None
64+
return f"{expert_name}.{local_expert_name}"
65+
return None
6866

6967
# gather amax values from SequentialMLP experts
7068
for name, module in model.named_modules():
71-
expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module)
72-
if expert_name and local_expert_name:
73-
amax_dict[local_expert_name] = amax_dict.get(local_expert_name, {})
74-
amax_dict[local_expert_name][expert_name] = max(
75-
amax_dict[local_expert_name].get(expert_name, 0), module.amax
69+
expert_name = get_sequential_mlp_expert_names(name, module)
70+
if expert_name and module.amax is not None:
71+
stored_amax = amax_dict.get(expert_name)
72+
amax_tensor = module.amax.detach().clone()
73+
amax_dict[expert_name] = (
74+
amax_tensor if stored_amax is None else torch.maximum(stored_amax, amax_tensor)
7675
)
7776

7877
# sync amax values across experts in SequentialMLP
7978
for name, module in model.named_modules():
80-
expert_name, local_expert_name = get_sequential_mlp_expert_names(name, module)
81-
if expert_name and local_expert_name:
82-
module.amax = amax_dict[local_expert_name][expert_name]
79+
expert_name = get_sequential_mlp_expert_names(name, module)
80+
if expert_name and module.amax is not None:
81+
module.amax = amax_dict[expert_name].detach().clone().to(module.amax.device)
8382

8483

8584
CUSTOM_POST_CALIBRATION_PLUGINS.add(sync_amax_across_sequential_mlp)

tests/_test_utils/torch_dist/plugins/megatron_common.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,9 +562,8 @@ def compare_amax_sync_across_expert_parallel(model):
562562
if isinstance(module, mtq.nn.TensorQuantizer) and hasattr(module, "_amax"):
563563
# Check for both TEGrouped and sequential MoE patterns
564564
if "local_experts" in name or ("experts" in name and "linear_fc" in name):
565-
expert_amax_values[name] = (
566-
module.amax.item() if hasattr(module.amax, "item") else module.amax
567-
)
565+
amax_val = module.amax.item() if hasattr(module.amax, "item") else module.amax
566+
expert_amax_values[name] = amax_val.detach().clone()
568567

569568
# Early return if no expert quantizers found
570569
assert expert_amax_values, "No expert quantizers found"
@@ -602,7 +601,7 @@ def compare_amax_sync_across_expert_parallel(model):
602601
# Check synchronization - fail fast on first inconsistency
603602
for quantizer_type, rank_values in expert_quantizers.items():
604603
if len(rank_values) > 1: # Only check if we have multiple ranks
605-
values = list(rank_values.values())
604+
values = list(rank_values.detach().values().cpu())
606605
max_diff = max(values) - min(values)
607606
if max_diff > 1e-6: # Allow for small floating point differences
608607
return False, quantizer_type, rank_values

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from _test_utils.torch_quantization.quantize_common import (
3535
auto_quantize_helper,
3636
data_tensor_context_parallel_test_helper,
37-
dp_cp_parallel_test_helper,
3837
)
3938

4039
skip_if_no_megatron()

0 commit comments

Comments
 (0)