|
22 | 22 | import megatron.core.parallel_state as mcore_parallel |
23 | 23 | import megatron.core.tensor_parallel.layers as megatron_parallel |
24 | 24 | import megatron.core.transformer.mlp as megatron_mlp |
| 25 | +import megatron.core.transformer.moe.experts as megatron_moe |
25 | 26 | import torch |
26 | 27 | import transformer_engine.pytorch.module.grouped_linear as te_grouped_linear |
27 | 28 | from megatron.core.extensions import transformer_engine as megatron_te |
|
38 | 39 | from modelopt.torch.utils.distributed import ParallelState |
39 | 40 |
|
40 | 41 | from ..nn import QuantModuleRegistry, TensorQuantizer |
41 | | -from ..nn.modules.quant_linear import RealQuantLinear, _QuantLinear |
| 42 | +from ..nn.modules.quant_linear import RealQuantLinear |
42 | 43 | from ..qtensor import QTensorWrapper |
43 | 44 | from .custom import CUSTOM_MODEL_PLUGINS, CUSTOM_POST_CALIBRATION_PLUGINS, _ParallelLinear |
44 | 45 |
|
@@ -518,29 +519,18 @@ def forward(self, input, *args, **kwargs): |
518 | 519 |
|
519 | 520 |
|
520 | 521 | # Register the public te.pytorch.GroupedLinear class |
521 | | -@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear_public"}) |
| 522 | +@QuantModuleRegistry.register({te_grouped_linear.GroupedLinear: "te_GroupedLinear"}) |
522 | 523 | class _QuantTEGroupedLinear(_MegatronParallelLinear): |
523 | 524 | def _setup(self): |
524 | | - if not hasattr(self, "parallel_state") or self.parallel_state is None: |
525 | | - data_parallel_group = None |
526 | | - try: |
527 | | - data_parallel_group = get_data_parallel_group(with_context_parallel=True) |
528 | | - except AssertionError: |
529 | | - data_parallel_group = get_data_parallel_group() |
530 | | - |
531 | | - self.parallel_state = ParallelState( |
532 | | - data_parallel_group, |
533 | | - tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), |
534 | | - expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), |
535 | | - ) |
536 | | - self.input_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_input) |
537 | | - self.weight_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_weight) |
538 | | - self.output_quantizer = TensorQuantizer(_QuantLinear.default_quant_desc_output) |
539 | | - self.output_quantizer.disable() |
540 | | - |
| 525 | + # GroupedMLP stores the weights as weight0, weight1, etc. To run setup in order to |
| 526 | + # initialize the quantizer states, self.weight is used to extract shape, dtype etc. Assigning |
| 527 | + # self.weight0 to self.weight to run the quantizer states initialization. |
| 528 | + self.weight = self.weight0 |
541 | 529 | # Memorize the original weight.dtype for modelopt_post_restore given that |
542 | 530 | # the dtype can change later. |
543 | | - self.original_weight_dtype = None if self.weight0 is None else self.weight0.dtype |
| 531 | + super()._setup() |
| 532 | + # Revert the weight to None after setup. |
| 533 | + self.weight = None |
544 | 534 |
|
545 | 535 | @property |
546 | 536 | def functionals_to_replace(self): |
@@ -577,7 +567,7 @@ def modelopt_post_restore(self, prefix: str = ""): |
577 | 567 | # self.weight0 to self.weight to run the quantizer states initialization. |
578 | 568 | self.weight = self.weight0 |
579 | 569 | super().modelopt_post_restore(prefix=prefix) |
580 | | - # Revert the weight to None after post_restore to avoid the weight being None during forward pass. |
| 570 | + # Revert the weight to None after post_restore. |
581 | 571 | self.weight = None |
582 | 572 |
|
583 | 573 | def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
@@ -611,3 +601,41 @@ class _QuantTEGroupedColumnParallelLinear(_QuantTEGroupedLinear, _MegatronColumn |
611 | 601 | ) |
612 | 602 | class _QuantTEGroupedRowParallelLinear(_QuantTEGroupedLinear, _MegatronRowParallelLinear): |
613 | 603 | _is_row_parallel = True |
| 604 | + |
| 605 | + |
| 606 | +# Register the public megatron_moe.TEGroupedMLP class |
| 607 | +@QuantModuleRegistry.register({megatron_moe.TEGroupedMLP: "megatron_moe_TEGroupedMLP"}) |
| 608 | +class _QuantTEGroupedMLP(_MegatronMLP): |
| 609 | + def _setup(self): |
| 610 | + if not hasattr(self, "parallel_state") or self.parallel_state is None: |
| 611 | + data_parallel_group = None |
| 612 | + try: |
| 613 | + data_parallel_group = get_data_parallel_group(with_context_parallel=True) |
| 614 | + except AssertionError: |
| 615 | + logger.warning( |
| 616 | + "Context parallel group is not initialized, using data parallel group" |
| 617 | + ) |
| 618 | + data_parallel_group = get_data_parallel_group() |
| 619 | + |
| 620 | + self.parallel_state = ParallelState( |
| 621 | + data_parallel_group, |
| 622 | + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), |
| 623 | + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), |
| 624 | + ) |
| 625 | + |
| 626 | + |
| 627 | +# Register the public megatron_moe.SequentialMLP class |
| 628 | +@QuantModuleRegistry.register({megatron_moe.SequentialMLP: "megatron_moe_SequentialMLP"}) |
| 629 | +class _QuantSequentialMLP(_MegatronMLP): |
| 630 | + def _setup(self): |
| 631 | + if not hasattr(self, "parallel_state") or self.parallel_state is None: |
| 632 | + try: |
| 633 | + data_parallel_group = mcore_parallel.get_expert_data_parallel_group() |
| 634 | + except AssertionError: |
| 635 | + data_parallel_group = None |
| 636 | + |
| 637 | + self.parallel_state = ParallelState( |
| 638 | + data_parallel_group, |
| 639 | + tensor_parallel_group=mcore_parallel.get_expert_tensor_parallel_group(), |
| 640 | + expert_model_parallel_group=mcore_parallel.get_expert_model_parallel_group(), |
| 641 | + ) |
0 commit comments