Skip to content

Commit cf3a107

Browse files
committed
mamba moe export fixes
Signed-off-by: Jennifer Chen <jennifchen@nvidia.com>
1 parent 484b95a commit cf3a107

File tree

5 files changed

+127
-19
lines changed

5 files changed

+127
-19
lines changed

modelopt/torch/export/plugins/mcore_nemotron.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
"""Custom mapping from Nemotron Hugging Face models to Megatron Core models."""
1818

1919
from .mcore_custom import (
20+
COL_ETP,
2021
COL_TP,
2122
REPLICATE,
23+
ROW_ETP,
2224
ROW_TP,
2325
CustomModuleMapping,
2426
NameRemapping,
@@ -63,6 +65,22 @@
6365
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm.", REPLICATE),
6466
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj.", COL_TP),
6567
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj.", ROW_TP),
68+
# MoE
69+
"router": NameRemapping(
70+
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
71+
),
72+
"local_experts.linear_fc1": NameRemapping(
73+
"backbone.layers.{}.mixer.experts.{}.up_proj.", COL_ETP
74+
),
75+
"local_experts.linear_fc2": NameRemapping(
76+
"backbone.layers.{}.mixer.experts.{}.down_proj.", ROW_ETP
77+
),
78+
"shared_experts.linear_fc1": NameRemapping(
79+
"backbone.layers.{}.mixer.shared_experts.up_proj.", COL_TP
80+
),
81+
"shared_experts.linear_fc2": NameRemapping(
82+
"backbone.layers.{}.mixer.shared_experts.down_proj.", ROW_TP
83+
),
6684
}
6785

6886

@@ -87,4 +105,14 @@
87105
"pre_mlp_layernorm": NameRemapping("backbone.layers.{}.norm."),
88106
"linear_fc1": NameRemapping("backbone.layers.{}.mixer.up_proj."),
89107
"linear_fc2": NameRemapping("backbone.layers.{}.mixer.down_proj."),
108+
# MoE
109+
"router": NameRemapping(
110+
"backbone.layers.{}.mixer.gate.", {"mapping": {"expert_bias": "e_score_correction_bias"}}
111+
),
112+
"local_experts.linear_fc1": NameRemapping("backbone.layers.{}.mixer.experts.{}.up_proj."),
113+
"local_experts.linear_fc2": NameRemapping("backbone.layers.{}.mixer.experts.{}.down_proj."),
114+
"shared_experts.linear_fc1": NameRemapping("backbone.layers.{}.mixer.shared_experts.up_proj."),
115+
"shared_experts.linear_fc2": NameRemapping(
116+
"backbone.layers.{}.mixer.shared_experts.down_proj."
117+
),
90118
}

modelopt/torch/export/plugins/megatron_importer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,18 @@ def __init__(
7777
dequantize: bool = True,
7878
trust_remote_code: bool = True,
7979
verbose: bool = False,
80+
moe_router_dtype: torch.dtype | None = None,
8081
):
8182
"""Create a GPTModel importer instance."""
8283
self._hf_config = transformers.AutoConfig.from_pretrained(
8384
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
8485
)
86+
self.moe_router_dtype = None
87+
if moe_router_dtype == "fp32":
88+
self.moe_router_dtype = torch.float32
89+
elif moe_router_dtype == "fp64":
90+
self.moe_router_dtype = torch.float64
91+
8592
pretrained_model_path = Path(pretrained_model_name_or_path)
8693
if not pretrained_model_path.is_dir():
8794
if workspace_dir is None:
@@ -118,7 +125,9 @@ def _custom_mapping_to_lambda(mapping):
118125
func = method_map[mapping.func_name]
119126
prefix = mapping.target_name_or_prefix
120127
func_kwargs = mapping.func_kwargs
121-
return lambda m, *args: func(m, prefix.format(*args), **func_kwargs)
128+
return lambda m, *args, **kwargs: func(
129+
m, prefix.format(*args), **{**func_kwargs, **kwargs}
130+
)
122131

123132
for arch, mappings in all_mcore_hf_import_mapping.items():
124133
all_rules[arch] = {
@@ -140,7 +149,10 @@ def _name_remapping(
140149
prefix,
141150
mapping={},
142151
parallel_config: ParallelConfig | None = None,
152+
dtype: torch.dtype | None = None,
143153
):
154+
if dtype is None:
155+
dtype = self.dtype
144156
if isinstance(module, torch.Tensor):
145157
tensor = self._get_safetensor(prefix, parallel_config=parallel_config)
146158
module.data.copy_(tensor)
@@ -193,7 +205,7 @@ def _name_remapping(
193205
tensor = self._get_safetensor(
194206
prefix + source_key, parallel_config=parallel_config
195207
)
196-
state_dict[key] = tensor.to(dtype=self.dtype).to(device=val.device)
208+
state_dict[key] = tensor.to(dtype=dtype).to(device=val.device)
197209

198210
module.load_state_dict(state_dict)
199211

@@ -523,7 +535,9 @@ def _import_state_dict(self):
523535
if not isinstance(layer.mlp, IdentityOp):
524536
if "MoE" in str(type(layer.mlp)):
525537
layer_pbar.set_description("Importing MoE")
526-
self.rules["router"](layer.mlp.router, layer_id)
538+
self.rules["router"](
539+
layer.mlp.router, layer_id, dtype=self.moe_router_dtype
540+
)
527541
if (
528542
hasattr(layer.mlp, "shared_experts")
529543
and layer.mlp.shared_experts is not None

modelopt/torch/export/unified_export_megatron.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_kv_cache_scaling_factor(kv_module: nn.Module) -> torch.Tensor:
109109

110110
def get_quantized_state(
111111
module: torch.nn.Module,
112-
dtype: torch.dtype = torch.float16,
112+
dtype: torch.dtype = torch.bfloat16,
113113
) -> tuple[dict[str, torch.Tensor], str, int]:
114114
"""Return a state_dict, quantization format, and block_size of the module.
115115
@@ -186,6 +186,7 @@ def __init__(
186186
export_extra_modules: bool = False,
187187
dtype=torch.bfloat16,
188188
trust_remote_code: bool = True,
189+
moe_router_dtype: torch.dtype | None = None,
189190
):
190191
"""Create a GPTModel exporter instance."""
191192
if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)):
@@ -196,6 +197,12 @@ def __init__(
196197
self._hf_config = transformers.AutoConfig.from_pretrained(
197198
pretrained_model_name_or_path, trust_remote_code=trust_remote_code
198199
)
200+
self.moe_router_dtype = None
201+
if moe_router_dtype == "fp32":
202+
self.moe_router_dtype = torch.float32
203+
elif moe_router_dtype == "fp64":
204+
self.moe_router_dtype = torch.float64
205+
199206
# If multimodal, extra the text_config
200207
self._hf_text_config = getattr(self._hf_config, "text_config", self._hf_config)
201208

@@ -489,7 +496,9 @@ def _custom_mapping_to_lambda(mapping):
489496
func = method_map[mapping.func_name]
490497
prefix = mapping.target_name_or_prefix
491498
func_kwargs = mapping.func_kwargs
492-
return lambda m, *args: func(m, prefix.format(*args), **func_kwargs)
499+
return lambda m, *args, **kwargs: func(
500+
m, prefix.format(*args), **{**func_kwargs, **kwargs}
501+
)
493502

494503
for arch, mappings in all_mcore_hf_export_mapping.items():
495504
all_rules[arch] = {
@@ -519,12 +528,16 @@ def _name_remapping(
519528
prefix: str,
520529
skip_output_scale: bool = True,
521530
mapping={},
531+
dtype: torch.dtype | None = None,
522532
):
533+
if dtype is None:
534+
dtype = self.dtype
535+
523536
if isinstance(module, torch.Tensor):
524537
self._state_dict[prefix] = module
525538
return
526539

527-
name_to_value, qformat, block_size = get_quantized_state(module, self.dtype)
540+
name_to_value, qformat, block_size = get_quantized_state(module, dtype)
528541

529542
weight = name_to_value.pop("weight")
530543
weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
@@ -1098,7 +1111,9 @@ def _get_state_dict(self):
10981111

10991112
if not isinstance(layer.mlp, IdentityOp):
11001113
if "MoE" in str(type(layer.mlp)):
1101-
self.rules["router"](layer.mlp.router, layer_id)
1114+
self.rules["router"](
1115+
layer.mlp.router, layer_id, dtype=self.moe_router_dtype
1116+
)
11021117
if (
11031118
hasattr(layer.mlp, "shared_experts")
11041119
and layer.mlp.shared_experts is not None
@@ -1136,8 +1151,9 @@ def export_mcore_gpt_to_hf(
11361151
model: torch.nn.Module,
11371152
pretrained_model_name_or_path: str | os.PathLike | None = None,
11381153
export_extra_modules: bool = False,
1139-
dtype: torch.dtype = torch.float16,
1154+
dtype: torch.dtype = torch.bfloat16,
11401155
export_dir: Path | str = tempfile.gettempdir(),
1156+
moe_router_dtype: torch.dtype | None = None,
11411157
):
11421158
"""Export Megatron Core GPTModel to unified checkpoint and save to export_dir.
11431159
@@ -1153,7 +1169,11 @@ def export_mcore_gpt_to_hf(
11531169
export_dir: The target export path.
11541170
"""
11551171
exporter = GPTModelExporter(
1156-
model, pretrained_model_name_or_path, export_extra_modules=export_extra_modules, dtype=dtype
1172+
model,
1173+
pretrained_model_name_or_path,
1174+
export_extra_modules=export_extra_modules,
1175+
dtype=dtype,
1176+
moe_router_dtype=moe_router_dtype,
11571177
)
11581178
exporter.save_pretrained(export_dir, pretrained_model_name_or_path)
11591179

@@ -1162,7 +1182,8 @@ def import_mcore_gpt_from_hf(
11621182
model: torch.nn.Module,
11631183
pretrained_model_path: str,
11641184
workspace_dir: str | None = None,
1165-
dtype: torch.dtype = torch.float16,
1185+
dtype: torch.dtype = torch.bfloat16,
1186+
moe_router_dtype: torch.dtype | None = None,
11661187
):
11671188
"""Import GPTModel state_dict from supported HuggingFace pretrained model path.
11681189
@@ -1173,6 +1194,10 @@ def import_mcore_gpt_from_hf(
11731194
dtype: The weights data type to import.
11741195
"""
11751196
importer = GPTModelImporter(
1176-
model, pretrained_model_path, workspace_dir=workspace_dir, dtype=dtype
1197+
model,
1198+
pretrained_model_path,
1199+
workspace_dir=workspace_dir,
1200+
dtype=dtype,
1201+
moe_router_dtype=moe_router_dtype,
11771202
)
11781203
importer._import_state_dict()

modelopt/torch/quantization/plugins/huggingface.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,16 @@ def backward(ctx, grad_output):
345345
_transposed_quantize = _TransposedQuantization.apply
346346

347347

348-
class _QuantMoeSparseMoe(QuantModule):
348+
class _QuantSparseMoe(QuantModule):
349+
"""Module to support special handling of token dispatching during calibration.
350+
351+
During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
352+
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
353+
returns.
354+
355+
If calibration is not enabled, this module behaves as a normal MoELayer.
356+
"""
357+
349358
def _setup(self):
350359
pass
351360

@@ -480,7 +489,7 @@ def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
480489
return self.w2_linear[expert_idx](x1)
481490

482491

483-
class _QuantDbrxFFN(_QuantMoeSparseMoe):
492+
class _QuantDbrxFFN(_QuantSparseMoe):
484493
@property
485494
def num_experts(self):
486495
return self.router.moe_num_experts
@@ -498,7 +507,7 @@ def top_k(self, value):
498507
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
499508

500509
if Llama4TextMoe not in QuantModuleRegistry:
501-
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantMoeSparseMoe)
510+
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
502511

503512
if Llama4TextExperts not in QuantModuleRegistry:
504513
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
@@ -526,7 +535,7 @@ def top_k(self, value):
526535

527536
if MixtralSparseMoeBlock not in QuantModuleRegistry:
528537
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
529-
_QuantMoeSparseMoe
538+
_QuantSparseMoe
530539
)
531540
except ImportError:
532541
pass
@@ -544,7 +553,7 @@ def top_k(self, value):
544553

545554
if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
546555
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
547-
_QuantMoeSparseMoe
556+
_QuantSparseMoe
548557
)
549558
except ImportError:
550559
pass
@@ -554,7 +563,7 @@ def top_k(self, value):
554563

555564
if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
556565
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
557-
_QuantMoeSparseMoe
566+
_QuantSparseMoe
558567
)
559568
except ImportError:
560569
pass
@@ -564,7 +573,7 @@ def top_k(self, value):
564573

565574
if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
566575
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
567-
_QuantMoeSparseMoe
576+
_QuantSparseMoe
568577
)
569578
except ImportError:
570579
pass

modelopt/torch/quantization/plugins/megatron.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import megatron.core.tensor_parallel.layers as megatron_parallel
2424
import megatron.core.transformer.mlp as megatron_mlp
2525
import megatron.core.transformer.moe.experts as megatron_moe
26+
import megatron.core.transformer.moe.moe_layer as megatron_moe_layer
2627
import torch
2728
from megatron.core.parallel_state import get_data_parallel_group
2829
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
@@ -36,7 +37,7 @@
3637
)
3738
from modelopt.torch.utils.distributed import ParallelState
3839

39-
from ..nn import QuantModuleRegistry, TensorQuantizer
40+
from ..nn import QuantModule, QuantModuleRegistry, TensorQuantizer
4041
from ..nn.modules.quant_linear import RealQuantLinear
4142
from ..qtensor import QTensorWrapper
4243
from .custom import CUSTOM_MODEL_PLUGINS, _ParallelLinear
@@ -247,6 +248,14 @@ def _setup(self):
247248
data_parallel_group,
248249
mcore_parallel.get_tensor_model_parallel_group(),
249250
)
251+
252+
if getattr(self, "gradient_accumulation_fusion", False):
253+
warnings.warn(
254+
"gradient_accumulation_fusion is not supported with ModelOpt quantization. "
255+
"Setting gradient_accumulation_fusion to False."
256+
)
257+
self.gradient_accumulation_fusion = False
258+
250259
super()._setup()
251260

252261
def _process_quantizer_amax(self, k, v, quantizer_state_dict):
@@ -580,3 +589,26 @@ def _setup(self):
580589
# initialize parallel state for submodules linear_fc1 and linear_fc2
581590
self.linear_fc1.parallel_state = self.parallel_state
582591
self.linear_fc2.parallel_state = self.parallel_state
592+
593+
594+
@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
595+
class _QuantMoELayer(QuantModule):
596+
"""Module to support special handling of token dispatching during calibration.
597+
598+
During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
599+
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
600+
returns.
601+
602+
If calibration is not enabled, this module behaves as a normal MoELayer.
603+
"""
604+
605+
def _setup(self):
606+
pass
607+
608+
def forward(self, hidden_states):
609+
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
610+
original_top_k = self.router.topk
611+
self.router.topk = self.router.num_experts
612+
super().forward(hidden_states)
613+
self.router.topk = original_top_k
614+
return super().forward(hidden_states)

0 commit comments

Comments
 (0)