diff --git a/src/transformers/quantizers/quantizer_bnb_4bit.py b/src/transformers/quantizers/quantizer_bnb_4bit.py index 764475e7b1db..c49610bbfe87 100644 --- a/src/transformers/quantizers/quantizer_bnb_4bit.py +++ b/src/transformers/quantizers/quantizer_bnb_4bit.py @@ -58,6 +58,8 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") if not (is_accelerate_available() and is_bitsandbytes_available()): raise ImportError( "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " @@ -70,9 +72,6 @@ def validate_environment(self, *args, **kwargs): " sure the weights are in PyTorch format." ) - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") - device_map = kwargs.get("device_map", None) if ( device_map is not None diff --git a/src/transformers/quantizers/quantizer_bnb_8bit.py b/src/transformers/quantizers/quantizer_bnb_8bit.py index 8016194f9d86..2d24c3404972 100644 --- a/src/transformers/quantizers/quantizer_bnb_8bit.py +++ b/src/transformers/quantizers/quantizer_bnb_8bit.py @@ -58,6 +58,9 @@ def __init__(self, quantization_config, **kwargs): self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules def validate_environment(self, *args, **kwargs): + if not torch.cuda.is_available(): + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + if not (is_accelerate_available() and is_bitsandbytes_available()): raise ImportError( "Using `bitsandbytes` 8-bit quantization requires Accelerate: `pip install accelerate` " @@ -70,9 +73,6 @@ def validate_environment(self, *args, **kwargs): " sure the weights are in PyTorch format." ) - if not torch.cuda.is_available(): - raise RuntimeError("No GPU found. A GPU is needed for quantization.") - device_map = kwargs.get("device_map", None) if ( device_map is not None