From 5e5c4d629d8bcd9d5b2b4fa859a2bbdbb0011e36 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 30 May 2024 11:45:03 +0200 Subject: [PATCH] FIX / Quantization: Add extra validation for bnb config (#31135) add validation for bnb config --- src/transformers/utils/quantization_config.py | 4 ++++ tests/quantization/bnb/test_4bit.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index f9e503cf862f18..6236827de34bb2 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -383,6 +383,10 @@ def __init__( if bnb_4bit_quant_storage is None: self.bnb_4bit_quant_storage = torch.uint8 elif isinstance(bnb_4bit_quant_storage, str): + if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]: + raise ValueError( + "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') " + ) self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage) elif isinstance(bnb_4bit_quant_storage, torch.dtype): self.bnb_4bit_quant_storage = bnb_4bit_quant_storage diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 443b1020a30e07..ac17979d175ce6 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -303,6 +303,13 @@ def test_fp32_4bit_conversion(self): model = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small", load_in_4bit=True, device_map="auto") self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32) + def test_bnb_4bit_wrong_config(self): + r""" + Test whether creating a bnb config with unsupported values leads to errors. + """ + with self.assertRaises(ValueError): + _ = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_storage="add") + @require_bitsandbytes @require_accelerate