Skip to content
Merged
9 changes: 8 additions & 1 deletion src/transformers/quantizers/quantizer_mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def validate_environment(self, *args, **kwargs):
return

if not torch.cuda.is_available():
raise RuntimeError("Using MXFP4 quantized models requires a GPU")
if self.pre_quantized:
logger.warning_once(
"Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
)
self.quantization_config.dequantize = True
return
else:
raise RuntimeError("Quantizing a model using MXFP4 requires a GPU")

if not is_accelerate_available():
raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
Expand Down