diff --git a/src/transformers/integrations/mxfp4.py b/src/transformers/integrations/mxfp4.py index b37f72e7c388..67cb89ba4812 100644 --- a/src/transformers/integrations/mxfp4.py +++ b/src/transformers/integrations/mxfp4.py @@ -314,12 +314,12 @@ def should_convert_module(current_key_name, patterns): def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs): from ..integrations.tensor_parallel import shard_and_distribute_module - model = kwargs.get("model", None) - empty_param = kwargs.get("empty_param", None) - casting_dtype = kwargs.get("casting_dtype", None) - to_contiguous = kwargs.get("to_contiguous", None) - rank = kwargs.get("rank", None) - device_mesh = kwargs.get("device_mesh", None) + model = kwargs.get("model") + empty_param = kwargs.get("empty_param") + casting_dtype = kwargs.get("casting_dtype") + to_contiguous = kwargs.get("to_contiguous") + rank = kwargs.get("rank") + device_mesh = kwargs.get("device_mesh") for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: @@ -357,12 +357,12 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwa ) from ..integrations.tensor_parallel import shard_and_distribute_module - model = kwargs.get("model", None) - empty_param = kwargs.get("empty_param", None) - casting_dtype = kwargs.get("casting_dtype", None) - to_contiguous = kwargs.get("to_contiguous", None) - rank = kwargs.get("rank", None) - device_mesh = kwargs.get("device_mesh", None) + model = kwargs.get("model") + empty_param = kwargs.get("empty_param") + casting_dtype = kwargs.get("casting_dtype") + to_contiguous = kwargs.get("to_contiguous") + rank = kwargs.get("rank") + device_mesh = kwargs.get("device_mesh") for proj in ["gate_up_proj", "down_proj"]: if proj in param_name: diff --git a/src/transformers/quantizers/quantizer_mxfp4.py b/src/transformers/quantizers/quantizer_mxfp4.py index 5281d4d76388..4fb4168a8bd9 100644 --- a/src/transformers/quantizers/quantizer_mxfp4.py +++ b/src/transformers/quantizers/quantizer_mxfp4.py @@ -101,7 +101,7 @@ def validate_environment(self, *args, **kwargs): global triton_kernels_hub triton_kernels_hub = get_kernel("kernels-community/triton_kernels") - device_map = kwargs.get("device_map", None) + device_map = kwargs.get("device_map") if device_map is None: logger.warning_once( "You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set " @@ -210,11 +210,11 @@ def create_quantized_param( # we take this path if already quantized but not in a compatible way # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales else: - empty_param = kwargs.get("empty_param", None) - casting_dtype = kwargs.get("casting_dtype", None) - to_contiguous = kwargs.get("to_contiguous", None) - rank = kwargs.get("rank", None) - device_mesh = kwargs.get("device_mesh", None) + empty_param = kwargs.get("empty_param") + casting_dtype = kwargs.get("casting_dtype") + to_contiguous = kwargs.get("to_contiguous") + rank = kwargs.get("rank") + device_mesh = kwargs.get("device_mesh") if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize: # blocks and scales have the same length that's this works for both module, _ = get_module_from_name(model, param_name[: -len("_blocks")])