diff --git a/docs/source/en/main_classes/quantization.mdx b/docs/source/en/main_classes/quantization.mdx
index 3dd6d36ee497d8..c168b11a8302f5 100644
--- a/docs/source/en/main_classes/quantization.mdx
+++ b/docs/source/en/main_classes/quantization.mdx
@@ -19,8 +19,45 @@ This is supported by most of the GPU hardwares since the `0.37.0` release of `bi
Learn more about the quantization method in the [LLM.int8()](https://arxiv.org/abs/2208.07339) paper, or the [blogpost](https://huggingface.co/blog/hf-bitsandbytes-integration) about the collaboration.
+Since its `0.39.0` release, you can load any model that supports `device_map` using 4-bit quantization, leveraging FP4 data type.
+
Here are the things you can do using `bitsandbytes` integration
+### FP4 quantization
+
+#### Requirements
+
+Make sure that you have installed the requirements below before running any of the code snippets below.
+
+- Latest `bitsandbytes` library
+`pip install bitsandbytes>=0.39.0`
+
+- Install latest `accelerate` from source
+`pip install git+https://github.com/huggingface/accelerate.git`
+
+- Install latest `transformers` from source
+`pip install git+https://github.com/huggingface/transformers.git`
+
+#### Load a large model in 4bit
+
+By using `load_in_4bit=True` when calling the `.from_pretrained` method, you can divide your memory use by 4 (roughly).
+
+```python
+# pip install transformers accelerate bitsandbytes
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+model_id = "bigscience/bloom-1b7"
+
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_4bit=True)
+```
+
+
+
+Note that once a model has been loaded in 4-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 4-bit weights as this is not supported yet. However you can use 4-bit models to train extra parameters, this will be covered in the next section.
+
+
+
### Load a large model in 8bit
You can load a model by roughly halving the memory requirements by using `load_in_8bit=True` argument when calling `.from_pretrained` method
@@ -48,10 +85,56 @@ With this integration we were able to load large models on smaller devices and r
-Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.
+Note that once a model has been loaded in 8-bit it is currently not possible to push the quantized weights on the Hub except if you use the latest `transformers` and `bitsandbytes`. Note also that you cannot train 8-bit weights as this is not supported yet. However you can use 8-bit models to train extra parameters, this will be covered in the next section.
+#### Advanced usecases
+
+Here we will cover some advanced usecases you can perform with FP4 quantization
+
+##### Change the compute dtype
+
+The compute dtype is used to change the dtype that will be used during computation. For example, hidden states could be in `float32` but computation can be set to bf16 for speedups. By default, the compute dtype is set to `float32`.
+
+```python
+import torch
+from transformers import BitsAndBytesConfig
+
+quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
+```
+
+##### Using NF4 (Normal Float 4) data type
+
+You can also use the NF4 data type, which is a new 4bit datatype adapted for weights that have been initialized using a normal distribution. For that run:
+
+```python
+from transformers import BitsAndBytesConfig
+
+nf4_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_quant_type="nf4",
+)
+
+model_nf4 = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=nf4_config)
+```
+
+##### Use nested quantization for more memory efficient inference
+
+We also advise users to use the nested quantization technique. This saves more memory at no additional performance - from our empirical observations, this enables fine-tuning llama-13b model on an NVIDIA-T4 16GB with a sequence length of 1024, batch size of 1 and gradient accumulation steps of 4.
+
+```python
+from transformers import BitsAndBytesConfig
+
+double_quant_config = BitsAndBytesConfig(
+ load_in_4bit=True,
+ bnb_4bit_use_double_quant=True,
+)
+
+model_double_quant = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=double_quant_config)
+```
+
+
### Push quantized models on the 🤗 Hub
You can push a quantized model on the Hub by naively using `push_to_hub` method. This will first push the quantization configuration file, then push the quantized model weights.
diff --git a/docs/source/en/perf_infer_gpu_one.mdx b/docs/source/en/perf_infer_gpu_one.mdx
index 3403e81fb38451..4bcf3c1111161c 100644
--- a/docs/source/en/perf_infer_gpu_one.mdx
+++ b/docs/source/en/perf_infer_gpu_one.mdx
@@ -34,6 +34,60 @@ model.save_pretrained("saved_model")
As of PyTorch 2.0, the attention fastpath is supported for both encoders and decoders. The list of supported architectures can be found [here](https://huggingface.co/docs/optimum/bettertransformer/overview#supported-models).
+## `bitsandbytes` integration for FP4 mixed-precision inference
+
+You can install `bitsandbytes` and benefit from easy model compression on GPUs. Using FP4 quantization you can expect to reduce up to 8x the model size compared to its native full precision version. Check out below how to get started.
+
+
+
+Note that this feature can also be used in a multi GPU setup.
+
+
+
+### Requirements
+
+- Latest `bitsandbytes` library
+`pip install bitsandbytes>=0.39.0`
+
+- Install latest `accelerate` from source
+`pip install git+https://github.com/huggingface/accelerate.git`
+
+- Install latest `transformers` from source
+`pip install git+https://github.com/huggingface/transformers.git`
+
+### Running FP4 models - single GPU setup - Quickstart
+
+You can quickly run a FP4 model on a single GPU by running the following code:
+
+```py
+from transformers import AutoModelForCausalLM
+
+model_name = "bigscience/bloom-2b5"
+model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
+```
+
+### Running FP4 models - multi GPU setup
+
+The way to load your mixed 8-bit model in multiple GPUs is as follows (same command as single GPU setup):
+```py
+model_name = "bigscience/bloom-2b5"
+model_8bit = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", load_in_4bit=True)
+```
+But you can control the GPU RAM you want to allocate on each GPU using `accelerate`. Use the `max_memory` argument as follows:
+
+```py
+max_memory_mapping = {0: "600MB", 1: "1GB"}
+model_name = "bigscience/bloom-3b"
+model_8bit = AutoModelForCausalLM.from_pretrained(
+ model_name, device_map="auto", load_in_4bit=True, max_memory=max_memory_mapping
+)
+```
+In this example, the first GPU will use 600MB of memory and the second 1GB.
+
+### Advanced usage
+
+For more advanced usage of this method, please have a look at the [quantization](main_classes/quantization) documentation page.
+
## `bitsandbytes` integration for Int8 mixed-precision matrix decomposition
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index b7b832b4640ed0..6495b997b89a5a 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -606,7 +606,7 @@ def _load_state_dict_into_meta_model(
state_dict_folder=None,
state_dict_index=None,
dtype=None,
- load_in_8bit=False,
+ is_quantized=False,
is_safetensors=False,
keep_in_fp32_modules=None,
):
@@ -627,8 +627,8 @@ def _load_state_dict_into_meta_model(
# - Is there a situation where some keys aren't in `loaded_state_dict_keys` and in which case
# they won't get loaded.
- if load_in_8bit:
- from .utils.bitsandbytes import set_module_8bit_tensor_to_device
+ if is_quantized:
+ from .utils.bitsandbytes import set_module_quantized_tensor_to_device
error_msgs = []
@@ -699,12 +699,13 @@ def _load_state_dict_into_meta_model(
# TODO: group all errors and raise at the end.
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
+
if param_device == "disk":
if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
- elif not load_in_8bit:
+ elif not is_quantized:
# For backward compatibility with older versions of `accelerate`
set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs)
else:
@@ -714,7 +715,7 @@ def _load_state_dict_into_meta_model(
fp16_statistics = None
if "SCB" not in param_name:
- set_module_8bit_tensor_to_device(
+ set_module_quantized_tensor_to_device(
model, param_name, param_device, value=param, fp16_statistics=fp16_statistics
)
@@ -1700,6 +1701,11 @@ def save_pretrained(
UserWarning,
)
+ if getattr(self, "is_loaded_in_4bit", False):
+ raise NotImplementedError(
+ "You are calling `save_pretrained` on a 4-bit converted model. This is currently not supported"
+ )
+
if "save_config" in kwargs:
warnings.warn(
"`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
@@ -1876,9 +1882,9 @@ def get_memory_footprint(self, return_buffers=True):
def to(self, *args, **kwargs):
# Checks if the model has been loaded in 8-bit
- if getattr(self, "is_loaded_in_8bit", False):
+ if getattr(self, "is_quantized", False):
raise ValueError(
- "`.to` is not supported for `8-bit` models. Please use the model as it is, since the"
+ "`.to` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the"
" model has already been set to the correct devices and casted to the correct `dtype`."
)
else:
@@ -1886,9 +1892,9 @@ def to(self, *args, **kwargs):
def half(self, *args):
# Checks if the model has been loaded in 8-bit
- if getattr(self, "is_loaded_in_8bit", False):
+ if getattr(self, "is_quantized", False):
raise ValueError(
- "`.half()` is not supported for `8-bit` models. Please use the model as it is, since the"
+ "`.half()` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the"
" model has already been casted to the correct `dtype`."
)
else:
@@ -1896,9 +1902,9 @@ def half(self, *args):
def float(self, *args):
# Checks if the model has been loaded in 8-bit
- if getattr(self, "is_loaded_in_8bit", False):
+ if getattr(self, "is_quantized", False):
raise ValueError(
- "`.float()` is not supported for `8-bit` models. Please use the model as it is, since the"
+ "`.float()` is not supported for `4-bit` or `8-bit` models. Please use the model as it is, since the"
" model has already been casted to the correct `dtype`."
)
else:
@@ -2156,6 +2162,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder = kwargs.pop("offload_folder", None)
offload_state_dict = kwargs.pop("offload_state_dict", False)
load_in_8bit = kwargs.pop("load_in_8bit", False)
+ load_in_4bit = kwargs.pop("load_in_4bit", False)
quantization_config = kwargs.pop("quantization_config", None)
subfolder = kwargs.pop("subfolder", "")
commit_hash = kwargs.pop("_commit_hash", None)
@@ -2194,10 +2201,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if quantization_config is None:
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
- config_dict={"load_in_8bit": load_in_8bit}, return_unused_kwargs=True, **kwargs
+ config_dict={"load_in_8bit": load_in_8bit, "load_in_4bit": load_in_4bit},
+ return_unused_kwargs=True,
+ **kwargs,
)
elif quantization_config is not None:
load_in_8bit = quantization_config.load_in_8bit
+ load_in_4bit = quantization_config.load_in_4bit
quantization_config_kwargs = {
k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters
@@ -2215,30 +2225,32 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
- if load_in_8bit:
+ if load_in_8bit or load_in_4bit:
if not (is_accelerate_available() and is_bitsandbytes_available()):
raise ImportError(
"Using `load_in_8bit=True` requires Accelerate: `pip install accelerate` and the latest version of"
" bitsandbytes `pip install -i https://test.pypi.org/simple/ bitsandbytes` or"
" pip install bitsandbytes` "
)
- if torch_dtype != torch.float16:
+
+ if torch_dtype is None:
# We force the `dtype` to be float16, this is a requirement from `bitsandbytes`
- logger.warning(
+ logger.info(
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
- "requirements of `bitsandbytes` to enable model loading in mixed int8. "
- "Either pass torch_dtype=torch.float16 or don't pass this argument at all to remove this warning."
+ "requirements of `bitsandbytes` to enable model loading in mixed kbit. "
+ "Pass your own torch_dtype to specify the dtype of the remaining non-linear layers or pass"
+ " torch_dtype=torch.float16 to remove this warning."
)
torch_dtype = torch.float16
if device_map is None:
raise ValueError(
- "A device map needs to be passed to run convert models into mixed-int8 format. Please run"
+ "A device map needs to be passed to run convert models into 8-bit and 4-bit formats. Please run"
"`.from_pretrained` with `device_map='auto'`"
)
if from_tf or from_flax:
raise ValueError(
- "Converting into mixed 8-bit weights from tf/flax weights is currently not supported, please make"
+ "Converting into 4-bit or 8-bit weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)
@@ -2296,8 +2308,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
load_in_8bit = quantization_config.load_in_8bit
if load_in_8bit:
- torch_dtype = torch.float16
-
+ if torch_dtype is None:
+ torch_dtype = torch.float16
if device_map is None:
device_map = "auto"
@@ -2582,7 +2594,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (
- (cls._keep_in_fp32_modules is not None) and is_accelerate_available() and torch_dtype == torch.float16
+ (cls._keep_in_fp32_modules is not None)
+ and is_accelerate_available()
+ and (torch_dtype == torch.float16 or load_in_4bit or load_in_8bit)
)
if (
(cls._keep_in_fp32_modules is not None)
@@ -2611,7 +2625,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts
- elif load_in_8bit or low_cpu_mem_usage:
+ elif load_in_8bit or load_in_4bit or low_cpu_mem_usage:
init_contexts.append(init_empty_weights())
with ContextManagers(init_contexts):
@@ -2624,20 +2638,19 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
else:
keep_in_fp32_modules = []
- if load_in_8bit:
- from .utils.bitsandbytes import get_keys_to_not_convert, replace_8bit_linear
+ if load_in_8bit or load_in_4bit:
+ from .utils.bitsandbytes import get_keys_to_not_convert, replace_with_bnb_linear
- load_in_8bit_skip_modules = quantization_config.llm_int8_skip_modules
- load_in_8bit_threshold = quantization_config.llm_int8_threshold
+ llm_int8_skip_modules = quantization_config.llm_int8_skip_modules
load_in_8bit_fp32_cpu_offload = quantization_config.llm_int8_enable_fp32_cpu_offload
logger.info("Detected 8-bit loading: activating 8-bit loading for this model")
# We keep some modules such as the lm_head in their original dtype for numerical stability reasons
- if load_in_8bit_skip_modules is None:
+ if llm_int8_skip_modules is None:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
- modules_to_not_convert = load_in_8bit_skip_modules
+ modules_to_not_convert = llm_int8_skip_modules
if not isinstance(modules_to_not_convert, list):
modules_to_not_convert = [modules_to_not_convert]
@@ -2657,21 +2670,36 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
modules_to_not_convert.extend(keys_on_cpu)
- model = replace_8bit_linear(
- model, threshold=load_in_8bit_threshold, modules_to_not_convert=modules_to_not_convert
+ supports_4bit = version.parse(importlib_metadata.version("bitsandbytes")) >= version.parse("0.39.0")
+
+ if load_in_4bit and not supports_4bit:
+ raise ValueError(
+ "You have a version of `bitsandbytes` that is not compatible with 4bit inference and training"
+ " make sure you have the latest version of `bitsandbytes` installed"
+ )
+
+ model = replace_with_bnb_linear(
+ model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config
)
# training in 8-bit is only available in 0.37.0+
- model._is_int8_training_enabled = version.parse(
+ model._is_kbit_training_enabled = version.parse(
importlib_metadata.version("bitsandbytes")
) >= version.parse("0.37.0")
model.config.quantization_config = quantization_config
model.is_8bit_serializable = is_8bit_serializable
+ if load_in_8bit and torch_dtype is None:
+ logger.warning(
+ "You are loading your model in 8bit but you did not specify a `torch_dtype` attribute."
+ "All non-linear modules will be loaded in full precision.",
+ " If you want to load the other modules in other precision, please specify a `torch_dtype` attribute.",
+ )
+
if isinstance(device_map, str):
special_dtypes = {}
- if load_in_8bit:
+ if load_in_8bit or load_in_4bit:
special_dtypes.update(
{
name: torch_dtype
@@ -2688,6 +2716,23 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
}
)
+ target_dtype = torch_dtype
+
+ if load_in_4bit:
+ if version.parse(importlib_metadata.version("accelerate")) > version.parse("0.19.0"):
+ from accelerate.utils import CustomDtype
+
+ target_dtype = CustomDtype.INT4
+ else:
+ raise ValueError(
+ "You are using `device_map='auto'` on a 4bit loaded version of the model. To automatically compute"
+ " the appropriate device map, you should upgrade your `accelerate` library,"
+ "`pip install --upgrade accelerare` or install it from source to support fp4 auto device map"
+ "calculation. You may encounter unexpected behavior, or pass your own device map"
+ )
+ elif load_in_8bit:
+ target_dtype = torch.int8
+
if model._no_split_modules is None:
raise ValueError(f"{model.__class__.__name__} does not support `device_map='{device_map}'` yet.")
no_split_modules = model._no_split_modules
@@ -2710,7 +2755,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if device_map != "sequential" and get_balanced_memory is not None:
max_memory = get_balanced_memory(
model,
- dtype=torch_dtype if not load_in_8bit else torch.int8,
+ dtype=target_dtype,
low_zero=(device_map == "balanced_low_0"),
max_memory=max_memory,
**kwargs,
@@ -2718,9 +2763,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
kwargs["max_memory"] = max_memory
# Make sure tied weights are tied before creating the device map.
model.tie_weights()
- device_map = infer_auto_device_map(model, dtype=torch_dtype if not load_in_8bit else torch.int8, **kwargs)
+ device_map = infer_auto_device_map(model, dtype=target_dtype, **kwargs)
- if load_in_8bit:
+ if load_in_8bit or load_in_4bit:
# The LM head / tied weights or any last module can stay on disk / CPU
device_map_without_lm_head = {
key: device_map[key] for key in device_map.keys() if key not in modules_to_not_convert
@@ -2795,11 +2840,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
offload_folder=offload_folder,
offload_state_dict=offload_state_dict,
dtype=torch_dtype,
- load_in_8bit=load_in_8bit,
+ is_quantized=(load_in_8bit or load_in_4bit),
keep_in_fp32_modules=keep_in_fp32_modules,
)
+ model.is_loaded_in_4bit = load_in_4bit
model.is_loaded_in_8bit = load_in_8bit
+ model.is_quantized = load_in_8bit or load_in_4bit
# make sure token embedding weights are still tied if needed
model.tie_weights()
@@ -2862,12 +2909,12 @@ def _load_pretrained_model(
offload_folder=None,
offload_state_dict=None,
dtype=None,
- load_in_8bit=False,
+ is_quantized=False,
keep_in_fp32_modules=None,
):
is_safetensors = False
- if load_in_8bit:
- from .utils.bitsandbytes import set_module_8bit_tensor_to_device
+ if is_quantized:
+ from .utils.bitsandbytes import set_module_quantized_tensor_to_device
if device_map is not None and "disk" in device_map.values():
archive_file = (
@@ -2973,10 +3020,10 @@ def _fix_key(key):
target_dtype = torch.float32
if param.device == torch.device("meta"):
- if not load_in_8bit:
+ if not (is_quantized):
set_module_tensor_to_device(model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype))
else:
- set_module_8bit_tensor_to_device(
+ set_module_quantized_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=target_dtype)
)
@@ -3134,7 +3181,7 @@ def _find_mismatched_keys(
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
- load_in_8bit=load_in_8bit,
+ is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
@@ -3174,7 +3221,7 @@ def _find_mismatched_keys(
)
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
- if load_in_8bit:
+ if is_quantized:
unexpected_keys = [elem for elem in unexpected_keys if "SCB" not in elem]
missing_keys = [elem for elem in missing_keys if "SCB" not in elem]
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 90b25cfb2097bc..e708de37015282 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -391,8 +391,8 @@ def __init__(
)
# At this stage the model is already loaded
- if getattr(model, "is_loaded_in_8bit", False):
- if getattr(model, "_is_int8_training_enabled", False):
+ if getattr(model, "is_loaded_in_kbit", False):
+ if getattr(model, "_is_kbit_training_enabled", False):
logger.info(
"The model is loaded in 8-bit precision. To train this model you need to add additional modules"
" inside the model such as adapters using `peft` library and freeze the model weights. Please"
diff --git a/src/transformers/utils/bitsandbytes.py b/src/transformers/utils/bitsandbytes.py
index 3768506f41138e..5cb82c44d6999b 100644
--- a/src/transformers/utils/bitsandbytes.py
+++ b/src/transformers/utils/bitsandbytes.py
@@ -1,3 +1,4 @@
+import warnings
from copy import deepcopy
from packaging import version
@@ -15,7 +16,7 @@
from accelerate.utils import find_tied_parameters
-def set_module_8bit_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
+def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
@@ -52,12 +53,16 @@ class `Int8Params` from `bitsandbytes`.
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
- if is_buffer:
- has_fp16_weights = None
+ is_4bit = False
+ is_8bit = False
+ if is_buffer or not is_bitsandbytes_available():
+ is_8bit = False
+ is_4bit = False
else:
- has_fp16_weights = getattr(module._parameters[tensor_name], "has_fp16_weights", None)
+ is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
+ is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
- if has_fp16_weights is not None:
+ if is_8bit or is_4bit:
param = module._parameters[tensor_name]
if param.device.type != "cuda":
if value is None:
@@ -75,11 +80,17 @@ class `Int8Params` from `bitsandbytes`.
)
else:
new_value = torch.tensor(value, device="cpu")
- new_value = bnb.nn.Int8Params(new_value, requires_grad=False, has_fp16_weights=has_fp16_weights).to(device)
- module._parameters[tensor_name] = new_value
+ kwargs = old_value.__dict__
+ if is_8bit:
+ new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
+ elif is_4bit:
+ new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
+
+ module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))
+
else:
if value is None:
new_value = old_value.to(device)
@@ -95,10 +106,10 @@ class `Int8Params` from `bitsandbytes`.
module._parameters[tensor_name] = new_value
-def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, current_key_name=None):
+def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
- library. This will enable running your models using mixed int8 precision as described by the paper `GPT3.int8():
+ library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
@@ -113,9 +124,6 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, curre
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
- threshold (`float`, *optional*, defaults to 6.0):
- `int8_threshold` for outlier detection as described in the formentioned paper. This parameters is set to
- `6.0` as described by the paper.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
@@ -128,29 +136,65 @@ def replace_8bit_linear(model, threshold=6.0, modules_to_not_convert=None, curre
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
- current_key_name.append(name)
-
- if len(list(module.children())) > 0:
- replace_8bit_linear(module, threshold, modules_to_not_convert, current_key_name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
- model._modules[name] = bnb.nn.Linear8bitLt(
- module.in_features,
- module.out_features,
- module.bias is not None,
- has_fp16_weights=False,
- threshold=threshold,
- )
+ if quantization_config.quantization_method() == "llm_int8":
+ model._modules[name] = bnb.nn.Linear8bitLt(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
+ threshold=quantization_config.llm_int8_threshold,
+ )
+ else:
+ if (
+ quantization_config.llm_int8_skip_modules is not None
+ and name in quantization_config.llm_int8_skip_modules
+ ):
+ pass
+ else:
+ model._modules[name] = bnb.nn.Linear4bit(
+ module.in_features,
+ module.out_features,
+ module.bias is not None,
+ quantization_config.bnb_4bit_compute_dtype,
+ compress_statistics=quantization_config.bnb_4bit_use_double_quant,
+ quant_type=quantization_config.bnb_4bit_quant_type,
+ )
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
- current_key_name.pop(-1)
+ if len(list(module.children())) > 0:
+ replace_with_bnb_linear(
+ module,
+ modules_to_not_convert,
+ current_key_name,
+ quantization_config,
+ )
return model
+# For backward compatibility
+def replace_8bit_linear(*args, **kwargs):
+ warnings.warn(
+ "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
+ FutureWarning,
+ )
+ return replace_with_bnb_linear(*args, **kwargs)
+
+
+# For backward compatiblity
+def set_module_8bit_tensor_to_device(*args, **kwargs):
+ warnings.warn(
+ "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
+ FutureWarning,
+ )
+ return set_module_quantized_tensor_to_device(*args, **kwargs)
+
+
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py
index f123faaab32f59..2647418a13a666 100644
--- a/src/transformers/utils/quantization_config.py
+++ b/src/transformers/utils/quantization_config.py
@@ -20,7 +20,14 @@
from dataclasses import dataclass
from typing import Any, Dict, Union
-from ..utils import logging
+from packaging import version
+
+from ..utils import is_torch_available, logging
+from ..utils.import_utils import importlib_metadata
+
+
+if is_torch_available():
+ import torch
logger = logging.get_logger(__name__)
@@ -32,14 +39,17 @@ class BitsAndBytesConfig:
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`.
- This replaces `load_in_8bit` therefore both options are mutually exclusive.
+ This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
- For now, only arguments that are relative to `LLM.int8()` are supported, therefore the arguments are all termed as
- `llm_int8_*`. If more methods are added to `bitsandbytes`, then more arguments will be added to this class.
+ Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
+ then more arguments will be added to this class.
Args:
load_in_8bit (`bool`, *optional*, defaults to `False`):
This flag is used to enable 8-bit quantization with LLM.int8().
+ load_in_4bit (`bool`, *optional*, defaults to `False`):
+ This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
+ `bitsandbytes`.
llm_int8_threshold (`float`, *optional*, defaults to 6):
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
@@ -58,6 +68,18 @@ class BitsAndBytesConfig:
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
operations will not be run on CPU.
+ llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
+ This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
+ have to be converted back and forth for the backward pass.
+ bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
+ This sets the computational type which might be different than the input time. For example, inputs might be
+ fp32, but computation can be set to bf16 for speedups.
+ bnb_4bit_quant_type (`str`, {fp4, fn4}, defaults to `fp4`):
+ This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
+ which are specified by `fp4` or `fn4`.
+ bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
+ This flag is used for nested quantization where the quantization constants from the first quantization are
+ quantized again.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
@@ -65,15 +87,33 @@ class BitsAndBytesConfig:
def __init__(
self,
load_in_8bit=False,
+ load_in_4bit=False,
llm_int8_threshold=6.0,
llm_int8_skip_modules=None,
llm_int8_enable_fp32_cpu_offload=False,
+ llm_int8_has_fp16_weight=False,
+ bnb_4bit_compute_dtype=None,
+ bnb_4bit_quant_type="fp4",
+ bnb_4bit_use_double_quant=False,
**kwargs,
):
self.load_in_8bit = load_in_8bit
+ self.load_in_4bit = load_in_4bit
self.llm_int8_threshold = llm_int8_threshold
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
+ self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
+ self.bnb_4bit_quant_type = bnb_4bit_quant_type
+ self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
+
+ if bnb_4bit_compute_dtype is None:
+ self.bnb_4bit_compute_dtype = torch.float32
+ elif isinstance(bnb_4bit_compute_dtype, str):
+ self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
+ elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
+ self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
+ else:
+ raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
self.post_init()
@@ -86,10 +126,48 @@ def post_init(self):
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
raise ValueError("llm_int8_skip_modules must be a list of strings")
-
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean")
+ if not isinstance(self.llm_int8_has_fp16_weight, bool):
+ raise ValueError("llm_int8_has_fp16_weight must be a boolean")
+
+ if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
+ raise ValueError("bnb_4bit_compute_dtype must be torch.dtype")
+
+ if not isinstance(self.bnb_4bit_quant_type, str):
+ raise ValueError("bnb_4bit_quant_type must be a string")
+
+ if not isinstance(self.bnb_4bit_use_double_quant, bool):
+ raise ValueError("bnb_4bit_use_double_quant must be a boolean")
+
+ if self.load_in_4bit and not version.parse(importlib_metadata.version("bitsandbytes")) >= version.parse(
+ "0.39.0"
+ ):
+ raise ValueError(
+ "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
+ )
+
+ def is_quantizable(self):
+ r"""
+ Returns `True` if the model is quantizable, `False` otherwise.
+ """
+ return self.load_in_8bit or self.load_in_4bit
+
+ def quantization_method(self):
+ r"""
+ This method returns the quantization method used for the model. If the model is not quantizable, it returns
+ `None`.
+ """
+ if self.load_in_8bit:
+ return "llm_int8"
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
+ return "fp4"
+ elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
+ return "nf4"
+ else:
+ return None
+
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
"""
@@ -107,6 +185,7 @@ def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
Returns:
[`BitsAndBytesConfig`]: The configuration object instantiated from those parameters.
"""
+
config = cls(**config_dict)
to_remove = []
@@ -144,5 +223,8 @@ def to_dict(self) -> Dict[str, Any]:
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
+
output = copy.deepcopy(self.__dict__)
+ output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
+
return output
diff --git a/tests/mixed_int8/README.md b/tests/bitsandbytes/README.md
similarity index 100%
rename from tests/mixed_int8/README.md
rename to tests/bitsandbytes/README.md
diff --git a/tests/mixed_int8/__init__.py b/tests/bitsandbytes/__init__.py
similarity index 100%
rename from tests/mixed_int8/__init__.py
rename to tests/bitsandbytes/__init__.py
diff --git a/tests/bitsandbytes/test_4bit.py b/tests/bitsandbytes/test_4bit.py
new file mode 100644
index 00000000000000..1d0ea6dc3de281
--- /dev/null
+++ b/tests/bitsandbytes/test_4bit.py
@@ -0,0 +1,460 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a clone of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import gc
+import tempfile
+import unittest
+
+from packaging import version
+
+from transformers import (
+ AutoModel,
+ AutoModelForCausalLM,
+ AutoModelForSeq2SeqLM,
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ BitsAndBytesConfig,
+ pipeline,
+)
+from transformers.testing_utils import (
+ is_torch_available,
+ require_accelerate,
+ require_bitsandbytes,
+ require_torch,
+ require_torch_gpu,
+ require_torch_multi_gpu,
+ slow,
+)
+from transformers.utils.versions import importlib_metadata
+
+
+if is_torch_available():
+ import torch
+ import torch.nn as nn
+
+ class LoRALayer(nn.Module):
+ """Wraps a linear layer with LoRA-like adapter - Used for testing purposes only"""
+
+ def __init__(self, module: nn.Module, rank: int):
+ super().__init__()
+ self.module = module
+ self.adapter = nn.Sequential(
+ nn.Linear(module.in_features, rank, bias=False),
+ nn.Linear(rank, module.out_features, bias=False),
+ )
+ small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
+ nn.init.normal_(self.adapter[0].weight, std=small_std)
+ nn.init.zeros_(self.adapter[1].weight)
+ self.adapter.to(module.weight.device)
+
+ def forward(self, input, *args, **kwargs):
+ return self.module(input, *args, **kwargs) + self.adapter(input)
+
+
+@require_bitsandbytes
+@require_accelerate
+@require_torch
+@require_torch_gpu
+@slow
+class Base4bitTest(unittest.TestCase):
+ # We keep the constants inside the init function and model loading inside setUp function
+
+ # We need to test on relatively large models (aka >1b parameters otherwise the quantiztion may not work as expected)
+ # Therefore here we use only bloom-1b3 to test our module
+ model_name = "bigscience/bloom-1b7"
+
+ # Constant values
+ EXPECTED_RELATIVE_DIFFERENCE = (
+ 2.109659552692574 # This was obtained on a RTX Titan so the number might slightly change
+ )
+
+ input_text = "Hello my name is"
+ EXPECTED_OUTPUTS = set()
+ EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
+ EXPECTED_OUTPUTS.add("Hello my name is John.\nI am a friend of your father.\n")
+ MAX_NEW_TOKENS = 10
+
+ def setUp(self):
+ # Models and tokenizer
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
+
+
+class Bnb4BitTest(Base4bitTest):
+ def setUp(self):
+ super().setUp()
+
+ # Models and tokenizer
+ self.model_fp16 = AutoModelForCausalLM.from_pretrained(
+ self.model_name, torch_dtype=torch.float16, device_map="auto"
+ )
+ self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
+
+ def tearDown(self):
+ r"""
+ TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
+ avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
+ """
+ del self.model_fp16
+ del self.model_4bit
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_memory_footprint(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ from bitsandbytes.nn import Params4bit
+
+ mem_fp16 = self.model_fp16.get_memory_footprint()
+ mem_4bit = self.model_4bit.get_memory_footprint()
+
+ self.assertAlmostEqual(mem_fp16 / mem_4bit, self.EXPECTED_RELATIVE_DIFFERENCE)
+ self.assertTrue(self.model_4bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
+
+ def test_linear_are_4bit(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ from transformers import T5PreTrainedModel
+
+ self.model_fp16.get_memory_footprint()
+ self.model_4bit.get_memory_footprint()
+
+ for name, module in self.model_4bit.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
+ # 4-bit parameters are packed in uint8 variables
+ self.assertTrue(module.weight.dtype == torch.uint8)
+
+ def test_generate_quality(self):
+ r"""
+ Test the generation quality of the quantized model and see that we are matching the expected output.
+ Given that we are operating on small numbers + the testing model is relatively small, we might not get
+ the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
+ """
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
+ output_sequences = self.model_4bit.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
+
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
+
+ def test_generate_quality_config(self):
+ r"""
+ Test that loading the model with the config is equivalent
+ """
+ bnb_config = BitsAndBytesConfig()
+ bnb_config.load_in_4bit = True
+
+ model_4bit_from_config = AutoModelForCausalLM.from_pretrained(
+ self.model_name, quantization_config=bnb_config, device_map="auto"
+ )
+
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
+ output_sequences = model_4bit_from_config.generate(
+ input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10
+ )
+
+ self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
+
+ def test_raise_on_save_pretrained(self):
+ r"""
+ Test whether trying to save a model after converting it in 8-bit will throw a warning.
+ """
+ with self.assertRaises(NotImplementedError), tempfile.TemporaryDirectory() as tmpdirname:
+ self.model_4bit.save_pretrained(tmpdirname)
+
+ def test_raise_if_config_and_load_in_4bit(self):
+ r"""
+ Test that loading the model with the config and `load_in_4bit` raises an error
+ """
+ bnb_config = BitsAndBytesConfig()
+
+ with self.assertRaises(ValueError):
+ _ = AutoModelForCausalLM.from_pretrained(
+ self.model_name,
+ quantization_config=bnb_config,
+ load_in_4bit=True,
+ device_map="auto",
+ bnb_4bit_quant_type="nf4",
+ )
+
+ def test_device_and_dtype_assignment(self):
+ r"""
+ Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
+ Checks also if other models are casted correctly.
+ """
+ with self.assertRaises(ValueError):
+ # Tries with `str`
+ self.model_4bit.to("cpu")
+
+ with self.assertRaises(ValueError):
+ # Tries with a `dtype``
+ self.model_4bit.to(torch.float16)
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_4bit.to(torch.device("cuda:0"))
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_4bit.float()
+
+ with self.assertRaises(ValueError):
+ # Tries with a `device`
+ self.model_4bit.half()
+
+ # Test if we did not break anything
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
+
+ self.model_fp16 = self.model_fp16.to(torch.float32)
+ _ = self.model_fp16.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
+
+ # Check this does not throw an error
+ _ = self.model_fp16.to("cpu")
+
+ # Check this does not throw an error
+ _ = self.model_fp16.half()
+
+ # Check this does not throw an error
+ _ = self.model_fp16.float()
+
+ def test_fp32_4bit_conversion(self):
+ r"""
+ Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
+ """
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-small", load_in_4bit=True, device_map="auto")
+ self.assertTrue(model.decoder.block[0].layer[2].DenseReluDense.wo.weight.dtype == torch.float32)
+
+
+@require_bitsandbytes
+@require_accelerate
+@require_torch
+@require_torch_gpu
+@slow
+class Bnb4BitT5Test(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.model_name = "t5-small"
+ cls.dense_act_model_name = "google/flan-t5-small" # flan-t5 uses dense-act instead of dense-relu-dense
+ cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
+ cls.input_text = "Translate in German: Hello, my dog is cute"
+
+ def tearDown(self):
+ r"""
+ TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
+ avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
+ """
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_inference_without_keep_in_fp32(self):
+ r"""
+ Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
+ `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
+ both cases.
+ """
+ from transformers import T5ForConditionalGeneration
+
+ modules = T5ForConditionalGeneration._keep_in_fp32_modules
+ T5ForConditionalGeneration._keep_in_fp32_modules = None
+
+ # test with `t5-small`
+ model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
+ _ = model.generate(**encoded_input)
+
+ # test with `flan-t5-small`
+ model = T5ForConditionalGeneration.from_pretrained(
+ self.dense_act_model_name, load_in_4bit=True, device_map="auto"
+ )
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
+ _ = model.generate(**encoded_input)
+ T5ForConditionalGeneration._keep_in_fp32_modules = modules
+
+ def test_inference_with_keep_in_fp32(self):
+ r"""
+ Test whether it is possible to mix both `4bit` and `fp32` weights when using `keep_in_fp32_modules` correctly.
+ `flan-t5-small` uses `T5DenseGatedActDense` whereas `t5-small` uses `T5DenseReluDense`. We need to test
+ both cases.
+ """
+ import bitsandbytes as bnb
+
+ from transformers import T5ForConditionalGeneration
+
+ # test with `t5-small`
+ model = T5ForConditionalGeneration.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
+
+ # there was a bug with decoders - this test checks that it is fixed
+ self.assertTrue(isinstance(model.decoder.block[0].layer[0].SelfAttention.q, bnb.nn.Linear4bit))
+
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
+ _ = model.generate(**encoded_input)
+
+ # test with `flan-t5-small`
+ model = T5ForConditionalGeneration.from_pretrained(
+ self.dense_act_model_name, load_in_4bit=True, device_map="auto"
+ )
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
+ _ = model.generate(**encoded_input)
+
+
+class Classes4BitModelTest(Base4bitTest):
+ def setUp(self):
+ super().setUp()
+ # model_name
+ self.model_name = "bigscience/bloom-560m"
+ self.seq_to_seq_name = "t5-small"
+
+ # Different types of model
+
+ self.base_model = AutoModel.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
+ # Sequence classification model
+ self.sequence_model = AutoModelForSequenceClassification.from_pretrained(
+ self.model_name, load_in_4bit=True, device_map="auto"
+ )
+ # CausalLM model
+ self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
+ # Seq2seq model
+ self.seq_to_seq_model = AutoModelForSeq2SeqLM.from_pretrained(
+ self.seq_to_seq_name, load_in_4bit=True, device_map="auto"
+ )
+
+ def tearDown(self):
+ r"""
+ TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
+ avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
+ """
+ del self.base_model
+ del self.sequence_model
+ del self.model_4bit
+ del self.seq_to_seq_model
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_correct_head_class(self):
+ r"""
+ A simple test to check if the last modules for some classes (AutoModelForCausalLM or SequenceClassification)
+ are kept in their native class.
+ """
+ from bitsandbytes.nn import Params4bit
+
+ self.assertTrue(self.base_model.h[-1].mlp.dense_4h_to_h.weight.__class__ == Params4bit)
+
+ # Other heads should be nn.Parameter
+ self.assertTrue(self.model_4bit.lm_head.weight.__class__ == torch.nn.Parameter)
+ self.assertTrue(self.sequence_model.score.weight.__class__ == torch.nn.Parameter)
+ self.assertTrue(self.seq_to_seq_model.lm_head.weight.__class__ == torch.nn.Parameter)
+
+
+class Pipeline4BitTest(Base4bitTest):
+ def setUp(self):
+ super().setUp()
+
+ def tearDown(self):
+ r"""
+ TearDown function needs to be called at the end of each test to free the GPU memory and cache, also to
+ avoid unexpected behaviors. Please see: https://discuss.pytorch.org/t/how-can-we-release-gpu-memory-cache/14530/27
+ """
+ del self.pipe
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def test_pipeline(self):
+ r"""
+ The aim of this test is to verify that the mixed 4bit is compatible with `pipeline` from transformers. Since
+ we used pipline for inference speed benchmarking we want to make sure that this feature does not break anything
+ on pipline.
+ """
+ # self._clear_cuda_cache()
+ self.pipe = pipeline(
+ "text-generation",
+ model=self.model_name,
+ model_kwargs={"device_map": "auto", "load_in_4bit": True, "torch_dtype": torch.float16},
+ max_new_tokens=self.MAX_NEW_TOKENS,
+ )
+
+ # Real second forward pass
+ pipeline_output = self.pipe(self.input_text)
+ self.assertIn(pipeline_output[0]["generated_text"], self.EXPECTED_OUTPUTS)
+
+
+@require_torch_multi_gpu
+class Bnb4bitTestMultiGpu(Base4bitTest):
+ def setUp(self):
+ super().setUp()
+
+ def test_multi_gpu_loading(self):
+ r"""
+ This tests that the model has been loaded and can be used correctly on a multi-GPU setup.
+ Let's just try to load a model on 2 GPUs and see if it works. The model we test has ~2GB of total, 3GB should suffice
+ """
+
+ model_parallel = AutoModelForCausalLM.from_pretrained(
+ self.model_name, load_in_4bit=True, device_map="balanced"
+ )
+
+ # Check correct device map
+ self.assertEqual(set(model_parallel.hf_device_map.values()), {0, 1})
+
+ # Check that inference pass works on the model
+ encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
+
+ # Second real batch
+ output_parallel = model_parallel.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
+ self.assertIn(self.tokenizer.decode(output_parallel[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
+
+
+class Bnb4BitTestTraining(Base4bitTest):
+ def setUp(self):
+ self.model_name = "facebook/opt-350m"
+ super().setUp()
+
+ def test_training(self):
+ if version.parse(importlib_metadata.version("bitsandbytes")) < version.parse("0.37.0"):
+ return
+
+ # Step 1: freeze all parameters
+ model = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True, device_map="auto")
+
+ for param in model.parameters():
+ param.requires_grad = False # freeze the model - train adapters later
+ if param.ndim == 1:
+ # cast the small parameters (e.g. layernorm) to fp32 for stability
+ param.data = param.data.to(torch.float32)
+
+ # Step 2: add adapters
+ for _, module in model.named_modules():
+ if "OPTAttention" in repr(type(module)):
+ module.q_proj = LoRALayer(module.q_proj, rank=16)
+ module.k_proj = LoRALayer(module.k_proj, rank=16)
+ module.v_proj = LoRALayer(module.v_proj, rank=16)
+
+ # Step 3: dummy batch
+ batch = self.tokenizer("Test batch ", return_tensors="pt").to(0)
+
+ # Step 4: Check if the gradient is not None
+ with torch.cuda.amp.autocast():
+ out = model.forward(**batch)
+ out.logits.norm().backward()
+
+ for module in model.modules():
+ if isinstance(module, LoRALayer):
+ self.assertTrue(module.adapter[1].weight.grad is not None)
+ self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0)
+ elif isinstance(module, nn.Embedding):
+ self.assertTrue(module.weight.grad is None)
diff --git a/tests/mixed_int8/test_mixed_int8.py b/tests/bitsandbytes/test_mixed_int8.py
similarity index 97%
rename from tests/mixed_int8/test_mixed_int8.py
rename to tests/bitsandbytes/test_mixed_int8.py
index 7053be30a101d8..b31aaa386ae043 100644
--- a/tests/mixed_int8/test_mixed_int8.py
+++ b/tests/bitsandbytes/test_mixed_int8.py
@@ -131,6 +131,21 @@ def test_memory_footprint(self):
self.assertAlmostEqual(mem_fp16 / mem_8bit, self.EXPECTED_RELATIVE_DIFFERENCE)
self.assertTrue(self.model_8bit.transformer.h[0].mlp.dense_4h_to_h.weight.__class__ == Int8Params)
+ def test_linear_are_8bit(self):
+ r"""
+ A simple test to check if the model conversion has been done correctly by checking on the
+ memory footprint of the converted model and the class type of the linear layers of the converted models
+ """
+ from transformers import T5PreTrainedModel
+
+ self.model_fp16.get_memory_footprint()
+ self.model_8bit.get_memory_footprint()
+
+ for name, module in self.model_8bit.named_modules():
+ if isinstance(module, torch.nn.Linear):
+ if name not in ["lm_head"] + T5PreTrainedModel._keep_in_fp32_modules:
+ self.assertTrue(module.weight.dtype == torch.int8)
+
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
@@ -147,6 +162,7 @@ def test_generate_quality_config(self):
Test that loading the model with the config is equivalent
"""
bnb_config = BitsAndBytesConfig()
+ bnb_config.load_in_8bit = True
model_8bit_from_config = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=bnb_config, device_map="auto"
@@ -329,6 +345,7 @@ def test_inference_without_keep_in_fp32(self):
"""
from transformers import T5ForConditionalGeneration
+ modules = T5ForConditionalGeneration._keep_in_fp32_modules
T5ForConditionalGeneration._keep_in_fp32_modules = None
# test with `t5-small`
@@ -342,6 +359,7 @@ def test_inference_without_keep_in_fp32(self):
)
encoded_input = self.tokenizer(self.input_text, return_tensors="pt").to(0)
_ = model.generate(**encoded_input)
+ T5ForConditionalGeneration._keep_in_fp32_modules = modules
def test_inference_with_keep_in_fp32(self):
r"""