diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 4b256ffc7324..c1c5f66f4aac 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -383,6 +383,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): state_dict = model.state_dict() + # Get num_local_experts from model config + num_local_experts = getattr(model.config, "num_local_experts", 32) + hidden_size = getattr(model.config, "hidden_size", 2880) + for name, module in model.named_modules(): if ( isinstance(module, Mxfp4GptOssExperts) @@ -392,7 +396,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): state_dict[f"{name}.gate_up_proj_blocks"] = ( module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data) .transpose(-1, -2) - .reshape(32, -1, 90, 16) + .reshape(num_local_experts, -1, 90, 16) ) state_dict[f"{name}.gate_up_proj_scales"] = ( module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data( @@ -402,7 +406,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): state_dict[f"{name}.down_proj_blocks"] = ( module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data) .transpose(-1, -2) - .reshape(32, 2880, 90, -1) + .reshape(num_local_experts, hidden_size, 90, -1) ) state_dict[f"{name}.down_proj_scales"] = ( module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(