Skip to content

Commit e2858f9

Browse files
committed
updated parallel state for experts
1 parent 17df5ca commit e2858f9

File tree

2 files changed

+62
-29
lines changed

2 files changed

+62
-29
lines changed

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import megatron.core.parallel_state as mcore_parallel
2323
import megatron.core.tensor_parallel.layers as megatron_parallel
2424
import megatron.core.transformer.mlp as megatron_mlp
25+
import megatron.core.transformer.moe.experts as megatron_moe
2526
import torch
2627
import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear
2728
from megatron.core.extensions import transformer_engine as megatron_te
@@ -38,7 +39,7 @@
3839
from modelopt.torch.utils.distributed import ParallelState
3940

4041
from ..nn import QuantModuleRegistry, TensorQuantizer
41-
from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear
42+
from ..nn.modules.quant_linear import RealQuantLinear
4243
from ..qtensor import QTensorWrapper
4344
from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear
4445

@@ -518,29 +519,18 @@ def forward(self, input, *args, **kwargs):
518519

519520

520521
# Register the public te.pytorch.GroupedLinear class
521-
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"})
522+
@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"})
522523
class _QuantTEGroupedLinear(_MegatronParallelLinear):
523524
def _setup(self):
524-
if not hasattr(self, "parallel_state") or self.parallel_state is None:
525-
data_parallel_group = None
526-
try:
527-
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
528-
except AssertionError:
529-
data_parallel_group = get_data_parallel_group()
530-
531-
self.parallel_state = ParallelState(
532-
data_parallel_group,
533-
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
534-
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
535-
)
536-
self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input)
537-
self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight)
538-
self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output)
539-
self.output_quantizer.disable()
540-
525+
# GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to
526+
# initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning
527+
# self.weight0 to self.weight to run the quantizer states initialization.
528+
self.weight = self.weight0
541529
# Memorize the original weight.dtype for modelopt_post_restore given that
542530
# the dtype can change later.
543-
self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype
531+
super()._setup()
532+
# Revert the weight to None after setup.
533+
self.weight = None
544534

545535
@property
546536
def functionals_to_replace(self):
@@ -577,7 +567,7 @@ def modelopt_post_restore(self, prefix: str = ""):
577567
# self.weight0 to self.weight to run the quantizer states initialization.
578568
self.weight = self.weight0
579569
super().modelopt_post_restore(prefix=prefix)
580-
# Revert the weight to None after post_restore to avoid the weight being None during forward pass.
570+
# Revert the weight to None after post_restore.
581571
self.weight = None
582572

583573
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
@@ -611,3 +601,41 @@ class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumn
611601
)
612602
class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronRowParallelLinear):
613603
_is_row_parallel = True
604+
605+
606+
# Register the public megatron_moe.TEGroupedMLP class
607+
@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"})
608+
class _QuantTEGroupedMLP(_MegatronMLP):
609+
def _setup(self):
610+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
611+
data_parallel_group = None
612+
try:
613+
data_parallel_group = get_data_parallel_group(with_context_parallel=True)
614+
except AssertionError:
615+
logger.warning(
616+
"Context parallel group is not initialized, using data parallel group"
617+
)
618+
data_parallel_group = get_data_parallel_group()
619+
620+
self.parallel_state = ParallelState(
621+
data_parallel_group,
622+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
623+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
624+
)
625+
626+
627+
# Register the public megatron_moe.SequentialMLP class
628+
@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"})
629+
class _QuantSequentialMLP(_MegatronMLP):
630+
def _setup(self):
631+
if not hasattr(self, "parallel_state") or self.parallel_state is None:
632+
try:
633+
data_parallel_group = mcore_parallel.get_expert_data_parallel_group()
634+
except AssertionError:
635+
data_parallel_group = None
636+
637+
self.parallel_state = ParallelState(
638+
data_parallel_group,
639+
tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(),
640+
expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(),
641+
)

tests/gpu/torch/quantization/plugins/test_megatron.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -516,17 +516,19 @@ def test_fp8_real_quantize():
516516
mtq.NVFP4_DEFAULT_CFG,
517517
],
518518
)
519-
def test_moe_sharded_state_dict(need_8_gpus, tmp_path, config):
519+
@pytest.mark.parametrize("moe_grouped_gemm", [False, True])
520+
def test_moe_sharded_state_dict(tmp_path, config, moe_grouped_gemm):
520521
size = torch.cuda.device_count()
521-
# TODO: Meta device doesn't work with TE
522522
# TODO: Add support for compress=True for TEGroupedMLP
523+
if size < 4:
524+
pytest.skip("Requires at least 4 GPUs for expert parallel test")
523525
moe_config = {
524-
"tp_size": 2,
526+
"tp_size": 1,
525527
"ep_size": 2,
526528
"etp_size": 2,
527529
"num_moe_experts": 4,
528-
"moe_grouped_gemm": True,
529-
"use_te": True,
530+
"moe_grouped_gemm": moe_grouped_gemm,
531+
"use_te": moe_grouped_gemm,
530532
}
531533
spawn_multiprocess_job(
532534
size=size,
@@ -627,10 +629,12 @@ def test_te_grouped_vs_sequential_quantize():
627629
)
628630

629631

630-
def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, config, rank, size):
632+
def _test_expert_model_parallel_amax_sync(
633+
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size
634+
):
631635
"""Test expert parallel synchronization with different configurations."""
632636
initialize_for_megatron(
633-
tensor_model_parallel_size=1,
637+
tensor_model_parallel_size=tp_size,
634638
pipeline_model_parallel_size=1,
635639
expert_model_parallel_size=ep_size,
636640
expert_tensor_parallel_size=etp_size,
@@ -639,7 +643,7 @@ def _test_expert_model_parallel_amax_sync(ep_size, etp_size, moe_grouped_gemm, c
639643

640644
# Create model with expert parallelism
641645
model = _gpt_model_provider(
642-
tp_size=1,
646+
tp_size=tp_size,
643647
ep_size=ep_size,
644648
etp_size=etp_size,
645649
hidden_size=256,
@@ -700,6 +704,7 @@ def test_expert_parallel_sync(ep_size, etp_size, moe_grouped_gemm):
700704
size=size,
701705
job=partial(
702706
_test_expert_model_parallel_amax_sync,
707+
1,
703708
ep_size,
704709
etp_size,
705710
moe_grouped_gemm,

0 commit comments

Comments
 (0)