From 74fc96ea5e28034a8d9e0b10b6533c802d1126c7 Mon Sep 17 00:00:00 2001 From: marksverdhei Date: Tue, 21 Oct 2025 21:25:28 +0200 Subject: [PATCH] Fix MXFP4 quantizer to support variable num_local_experts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The quantizer hardcoded 32 experts and 2880 hidden_size in the reshape operations. This caused failures when quantizing models with different numbers of experts (e.g., averaged single-expert models). Changes: - Read num_local_experts and hidden_size from model.config - Use dynamic values in reshape operations instead of hardcoded constants - Defaults to 32 and 2880 for backward compatibility This enables quantizing averaged/merged MoE models with fewer experts. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- src/transformers/quantizers/quantizer_mxfp4.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 4b256ffc7324..c1c5f66f4aac 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -383,6 +383,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): state_dict = model.state_dict() + # Get num_local_experts from model config + num_local_experts = getattr(model.config, "num_local_experts", 32) + hidden_size = getattr(model.config, "hidden_size", 2880) + for name, module in model.named_modules(): if ( isinstance(module, Mxfp4GptOssExperts) @@ -392,7 +396,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): state_dict[f"{name}.gate_up_proj_blocks"] = ( module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data) .transpose(-1, -2) - .reshape(32, -1, 90, 16) + .reshape(num_local_experts, -1, 90, 16) ) state_dict[f"{name}.gate_up_proj_scales"] = ( module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data( @@ -402,7 +406,7 @@ def get_state_dict_and_metadata(self, model, safe_serialization: bool = False): state_dict[f"{name}.down_proj_blocks"] = ( module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data) .transpose(-1, -2) - .reshape(32, 2880, 90, -1) + .reshape(num_local_experts, hidden_size, 90, -1) ) state_dict[f"{name}.down_proj_scales"] = ( module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(