From 4c62a7699c2422a4ac5f20aeb95c6f8630735810 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 14:34:51 +0800 Subject: [PATCH 01/30] add peft --- gptqmodel/integration/__init__.py | 1 + gptqmodel/integration/peft/import_utils.py | 131 +++ .../integration/peft/tuners/adalora/model.py | 365 +++++++ .../integration/peft/tuners/lora/gptq.py | 122 +++ .../integration/peft/tuners/lora/model.py | 939 ++++++++++++++++++ .../integration/peft/tuners/utils/__init__.py | 57 ++ .../integration/peft/tuners/utils/other.py | 759 ++++++++++++++ 7 files changed, 2374 insertions(+) create mode 100644 gptqmodel/integration/__init__.py create mode 100644 gptqmodel/integration/peft/import_utils.py create mode 100644 gptqmodel/integration/peft/tuners/adalora/model.py create mode 100644 gptqmodel/integration/peft/tuners/lora/gptq.py create mode 100644 gptqmodel/integration/peft/tuners/lora/model.py create mode 100644 gptqmodel/integration/peft/tuners/utils/__init__.py create mode 100644 gptqmodel/integration/peft/tuners/utils/other.py diff --git a/gptqmodel/integration/__init__.py b/gptqmodel/integration/__init__.py new file mode 100644 index 000000000..88638ca39 --- /dev/null +++ b/gptqmodel/integration/__init__.py @@ -0,0 +1 @@ +from .optimum import monkey_patch_gptqmodel_into_transformers diff --git a/gptqmodel/integration/peft/import_utils.py b/gptqmodel/integration/peft/import_utils.py new file mode 100644 index 000000000..2314c3778 --- /dev/null +++ b/gptqmodel/integration/peft/import_utils.py @@ -0,0 +1,131 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import importlib.metadata as importlib_metadata +from functools import lru_cache + +import packaging.version + + +@lru_cache +def is_bnb_available() -> bool: + return importlib.util.find_spec("bitsandbytes") is not None + + +@lru_cache +def is_bnb_4bit_available() -> bool: + if not is_bnb_available(): + return False + + import bitsandbytes as bnb + + return hasattr(bnb.nn, "Linear4bit") + + +@lru_cache +def is_auto_gptq_available(): + if importlib.util.find_spec("auto_gptq") is not None: + AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.5.0") + version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq")) + if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq: + return True + else: + raise ImportError( + f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, " + f"but only versions above {AUTOGPTQ_MINIMUM_VERSION} are supported" + ) + + +@lru_cache +def is_gptqmodel_available(): + if importlib.util.find_spec("gptqmodel") is not None: + GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.3.0") + version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel")) + if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel: + return True + else: + raise ImportError( + f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, " + f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported" + ) + + +@lru_cache +def is_optimum_available() -> bool: + return importlib.util.find_spec("optimum") is not None + + +@lru_cache +def is_torch_tpu_available(check_device=True): + "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm + + _ = xm.xla_device() + return True + except RuntimeError: + return False + return True + return False + + +@lru_cache +def is_aqlm_available(): + return importlib.util.find_spec("aqlm") is not None + + +@lru_cache +def is_auto_awq_available(): + return importlib.util.find_spec("awq") is not None + + +@lru_cache +def is_eetq_available(): + if importlib.util.find_spec("eetq") is None: + return False + + is_available = True + try: + from eetq import EetqLinear # noqa: F401 + except ImportError as exc: + if "shard_checkpoint" in str(exc): + # eetq is currently broken with newer transformers versions because it tries to import shard_checkpoint + # see https://github.com/NetEase-FuXi/EETQ/issues/34 + # TODO: Remove once eetq releasees a fix and this release is used in CI + is_available = False + return is_available + + +@lru_cache +def is_hqq_available(): + return importlib.util.find_spec("hqq") is not None + + +@lru_cache +def is_torchao_available(): + if importlib.util.find_spec("torchao") is None: + return False + + TORCHAO_MINIMUM_VERSION = packaging.version.parse("0.4.0") + torchao_version = packaging.version.parse(importlib_metadata.version("torchao")) + + if torchao_version < TORCHAO_MINIMUM_VERSION: + raise ImportError( + f"Found an incompatible version of torchao. Found version {torchao_version}, " + f"but only versions above {TORCHAO_MINIMUM_VERSION} are supported" + ) + return True \ No newline at end of file diff --git a/gptqmodel/integration/peft/tuners/adalora/model.py b/gptqmodel/integration/peft/tuners/adalora/model.py new file mode 100644 index 000000000..0654a87b0 --- /dev/null +++ b/gptqmodel/integration/peft/tuners/adalora/model.py @@ -0,0 +1,365 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import torch +from transformers.pytorch_utils import Conv1D + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.lora import LoraConfig, LoraModel +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import ( + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + _freeze_adapter, + _get_submodules, + get_auto_gptq_quant_linear, + get_gptqmodel_quant_linear, + get_quantization_config, +) +from peft.import_utils import is_gptqmodel_available +from peft.utils.integrations import gather_params_ctx + +from .gptq import SVDQuantLinear +from .layer import AdaLoraLayer, RankAllocator, SVDLinear + + +class AdaLoraModel(LoraModel): + """ + Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper: + https://openreview.net/forum?id=lq62uWRJjiY + + Args: + model ([`transformers.PreTrainedModel`]): The model to be adapted. + config ([`AdaLoraConfig`]): The configuration of the AdaLora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + + Returns: + `torch.nn.Module`: The AdaLora model. + + Example:: + + >>> from transformers import AutoModelForSeq2SeqLM >>> from peft import LoraConfig, AdaLoraModel, AdaLoraConfig + >>> config = AdaLoraConfig( + peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", init_r=12, lora_alpha=32, target_modules=["q", "v"], + lora_dropout=0.01, + ) + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default") + + **Attributes**: + - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model. + """ + + # Note: don't redefine prefix here, it should be inherited from LoraModel + + def __init__(self, model, config, adapter_name): + super().__init__(model, config, adapter_name) + + traininable_mode_counter = 0 + for config in self.peft_config.values(): + if not config.inference_mode: + traininable_mode_counter += 1 + + if traininable_mode_counter > 1: + raise ValueError( + "AdaLoraModel supports only 1 trainable adapter. " + "When using multiple adapters, set inference_mode to True for all adapters except the one you want to train." + ) + + if self.peft_config[adapter_name].inference_mode: + _freeze_adapter(self.model, adapter_name) + else: + self.trainable_adapter_name = adapter_name + self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name) + + def _check_new_adapter_config(self, config: LoraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + super()._check_new_adapter_config(config) + + traininable_mode_counter = 0 + for config_ in self.peft_config.values(): + if not config_.inference_mode: + traininable_mode_counter += 1 + + if traininable_mode_counter > 1: + raise ValueError( + f"{self.__class__.__name__} supports only 1 trainable adapter. " + "When using multiple adapters, set inference_mode to True for all adapters except the one " + "you want to train." + ) + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + kwargs = { + "r": lora_config.init_r, + "lora_alpha": lora_config.lora_alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available(): + raise ImportError( + "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. " + "You can install it with `pip install bitsandbytes`." + ) + + quantization_config = get_quantization_config(self.model, method="gptq") + if quantization_config is not None: + kwargs["gptq_quantization_config"] = quantization_config + + # If it is not an AdaLoraLayer, create a new module, else update it with new adapters + if not isinstance(target, AdaLoraLayer): + new_module = self._create_new_module(lora_config, adapter_name, target, self.model.hf_device_map, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + else: + target.update_layer( + adapter_name, + lora_config.init_r, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, device_map, **kwargs): + # avoid eager bnb import + if is_bnb_available(): + import bitsandbytes as bnb + + from .bnb import SVDLinear8bitLt + if is_bnb_4bit_available(): + from .bnb import SVDLinear4bit + + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) + + if is_gptqmodel_available(): + QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map) + else: + QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt): + kwargs.update( + { + "has_fp16_weights": target_base_layer.state.has_fp16_weights, + "memory_efficient_backward": target_base_layer.state.memory_efficient_backward, + "threshold": target_base_layer.state.threshold, + "index": target_base_layer.index, + } + ) + new_module = SVDLinear8bitLt(target, adapter_name, **kwargs) + elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit): + fourbit_kwargs = kwargs.copy() + fourbit_kwargs.update( + { + "compute_dtype": target_base_layer.compute_dtype, + "compress_statistics": target_base_layer.weight.compress_statistics, + "quant_type": target_base_layer.weight.quant_type, + } + ) + new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs) + elif QuantLinear is not None and isinstance(target, QuantLinear): + new_module = SVDQuantLinear(target, adapter_name, **kwargs) + else: + if isinstance(target_base_layer, torch.nn.Linear): + if kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. " + "Setting fan_in_fan_out to False." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False + elif isinstance(target_base_layer, Conv1D): + if not kwargs["fan_in_fan_out"]: + warnings.warn( + "fan_in_fan_out is set to False but the target module is `Conv1D`. " + "Setting fan_in_fan_out to True." + ) + kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True + else: + raise ValueError( + f"Target module {target} is not supported. " + f"Currently, only `torch.nn.Linear` and `Conv1D` are supported." + ) + new_module = SVDLinear(target, adapter_name, **kwargs) + + return new_module + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[ + model_config["model_type"] + ] + return peft_config + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.model, name) + + def forward(self, *args, **kwargs): + outputs = self.model.forward(*args, **kwargs) + + if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor): + # Calculate the orthogonal regularization + orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight + + if orth_reg_weight <= 0: + raise ValueError("orth_reg_weight should be greater than 0. ") + + regu_loss = 0 + num_param = 0 + for n, p in self.model.named_parameters(): + if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n: + if p.shape == torch.Size([0]): + with gather_params_ctx(p, fwd_module=self): + para_cov = p @ p.T if "lora_A" in n else p.T @ p + else: + para_cov = p @ p.T if "lora_A" in n else p.T @ p + I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741 + I.requires_grad = False + num_param += 1 + regu_loss += torch.norm(para_cov - I, p="fro") + if num_param > 0: + regu_loss = regu_loss / num_param + else: + regu_loss = 0 + outputs.loss += orth_reg_weight * regu_loss + return outputs + + def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name): + lora_config = self.peft_config[adapter_name] + for name, rank_idx in rank_pattern.items(): + if isinstance(rank_idx, list): + rank = sum(rank_idx) + elif isinstance(rank_idx, torch.Tensor): + rank_idx = rank_idx.view(-1) + rank = rank_idx.sum().item() + else: + raise ValueError("Unexpected type of rank_idx") + key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) + _, target, _ = _get_submodules(self.model, key) + lora_E_weights = target.lora_E[adapter_name][rank_idx] + lora_A_weights = target.lora_A[adapter_name][rank_idx] + lora_B_weights = target.lora_B[adapter_name][:, rank_idx] + ranknum = target.ranknum[adapter_name] + target.update_layer( + adapter_name, + rank, + lora_config.lora_alpha, + lora_config.lora_dropout, + lora_config.init_lora_weights, + ) + with torch.no_grad(): + if rank > 0: + target.lora_E[adapter_name].copy_(lora_E_weights) + target.lora_A[adapter_name].copy_(lora_A_weights) + target.lora_B[adapter_name].copy_(lora_B_weights) + # The scaling is exactly as the previous + target.ranknum[adapter_name].copy_(ranknum) + + def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name): + for name, rank_idx in rank_pattern.items(): + rank = sum(rank_idx) + prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1]) + for layer in ["lora_E", "lora_A", "lora_B"]: + key = f"base_model.model.{prefix}.{layer}.{adapter_name}" + if layer != "lora_B": + state_dict[key] = ( + state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key] + ) + else: + state_dict[key] = ( + state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key] + ) + return state_dict + + def update_and_allocate(self, global_step): + """ + This method updates Adalora budget and mask. + + This should be called in every training step after `loss.backward()` and before `zero_grad()`. + + `tinit`, `tfinal` and `deltaT` are handled with in the method. + + Args: + global_step (`int`): The current training step, it is used to calculate adalora budget. + + Example: + + ```python + >>> loss = model(**input).loss + >>> loss.backward() + >>> optimizer.step() + >>> model.base_model.update_and_allocate(i_step) + >>> optimizer.zero_grad() + ``` + """ + lora_config = self.peft_config[self.trainable_adapter_name] + # Update the importance score and allocate the budget + if global_step < lora_config.total_step - lora_config.tfinal: + _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step) + if rank_pattern: + lora_config.rank_pattern = rank_pattern + # Finalize the budget allocation + elif global_step == lora_config.total_step - lora_config.tfinal: + _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True) + # for some reason, this freezes the trainable parameters and nothing gets updates + # self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name) + lora_config.rank_pattern = rank_pattern + self.rankallocator.reset_ipt() + # Currently using inefficient way to mask the unimportant weights using the rank pattern + # due to problem mentioned above + elif global_step > lora_config.total_step - lora_config.tfinal: + self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern) + # Pass the function and do forward propagation + else: + return None + + def add_weighted_adapter(self, *args, **kwargs): + """This method is not supported for AdaLoRA, use LoRA instead.""" + raise TypeError(f"{self.__class__.__name__} does not support add_weighted_adapter method.") \ No newline at end of file diff --git a/gptqmodel/integration/peft/tuners/lora/gptq.py b/gptqmodel/integration/peft/tuners/lora/gptq.py new file mode 100644 index 000000000..d33bbb2e5 --- /dev/null +++ b/gptqmodel/integration/peft/tuners/lora/gptq.py @@ -0,0 +1,122 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import torch + +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear +from peft.import_utils import is_gptqmodel_available + + +class QuantLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name: str, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + use_dora: bool = False, + lora_bias: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + if use_dora: + raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False") + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + self._active_adapter = adapter_name + self.update_layer( + adapter_name, + r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout, + init_lora_weights=init_lora_weights, + use_rslora=use_rslora, + use_dora=use_dora, + lora_bias=lora_bias, + ) + + def forward(self, x: torch.Tensor): + # note: logic differs from default Linear because merging is not supported + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result += output + return result + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102 + # def reset_lora_parameters(self, adapter_name): + # if adapter_name in self.lora_A.keys(): + # torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight) + # torch.nn.init.zeros_(self.lora_B[adapter_name].weight) + + +def dispatch_gptq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + cfg = kwargs.get("gptq_quantization_config", None) + + if is_gptqmodel_available(): + device_map = kwargs.get("device_map", None) + quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map) + else: + quant_linear = get_auto_gptq_quant_linear(cfg) + + if quant_linear is not None and isinstance(target_base_layer, quant_linear): + new_module = QuantLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight + + return new_module \ No newline at end of file diff --git a/gptqmodel/integration/peft/tuners/lora/model.py b/gptqmodel/integration/peft/tuners/lora/model.py new file mode 100644 index 000000000..847f276ec --- /dev/null +++ b/gptqmodel/integration/peft/tuners/lora/model.py @@ -0,0 +1,939 @@ +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import math +import operator +import warnings +from contextlib import contextmanager +from dataclasses import asdict, replace +from enum import Enum +from functools import partial, reduce +from typing import Literal, Optional + +import torch +from torch import nn +from tqdm import tqdm + +from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.tuners.tuners_utils import ( + BaseTuner, + BaseTunerLayer, + check_target_module_exists, + onload_layer, + replicate_layers, +) +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + _freeze_adapter, + _get_submodules, + get_peft_model_state_dict, + get_quantization_config, +) +from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties +from peft.utils.other import get_pattern_key + +from .aqlm import dispatch_aqlm +from .awq import dispatch_awq +from .config import LoraConfig +from .eetq import dispatch_eetq +from .gptq import dispatch_gptq +from .hqq import dispatch_hqq +from .layer import Conv2d, LoraLayer, dispatch_default +from .torchao import dispatch_torchao +from .tp_layer import dispatch_megatron + + +def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): + # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference + kwargs["adapter_names"] = adapter_names + return args, kwargs + + +class LoraModel(BaseTuner): + """ + Creates Low Rank Adapter (LoRA) model from a pretrained transformers model. + + The method is described in detail in https://arxiv.org/abs/2106.09685. + + Args: + model ([`torch.nn.Module`]): The model to be adapted. + config ([`LoraConfig`]): The configuration of the Lora model. + adapter_name (`str`): The name of the adapter, defaults to `"default"`. + low_cpu_mem_usage (`bool`, `optional`, defaults to `False`): + Create empty adapter weights on meta device. Useful to speed up the loading process. + + Returns: + `torch.nn.Module`: The Lora model. + + Example: + + ```py + >>> from transformers import AutoModelForSeq2SeqLM + >>> from peft import LoraModel, LoraConfig + + >>> config = LoraConfig( + ... task_type="SEQ_2_SEQ_LM", + ... r=8, + ... lora_alpha=32, + ... target_modules=["q", "v"], + ... lora_dropout=0.01, + ... ) + + >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") + >>> lora_model = LoraModel(model, config, "default") + ``` + + ```py + >>> import torch + >>> import transformers + >>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training + + >>> rank = ... + >>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"] + >>> config = LoraConfig( + ... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM" + ... ) + >>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True) + + >>> tokenizer = transformers.AutoTokenizer.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... bos_token="[BOS]", + ... eos_token="[EOS]", + ... unk_token="[UNK]", + ... pad_token="[PAD]", + ... mask_token="[MASK]", + ... ) + >>> model = transformers.GPTJForCausalLM.from_pretrained( + ... "kakaobrain/kogpt", + ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b + ... pad_token_id=tokenizer.eos_token_id, + ... use_cache=False, + ... device_map={"": rank}, + ... torch_dtype=torch.float16, + ... quantization_config=quantization_config, + ... ) + >>> model = prepare_model_for_kbit_training(model) + >>> lora_model = get_peft_model(model, config) + ``` + + **Attributes**: + - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted. + - **peft_config** ([`LoraConfig`]): The configuration of the Lora model. + """ + + prefix: str = "lora_" + + def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None: + super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage) + + def _check_new_adapter_config(self, config: LoraConfig) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check + # does not fully correspond to the error message. + if (len(self.peft_config) > 1) and (config.bias != "none"): + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(lora_config, key): + return check_target_module_exists(lora_config, key) + + def _prepare_model(self, peft_config: LoraConfig, model: nn.Module): + r""" + A private method to modify the model structure before adapter is applied. + + Args: + peft_config (`PeftConfig`): + The prepared adapter config. + model (`nn.Module`): + The model that is going to be adapted. + """ + if peft_config.layer_replication: + replicate_layers(model, peft_config.layer_replication) + + def _create_and_replace( + self, + lora_config, + adapter_name, + target, + target_name, + parent, + current_key, + ): + if current_key is None: + raise ValueError("Current Key shouldn't be `None`") + + # Regexp matching - Find key which matches current target_name in patterns provided + r_key = get_pattern_key(lora_config.rank_pattern.keys(), current_key) + alpha_key = get_pattern_key(lora_config.alpha_pattern.keys(), current_key) + r = lora_config.rank_pattern.get(r_key, lora_config.r) + alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha) + + kwargs = { + "r": r, + "lora_alpha": alpha, + "lora_dropout": lora_config.lora_dropout, + "fan_in_fan_out": lora_config.fan_in_fan_out, + "init_lora_weights": lora_config.init_lora_weights, + "use_rslora": lora_config.use_rslora, + "use_dora": lora_config.use_dora, + "ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload, + "lora_bias": lora_config.lora_bias, + "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False), + "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False), + } + # for torchao merging, we need the get_apply_tensor_subclass from the quantization config + try: + kwargs["get_apply_tensor_subclass"] = operator.attrgetter( + "hf_quantizer.quantization_config.get_apply_tensor_subclass" + )(self.model) + except AttributeError: + pass + + quant_methods = ["gptq", "aqlm", "awq"] + for quant_method in quant_methods: + quantization_config = get_quantization_config(self.model, method=quant_method) + if quantization_config is not None: + kwargs[f"{quant_method}_quantization_config"] = quantization_config + + # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it + from peft.tuners.adalora import AdaLoraLayer + + if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer): + target.update_layer( + adapter_name, + r, + lora_alpha=alpha, + lora_dropout=lora_config.lora_dropout, + init_lora_weights=lora_config.init_lora_weights, + use_rslora=lora_config.use_rslora, + use_dora=lora_config.use_dora, + lora_bias=lora_config.lora_bias, + ) + else: + new_module = self._create_new_module(lora_config, adapter_name, target, device_map=self.model.hf_device_map, **kwargs) + if adapter_name not in self.active_adapters: + # adding an additional adapter: it is not automatically trainable + new_module.requires_grad_(False) + self._replace_module(parent, target_name, new_module, target) + + def _replace_module(self, parent, child_name, new_module, child): + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.base_layer + + if not hasattr(new_module, "base_layer"): + if hasattr(new_module, "W_q"): # HQQ + new_module.W_q = child.W_q + else: + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + meta = torch.device("meta") + # dispatch to correct device + for name, module in new_module.named_modules(): + if (self.prefix in name) or ("ranknum" in name): + weight = ( + child.qweight + if hasattr(child, "qweight") + else child.W_q + if hasattr(child, "W_q") + else child.weight + if hasattr(child, "weight") + else next(child.parameters()) + ) + if not any(p.device == meta for p in module.parameters()): + module.to(weight.device) + + def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None: + for n, p in model.named_parameters(): + if self.prefix not in n: + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = self.peft_config[active_adapter].bias + if bias == "none": + continue + + if bias == "all": + for n, p in model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + for m in model.modules(): + if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise NotImplementedError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(lora_config, adapter_name, target, **kwargs): + # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters, + # because the first match is always used. Therefore, the default layers should be checked last. + dispatchers = [] + + if lora_config._custom_modules: + # Experimental custom LoRA module support. Allows users to pass a custom mapping for unsupported layer + # types by impelementing their own LoRA layers. + def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + for key, custom_cls in lora_config._custom_modules.items(): + if isinstance(target_base_layer, key): + new_module = custom_cls(target, adapter_name, **kwargs) + break + + return new_module + + dispatchers.append(dynamic_dispatch_func) + + # avoid eager bnb import + if is_bnb_available(): + from .bnb import dispatch_bnb_8bit + + dispatchers.append(dispatch_bnb_8bit) + + if is_bnb_4bit_available(): + from .bnb import dispatch_bnb_4bit + + dispatchers.append(dispatch_bnb_4bit) + + dispatchers.extend( + [ + dispatch_eetq, + dispatch_aqlm, + dispatch_awq, + dispatch_gptq, + dispatch_hqq, + dispatch_torchao, + dispatch_megatron, + dispatch_default, + ] + ) + + new_module = None + for dispatcher in dispatchers: + new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs) + if new_module is not None: # first match wins + break + + if new_module is None: + # no module could be matched + raise ValueError( + f"Target module {target} is not supported. Currently, only the following modules are supported: " + "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, " + "`transformers.pytorch_utils.Conv1D`." + ) + + return new_module + + def __getattr__(self, name: str): + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + if name == "model": # see #1892: prevent infinite recursion if class is not initialized + raise + return getattr(self.model, name) + + def get_peft_config_as_dict(self, inference: bool = False): + config_dict = {} + for key, value in self.peft_config.items(): + config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()} + if inference: + config["inference_mode"] = True + config_dict[key] = config + return config + + def _set_adapter_layers(self, enabled: bool = True) -> None: + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self) -> None: + """Enable all adapters. + + Call this if you have previously disabled all adapters and want to re-enable them. + """ + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self) -> None: + """Disable all adapters. + + When disabling all adapters, the model output corresponds to the output of the base model. + """ + for active_adapter in self.active_adapters: + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: str | list[str]) -> None: + """Set the active adapter(s). + + Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated. + """ + for module in self.model.modules(): + if isinstance(module, LoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @contextmanager + def _enable_peft_forward_hooks(self, *args, **kwargs): + # If adapter_names is passed as an argument, we inject it into the forward arguments. + adapter_names = kwargs.pop("adapter_names", None) + if adapter_names is None: + # nothing to do + yield + return + + if self.training: + raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + + # Check that users only passed actually existing adapters. + # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want + # to check that there is at least one layer with the given name, or else something like typos can easily slip. + expected_adapters = set() + for layer in self.modules(): + if isinstance(layer, LoraLayer): + expected_adapters |= layer.lora_A.keys() + expected_adapters |= layer.lora_embedding_A.keys() + unique_adapters = {name for name in adapter_names if name != "__base__"} + unexpected_adapters = unique_adapters - expected_adapters + if unexpected_adapters: + raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") + + hook_handles = [] + for module in self.modules(): + if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper): + pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names) + handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True) + hook_handles.append(handle) + + yield + + for handle in hook_handles: + handle.remove() + + def _check_merge_allowed(self): + """Verify that the configuration supports merging. + + Currently gptq quantization and replicated layers do not support merging. + """ + super()._check_merge_allowed() + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge LORA layers when the model is gptq quantized") + if self.peft_config.get("layer_replication"): + raise ValueError("Cannot merge LORA layers when base model layers are replicated") + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + self._check_merge_allowed() + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + desc = "Unloading " + ("and merging " if merge else "") + "model" + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + with onload_layer(target): + if hasattr(target, "base_layer"): + if merge: + target.merge(safe_merge=safe_merge, adapter_names=adapter_names) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + new_module = target.modules_to_save[target.active_adapter] + if hasattr(new_module, "base_layer"): + # check if the module is itself a tuner layer + if merge: + new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + new_module = new_module.get_base_layer() + setattr(parent, target_name, new_module) + + return self.model + + def _check_add_weighted_adapter( + self, adapters: list[str], combination_type: str, svd_rank: int | None + ) -> tuple[str, int, str]: + """ + Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying + model. + """ + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + # If more than one of the adapters targets the same module with modules_to_save, raise an error, as these + # modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they + # have modules for the adapters to be merged. + modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)] + problematic_wrappers = [ + wrapper + for wrapper in modules_to_save_wrappers + if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1 + ] + if problematic_wrappers: + raise ValueError( + "Cannot add weighted adapters if they target the same module with modules_to_save, but found " + f"{len(problematic_wrappers)} such instance(s)." + ) + + # if there is only one adapter, we can only use linear merging + combination_type = "linear" if len(adapters) == 1 else combination_type + + adapters_ranks = [self.peft_config[adapter].r for adapter in adapters] + if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"): + # all adapters ranks should be same, new rank is just this value + if len(set(adapters_ranks)) != 1: + raise ValueError( + "All adapters must have the same r value when using combination_type linear, ties, dare_ties or " + "dare_linear." + ) + new_rank = adapters_ranks[0] + elif combination_type == "cat": + # adapters ranks may be different, new rank is sum of all ranks + # be careful, because output adapter rank may be really big if mixing a lot of adapters + new_rank = sum(adapters_ranks) + elif combination_type.endswith("svd"): + # new rank is the max of all ranks of the adapters if not provided + new_rank = svd_rank or max(adapters_ranks) + else: + raise ValueError(f"Invalid combination_type: {combination_type}") + + target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters] + if not target_module_types: + raise ValueError(f"Found no adapter matching the names in {adapters}") + if len(set(target_module_types)) > 1: + raise ValueError( + "all adapter configs should follow the same target modules type. " + "Combining adapters with `target_modules` type being a mix of list/set and string is not supported." + ) + + if target_module_types[0] is str: + new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters) + elif target_module_types[0] is set: + new_target_modules = reduce( + operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters) + ) + else: + raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules") + + return combination_type, new_rank, new_target_modules + + def add_weighted_adapter( + self, + adapters: list[str], + weights: list[float], + adapter_name: str, + combination_type: str = "svd", + svd_rank: int | None = None, + svd_clamp: int | None = None, + svd_full_matrices: bool = True, + svd_driver: str | None = None, + density: float | None = None, + majority_sign_method: Literal["total", "frequency"] = "total", + ) -> None: + """ + This method adds a new adapter by merging the given adapters with the given weights. + + When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to + the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM + errors. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + combination_type (`str`): + The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`, + `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat` + combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the + mixed adapter may be too big and result in OOM errors). + svd_rank (`int`, *optional*): + Rank of output adapter for svd. If None provided, will use max rank of merging adapters. + svd_clamp (`float`, *optional*): + A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform + clamping. Defaults to None. + svd_full_matrices (`bool`, *optional*): + Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned + tensors U and Vh. Defaults to True. + svd_driver (`str`, *optional*): + Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be + one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd` + documentation. Defaults to None. + density (`float`, *optional*): + Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used + with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`, + `magnintude_prune`, `magnitude_prune_svd`] + majority_sign_method (`str`): + The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values. + Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`] + """ + + if adapter_name in list(self.peft_config.keys()): + return + + combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter( + adapters=adapters, + combination_type=combination_type, + svd_rank=svd_rank, + ) + + self.peft_config[adapter_name] = replace( + self.peft_config[adapters[0]], + r=new_rank, + lora_alpha=new_rank, + target_modules=new_target_modules, + ) + self.inject_adapter(self.model, adapter_name) + + # Do we really need that? + _freeze_adapter(self.model, adapter_name) + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LoraLayer): + if adapter_name in target.lora_A: + target_lora_A = target.lora_A[adapter_name].weight + target_lora_B = target.lora_B[adapter_name].weight + elif adapter_name in target.lora_embedding_A: + target_lora_A = target.lora_embedding_A[adapter_name] + target_lora_B = target.lora_embedding_B[adapter_name] + else: + continue + + target_lora_A.data = target_lora_A.data * 0.0 + target_lora_B.data = target_lora_B.data * 0.0 + if combination_type == "cat": + loras_A, loras_B = [], [] + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A: + current_adapter_lora_A = target.lora_A[adapter].weight + current_adapter_lora_B = target.lora_B[adapter].weight + elif adapter in target.lora_embedding_A: + current_adapter_lora_A = target.lora_embedding_A[adapter] + current_adapter_lora_B = target.lora_embedding_B[adapter] + else: + continue + loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter]) + loras_B.append(current_adapter_lora_B.data) + + if len(loras_A) == 0: + raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.") + loras_A = torch.cat(loras_A, dim=0) + loras_B = torch.cat(loras_B, dim=1) + target_lora_A.data[: loras_A.shape[0], :] = loras_A + target_lora_B.data[:, : loras_B.shape[1]] = loras_B + elif combination_type in [ + "svd", + "ties_svd", + "dare_linear_svd", + "dare_ties_svd", + "magnitude_prune_svd", + ]: + target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter( + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + svd_clamp, + full_matrices=svd_full_matrices, + driver=svd_driver, + ) + elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]: + target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter( + combination_type, adapters, weights, target, density, majority_sign_method + ) + + def _svd_generalized_task_arithmetic_weighted_adapter( + self, + combination_type, + adapters, + weights, + new_rank, + target, + target_lora_A, + target_lora_B, + density, + majority_sign_method, + clamp=None, + full_matrices=True, + driver=None, + ): + valid_adapters = [] + valid_weights = [] + is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters) + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A or adapter in target.lora_embedding_A: + valid_adapters.append(adapter) + valid_weights.append(weight * target.scaling[adapter]) + + # if no valid adapter, nothing to do + if len(valid_adapters) == 0: + raise ValueError("No matching LoRAs found. Please raise an issue on Github.") + delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters] + valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device) + if combination_type == "svd": + delta_weight = task_arithmetic(delta_weight, valid_weights) + elif combination_type == "ties_svd": + delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method) + elif combination_type == "dare_linear_svd": + delta_weight = dare_linear(delta_weight, valid_weights, density) + elif combination_type == "dare_ties_svd": + delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method) + elif combination_type == "magnitude_prune_svd": + delta_weight = magnitude_prune(delta_weight, valid_weights, density) + else: + raise ValueError(f"Invalid value passed to combination type: {combination_type}") + + conv2d = isinstance(target, Conv2d) + if conv2d: + conv2d_1x1 = target.weight.size()[2:4] == (1, 1) + if not conv2d_1x1: + delta_weight = delta_weight.flatten(start_dim=1) + else: + delta_weight = delta_weight.squeeze() + if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding: + delta_weight = delta_weight.T + + # based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131 + U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver) + U = U[:, :new_rank] + S = S[:new_rank] + U = U @ torch.diag(S) + Vh = Vh[:new_rank, :] + if clamp is not None: + dist = torch.cat([U.flatten(), Vh.flatten()]) + hi_val = torch.quantile(dist, clamp) + low_val = -hi_val + U = U.clamp(low_val, hi_val) + Vh = Vh.clamp(low_val, hi_val) + if conv2d: + U = U.reshape(target_lora_B.data.shape) + Vh = Vh.reshape(target_lora_A.data.shape) + return Vh, U + + def _generalized_task_arithmetic_weighted_adapter( + self, + combination_type, + adapters, + weights, + target, + density, + majority_sign_method, + ): + # account weights for LoRA A and B layers. + valid_weights = [] + lora_A_deltas = [] + lora_B_deltas = [] + for adapter, weight in zip(adapters, weights): + if adapter in target.lora_A: + current_adapter_lora_A = target.lora_A[adapter].weight + current_adapter_lora_B = target.lora_B[adapter].weight + elif adapter in target.lora_embedding_A: + current_adapter_lora_A = target.lora_embedding_A[adapter] + current_adapter_lora_B = target.lora_embedding_B[adapter] + else: + continue + valid_weights.append(math.sqrt(weight * target.scaling[adapter])) + lora_A_deltas.append(current_adapter_lora_A.data) + lora_B_deltas.append(current_adapter_lora_B.data) + valid_weights = torch.tensor(valid_weights).to(lora_A_deltas[0].device) + lora_deltas = [lora_A_deltas, lora_B_deltas] + dtype = lora_A_deltas[0].dtype + for i, task_tensors in enumerate(lora_deltas): + if combination_type == "linear": + lora_deltas[i] = task_arithmetic(task_tensors, valid_weights) + elif combination_type == "ties": + lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method) + elif combination_type == "dare_linear": + lora_deltas[i] = dare_linear(task_tensors, valid_weights, density) + elif combination_type == "dare_ties": + lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method) + elif combination_type == "magnitude_prune": + lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density) + else: + raise ValueError("Invalid combination type") + lora_deltas = [delta.to(dtype) for delta in lora_deltas] + return lora_deltas + + def delete_adapter(self, adapter_name: str) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, LoraLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> torch.nn.Module: + r""" + This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model + as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + Example: + + ```py + >>> from transformers import AutoModelForCausalLM + >>> from peft import PeftModel + + >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b") + >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample" + >>> model = PeftModel.from_pretrained(base_model, peft_model_id) + >>> merged_model = model.merge_and_unload() + ``` + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> torch.nn.Module: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None): + """ + This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA | + OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus + converting [PiSSA | OLoRA] to LoRA. + """ + for name, param in self.model.named_parameters(): + if ( + param.data.dtype != torch.float32 + and param.data.dtype != torch.float16 + and param.data.dtype != torch.bfloat16 + ) and adapter_name.startswith("pissa"): + warnings.warn( + r"Note that Quant(W_res) + AB != Quant(W) + \Delta(AB); " + "the converted LoRA, when combined with W or Quant(W), may introduce a certain gap in the fine-tuned model. " + "Therefore, we recommend directly using the Quant(W_res) in conjunction with the PiSSA adapter. " + ) + mutated_init_state_dict = get_peft_model_state_dict( + self, + state_dict=kwargs.get("state_dict", None), + adapter_name=adapter_name, + ) + tensors_lora = {} + for name in output_state_dict.keys(): + ## W = W^{res} + A_0 \times B_0, + ## W + \Delta W = W^{res} + A \times B, + ## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'. + if "lora_A" in name: + tensors_lora[name] = torch.cat( + [output_state_dict[name], mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=0 + ) + elif "lora_B" in name: + tensors_lora[name] = torch.cat( + [output_state_dict[name], -mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=1 + ) + + return tensors_lora \ No newline at end of file diff --git a/gptqmodel/integration/peft/tuners/utils/__init__.py b/gptqmodel/integration/peft/tuners/utils/__init__.py new file mode 100644 index 000000000..b4f105907 --- /dev/null +++ b/gptqmodel/integration/peft/tuners/utils/__init__.py @@ -0,0 +1,57 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType +from .integrations import map_cache_to_layer_device_map +from .loftq_utils import replace_lora_weights_loftq +from .peft_types import PeftType, TaskType +from .other import ( + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, + CONFIG_NAME, + WEIGHTS_NAME, + SAFETENSORS_WEIGHTS_NAME, + INCLUDE_LINEAR_LAYERS_SHORTHAND, + _set_trainable, + bloom_model_postprocess_past_key_value, + prepare_model_for_kbit_training, + shift_tokens_right, + transpose, + _get_batch_size, + _get_submodules, + _set_adapter, + _freeze_adapter, + ModulesToSaveWrapper, + _prepare_prompt_learning_config, + _is_valid_match, + infer_device, + get_auto_gptq_quant_linear, + get_gptqmodel_quant_linear, + get_quantization_config, + id_tensor_storage, + cast_mixed_precision_params, +) +from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights \ No newline at end of file diff --git a/gptqmodel/integration/peft/tuners/utils/other.py b/gptqmodel/integration/peft/tuners/utils/other.py new file mode 100644 index 000000000..2116a1970 --- /dev/null +++ b/gptqmodel/integration/peft/tuners/utils/other.py @@ -0,0 +1,759 @@ + +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import copy +import inspect +import os +import re +import warnings +from contextlib import nullcontext +from typing import Any, Optional + +import accelerate +import torch +from accelerate.hooks import add_hook_to_module, remove_hook_from_module +from accelerate.utils import is_npu_available, is_xpu_available +from huggingface_hub import file_exists +from huggingface_hub.errors import EntryNotFoundError, HFValidationError +from packaging import version +from safetensors.torch import storage_ptr, storage_size + +from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available +from .constants import ( + CONFIG_NAME, + EMBEDDING_LAYER_NAMES, + INCLUDE_LINEAR_LAYERS_SHORTHAND, + SAFETENSORS_WEIGHTS_NAME, + TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, + TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, + TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, + WEIGHTS_NAME, + bloom_model_postprocess_past_key_value, + starcoder_model_postprocess_past_key_value, +) + + +mlu_available = False +if version.parse(accelerate.__version__) >= version.parse("0.29.0"): + from accelerate.utils import is_mlu_available + + mlu_available = is_mlu_available() + + +__all__ = [ + "CONFIG_NAME", + "EMBEDDING_LAYER_NAMES", + "SAFETENSORS_WEIGHTS_NAME", + "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING", + "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING", + "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING", + "WEIGHTS_NAME", + "INCLUDE_LINEAR_LAYERS_SHORTHAND", + "bloom_model_postprocess_past_key_value", + "starcoder_model_postprocess_past_key_value", +] + + +# Get current device name based on available devices +def infer_device() -> str: + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + elif mlu_available: + return "mlu" + elif is_xpu_available(): + return "xpu" + elif is_npu_available(): + return "npu" + return "cpu" + + +def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None): + r""" + Note this method only works for `transformers` models. + + This method wraps the entire protocol for preparing a model before running a training. This includes: + 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm + head to fp32 + + Args: + model (`transformers.PreTrainedModel`): + The loaded model from `transformers` + use_gradient_checkpointing (`bool`, *optional*, defaults to `True`): + If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`): + Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of + `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method. + Note this is only available in the latest transformers versions (> 4.34.1). + """ + loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) + is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" + is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" + is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" + is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao" + is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + + for name, param in model.named_parameters(): + # freeze base model's layers + param.requires_grad = False + + if ( + not is_gptq_quantized + and not is_aqlm_quantized + and not is_eetq_quantized + and not is_hqq_quantized + and not is_torchao_quantized + ): + # cast all non INT8 parameters to fp32 + for param in model.parameters(): + if ( + (param.dtype == torch.float16) or (param.dtype == torch.bfloat16) + ) and param.__class__.__name__ != "Params4bit": + param.data = param.data.to(torch.float32) + + if ( + loaded_in_kbit + or is_gptq_quantized + or is_aqlm_quantized + or is_eetq_quantized + or is_hqq_quantized + or is_torchao_quantized + ) and use_gradient_checkpointing: + # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack + if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: + # For backward compatibility + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs + _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list( + inspect.signature(model.gradient_checkpointing_enable).parameters + ) + + if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0: + warnings.warn( + "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored." + " if you want to use that feature, please upgrade to the latest version of transformers.", + FutureWarning, + ) + + gc_enable_kwargs = ( + {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs} + ) + + # enable gradient checkpointing for memory efficiency + model.gradient_checkpointing_enable(**gc_enable_kwargs) + return model + + +# copied from transformers.models.bart.modeling_bart +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): + """ + Shift input ids one token to the right. + + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids + pad_token_id (`int`): The id of the `padding` token. + decoder_start_token_id (`int`): The id of the `start` token. + """ + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() + shifted_input_ids[:, 0] = decoder_start_token_id + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class ModulesToSaveWrapper(torch.nn.Module): + def __init__(self, module_to_save, adapter_name): + super().__init__() + self.original_module = module_to_save + self.modules_to_save = torch.nn.ModuleDict({}) + self._active_adapter = adapter_name + self._disable_adapters = False + self.update(adapter_name) + self.check_module() + + def check_module(self): + """Perform some sanity checks on the module to ensure that it works""" + # Try to anticipate some modules that users could try to target that would not work. + # Note: It's not possible to check hasattr(module, "forward"), since that returns True for ModuleDict and + # ModuleList, even though their forward methods cannot be called + forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList) + if isinstance(self.original_module, forbidden_classes): + cls_name = self.original_module.__class__ + raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}") + + # local import to avoid circular import + from peft.tuners.tuners_utils import BaseTunerLayer + + if isinstance(self.original_module, BaseTunerLayer): + # e.g. applying modules_to_save to a lora layer makes no sense + cls_name = self.original_module.__class__ + raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}") + + @property + def disable_adapters(self) -> bool: + # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method + return self._disable_adapters + + @property + def active_adapter(self) -> str: + # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method + return self._active_adapter + + def __getattr__(self, name: str): + # Note: This whole method may seem overly complex at first but PyTorch messes with __getattr__ in a way that + # requires very careful handling to avoid infinite recursion. + try: + return super().__getattr__(name) + except AttributeError: + pass + + if "_modules" not in self.__dict__: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + # Could not find the attribute the PyTorch way. So let's check if it's an attribute on the + # original_module/modules_to_save. + modules = self.__dict__["_modules"] + if self.disable_adapters: + module = modules["original_module"] + elif self.active_adapter in modules["modules_to_save"]: + module = modules["modules_to_save"][self.active_adapter] + else: + # For some reason, there is no module corresponding to the active adapter; this should normally not be + # reached and exists as a failsafe (otherwise, a KeyError would be raised) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + return getattr(module, name) + + def update(self, adapter_name): + context_manager = nullcontext() + for _, param in self.original_module.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + import deepspeed + + context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0) + break + with context_manager: + self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)})) + + if hasattr(self.modules_to_save[adapter_name], "_hf_hook"): + old_hook = self.modules_to_save[adapter_name]._hf_hook + new_hook = self._create_new_hook(old_hook) + remove_hook_from_module(self.modules_to_save[adapter_name]) + add_hook_to_module(self.modules_to_save[adapter_name], new_hook) + + self.original_module.requires_grad_(False) + if adapter_name == self.active_adapter: + self.modules_to_save[adapter_name].requires_grad_(True) + + def _create_new_hook(self, old_hook): + r""" + Creates a new hook based on the old hook. Use it only if you know what you are doing ! + """ + old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__) + old_hook_attr = old_hook.__dict__ + filtered_old_hook_attr = {} + old_hook_init_signature = inspect.signature(old_hook_cls.__init__) + for k in old_hook_attr.keys(): + if k in old_hook_init_signature.parameters: + filtered_old_hook_attr[k] = old_hook_attr[k] + new_hook = old_hook_cls(**filtered_old_hook_attr) + return new_hook + + def _check_forward_args(self, x, *args, **kwargs): + """Check if the arguments are compatible with the configs and state of the model""" + adapter_names = kwargs.get("adapter_names", None) + if adapter_names is None: + return + + if len(x) != len(adapter_names): + msg = ( + "Length of `adapter_names` should be the same as the number of inputs, but got " + f"{len(adapter_names)} and {len(x)} respectively." + ) + raise ValueError(msg) + + def _mixed_batch_forward( + self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any + ) -> torch.Tensor: + # This is a special method that handles the case when users pass the argument `adapter_names`. This is an + # extra argument that allows mixing different adapters in the same batch at inference time. + + SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d) + + module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES]) + + if not isinstance(self.original_module, SUPPORTED_MODULES): + raise TypeError(f"Mixed batching is only supported for the following modules: {module_names}.") + + unique_adapters = set(adapter_names) + sub_batch_indices_list = [] + + for adapter in unique_adapters: + sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter]) + + results = [0 for _ in range(len(input))] + + for i, active_adapter in enumerate(unique_adapters): + sub_batch = input[sub_batch_indices_list[i]] + + if active_adapter == "__base__": + output = self.original_module(sub_batch, *args, **kwargs) + else: + output = self.modules_to_save[active_adapter](sub_batch, *args, **kwargs) + + for index, j in enumerate(sub_batch_indices_list[i]): + results[j] = output[index] + + return torch.stack(results) + + def forward(self, x: torch.Tensor, *args, **kwargs): + self._check_forward_args(x, *args, **kwargs) + adapter_names = kwargs.pop("adapter_names", None) + + if self.disable_adapters or (self.active_adapter not in self.modules_to_save): + return self.original_module(x, *args, **kwargs) + if adapter_names is None: + return self.modules_to_save[self.active_adapter](x, *args, **kwargs) + return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs) + + def enable_adapters(self, enabled: bool): + """Toggle the enabling and disabling of adapters + + Takes care of setting the requires_grad flag for the adapter weights. + + Args: + enabled (bool): True to enable adapters, False to disable adapters + """ + if self._disable_adapters is not enabled: + # already in the desired state, do nothing + return + + if enabled: + self.original_module.requires_grad_(False) + self.modules_to_save[self.active_adapter].requires_grad_(True) + self._disable_adapters = False + else: + self.original_module.requires_grad_(True) + self.modules_to_save.requires_grad_(False) + self._disable_adapters = True + + def set_adapter(self, adapter_name: str): + """Set the active adapter + + Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is + not desired, use the following code. + + ```py + >>> for name, param in model_peft.named_parameters(): + ... if ...: # some check on name (ex. if 'lora' in name) + ... param.requires_grad = False + ``` + + Args: + adapter_name (str): The name of the adapter to set as active + """ + if adapter_name not in self.modules_to_save: + raise ValueError(f"Adapter {adapter_name} not found in {self.modules_to_save.keys()}") + + self.modules_to_save[self.active_adapter].requires_grad_(False) + self.modules_to_save[adapter_name].requires_grad_(True) + self._active_adapter = adapter_name + + +def _get_submodules(model, key): + parent = model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, target, target_name + + +def _freeze_adapter(model, adapter_name): + for n, p in model.named_parameters(): + if adapter_name in n: + p.requires_grad = False + + +def _set_trainable(model, adapter_name): + key_list = [key for key, _ in model.named_modules()] + for key in key_list: + target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save) + if target_module_found: + parent, target, target_name = _get_submodules(model, key) + if isinstance(target, ModulesToSaveWrapper): + target.update(adapter_name) + target.set_adapter(target.active_adapter) + else: + new_module = ModulesToSaveWrapper(target, adapter_name) + new_module.set_adapter(adapter_name) + setattr(parent, target_name, new_module) + + +def _set_adapter(model, adapter_name): + def check_adapter_name(adapter_name): + if isinstance(adapter_name, str): + return adapter_name + + # adapter_name is a list of str + if len(adapter_name) > 1: + raise ValueError("Only one adapter can be set at a time for modules_to_save") + elif len(adapter_name) == 0: + raise ValueError("Please specify at least one adapter to set") + adapter_name = adapter_name[0] + return adapter_name + + for module in model.modules(): + if isinstance(module, ModulesToSaveWrapper): + # only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care + adapter_name = check_adapter_name(adapter_name) + + # if the adapter is found in this module, set it as the active adapter, else disable the adapters of this + # module + if adapter_name in module.modules_to_save: + module.set_adapter(adapter_name) + else: + module.enable_adapters(False) + + +def _prepare_prompt_learning_config(peft_config, model_config): + if peft_config.num_layers is None: + if "num_hidden_layers" in model_config: + num_layers = model_config["num_hidden_layers"] + elif "num_layers" in model_config: + num_layers = model_config["num_layers"] + elif "n_layer" in model_config: + num_layers = model_config["n_layer"] + else: + raise ValueError("Please specify `num_layers` in `peft_config`") + peft_config.num_layers = num_layers + + if peft_config.token_dim is None: + if "hidden_size" in model_config: + token_dim = model_config["hidden_size"] + elif "n_embd" in model_config: + token_dim = model_config["n_embd"] + elif "d_model" in model_config: + token_dim = model_config["d_model"] + else: + raise ValueError("Please specify `token_dim` in `peft_config`") + peft_config.token_dim = token_dim + + if peft_config.num_attention_heads is None: + if "num_attention_heads" in model_config: + num_attention_heads = model_config["num_attention_heads"] + elif "n_head" in model_config: + num_attention_heads = model_config["n_head"] + elif "num_heads" in model_config: + num_attention_heads = model_config["num_heads"] + elif "encoder_attention_heads" in model_config: + num_attention_heads = model_config["encoder_attention_heads"] + else: + raise ValueError("Please specify `num_attention_heads` in `peft_config`") + peft_config.num_attention_heads = num_attention_heads + + # For grouped-query attention, see #1901. + if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config: + num_key_value_heads = model_config["num_key_value_heads"] + peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads + peft_config.num_attention_heads = num_key_value_heads + + if getattr(peft_config, "encoder_hidden_size", None) is None: + setattr(peft_config, "encoder_hidden_size", peft_config.token_dim) + + return peft_config + + +def fsdp_auto_wrap_policy(model): + import functools + import os + + from accelerate import FullyShardedDataParallelPlugin + + if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"): + get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name + else: + from accelerate.utils.dataclasses import get_module_class_from_name + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + default_transformer_cls_names_to_wrap = ( + ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" + ) + transformer_cls_names_to_wrap = os.environ.get( + "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap + ).split(",") + transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding} + for layer_class in transformer_cls_names_to_wrap: + if len(layer_class) == 0: + continue + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=transformer_cls_to_wrap, + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy + + +def transpose(weight, fan_in_fan_out): + if not fan_in_fan_out: + return weight + + if isinstance(weight, torch.nn.Parameter): + return torch.nn.Parameter(weight.T) + return weight.T + + +def _is_valid_match(key: str, target_key: str): + """ + Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key + or the target_key is a submodule of key + """ + if key.endswith(target_key): + if len(key) > len(target_key): + return key.endswith("." + target_key) # must be a sub module + return True + return False + + +def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: + """Get the batch size based on either input_ids or input_embeds + + Raises an ValueError if both are None. + + """ + if (input_ids is None) and (inputs_embeds is None): + raise ValueError("You have to provide either input_ids or inputs_embeds") + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + return batch_size + + +def get_quantization_config(model: torch.nn.Module, method: str): + """ + Get the quantization config of the related quantization method + """ + if ( + hasattr(model, "config") + and hasattr(model.config, "quantization_config") + and (getattr(model, "quantization_method", None) == method) + ): + return model.config.quantization_config + return None + + +def get_auto_gptq_quant_linear(gptq_quantization_config): + """ + Get the right AutoGPTQQuantLinear class based on the quantization config file + """ + if gptq_quantization_config is None: + return None + + if is_auto_gptq_available(): + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear + else: + return None + + desc_act = gptq_quantization_config.desc_act + group_size = gptq_quantization_config.group_size + bits = gptq_quantization_config.bits + if hasattr(gptq_quantization_config, "use_exllama"): + use_exllama = gptq_quantization_config.use_exllama + else: + use_exllama = not gptq_quantization_config.disable_exllama + if hasattr(gptq_quantization_config, "exllama_config"): + exllama_version = gptq_quantization_config.exllama_config["version"] + else: + exllama_version = 1 + + QuantLinear = dynamically_import_QuantLinear( + use_triton=False, + desc_act=desc_act, + group_size=group_size, + bits=bits, + disable_exllama=not (use_exllama and exllama_version == 1), + disable_exllamav2=not (use_exllama and exllama_version == 2), + ) + + return QuantLinear + + +def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None): + """ + Get the right GPTQQuantLinear class based on the quantization config file + """ + if gptq_quantization_config is None: + return None + + if not is_gptqmodel_available(): + return None + + from gptqmodel.utils.importer import hf_select_quant_linear + + desc_act = gptq_quantization_config.desc_act + group_size = gptq_quantization_config.group_size + bits = gptq_quantization_config.bits + checkpoint_format = gptq_quantization_config.checkpoint_format if hasattr(gptq_quantization_config, "checkpoint_format") else "gptq" + sym = gptq_quantization_config.sym + meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None + + QuantLinear = hf_select_quant_linear(bits=bits, group_size=group_size, + desc_act=desc_act, sym=sym, device_map=device_map, + checkpoint_format=checkpoint_format, meta=meta, backend="auto_trainable") + + return QuantLinear + + +def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + + This method is the exact same copy of + https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added + it here manually to avoid import issue with old versions of transformers. + """ + if tensor.device.type == "xla" and is_torch_tpu_available(): + # NOTE: xla tensors dont have storage + # use some other unique id to distinguish. + # this is a XLA tensor, it must be created using torch_xla's + # device. So the following import is safe: + import torch_xla + + unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor) + else: + unique_id = storage_ptr(tensor) + + return tensor.device, unique_id, storage_size(tensor) + + +def cast_mixed_precision_params(model, dtype): + """ + Cast all non-trainable parameters of the model to the given `dtype`. The `dtype` can be `torch.float16` or + `torch.bfloat16` as per the mixed-precision training you are performing. The trainable parameters are cast to full + precision. This is meant to reduce the GPU memory usage when using PEFT methods by using half-precision dtype for + non-trainable parameters. Having the trainable parameters in full-precision preserves training stability when using + automatic mixed-precision training. + + Args: + model (`torch.nn.Module`): + The model to cast the non-trainable parameters of. + dtype (`torch.dtype`): + The dtype to cast the non-trainable parameters to. The `dtype` can be `torch.float16` or + `torch.bfloat16` as per the mixed-precision training you are performing. + """ + for p in model.parameters(): + if not p.requires_grad: + p.data = p.to(dtype) + else: + p.data = p.to(torch.float32) + + +def str_to_bool(value: str) -> int: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + # same as function as in accelerate.utils, which replaces the deprecated distutils.util.strtobool + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {value}") + + +def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Optional[bool]: + """Check if a file exists on HF Hub, if check was not successful returns None instead of erroring. + + Respect offline mode if set. + + """ + exists: Optional[bool] = None + if str_to_bool(os.environ.get("HF_HUB_OFFLINE", "0")): + # user set offline mode, cannot check + return exists + + try: + exists = file_exists(repo_id, filename, **kwargs) + except (HFValidationError, EntryNotFoundError): + # error, exists stays None + pass + except Exception as e: + warnings.warn( + f"Unable to fetch remote file due to the following error {e} - silently ignoring the lookup" + f" for the file {filename} in {repo_id}." + ) + + return exists + + +def get_pattern_key(pattern_keys, key_to_match): + """Match a substring of key_to_match in pattern keys""" + return next(filter(lambda key: re.match(rf".*\.{key}$", key_to_match), pattern_keys), key_to_match) From 9d58131c3cda4b1f6a2d2f9c368c5c71db782462 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 14:38:36 +0800 Subject: [PATCH 02/30] add optimum --- .../integration/optimum/gptq/quantizer.py | 923 ++++++++++++++++++ .../integration/optimum/utils/import_utils.py | 206 ++++ .../optimum/utils/testing_utils.py | 211 ++++ 3 files changed, 1340 insertions(+) create mode 100644 gptqmodel/integration/optimum/gptq/quantizer.py create mode 100644 gptqmodel/integration/optimum/utils/import_utils.py create mode 100644 gptqmodel/integration/optimum/utils/testing_utils.py diff --git a/gptqmodel/integration/optimum/gptq/quantizer.py b/gptqmodel/integration/optimum/gptq/quantizer.py new file mode 100644 index 000000000..6ee53da54 --- /dev/null +++ b/gptqmodel/integration/optimum/gptq/quantizer.py @@ -0,0 +1,923 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. team and GPTQ and AutoGPTQ authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import json +import os +from enum import Enum +from logging import getLogger +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from packaging import version +from torch import nn +from tqdm.auto import tqdm +from transformers import AutoTokenizer +from transformers.pytorch_utils import Conv1D +from transformers.utils.quantization_config import QuantizationMethod + +from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available +from ..utils.modeling_utils import recurse_getattr +from .constants import GPTQ_CONFIG +from .data import get_dataset, prepare_dataset +from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen +from ..version import __version__ as optimum_version + + +if is_accelerate_available(): + from accelerate import ( + cpu_offload_with_hook, + load_checkpoint_and_dispatch, + ) + from accelerate.hooks import remove_hook_from_module + +if is_auto_gptq_available(): + from auto_gptq import exllama_set_max_input_length + from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init + from auto_gptq.quantization import GPTQ + from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear + from auto_gptq import __version__ as autogptq_version + +if is_gptqmodel_available(): + from gptqmodel import exllama_set_max_input_length + from gptqmodel.quantization import GPTQ + from gptqmodel.utils.importer import hf_select_quant_linear + from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format + from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init + from gptqmodel.version import __version__ as gptqmodel_version + +logger = getLogger(__name__) + + +def has_device_more_than_cpu(): + return torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available()) + + +class ExllamaVersion(int, Enum): + ONE = 1 + TWO = 2 + + +class GPTQQuantizer(object): + r""" + A simple API for GPTQ Quantization + """ + + def __init__( + self, + bits: int, + dataset: Optional[Union[List[str], str]] = None, + group_size: int = 128, + damp_percent: float = 0.1, + desc_act: bool = False, + sym: bool = True, + true_sequential: bool = True, + checkpoint_format: str = "gptq", + meta: Optional[Dict[str, any]] = None, + backend: Optional[str] = None, + use_cuda_fp16: bool = False, + model_seqlen: Optional[int] = None, + block_name_to_quantize: Optional[str] = None, + module_name_preceding_first_block: Optional[List[str]] = None, + batch_size: int = 1, + pad_token_id: Optional[int] = None, + disable_exllama: bool = False, + exllama_config: Optional[Dict[str, Any]] = None, + max_input_length: Optional[int] = None, + cache_block_outputs: Optional[bool] = True, + modules_in_block_to_quantize: Optional[List[List[str]]] = None, + *args, + **kwargs, + ): + """ + Args: + bits (`int`): + The number of bits to quantize to, supported numbers are (2, 3, 4, 8). + dataset (`Union[List[str], str, Any]`, defaults to `None`): + The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data + (e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...]) + or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new']. + group_size (int, defaults to 128): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. + damp_percent (`float`, defaults to `0.1`): + The percent of the average Hessian diagonal to use for dampening, recommended value is 0.1. + desc_act (`bool`, defaults to `False`): + Whether to quantize columns in order of decreasing activation size. + Setting it to False can significantly speed up inference but the perplexity may become slightly worse. + Also known as act-order. + sym (`bool`, defaults to `True`): + Whether to use symetric quantization. + true_sequential (`bool`, defaults to `True`): + Whether to perform sequential quantization even within a single Transformer block. + Instead of quantizing the entire block at once, we perform layer-wise quantization. + As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers. + checkpoint_format (`str`, *optional*, defaults to `gptq`): + GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only. + meta (`Dict[str, any]`, *optional*): + Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta. + i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"] + backend (`str`, *optional*): + Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only + valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py + use_cuda_fp16 (`bool`, defaults to `False`): + Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. + model_seqlen (`Optional[int]`, defaults to `None`): + The maximum sequence length that the model can take. + block_name_to_quantize (`Optional[str]`, defaults to `None`): + The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers) + module_name_preceding_first_block (`Optional[List[str]]`, defaults to `None`): + The layers that are preceding the first Transformer block. + batch_size (`int`, defaults to `1`): + The batch size of the dataset + pad_token_id (`Optional[int]`, defaults to `None`): + The pad token id. Needed to prepare the dataset when `batch_size` > 1. + disable_exllama (`bool`, defaults to `False`): + Whether to use exllama backend. Only works with `bits` = 4. + exllama_config (`Dict[str, Any]`, *optional*): + The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults to `{"version": 2}` if unset. + max_input_length (`Optional[int]`, defaults to `None`): + The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length. + It is specific to the exllama backend with act-order. + cache_block_outputs (`bool`, defaults to `True`): + Whether to cache block outputs to reuse as inputs for the succeeding block. It allows optimization of non-standard models + (e.g. ChatGLM) but can require more time. + modules_in_block_to_quantize (`Optional[List[List[str]]]`, defaults to `None`): + List list of module names to quantize in the block specified. This argument is useful to exclude certain linear modules from being quantized. + The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. + If not set, we will quantize all linear layers. Example: `inside_layer_modules=[["self_attention.query_key_value"], ["mlp.dense_h_to_4h"]]` + """ + + self.bits = bits + self.dataset = dataset + self.group_size = group_size + self.damp_percent = damp_percent + self.desc_act = desc_act + self.sym = sym + self.true_sequential = true_sequential + self.checkpoint_format = checkpoint_format.lower() + self.meta = meta + self.backend = backend.lower() if backend is not None else None + self.use_cuda_fp16 = use_cuda_fp16 + self.model_seqlen = model_seqlen + self.block_name_to_quantize = block_name_to_quantize + self.module_name_preceding_first_block = module_name_preceding_first_block + self.batch_size = batch_size + self.pad_token_id = pad_token_id + self.disable_exllama = disable_exllama + self.exllama_config = exllama_config + self.max_input_length = max_input_length + self.quant_method = QuantizationMethod.GPTQ + self.cache_block_outputs = cache_block_outputs + self.modules_in_block_to_quantize = modules_in_block_to_quantize + + self.serialization_keys = [ + "bits", + "dataset", + "group_size", + "damp_percent", + "desc_act", + "sym", + "true_sequential", + "quant_method", + "modules_in_block_to_quantize", + "checkpoint_format", + "meta", + ] + + if self.bits not in [2, 3, 4, 8]: + raise ValueError("only support quantize to [2,3,4,8] bits.") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + + if self.exllama_config is None: + self.exllama_config = {"version": ExllamaVersion.TWO} + else: + if "version" not in self.exllama_config: + raise ValueError("`exllama_config` needs to have a `version` key") + elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: + version = self.exllama_config["version"] + raise ValueError( + f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {version}" + ) + self.exllama_version = self.exllama_config["version"] + + def select_quant_linear(self, device_map: Union[str, dict]): + if is_gptqmodel_available(): + self.quant_linear = hf_select_quant_linear( + bits=self.bits, + group_size=self.group_size, + desc_act=self.desc_act, + sym=self.sym, + checkpoint_format=self.checkpoint_format, + meta=self.meta, + device_map=device_map, + backend=self.backend, + ) + else: + self.quant_linear = hf_select_quant_linear( + use_triton=False, + desc_act=self.desc_act, + group_size=self.group_size, + bits=self.bits, + disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE, + disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO, + ) + + def to_dict(self): + """ + Returns the args in dict format. + """ + gptq_dict = {} + for key in self.serialization_keys: + gptq_dict[key] = getattr(self, key) + + if gptq_dict.get("meta") is None: + gptq_dict["meta"] = {} + + meta = gptq_dict["meta"] + # store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer + if meta.get("quantizer") is None: + meta["quantizer"] = [f"optimum:{optimum_version}"] + + if is_gptqmodel_available(): + meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}") + elif is_auto_gptq_available(): + meta["quantizer"].append(f"auto_gptq:{autogptq_version}") + + return gptq_dict + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any]): + """ + Instantiates a `GPTQQuantizer` using config_dict as kwargs + + Args: + config_dict (`Dict[str,Any]`): + quantization config + + Returns: + `GPTQQuantizer`: The quantizer object instantiated from those parameters. + """ + return cls(**config_dict) + + def convert_model(self, model: nn.Module, **kwargs): + """ + Convert the model to a GPTQ model by getting and replacing the layers. + + Args: + model (`nn.Module`): + Model to be converted + + """ + if self.block_name_to_quantize is None: + self.block_name_to_quantize = get_block_name_with_pattern(model) + block_name = self.block_name_to_quantize + layers_to_be_replaced = get_layers(model, prefix=block_name) + if self.modules_in_block_to_quantize is not None: + layers_to_keep = sum(self.modules_in_block_to_quantize, []) + for name in list(layers_to_be_replaced.keys()): + if not any(name.endswith(layer) for layer in layers_to_keep): + logger.info( + f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)" + ) + del layers_to_be_replaced[name] + + self.select_quant_linear(device_map=kwargs.get("device_map", None)) + + self._replace_by_quant_layers(model, layers_to_be_replaced) + + return model + + def get_no_split_module_classes(self, model): + """ + Get the modules that should not be split across multiple devices. + Args: + model (`nn.Module`): + The input model + """ + + block_class_name = recurse_getattr(model, self.block_name_to_quantize)[0].__class__.__name__ + no_split_module_classes = [block_class_name] + return no_split_module_classes + + def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: str = ""): + """ + Replaces linear layers in `module` by `QuantLinear` + + Args: + module (`nn.Module`): + Module to quantize + names (`List[str]`): + List of names of the module to quantize + name (`str`, defaults to `""`): + To keep track of the name of the current module + """ + if isinstance(module, self.quant_linear): + return + for attr in dir(module): + layer = getattr(module, attr) + name1 = name + "." + attr if name != "" else attr + if name1 in names: + device = get_device(layer) + delattr(module, attr) + if isinstance(layer, nn.Linear): + in_features = layer.in_features + out_features = layer.out_features + elif isinstance(layer, nn.Conv2d): + in_features = layer.in_channels + out_features = layer.out_channels + elif isinstance(layer, Conv1D): + in_features = layer.weight.shape[0] + out_features = layer.weight.shape[1] + bias = layer.bias is not None + if is_gptqmodel_available(): + new_layer = self.quant_linear( + self.bits, + self.group_size, + self.desc_act, + self.sym, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + ) + else: + if not (self.desc_act) or self.group_size == -1: + new_layer = self.quant_linear( + self.bits, + self.group_size, + in_features, + out_features, + bias, + use_cuda_fp16=self.use_cuda_fp16, + weight_dtype=layer.weight.dtype, + ) + else: + new_layer = self.quant_linear( + self.bits, + self.group_size, + in_features, + out_features, + bias, + weight_dtype=layer.weight.dtype, + ) + new_layer.device = device + setattr(module, attr, new_layer.to(device)) + for name1, child in module.named_children(): + self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1) + + @torch.no_grad() + def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None): + """ + Quantizes the model using the dataset + + Args: + model (`nn.Module`): + The model to quantize + tokenizer (Optional[`Any`], defaults to `None`): + The tokenizer to use in order to prepare the dataset. You can pass either: + - A custom tokenizer object. + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + Returns: + `nn.Module`: The quantized model + """ + + if not is_auto_gptq_available() and not is_gptqmodel_available(): + raise RuntimeError( + "gptqmodel or auto-gptq is required in order to perform gptq quantzation: `pip install gptqmodel` or `pip install auto-gptq`. Please notice that auto-gptq will be deprecated in the future." + ) + elif is_gptqmodel_available() and is_auto_gptq_available(): + logger.warning( + "Detected gptqmodel and auto-gptq, will use gptqmodel. The auto_gptq will be deprecated in the future." + ) + + gptq_supports_cpu = ( + is_auto_gptq_available() + and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + ) or is_gptqmodel_available() + + if not gptq_supports_cpu and not torch.cuda.is_available(): + raise RuntimeError( + "No cuda gpu or cpu support using Intel/IPEX found. A gpu or cpu with Intel/IPEX is required for quantization." + ) + + if not self.sym and not is_gptqmodel_available(): + raise ValueError( + "Asymmetric sym=False quantization is not supported with auto-gptq. Please use gptqmodel: `pip install gptqmodel`" + ) + + if self.checkpoint_format == "gptq_v2" and not is_gptqmodel_available(): + raise ValueError( + "gptq_v2 format only supported with gptqmodel. Please install gptqmodel: `pip install gptqmodel`" + ) + + model.eval() + + # gptqmodel internal is gptq_v2 for asym support, gptq(v1) can only support sym=True + if is_gptqmodel_available() and self.checkpoint_format != "gptq_v2": + self.checkpoint_format = "gptq_v2" + + # For Transformer model + has_config = False + has_device_map = False + if hasattr(model, "config"): + has_config = True + use_cache = model.config.use_cache + model.config.use_cache = False + + # If the model has a device_map, we don't move to model. We have already dispatched the hook that will do the work + if hasattr(model, "hf_device_map"): + devices = list(model.hf_device_map.values()) + has_device_map = True + if "disk" in devices: + raise ValueError("disk offload is not supported with GPTQ quantization") + if "cpu" in devices or torch.device("cpu") in devices: + if len(model.hf_device_map) > 1: + logger.info("Cpu offload is not recommended. There might be some issues with the memory") + hook = None + for name, device in model.hf_device_map.items(): + if device == "cpu": + module = recurse_getattr(model, name) + remove_hook_from_module(module, recurse=True) + module, hook = cpu_offload_with_hook(module, prev_module_hook=hook) + else: + has_device_map = False + + if hasattr(model, "dtype"): + self.use_cuda_fp16 = model.dtype == torch.float16 + + if self.model_seqlen is None: + # We allow a max value of 4028 to avoid passing data with huge length to the model during the calibration step + self.model_seqlen = min(4028, get_seqlen(model)) + + device = get_device(model) + + # Step 1: Prepare the data + if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str): + dataset = self.dataset + logger.info("GPTQQuantizer dataset appears to be already tokenized. Skipping tokenization.") + else: + if isinstance(tokenizer, str): + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer) + except Exception: + raise ValueError( + f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained` + with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input. + For now, we only support quantization for text model. Support for vision, speech and multimodel will come later.""" + ) + if self.dataset is None: + raise ValueError("You need to pass `dataset` in order to quantize your model") + elif isinstance(self.dataset, str): + dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train") + elif isinstance(self.dataset, list): + dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset] + else: + raise ValueError( + f"You need to pass a list of string, a list of tokenized data or a string for `dataset`. Found: {type(self.dataset)}." + ) + + dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size) + + # Step 2: get the input of the 1st block + # To do that, we need to put the modules preceding the first block on the same device as the first bloc. + # Then we run the model and it will stop at the first bloc as we added a prehook that raise an Exception after storing the inputs. + + layer_inputs = [] + layer_outputs = [] + layer_input_kwargs = [] + + if self.block_name_to_quantize is None: + self.block_name_to_quantize = get_block_name_with_pattern(model) + + if self.module_name_preceding_first_block is None: + self.module_name_preceding_first_block = get_preceding_modules(model, self.block_name_to_quantize) + + blocks = recurse_getattr(model, self.block_name_to_quantize) + + if not has_device_map: + # put modules from module_name_preceding_first_block on cuda or xpu or cpu + to_device = 0 if has_device_more_than_cpu() else "cpu" + for module_name in self.module_name_preceding_first_block: + module = recurse_getattr(model, module_name) + if module is None: + raise ValueError(f"Module {module_name} was not found in model") + module = module.to(to_device) + blocks[0] = blocks[0].to(to_device) + + def store_input_hook(_, input, *args): + kwargs = args[0] + if input is None: + if "hidden_states" in kwargs: + input = (kwargs["hidden_states"],) + else: + raise ValueError("No input value found in the foward pass") + layer_inputs.append(input) + other_kwargs = {} + for k, v in kwargs.items(): # make sure other arguments also be captured + if k not in ["hidden_states"]: + other_kwargs[k] = v + layer_input_kwargs.append(other_kwargs) + raise ValueError + + if self.cache_block_outputs: + handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) + for data in dataset: + for k, v in data.items(): + # put the data on gpu, we won't put them back to cpu + if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu(): + data[k] = v.to(0) + else: + data[k] = v.to(device) + try: + model(**data) + except ValueError: + pass + handle.remove() + + if not has_device_map: + blocks[0].to(device) + for module_name in self.module_name_preceding_first_block: + module = recurse_getattr(model, module_name) + if module is None: + raise ValueError(f"Module {module_name} was not found in model") + + torch.cuda.empty_cache() + if hasattr(torch, "xpu"): + torch.xpu.empty_cache() + + # Step 3: Quantize the blocks + quantizers = {} + for i, block in enumerate(tqdm(blocks, desc=f"Quantizing {self.block_name_to_quantize} blocks ")): + logger.info(f"Start quantizing block {self.block_name_to_quantize} {i + 1}/{len(blocks)}") + + if not self.cache_block_outputs: + handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True) + for data in dataset: + for k, v in data.items(): + # put the data on gpu, we won't put them back to cpu + if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu(): + data[k] = v.to(0) + else: + data[k] = v.to(device) + try: + model(**data) + except ValueError: + pass + handle.remove() + + # move block to cuda if needed + # in case we have offload modules, we need to put them on cuda because of GPTQ object + if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu(): + block = block.to(0) + layers = get_layers(block) + if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0: + if self.true_sequential: + layers_name_list = self.modules_in_block_to_quantize + else: + layers_name_list = [sum(self.modules_in_block_to_quantize, [])] + else: + if self.true_sequential: + # lazy sequential but works well + layers_name_list = [[key] for key in layers.keys()] + else: + layers_name_list = [list(layers.keys())] + logger.info(f"Module to quantize {layers_name_list}") + for subset_name_list in tqdm(layers_name_list, leave=False, desc="Quantizing layers inside the block"): + subset_layers = {name: layers[name] for name in subset_name_list} + gptq = {} + handles = [] + # add hook for each layer in subset_layers + for name in subset_layers: + gptq[name] = GPTQ(subset_layers[name]) + gptq[name].quantizer.configure(bits=self.bits, sym=self.sym, perchannel=True) + + def add_batch(name): + def tmp(_, input, output): + gptq[name].add_batch(input[0].data, output.data) + + return tmp + + # because it adding a hook will replace the old one. + handles.append(subset_layers[name].register_forward_hook(add_batch(name))) + # update Hessian for each layer in subset_layers thanks to the hook + for j in range(len(dataset)): + # the args are already on the gpu + # don't need to store the output + block(*layer_inputs[j], **layer_input_kwargs[j]) + # remove hook + for h in handles: + h.remove() + for name in subset_name_list: + logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...") + quant_outputs = gptq[name].fasterquant( + percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act + ) + scale, zero, g_idx = quant_outputs[0], quant_outputs[1], quant_outputs[2] + quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = ( + gptq[name].quantizer, + scale, + zero, + g_idx, + ) + gptq[name].free() + del subset_layers + # we get the new output from the partial quantized block + if self.cache_block_outputs: + for j in range(len(dataset)): + layer_output = block(*layer_inputs[j], **layer_input_kwargs[j]) + layer_outputs.append(layer_output) + + # put back to device + if not has_device_map: + blocks[i] = block.to(device) + del layers + del layer_inputs + layer_inputs, layer_outputs = layer_outputs, [] + else: + del layers + del layer_inputs + layer_inputs = [] + torch.cuda.empty_cache() + if hasattr(torch, "xpu"): + torch.xpu.empty_cache() + + if self.bits == 4: + # device not on gpu + if device.type != "cuda" or (has_device_map and any(d in devices for d in ["cpu", "disk", "hpu"])): + if not self.disable_exllama and not is_gptqmodel_available(): + logger.warning( + "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`" + ) + self.disable_exllama = True + # act order and exllama + elif self.desc_act and not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE: + logger.warning( + "Using Exllama backend with act_order will reorder the weights offline, thus you will not be able to save the model with the right weights." + "Setting `disable_exllama=True`. You should only use Exllama backend with act_order for inference. " + ) + self.disable_exllama = True + elif not self.disable_exllama and self.exllama_version == ExllamaVersion.TWO: + logger.warning( + "Using Exllamav2 backend will reorder the weights offline, thus you will not be able to save the model with the right weights." + "Setting `disable_exllama=True`. You should only use Exllamav2 backend for inference. " + ) + self.disable_exllama = True + # Step 4: Pack the model at the end (Replacing the layers) + self.pack_model(model=model, quantizers=quantizers) + + model.is_quantized = True + model.quantization_method = QuantizationMethod.GPTQ + if has_config: + model.config.use_cache = use_cache + model.config.quantization_config = self.to_dict() + + # Step 5: Any post-initialization that require device information, for example buffers initialization on device. + model = self.post_init_model(model) + + torch.cuda.empty_cache() + if hasattr(torch, "xpu"): + torch.xpu.empty_cache() + return model + + def post_init_model(self, model): + """ + Post-initialization that require device information, for example buffers initialization on device. + + Args: + model (`nn.Module`): + The input model + """ + if self.bits == 4 and not self.disable_exllama: + if get_device(model).type != "cuda" or ( + hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk", "hpu"]) + ): + if not self.disable_exllama: + logger.warning( + "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`" + ) + self.disable_exllama = True + + class StoreAttr(object): + pass + + if is_gptqmodel_available(): + model, _ = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear, self.checkpoint_format, self.meta) + + model.quantize_config = StoreAttr() + model.quantize_config.desc_act = self.desc_act + model = gptq_post_init(model, use_act_order=self.desc_act) + if ( + self.desc_act + and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE) + and self.max_input_length is not None + ): + model = exllama_set_max_input_length(model, self.max_input_length) + return model + + def pack_model( + self, + model: nn.Module, + quantizers: Dict[str, Tuple], + ): + """ + Pack the model by replacing the layers by quantized layers + + Args: + model (`nn.Module`): + The model to pack + quantizers (`Dict[str,Tuple]`): + A mapping of the layer name and the data needed to pack the layer + """ + logger.info("Packing model...") + layers = get_layers(model) + layers = {n: layers[n] for n in quantizers} + + self.select_quant_linear(device_map=model.hf_device_map) + + self._replace_by_quant_layers(model, quantizers) + qlayers = get_layers(model, [self.quant_linear]) + for name in qlayers: + logger.info(name) + quantizers[name], scale, zero, g_idx = quantizers[name] + # so far can only pack layer on CPU + layer_device = qlayers[name].device + qlayers[name].to("cpu") + layers[name], scale, zero, g_idx = layers[name].to("cpu"), scale.to("cpu"), zero.to("cpu"), g_idx.to("cpu") + qlayers[name].pack(layers[name], scale, zero, g_idx) + qlayers[name].to(layer_device) + + logger.info("Model packed.") + + def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True): + """ + Save model state dict and configs + + Args: + model (`nn.Module`): + Model to be saved. The model can be wrapped or unwraped. + save_dir (`str`): + Directory to which to save. Will be created if it doesn't exist. + max_shard_size (`str`, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + safe_serialization (`bool`, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + + """ + + # convert gptqmodel internal gptq_v2 format to v1 for max compatibility + model, converted = hf_convert_gptq_v2_to_v1_format(model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta) + if converted: + self.checkpoint_format = "gptq" + + os.makedirs(save_dir, exist_ok=True) + model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f: + json.dump(self.to_dict(), f, indent=2) + + +def load_quantized_model( + model: nn.Module, + save_folder: str, + quant_config_name: str = GPTQ_CONFIG, + state_dict_name: Optional[str] = None, + device_map: Optional[str] = None, + max_memory: Optional[Dict] = None, + no_split_module_classes: Optional[Dict] = None, + offload_folder: Optional[str] = None, + offload_buffers: Optional[str] = None, + offload_state_dict: bool = False, + disable_exllama: bool = False, + exllama_config: Optional[Dict[str, Any]] = None, + max_input_length: Optional[int] = None, +): + """ + Load quantized weights from the save_folder into the converted model and dispatch the weights according to the device_map. + + Args: + model (`nn.Module`): + The model can be enpty or not. + save_folder (`str`): + Directory to which to load the weights. + quant_config_name (`str`, defaults to `GPTQ_CONFIG`): + Name of the quantization config file + state_dict_name (`Optional[str]`, defaults to `None`): + Name of the state dict file + device_map (`Optional[str]`, defaults to `None`): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. + max_memory (`Optional[Dict]`, defaults to `None`): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU + and the available CPU RAM if unset. + no_split_module_classes (`Optional[Dict]`, defaults to `None`): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + offload_folder (`Optional[str]`, defaults to `None`): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_buffers (`Optional[str]`, defaults to `None`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. + offload_state_dict (`bool`, defaults to `False`): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map + picked contains `"disk"` values. + disable_exllama (`Optional[bool]`, defaults to `None`): + Whether to use exllama backend. Only works with `bits` = 4. + exllama_config (`Optional[Dict[str, Any]]`, defaults to `None`): + The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults to `{"version": 2}` if unset. + max_input_length (`Optional[int]`, defaults to `None`): + The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length. + It is specific to the exllama backend with act-order. + + Returns: + `nn.Module`: The quantized model + """ + if not torch.cuda.is_available() and not is_gptqmodel_available(): + raise RuntimeError("No GPU found. A GPU is needed to run quantized model by auto_gptq.") + if not is_auto_gptq_available() and not is_gptqmodel_available(): + raise RuntimeError( + "gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) is required in order to load quantized weights. Please notice that auto-gptq will be deprecated in the future." + ) + if not is_accelerate_available(): + raise RuntimeError( + "You need to install accelerate in order to load and dispatch weights to" + "a quantized model. You can do it with `pip install accelerate`" + ) + if device_map is None: + device_map = {"": torch.cuda.current_device()} + logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.") + + if exllama_config is None: + exllama_config = {"version": ExllamaVersion.TWO} + else: + if "version" not in exllama_config: + raise ValueError("`exllama_config` needs to have a `version` key") + elif exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: + version = exllama_config["version"] + raise ValueError( + f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {version}" + ) + + # this branch will check if model is from huggingface + try: + if hasattr(model, "config") and hasattr(model.config, "quantization_config"): + quantize_config_dict = model.config.quantization_config.to_dict() + else: + with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f: + quantize_config_dict = json.load(f) + except Exception as err: + raise ValueError( + f"Failed to load quantization config from {save_folder} (lookup for traceback): {err}\nTip: If the save directory is saved from a transformers.PreTrainedModel, make sure that `config.json` contains a 'quantization_config' key." + ) from err + quantizer = GPTQQuantizer.from_dict(quantize_config_dict) + quantizer.disable_exllama = disable_exllama + quantizer.exllama_config = exllama_config + quantizer.exllama_version = quantizer.exllama_config["version"] + quantizer.max_input_length = max_input_length + + model = quantizer.convert_model(model, device_map=device_map) + + if no_split_module_classes is None: + no_split_module_classes = quantizer.get_no_split_module_classes(model) + + model = load_checkpoint_and_dispatch( + model, + checkpoint=os.path.join(save_folder, state_dict_name) if state_dict_name is not None else save_folder, + device_map=device_map, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + offload_folder=offload_folder, + offload_buffers=offload_buffers, + offload_state_dict=offload_state_dict, + ) + + model = quantizer.post_init_model(model) + model.is_quantized = True + model.quantization_method = QuantizationMethod.GPTQ + model.eval() + return model \ No newline at end of file diff --git a/gptqmodel/integration/optimum/utils/import_utils.py b/gptqmodel/integration/optimum/utils/import_utils.py new file mode 100644 index 000000000..3137d453b --- /dev/null +++ b/gptqmodel/integration/optimum/utils/import_utils.py @@ -0,0 +1,206 @@ +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib.util +import itertools +import os +import shutil +import subprocess +import sys +import unittest +from collections.abc import MutableMapping +from typing import Any, Callable, Dict, Iterable, Optional, Tuple + +import torch + +from . import ( + is_accelerate_available, + is_auto_gptq_available, + is_datasets_available, + is_diffusers_available, + is_gptqmodel_available, + is_sentence_transformers_available, + is_timm_available, +) + + +# Used to test the hub +USER = "__DUMMY_OPTIMUM_USER__" + + +def flatten_dict(dictionary: Dict): + """ + Flatten a nested dictionaries as a flat dictionary. + """ + items = [] + for k, v in dictionary.items(): + new_key = k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def require_accelerate(test_case): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) + + +def require_gptq(test_case): + """ + Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed. + """ + return unittest.skipUnless(is_auto_gptq_available() or is_gptqmodel_available(), "test requires auto-gptq")( + test_case + ) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) + + +def require_ort_rocm(test_case): + """Decorator marking a test that requires ROCMExecutionProvider for ONNX Runtime.""" + import onnxruntime as ort + + providers = ort.get_available_providers() + + return unittest.skipUnless("ROCMExecutionProvider" == providers[0], "test requires ROCMExecutionProvider")( + test_case + ) + + +def require_hf_token(test_case): + """ + Decorator marking a test that requires huggingface hub token. + """ + # is HF_AUTH_TOKEN used instead of HF_TOKEN to avoid huggigface_hub picking it up ? + hf_token = os.environ.get("HF_AUTH_TOKEN", None) + if hf_token is None: + return unittest.skip("test requires hf token as `HF_AUTH_TOKEN` environment variable")(test_case) + else: + return test_case + + +def require_sigopt_token_and_project(test_case): + """ + Decorator marking a test that requires sigopt API token. + """ + sigopt_api_token = os.environ.get("SIGOPT_API_TOKEN", None) + has_sigopt_project = os.environ.get("SIGOPT_PROJECT", None) + if sigopt_api_token is None or has_sigopt_project is None: + return unittest.skip("test requires an environment variable `SIGOPT_API_TOKEN` and `SIGOPT_PROJECT`")( + test_case + ) + else: + return test_case + + +def is_ort_training_available(): + is_ort_train_available = importlib.util.find_spec("onnxruntime.training") is not None + + if importlib.util.find_spec("torch_ort") is not None: + try: + is_torch_ort_configured = True + subprocess.run([sys.executable, "-m", "torch_ort.configure"], shell=False, check=True) + except subprocess.CalledProcessError: + is_torch_ort_configured = False + + return is_ort_train_available and is_torch_ort_configured + + +def require_ort_training(test_case): + """ + Decorator marking a test that requires onnxruntime-training and torch_ort correctly installed and configured. + These tests are skipped otherwise. + """ + return unittest.skipUnless( + is_ort_training_available(), + "test requires torch_ort correctly installed and configured", + )(test_case) + + +def require_diffusers(test_case): + return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case) + + +def require_timm(test_case): + return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) + + +def require_sentence_transformers(test_case): + return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case) + + +def require_datasets(test_case): + return unittest.skipUnless(is_datasets_available(), "test requires datasets")(test_case) + + +def grid_parameters( + parameters: Dict[str, Iterable[Any]], + yield_dict: bool = False, + add_test_name: bool = True, + filter_params_func: Optional[Callable[[Tuple], Tuple]] = None, +) -> Iterable: + """ + Generates an iterable over the grid of all combinations of parameters. + + Args: + `parameters` (`Dict[str, Iterable[Any]]`): + Dictionary of multiple values to generate a grid from. + `yield_dict` (`bool`, defaults to `False`): + If True, a dictionary with all keys, and sampled values will be returned. Otherwise, return sampled values as a list. + `add_test_name` (`bool`, defaults to `True`): + Whether to add the test name in the yielded list or dictionary. + filter_params_func (`Optional[Callable[[Tuple], Tuple]]`, defaults to `None`): + A function that can modify or exclude the current set of parameters. The function should take a tuple of the + parameters and return the same. If a parameter set is to be excluded, the function should return an empty tuple. + """ + for params in itertools.product(*parameters.values()): + if filter_params_func is not None: + params = filter_params_func(list(params)) + if params is None: + continue + + test_name = "_".join([str(param) for param in params]) + if yield_dict is True: + res_dict = {} + for i, key in enumerate(parameters.keys()): + res_dict[key] = params[i] + if add_test_name is True: + res_dict["test_name"] = test_name + yield res_dict + else: + returned_list = [test_name] + list(params) if add_test_name is True else list(params) + yield returned_list + + +def remove_directory(dirpath): + """ + Remove a directory and its content. + This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows. + Reference: https://github.com/python/cpython/issues/107408 + """ + if os.path.exists(dirpath) and os.path.isdir(dirpath): + if os.name == "nt": + os.system(f"rmdir /S /Q {dirpath}") + else: + shutil.rmtree(dirpath) \ No newline at end of file diff --git a/gptqmodel/integration/optimum/utils/testing_utils.py b/gptqmodel/integration/optimum/utils/testing_utils.py new file mode 100644 index 000000000..68f713860 --- /dev/null +++ b/gptqmodel/integration/optimum/utils/testing_utils.py @@ -0,0 +1,211 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Callable, Optional, Union + +import torch +from torch import nn +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import Conv1D + +from gptqmodel.utils.logger import setup_logger +from gptqmodel.integration.optimum.gptq.quantizer import BLOCK_PATTERNS, SEQLEN_KEYS_TRANFORMERS + +ori_save_pretrained = PreTrainedModel.save_pretrained + +logger = setup_logger() + + +""" +Set of utilities to get specific attributes of a model +""" + + +def get_layers(module: nn.Module, layers=[Conv1D, nn.Conv2d, nn.Linear], prefix: Optional[str] = None, name: str = ""): + """ + Get all the layers with a specific prefix in the module + Args: + module (`nn.Module`): + The module that contains our layers + layers (`list`, defaults to `[Conv1D, nn.Conv2d, nn.Linear]`): + Type of the layers that we want to get + prefix (`Optional[str]`, defaults to `None`): + Prefix of layers + name (`str`, defaults to `""`): + Used for recursion. Don't modify + + Returns: + `Dict[str,Union[Conv1D, nn.Conv2d, nn.Linear]]`: Mapping of the name of the layer and the actual layer + """ + for layer in layers: + if isinstance(module, layer): + if prefix is not None: + if name.startswith(prefix): + return {name: module} + else: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update(get_layers(child, layers=layers, prefix=prefix, name=name + "." + name1 if name != "" else name1)) + return res + + +def get_block_name_with_pattern(model: nn.Module): + """ + Get the name of the module that contains the transformers blocks by checking if any modules has a specific pattern + + Args: + model (`nn.Module`): + The input model + Returns: + `str`: The name of the module that contains the Transformer blocks. + """ + modules_names = [n for n, _ in model.named_modules()] + for pattern_candidate in BLOCK_PATTERNS: + pattern_candidate = pattern_candidate + if any(pattern_candidate in name for name in modules_names): + return pattern_candidate + raise ValueError("Block pattern could not be match. Pass `block_name_to_quantize` argument in `quantize_model`") + + +def get_preceding_modules(model: nn.Module, module_name: str): + previous_module_name = [] + stop_adding = False + + def _get_preceding_modules(model: nn.Module, module_name: str, name: str = ""): + nonlocal stop_adding + for name_bis, child in model.named_children(): + new_name = name + "." + name_bis if name != "" else name_bis + if new_name == module_name: + stop_adding = True + break + _get_preceding_modules(child, module_name, name=new_name) + if not stop_adding: + previous_module_name.append(name) + return previous_module_name + + return _get_preceding_modules(model, module_name) + + +def get_device(obj: Union[torch.Tensor, nn.Module]): + if isinstance(obj, torch.Tensor): + return obj.device + return next(obj.parameters()).device + + +def get_seqlen(model: nn.Module): + if hasattr(model, "config"): + model_config = model.config.to_dict() + if any(k in model_config for k in SEQLEN_KEYS_TRANFORMERS): + for key in SEQLEN_KEYS_TRANFORMERS: + if key in model_config: + return model_config[key] + logger.info( + "We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`" + ) + return 2048 + + +def monkey_patch_gptqmodel_into_transformers(): + # monkey_patch transformers.utils.quantization_config.GPTQConfig.post_init() + # Because it checks the auto_gptq version + def post_init(self): + r""" + Safety checker that arguments are correct + """ + import importlib + + from packaging import version + print("monkey patch postin") + if self.bits not in [2, 3, 4, 8]: + raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + if self.dataset is not None: + if isinstance(self.dataset, str): + if self.dataset in ["ptb", "ptb-new"]: + raise ValueError( + f"""{self.dataset} dataset was deprecated. You can only choose between + ['wikitext2','c4','c4-new']""" + ) + if self.dataset not in ["wikitext2", "c4", "c4-new"]: + raise ValueError( + f"""You have entered a string value for dataset. You can only choose between + ['wikitext2','c4','c4-new'], but we found {self.dataset}""" + ) + elif not isinstance(self.dataset, list): + raise ValueError( + f"""dataset needs to be either a list of string or a value in + ['wikitext2','c4','c4-new'], but we found {self.dataset}""" + ) + + if self.use_exllama is None: + # New default behaviour + self.use_exllama = True + + if self.bits == 4 and self.use_exllama: + optimum_version = version.parse(importlib.metadata.version("optimum")) + # autogptq_version = version.parse(importlib.metadata.version("auto_gptq")) + # if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"): + if optimum_version <= version.parse("1.13.2"): + raise ValueError( + # f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}" + f"You need optimum > 1.13.2 . Make sure to have that version installed - detected version : optimum {optimum_version}" + ) + if self.modules_in_block_to_quantize is not None: + optimum_version = version.parse(importlib.metadata.version("optimum")) + if optimum_version < version.parse("1.15.0"): + raise ValueError( + "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ." + ) + + from transformers.utils.quantization_config import GPTQConfig + GPTQConfig.post_init = post_init + + from transformers.quantizers import auto + + from .hf_quantizer_gptq import GptqHfQuantizer + + auto.AUTO_QUANTIZER_MAPPING["gptq"] = GptqHfQuantizer + + # TODO monkey patch GPTQConfig? + + # model.save_pretrained() will not call optimum.quantizer.GPTQModelQuantizer.save(), + # we need to monkey patch save_pretrained() to convert gptq_v2 to gptq_v1 format. + def monkey_patch_save_pretrained(self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + **kwargs, ): + hf_quantizer = getattr(self, "hf_quantizer", None) + if hf_quantizer: + ori_model = getattr(self, "model", None) + assert ori_model + + model = hf_quantizer.optimum_quantizer.convert_gptq_v2_to_v1(ori_model) + setattr(self, "model", model) + + ori_save_pretrained(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, + max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs) + + PreTrainedModel.save_pretrained = monkey_patch_save_pretrained From 5a754526e12b889634a7e496d298f14e64e32204 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 14:56:08 +0800 Subject: [PATCH 03/30] fix import --- .../integration/peft/tuners/adalora/model.py | 11 ++++----- .../integration/peft/tuners/utils/__init__.py | 4 ---- .../integration/peft/tuners/utils/other.py | 24 +++---------------- 3 files changed, 8 insertions(+), 31 deletions(-) diff --git a/gptqmodel/integration/peft/tuners/adalora/model.py b/gptqmodel/integration/peft/tuners/adalora/model.py index 0654a87b0..6cfe78158 100644 --- a/gptqmodel/integration/peft/tuners/adalora/model.py +++ b/gptqmodel/integration/peft/tuners/adalora/model.py @@ -15,6 +15,7 @@ import warnings import torch +from peft.tuners.adalora import RankAllocator, AdaLoraLayer, SVDQuantLinear, SVDLinear from transformers.pytorch_utils import Conv1D from peft.import_utils import is_bnb_4bit_available, is_bnb_available @@ -25,14 +26,12 @@ _freeze_adapter, _get_submodules, get_auto_gptq_quant_linear, - get_gptqmodel_quant_linear, get_quantization_config, ) -from peft.import_utils import is_gptqmodel_available from peft.utils.integrations import gather_params_ctx -from .gptq import SVDQuantLinear -from .layer import AdaLoraLayer, RankAllocator, SVDLinear +from ..utils import get_gptqmodel_quant_linear +from ...import_utils import is_gptqmodel_available class AdaLoraModel(LoraModel): @@ -157,9 +156,9 @@ def _create_new_module(lora_config, adapter_name, target, device_map, **kwargs): if is_bnb_available(): import bitsandbytes as bnb - from .bnb import SVDLinear8bitLt + from peft.tuners.adalora.bnb import SVDLinear8bitLt if is_bnb_4bit_available(): - from .bnb import SVDLinear4bit + from peft.tuners.adalora.bnb import SVDLinear4bit gptq_quantization_config = kwargs.get("gptq_quantization_config", None) diff --git a/gptqmodel/integration/peft/tuners/utils/__init__.py b/gptqmodel/integration/peft/tuners/utils/__init__.py index b4f105907..dc96e4d2e 100644 --- a/gptqmodel/integration/peft/tuners/utils/__init__.py +++ b/gptqmodel/integration/peft/tuners/utils/__init__.py @@ -18,9 +18,6 @@ # limitations under the License. # from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType -from .integrations import map_cache_to_layer_device_map -from .loftq_utils import replace_lora_weights_loftq -from .peft_types import PeftType, TaskType from .other import ( TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, @@ -54,4 +51,3 @@ id_tensor_storage, cast_mixed_precision_params, ) -from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict, load_peft_weights \ No newline at end of file diff --git a/gptqmodel/integration/peft/tuners/utils/other.py b/gptqmodel/integration/peft/tuners/utils/other.py index 2116a1970..f37b09d4d 100644 --- a/gptqmodel/integration/peft/tuners/utils/other.py +++ b/gptqmodel/integration/peft/tuners/utils/other.py @@ -29,28 +29,10 @@ from huggingface_hub import file_exists from huggingface_hub.errors import EntryNotFoundError, HFValidationError from packaging import version +from peft.utils.constants import * from safetensors.torch import storage_ptr, storage_size -from ..import_utils import is_auto_gptq_available, is_gptqmodel_available, is_torch_tpu_available -from .constants import ( - CONFIG_NAME, - EMBEDDING_LAYER_NAMES, - INCLUDE_LINEAR_LAYERS_SHORTHAND, - SAFETENSORS_WEIGHTS_NAME, - TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, - TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING, - TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING, - WEIGHTS_NAME, - bloom_model_postprocess_past_key_value, - starcoder_model_postprocess_past_key_value, -) - +from gptqmodel.integration.peft.import_utils import is_auto_gptq_available, is_torch_tpu_available, is_gptqmodel_available mlu_available = False if version.parse(accelerate.__version__) >= version.parse("0.29.0"): @@ -517,7 +499,7 @@ def fsdp_auto_wrap_policy(model): from accelerate.utils.dataclasses import get_module_class_from_name from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy - from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder default_transformer_cls_names_to_wrap = ( ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" From 065b14e269b0a2e3ee0eec286b4e32a4684e4e89 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 14:57:48 +0800 Subject: [PATCH 04/30] update optimun --- .../integration/optimum/gptq/quantizer.py | 5 +- .../integration/optimum/utils/import_utils.py | 297 ++++++++++++++ .../optimum/utils/testing_utils.py | 385 +++++++++--------- 3 files changed, 490 insertions(+), 197 deletions(-) diff --git a/gptqmodel/integration/optimum/gptq/quantizer.py b/gptqmodel/integration/optimum/gptq/quantizer.py index 6ee53da54..422cd4438 100644 --- a/gptqmodel/integration/optimum/gptq/quantizer.py +++ b/gptqmodel/integration/optimum/gptq/quantizer.py @@ -1,3 +1,4 @@ + # coding=utf-8 # Copyright 2023 HuggingFace Inc. team and GPTQ and AutoGPTQ authors. # @@ -128,7 +129,7 @@ def __init__( Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta. i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"] backend (`str`, *optional*): - Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only + Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py use_cuda_fp16 (`bool`, defaults to `False`): Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. @@ -920,4 +921,4 @@ def load_quantized_model( model.is_quantized = True model.quantization_method = QuantizationMethod.GPTQ model.eval() - return model \ No newline at end of file + return model diff --git a/gptqmodel/integration/optimum/utils/import_utils.py b/gptqmodel/integration/optimum/utils/import_utils.py index 3137d453b..627ae7a1e 100644 --- a/gptqmodel/integration/optimum/utils/import_utils.py +++ b/gptqmodel/integration/optimum/utils/import_utils.py @@ -35,7 +35,304 @@ is_timm_available, ) +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Import utilities.""" + +import importlib.util +import inspect +import sys +from collections import OrderedDict +from contextlib import contextmanager +from typing import Tuple, Union + +import numpy as np +from packaging import version +from transformers.utils import is_torch_available + + +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib.metadata.version(pkg_name) + package_exists = True + except importlib.metadata.PackageNotFoundError: + package_exists = False + if return_version: + return package_exists, package_version + else: + return package_exists + + +# The package importlib_metadata is in a different place, depending on the python version. +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +TORCH_MINIMUM_VERSION = version.parse("1.11.0") +TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0") +DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0") +AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0 +GPTQMODEL_MINIMUM_VERSION = version.parse("1.3.99") # Allows 1.4.0.dev0 + + +# This is the minimal required version to support some ONNX Runtime features +ORT_QUANTIZE_MINIMUM_VERSION = version.parse("1.4.0") + + +_onnx_available = _is_package_available("onnx") + +# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.) +_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None + +_pydantic_available = _is_package_available("pydantic") +_accelerate_available = _is_package_available("accelerate") +_diffusers_available = _is_package_available("diffusers") +_auto_gptq_available = _is_package_available("auto_gptq") +_gptqmodel_available = _is_package_available("gptqmodel") +_timm_available = _is_package_available("timm") +_sentence_transformers_available = _is_package_available("sentence_transformers") +_datasets_available = _is_package_available("datasets") + +torch_version = None +if is_torch_available(): + torch_version = version.parse(importlib_metadata.version("torch")) + +_is_torch_onnx_support_available = is_torch_available() and ( + TORCH_MINIMUM_VERSION.major, + TORCH_MINIMUM_VERSION.minor, +) <= ( + torch_version.major, + torch_version.minor, + ) + + +_diffusers_version = None +if _diffusers_available: + try: + _diffusers_version = importlib_metadata.version("diffusers") + except importlib_metadata.PackageNotFoundError: + _diffusers_available = False + + +def is_torch_onnx_support_available(): + return _is_torch_onnx_support_available + + +def is_onnx_available(): + return _onnx_available + + +def is_onnxruntime_available(): + try: + # Try to import the source file of onnxruntime - if you run the tests from `tests` the function gets + # confused since there a folder named `onnxruntime` in `tests`. Therefore, `_onnxruntime_available` + # will be set to `True` even if not installed. + mod = importlib.import_module("onnxruntime") + inspect.getsourcefile(mod) + except Exception: + return False + return _onnxruntime_available + + +def is_pydantic_available(): + return _pydantic_available + + +def is_accelerate_available(): + return _accelerate_available + + +def is_diffusers_available(): + return _diffusers_available + + +def is_timm_available(): + return _timm_available + + +def is_sentence_transformers_available(): + return _sentence_transformers_available + + +def is_datasets_available(): + return _datasets_available + + +def is_auto_gptq_available(): + if _auto_gptq_available: + v = version.parse(importlib_metadata.version("auto_gptq")) + if v >= AUTOGPTQ_MINIMUM_VERSION: + return True + else: + raise ImportError( + f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported" + ) + + +def is_gptqmodel_available(): + if _gptqmodel_available: + v = version.parse(importlib_metadata.version("gptqmodel")) + if v >= GPTQMODEL_MINIMUM_VERSION: + return True + else: + raise ImportError( + f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported" + ) + + +@contextmanager +def check_if_pytorch_greater(target_version: str, message: str): + r""" + A context manager that does nothing except checking if the PyTorch version is greater than `pt_version` + """ + import torch + + if not version.parse(torch.__version__) >= version.parse(target_version): + raise ImportError( + f"Found an incompatible version of PyTorch. Found version {torch.__version__}, but only {target_version} and above are supported. {message}" + ) + try: + yield + finally: + pass + + +def check_if_transformers_greater(target_version: Union[str, version.Version]) -> bool: + """ + Checks whether the current install of transformers is greater than or equal to the target version. + + Args: + target_version (`Union[str, packaging.version.Version]`): version used as the reference for comparison. + + Returns: + bool: whether the check is True or not. + """ + import transformers + + if isinstance(target_version, str): + target_version = version.parse(target_version) + + return version.parse(transformers.__version__) >= target_version + + +def check_if_diffusers_greater(target_version: str) -> bool: + """ + Checks whether the current install of diffusers is greater than or equal to the target version. + + Args: + target_version (str): version used as the reference for comparison. + + Returns: + bool: whether the check is True or not. + """ + if not _diffusers_available: + return False + + return version.parse(_diffusers_version) >= version.parse(target_version) + + +def check_if_torch_greater(target_version: str) -> bool: + """ + Checks whether the current install of torch is greater than or equal to the target version. + + Args: + target_version (str): version used as the reference for comparison. + + Returns: + bool: whether the check is True or not. + """ + if not is_torch_available(): + return False + + return torch_version >= version.parse(target_version) + + +@contextmanager +def require_numpy_strictly_lower(package_version: str, message: str): + if not version.parse(np.__version__) < version.parse(package_version): + raise ImportError( + f"Found an incompatible version of numpy. Found version {np.__version__}, but expected numpy<{version}. {message}" + ) + try: + yield + finally: + pass + + +DIFFUSERS_IMPORT_ERROR = """ +{0} requires the diffusers library but it was not found in your environment. You can install it with pip: `pip install +diffusers`. Please note that you may need to restart your runtime after installation. +""" + +TRANSFORMERS_IMPORT_ERROR = """requires the transformers>={0} library but it was not found in your environment. You can install it with pip: `pip install +-U transformers`. Please note that you may need to restart your runtime after installation. +""" + +DATASETS_IMPORT_ERROR = """ +{0} requires the datasets library but it was not found in your environment. You can install it with pip: +`pip install datasets`. Please note that you may need to restart your runtime after installation. +""" + + +BACKENDS_MAPPING = OrderedDict( + [ + ("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)), + ( + "transformers_431", + (lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")), + ), + ( + "transformers_432", + (lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")), + ), + ( + "transformers_434", + (lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")), + ), + ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + +# Copied from: https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/utils/import_utils.py#L1041 +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + def __getattr__(cls, key): + if key.startswith("_"): + return super().__getattr__(cls, key) + requires_backends(cls, cls._backends) # Used to test the hub USER = "__DUMMY_OPTIMUM_USER__" diff --git a/gptqmodel/integration/optimum/utils/testing_utils.py b/gptqmodel/integration/optimum/utils/testing_utils.py index 68f713860..3137d453b 100644 --- a/gptqmodel/integration/optimum/utils/testing_utils.py +++ b/gptqmodel/integration/optimum/utils/testing_utils.py @@ -1,211 +1,206 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import importlib.util +import itertools import os -from typing import Callable, Optional, Union +import shutil +import subprocess +import sys +import unittest +from collections.abc import MutableMapping +from typing import Any, Callable, Dict, Iterable, Optional, Tuple import torch -from torch import nn -from transformers.modeling_utils import PreTrainedModel -from transformers.pytorch_utils import Conv1D -from gptqmodel.utils.logger import setup_logger -from gptqmodel.integration.optimum.gptq.quantizer import BLOCK_PATTERNS, SEQLEN_KEYS_TRANFORMERS +from . import ( + is_accelerate_available, + is_auto_gptq_available, + is_datasets_available, + is_diffusers_available, + is_gptqmodel_available, + is_sentence_transformers_available, + is_timm_available, +) -ori_save_pretrained = PreTrainedModel.save_pretrained -logger = setup_logger() +# Used to test the hub +USER = "__DUMMY_OPTIMUM_USER__" -""" -Set of utilities to get specific attributes of a model -""" +def flatten_dict(dictionary: Dict): + """ + Flatten a nested dictionaries as a flat dictionary. + """ + items = [] + for k, v in dictionary.items(): + new_key = k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v).items()) + else: + items.append((new_key, v)) + return dict(items) -def get_layers(module: nn.Module, layers=[Conv1D, nn.Conv2d, nn.Linear], prefix: Optional[str] = None, name: str = ""): +def require_accelerate(test_case): """ - Get all the layers with a specific prefix in the module - Args: - module (`nn.Module`): - The module that contains our layers - layers (`list`, defaults to `[Conv1D, nn.Conv2d, nn.Linear]`): - Type of the layers that we want to get - prefix (`Optional[str]`, defaults to `None`): - Prefix of layers - name (`str`, defaults to `""`): - Used for recursion. Don't modify - - Returns: - `Dict[str,Union[Conv1D, nn.Conv2d, nn.Linear]]`: Mapping of the name of the layer and the actual layer - """ - for layer in layers: - if isinstance(module, layer): - if prefix is not None: - if name.startswith(prefix): - return {name: module} - else: - return {name: module} - res = {} - for name1, child in module.named_children(): - res.update(get_layers(child, layers=layers, prefix=prefix, name=name + "." + name1 if name != "" else name1)) - return res - - -def get_block_name_with_pattern(model: nn.Module): - """ - Get the name of the module that contains the transformers blocks by checking if any modules has a specific pattern + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) - Args: - model (`nn.Module`): - The input model - Returns: - `str`: The name of the module that contains the Transformer blocks. - """ - modules_names = [n for n, _ in model.named_modules()] - for pattern_candidate in BLOCK_PATTERNS: - pattern_candidate = pattern_candidate - if any(pattern_candidate in name for name in modules_names): - return pattern_candidate - raise ValueError("Block pattern could not be match. Pass `block_name_to_quantize` argument in `quantize_model`") - - -def get_preceding_modules(model: nn.Module, module_name: str): - previous_module_name = [] - stop_adding = False - - def _get_preceding_modules(model: nn.Module, module_name: str, name: str = ""): - nonlocal stop_adding - for name_bis, child in model.named_children(): - new_name = name + "." + name_bis if name != "" else name_bis - if new_name == module_name: - stop_adding = True - break - _get_preceding_modules(child, module_name, name=new_name) - if not stop_adding: - previous_module_name.append(name) - return previous_module_name - - return _get_preceding_modules(model, module_name) - - -def get_device(obj: Union[torch.Tensor, nn.Module]): - if isinstance(obj, torch.Tensor): - return obj.device - return next(obj.parameters()).device - - -def get_seqlen(model: nn.Module): - if hasattr(model, "config"): - model_config = model.config.to_dict() - if any(k in model_config for k in SEQLEN_KEYS_TRANFORMERS): - for key in SEQLEN_KEYS_TRANFORMERS: - if key in model_config: - return model_config[key] - logger.info( - "We couldn't get the model sequence length. Setting it to 2048. You can overwrite this value by passing `model_seqlen` in` GPTQQuantizer`" + +def require_gptq(test_case): + """ + Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed. + """ + return unittest.skipUnless(is_auto_gptq_available() or is_gptqmodel_available(), "test requires auto-gptq")( + test_case ) - return 2048 - - -def monkey_patch_gptqmodel_into_transformers(): - # monkey_patch transformers.utils.quantization_config.GPTQConfig.post_init() - # Because it checks the auto_gptq version - def post_init(self): - r""" - Safety checker that arguments are correct - """ - import importlib - - from packaging import version - print("monkey patch postin") - if self.bits not in [2, 3, 4, 8]: - raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}") - if self.group_size != -1 and self.group_size <= 0: - raise ValueError("group_size must be greater than 0 or equal to -1") - if not (0 < self.damp_percent < 1): - raise ValueError("damp_percent must between 0 and 1.") - if self.dataset is not None: - if isinstance(self.dataset, str): - if self.dataset in ["ptb", "ptb-new"]: - raise ValueError( - f"""{self.dataset} dataset was deprecated. You can only choose between - ['wikitext2','c4','c4-new']""" - ) - if self.dataset not in ["wikitext2", "c4", "c4-new"]: - raise ValueError( - f"""You have entered a string value for dataset. You can only choose between - ['wikitext2','c4','c4-new'], but we found {self.dataset}""" - ) - elif not isinstance(self.dataset, list): - raise ValueError( - f"""dataset needs to be either a list of string or a value in - ['wikitext2','c4','c4-new'], but we found {self.dataset}""" - ) - - if self.use_exllama is None: - # New default behaviour - self.use_exllama = True - - if self.bits == 4 and self.use_exllama: - optimum_version = version.parse(importlib.metadata.version("optimum")) - # autogptq_version = version.parse(importlib.metadata.version("auto_gptq")) - # if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"): - if optimum_version <= version.parse("1.13.2"): - raise ValueError( - # f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}" - f"You need optimum > 1.13.2 . Make sure to have that version installed - detected version : optimum {optimum_version}" - ) - if self.modules_in_block_to_quantize is not None: - optimum_version = version.parse(importlib.metadata.version("optimum")) - if optimum_version < version.parse("1.15.0"): - raise ValueError( - "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ." - ) - - from transformers.utils.quantization_config import GPTQConfig - GPTQConfig.post_init = post_init - - from transformers.quantizers import auto - - from .hf_quantizer_gptq import GptqHfQuantizer - - auto.AUTO_QUANTIZER_MAPPING["gptq"] = GptqHfQuantizer - - # TODO monkey patch GPTQConfig? - - # model.save_pretrained() will not call optimum.quantizer.GPTQModelQuantizer.save(), - # we need to monkey patch save_pretrained() to convert gptq_v2 to gptq_v1 format. - def monkey_patch_save_pretrained(self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - state_dict: Optional[dict] = None, - save_function: Callable = torch.save, - push_to_hub: bool = False, - max_shard_size: Union[int, str] = "5GB", - safe_serialization: bool = True, - variant: Optional[str] = None, - token: Optional[Union[str, bool]] = None, - save_peft_format: bool = True, - **kwargs, ): - hf_quantizer = getattr(self, "hf_quantizer", None) - if hf_quantizer: - ori_model = getattr(self, "model", None) - assert ori_model - - model = hf_quantizer.optimum_quantizer.convert_gptq_v2_to_v1(ori_model) - setattr(self, "model", model) - - ori_save_pretrained(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, - max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs) - - PreTrainedModel.save_pretrained = monkey_patch_save_pretrained + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + torch_device = "cuda" if torch.cuda.is_available() else "cpu" + + return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) + + +def require_ort_rocm(test_case): + """Decorator marking a test that requires ROCMExecutionProvider for ONNX Runtime.""" + import onnxruntime as ort + + providers = ort.get_available_providers() + + return unittest.skipUnless("ROCMExecutionProvider" == providers[0], "test requires ROCMExecutionProvider")( + test_case + ) + + +def require_hf_token(test_case): + """ + Decorator marking a test that requires huggingface hub token. + """ + # is HF_AUTH_TOKEN used instead of HF_TOKEN to avoid huggigface_hub picking it up ? + hf_token = os.environ.get("HF_AUTH_TOKEN", None) + if hf_token is None: + return unittest.skip("test requires hf token as `HF_AUTH_TOKEN` environment variable")(test_case) + else: + return test_case + + +def require_sigopt_token_and_project(test_case): + """ + Decorator marking a test that requires sigopt API token. + """ + sigopt_api_token = os.environ.get("SIGOPT_API_TOKEN", None) + has_sigopt_project = os.environ.get("SIGOPT_PROJECT", None) + if sigopt_api_token is None or has_sigopt_project is None: + return unittest.skip("test requires an environment variable `SIGOPT_API_TOKEN` and `SIGOPT_PROJECT`")( + test_case + ) + else: + return test_case + + +def is_ort_training_available(): + is_ort_train_available = importlib.util.find_spec("onnxruntime.training") is not None + + if importlib.util.find_spec("torch_ort") is not None: + try: + is_torch_ort_configured = True + subprocess.run([sys.executable, "-m", "torch_ort.configure"], shell=False, check=True) + except subprocess.CalledProcessError: + is_torch_ort_configured = False + + return is_ort_train_available and is_torch_ort_configured + + +def require_ort_training(test_case): + """ + Decorator marking a test that requires onnxruntime-training and torch_ort correctly installed and configured. + These tests are skipped otherwise. + """ + return unittest.skipUnless( + is_ort_training_available(), + "test requires torch_ort correctly installed and configured", + )(test_case) + + +def require_diffusers(test_case): + return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case) + + +def require_timm(test_case): + return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) + + +def require_sentence_transformers(test_case): + return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case) + + +def require_datasets(test_case): + return unittest.skipUnless(is_datasets_available(), "test requires datasets")(test_case) + + +def grid_parameters( + parameters: Dict[str, Iterable[Any]], + yield_dict: bool = False, + add_test_name: bool = True, + filter_params_func: Optional[Callable[[Tuple], Tuple]] = None, +) -> Iterable: + """ + Generates an iterable over the grid of all combinations of parameters. + + Args: + `parameters` (`Dict[str, Iterable[Any]]`): + Dictionary of multiple values to generate a grid from. + `yield_dict` (`bool`, defaults to `False`): + If True, a dictionary with all keys, and sampled values will be returned. Otherwise, return sampled values as a list. + `add_test_name` (`bool`, defaults to `True`): + Whether to add the test name in the yielded list or dictionary. + filter_params_func (`Optional[Callable[[Tuple], Tuple]]`, defaults to `None`): + A function that can modify or exclude the current set of parameters. The function should take a tuple of the + parameters and return the same. If a parameter set is to be excluded, the function should return an empty tuple. + """ + for params in itertools.product(*parameters.values()): + if filter_params_func is not None: + params = filter_params_func(list(params)) + if params is None: + continue + + test_name = "_".join([str(param) for param in params]) + if yield_dict is True: + res_dict = {} + for i, key in enumerate(parameters.keys()): + res_dict[key] = params[i] + if add_test_name is True: + res_dict["test_name"] = test_name + yield res_dict + else: + returned_list = [test_name] + list(params) if add_test_name is True else list(params) + yield returned_list + + +def remove_directory(dirpath): + """ + Remove a directory and its content. + This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows. + Reference: https://github.com/python/cpython/issues/107408 + """ + if os.path.exists(dirpath) and os.path.isdir(dirpath): + if os.name == "nt": + os.system(f"rmdir /S /Q {dirpath}") + else: + shutil.rmtree(dirpath) \ No newline at end of file From 663d4d3dd1b5fc301f688235707ef086a9a9bb0e Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 15:00:28 +0800 Subject: [PATCH 05/30] fix optimun import --- gptqmodel/integration/optimum/gptq/quantizer.py | 15 ++++++++------- .../integration/optimum/utils/import_utils.py | 4 +--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/gptqmodel/integration/optimum/gptq/quantizer.py b/gptqmodel/integration/optimum/gptq/quantizer.py index 422cd4438..96b4897dd 100644 --- a/gptqmodel/integration/optimum/gptq/quantizer.py +++ b/gptqmodel/integration/optimum/gptq/quantizer.py @@ -28,13 +28,14 @@ from transformers.pytorch_utils import Conv1D from transformers.utils.quantization_config import QuantizationMethod -from ..utils import is_accelerate_available, is_auto_gptq_available, is_gptqmodel_available -from ..utils.modeling_utils import recurse_getattr -from .constants import GPTQ_CONFIG -from .data import get_dataset, prepare_dataset -from .utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen -from ..version import __version__ as optimum_version - +from optimum.utils import is_accelerate_available, is_auto_gptq_available +from optimum.utils.modeling_utils import recurse_getattr +from optimum.gptq.constants import GPTQ_CONFIG +from optimum.gptq.data import get_dataset, prepare_dataset +from optimum.gptq.utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen +from optimum.version import __version__ as optimum_version + +from gptqmodel.integration.optimum.utils.import_utils import is_gptqmodel_available if is_accelerate_available(): from accelerate import ( diff --git a/gptqmodel/integration/optimum/utils/import_utils.py b/gptqmodel/integration/optimum/utils/import_utils.py index 627ae7a1e..a29cb063d 100644 --- a/gptqmodel/integration/optimum/utils/import_utils.py +++ b/gptqmodel/integration/optimum/utils/import_utils.py @@ -25,12 +25,10 @@ import torch -from . import ( +from optimum.utils import ( is_accelerate_available, is_auto_gptq_available, - is_datasets_available, is_diffusers_available, - is_gptqmodel_available, is_sentence_transformers_available, is_timm_available, ) From 11a9b2681a3fc79f484065b501b025e19a7c41a3 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 15:02:54 +0800 Subject: [PATCH 06/30] add transformers --- .../transformers/quantizers/quantizer_gptq.py | 110 + .../integration/transformers/testing_utils.py | 2756 +++++++++++++++++ .../transformers/utils/import_utils.py | 2218 +++++++++++++ .../transformers/utils/quantization_config.py | 1378 +++++++++ 4 files changed, 6462 insertions(+) create mode 100644 gptqmodel/integration/transformers/quantizers/quantizer_gptq.py create mode 100644 gptqmodel/integration/transformers/testing_utils.py create mode 100644 gptqmodel/integration/transformers/utils/import_utils.py create mode 100644 gptqmodel/integration/transformers/utils/quantization_config.py diff --git a/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py b/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py new file mode 100644 index 000000000..51a5f7df4 --- /dev/null +++ b/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py @@ -0,0 +1,110 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +from typing import TYPE_CHECKING, Optional + +from packaging import version + +from .base import HfQuantizer + + +if TYPE_CHECKING: + from ..modeling_utils import PreTrainedModel + +from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging +from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class GptqHfQuantizer(HfQuantizer): + """ + Quantizer of the GPTQ method - for GPTQ the quantizer support calibration of the model through + `auto_gptq` or `gptqmodel` package. Quantization is done under the hood for users if they load a non-prequantized model. + """ + + requires_calibration = False + required_packages = ["optimum", "auto_gptq", "gptqmodel"] + optimum_quantizer = None + + def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): + super().__init__(quantization_config, **kwargs) + from optimum.gptq import GPTQQuantizer + + self.optimum_quantizer = GPTQQuantizer.from_dict(self.quantization_config.to_dict_optimum()) + + def validate_environment(self, *args, **kwargs): + if not is_optimum_available(): + raise ImportError("Loading a GPTQ quantized model requires optimum (`pip install optimum`)") + if is_auto_gptq_available() and is_gptqmodel_available(): + logger.warning( + "Detected gptqmodel and auto-gptq, will use gptqmodel, auto-gptq will be deprecated in the future." + ) + + gptq_supports_cpu = ( + is_auto_gptq_available() + and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") + ) or is_gptqmodel_available() + if not gptq_supports_cpu and not torch.cuda.is_available(): + raise RuntimeError("GPU is required to quantize or run quantize model.") + elif not (is_auto_gptq_available() or is_gptqmodel_available()): + raise ImportError( + "Loading a GPTQ quantized model requires gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) library. Please notice that auto-gptq will be deprecated in the future." + ) + elif is_auto_gptq_available() and version.parse(importlib.metadata.version("auto_gptq")) < version.parse( + "0.4.2" + ): + raise ImportError( + "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq` or use gptqmodel by `pip install gptqmodel`. Please notice that auto-gptq will be deprecated in the future." + ) + elif is_gptqmodel_available() and ( + version.parse(importlib.metadata.version("gptqmodel")) <= version.parse("1.3.1") + or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.99") + ): + raise ImportError("The gptqmodel version should be >= 1.3.2, optimum version should >= 1.24.0") + + def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": + if torch_dtype is None: + torch_dtype = torch.float16 + elif torch_dtype != torch.float16: + logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.") + return torch_dtype + + def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs): + if model.__class__.main_input_name != "input_ids": + raise RuntimeError("We can only quantize pure text model.") + + if self.pre_quantized: + model = self.optimum_quantizer.convert_model(model, **kwargs) + + def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): + if self.pre_quantized: + model = self.optimum_quantizer.post_init_model(model) + else: + if self.quantization_config.tokenizer is None: + self.quantization_config.tokenizer = model.name_or_path + + self.optimum_quantizer.quantize_model(model, self.quantization_config.tokenizer) + model.config.quantization_config = GPTQConfig.from_dict(self.optimum_quantizer.to_dict()) + + @property + def is_trainable(self, model: Optional["PreTrainedModel"] = None): + return True + + def is_serializable(self, safe_serialization=None): + return True \ No newline at end of file diff --git a/gptqmodel/integration/transformers/testing_utils.py b/gptqmodel/integration/transformers/testing_utils.py new file mode 100644 index 000000000..bc6f98b9e --- /dev/null +++ b/gptqmodel/integration/transformers/testing_utils.py @@ -0,0 +1,2756 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import contextlib +import doctest +import functools +import gc +import importlib +import inspect +import logging +import multiprocessing +import os +import re +import shlex +import shutil +import subprocess +import sys +import tempfile +import time +import unittest +from collections import defaultdict +from collections.abc import Mapping +from dataclasses import MISSING, fields +from functools import wraps +from io import StringIO +from pathlib import Path +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union +from unittest import mock +from unittest.mock import patch + +import huggingface_hub.utils +import urllib3 +from huggingface_hub import delete_repo + +from transformers import logging as transformers_logging + +from .integrations import ( + is_clearml_available, + is_optuna_available, + is_ray_available, + is_sigopt_available, + is_tensorboard_available, + is_wandb_available, +) +from .integrations.deepspeed import is_deepspeed_available +from .utils import ( + ACCELERATE_MIN_VERSION, + GGUF_MIN_VERSION, + is_accelerate_available, + is_apex_available, + is_aqlm_available, + is_auto_awq_available, + is_auto_gptq_available, + is_av_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_bs4_available, + is_compressed_tensors_available, + is_cv2_available, + is_cython_available, + is_detectron2_available, + is_eetq_available, + is_essentia_available, + is_faiss_available, + is_fbgemm_gpu_available, + is_flash_attn_2_available, + is_flax_available, + is_fsdp_available, + is_ftfy_available, + is_g2p_en_available, + is_galore_torch_available, + is_gguf_available, + is_gptqmodel_available, + is_grokadamw_available, + is_ipex_available, + is_jieba_available, + is_jinja_available, + is_jumanpp_available, + is_keras_nlp_available, + is_levenshtein_available, + is_librosa_available, + is_liger_kernel_available, + is_lomo_available, + is_natten_available, + is_nltk_available, + is_onnx_available, + is_optimum_available, + is_optimum_quanto_available, + is_pandas_available, + is_peft_available, + is_phonemizer_available, + is_pretty_midi_available, + is_pyctcdecode_available, + is_pytesseract_available, + is_pytest_available, + is_pytorch_quantization_available, + is_rjieba_available, + is_sacremoses_available, + is_safetensors_available, + is_schedulefree_available, + is_scipy_available, + is_sentencepiece_available, + is_seqio_available, + is_soundfile_availble, + is_spacy_available, + is_sudachi_available, + is_sudachi_projection_available, + is_tensorflow_probability_available, + is_tensorflow_text_available, + is_tf2onnx_available, + is_tf_available, + is_tiktoken_available, + is_timm_available, + is_tokenizers_available, + is_torch_available, + is_torch_bf16_available_on_device, + is_torch_bf16_cpu_available, + is_torch_bf16_gpu_available, + is_torch_deterministic, + is_torch_fp16_available_on_device, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_sdpa_available, + is_torch_tensorrt_fx_available, + is_torch_tf32_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + is_torchaudio_available, + is_torchdynamo_available, + is_torchvision_available, + is_vision_available, + strtobool, +) + + +if is_accelerate_available(): + from accelerate.state import AcceleratorState, PartialState + from accelerate.utils.imports import is_fp8_available + + +if is_pytest_available(): + from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + ) + from _pytest.outcomes import skip + from _pytest.pathlib import import_path + from pytest import DoctestItem +else: + Module = object + DoctestItem = object + + +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" +DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" +DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer" +# Used to test Auto{Config, Model, Tokenizer} model_type detection. + +# Used to test the hub +USER = "__DUMMY_TRANSFORMERS_USER__" +ENDPOINT_STAGING = "https://hub-ci.huggingface.co" + +# Not critical, only usable on the sandboxed CI instance. +TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL" + +if is_torch_available(): + import torch + + IS_ROCM_SYSTEM = torch.version.hip is not None + IS_CUDA_SYSTEM = torch.version.cuda is not None +else: + IS_ROCM_SYSTEM = False + IS_CUDA_SYSTEM = False + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +def parse_int_from_env(key, default=None): + try: + value = os.environ[key] + except KeyError: + _value = default + else: + try: + _value = int(value) + except ValueError: + raise ValueError(f"If set, {key} must be a int.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=True) +_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True) +_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False) +_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False) +_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None) +_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) +_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False) +_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False) + + +def get_device_count(): + import torch + + if is_torch_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() + + return num_devices + + +def is_pt_tf_cross_test(test_case): + """ + Decorator marking a test as a test that control interactions between PyTorch and TensorFlow. + + PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable + to a truthy value and selecting the is_pt_tf_cross_test pytest mark. + + """ + if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available(): + return unittest.skip(reason="test is PT+TF test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pt_tf_cross_test()(test_case) + + +def is_pt_flax_cross_test(test_case): + """ + Decorator marking a test as a test that control interactions between PyTorch and Flax + + PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment + variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark. + + """ + if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available(): + return unittest.skip(reason="test is PT+FLAX test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pt_flax_cross_test()(test_case) + + +def is_staging_test(test_case): + """ + Decorator marking a test as a staging test. + + Those tests will run using the staging environment of huggingface.co instead of the real model hub. + """ + if not _run_staging: + return unittest.skip(reason="test is staging test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_staging_test()(test_case) + + +def is_pipeline_test(test_case): + """ + Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be + skipped. + """ + if not _run_pipeline_tests: + return unittest.skip(reason="test is pipeline test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pipeline_test()(test_case) + + +def is_agent_test(test_case): + """ + Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped. + """ + if not _run_agent_tests: + return unittest.skip(reason="test is an agent test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_agent_test()(test_case) + + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def tooslow(test_case): + """ + Decorator marking a test as too slow. + + Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as + these will not be tested by the CI. + + """ + return unittest.skip(reason="test is too slow")(test_case) + + +def skip_if_not_implemented(test_func): + @functools.wraps(test_func) + def wrapper(*args, **kwargs): + try: + return test_func(*args, **kwargs) + except NotImplementedError as e: + raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}") + + return wrapper + + +def apply_skip_if_not_implemented(cls): + """ + Class decorator to apply @skip_if_not_implemented to all test methods. + """ + for attr_name in dir(cls): + if attr_name.startswith("test_"): + attr = getattr(cls, attr_name) + if callable(attr): + setattr(cls, attr_name, skip_if_not_implemented(attr)) + return cls + + +def custom_tokenizers(test_case): + """ + Decorator marking a test for a custom tokenizer. + + Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS + environment variable to a truthy value to run them. + """ + return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case) + + +def require_bs4(test_case): + """ + Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed. + """ + return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case) + + +def require_galore_torch(test_case): + """ + Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed. + https://github.com/jiaweizzhao/GaLore + """ + return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case) + + +def require_lomo(test_case): + """ + Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed. + https://github.com/OpenLMLab/LOMO + """ + return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case) + + +def require_grokadamw(test_case): + """ + Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed. + """ + return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case) + + +def require_schedulefree(test_case): + """ + Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed. + https://github.com/facebookresearch/schedule_free + """ + return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case) + + +def require_cv2(test_case): + """ + Decorator marking a test that requires OpenCV. + + These tests are skipped when OpenCV isn't installed. + + """ + return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case) + + +def require_levenshtein(test_case): + """ + Decorator marking a test that requires Levenshtein. + + These tests are skipped when Levenshtein isn't installed. + + """ + return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case) + + +def require_nltk(test_case): + """ + Decorator marking a test that requires NLTK. + + These tests are skipped when NLTK isn't installed. + + """ + return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case) + + +def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION): + """ + Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. + """ + return unittest.skipUnless( + is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}" + )(test_case) + + +def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION): + """ + Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed. + """ + return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")( + test_case + ) + + +def require_fsdp(test_case, min_version: str = "1.12.0"): + """ + Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed. + """ + return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")( + test_case + ) + + +def require_g2p_en(test_case): + """ + Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case) + + +def require_safetensors(test_case): + """ + Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed. + """ + return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) + + +def require_rjieba(test_case): + """ + Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed. + """ + return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case) + + +def require_jieba(test_case): + """ + Decorator marking a test that requires jieba. These tests are skipped when jieba isn't installed. + """ + return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case) + + +def require_jinja(test_case): + """ + Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed. + """ + return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case) + + +def require_tf2onnx(test_case): + return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case) + + +def require_onnx(test_case): + return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case) + + +def require_timm(test_case): + """ + Decorator marking a test that requires Timm. + + These tests are skipped when Timm isn't installed. + + """ + return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case) + + +def require_natten(test_case): + """ + Decorator marking a test that requires NATTEN. + + These tests are skipped when NATTEN isn't installed. + + """ + return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case) + + +def require_torch(test_case): + """ + Decorator marking a test that requires PyTorch. + + These tests are skipped when PyTorch isn't installed. + + """ + return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) + + +def require_flash_attn(test_case): + """ + Decorator marking a test that requires Flash Attention. + + These tests are skipped when Flash Attention isn't installed. + + """ + return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) + + +def require_torch_sdpa(test_case): + """ + Decorator marking a test that requires PyTorch's SDPA. + + These tests are skipped when requirements are not met (torch version). + """ + return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case) + + +def require_read_token(fn): + """ + A decorator that loads the HF token for tests that require to load gated models. + """ + token = os.getenv("HF_HUB_READ_TOKEN") + + @wraps(fn) + def _inner(*args, **kwargs): + if token is not None: + with patch("huggingface_hub.utils._headers.get_token", return_value=token): + return fn(*args, **kwargs) + else: # Allow running locally with the default token env variable + return fn(*args, **kwargs) + + return _inner + + +def require_peft(test_case): + """ + Decorator marking a test that requires PEFT. + + These tests are skipped when PEFT isn't installed. + + """ + return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case) + + +def require_torchvision(test_case): + """ + Decorator marking a test that requires Torchvision. + + These tests are skipped when Torchvision isn't installed. + + """ + return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case) + + +def require_torch_or_tf(test_case): + """ + Decorator marking a test that requires PyTorch or TensorFlow. + + These tests are skipped when neither PyTorch not TensorFlow is installed. + + """ + return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")( + test_case + ) + + +def require_intel_extension_for_pytorch(test_case): + """ + Decorator marking a test that requires Intel Extension for PyTorch. + + These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch + version. + + """ + return unittest.skipUnless( + is_ipex_available(), + "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see" + " https://github.com/intel/intel-extension-for-pytorch", + )(test_case) + + +def require_tensorflow_probability(test_case): + """ + Decorator marking a test that requires TensorFlow probability. + + These tests are skipped when TensorFlow probability isn't installed. + + """ + return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")( + test_case + ) + + +def require_torchaudio(test_case): + """ + Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed. + """ + return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case) + + +def require_tf(test_case): + """ + Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed. + """ + return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case) + + +def require_flax(test_case): + """ + Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed + """ + return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) + + +def require_sentencepiece(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case) + + +def require_sacremoses(test_case): + """ + Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed. + """ + return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case) + + +def require_seqio(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case) + + +def require_scipy(test_case): + """ + Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case) + + +def require_tokenizers(test_case): + """ + Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. + """ + return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) + + +def require_tensorflow_text(test_case): + """ + Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't + installed. + """ + return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case) + + +def require_keras_nlp(test_case): + """ + Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed. + """ + return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case) + + +def require_pandas(test_case): + """ + Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed. + """ + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) + + +def require_pytesseract(test_case): + """ + Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed. + """ + return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case) + + +def require_pytorch_quantization(test_case): + """ + Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch + Quantization Toolkit isn't installed. + """ + return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")( + test_case + ) + + +def require_vision(test_case): + """ + Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't + installed. + """ + return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case) + + +def require_ftfy(test_case): + """ + Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed. + """ + return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case) + + +def require_spacy(test_case): + """ + Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed. + """ + return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case) + + +def require_torch_multi_gpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without + multiple GPUs. + + To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu" + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + device_count = get_device_count() + + return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case) + + +def require_torch_multi_accelerator(test_case): + """ + Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine + without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain + multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator" + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")( + test_case + ) + + +def require_torch_non_multi_gpu(test_case): + """ + Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case) + + +def require_torch_non_multi_accelerator(test_case): + """ + Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case) + + +def require_torch_up_to_2_gpus(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + import torch + + return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case) + + +def require_torch_up_to_2_accelerators(test_case): + """ + Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch). + """ + if not is_torch_available(): + return unittest.skip(reason="test requires PyTorch")(test_case) + + return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")( + test_case + ) + + +def require_torch_xla(test_case): + """ + Decorator marking a test that requires TorchXLA (in PyTorch). + """ + return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case) + + +def require_torch_neuroncore(test_case): + """ + Decorator marking a test that requires NeuronCore (in PyTorch). + """ + return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")( + test_case + ) + + +def require_torch_npu(test_case): + """ + Decorator marking a test that requires NPU (in PyTorch). + """ + return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case) + + +def require_torch_multi_npu(test_case): + """ + Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without + multiple NPUs. + + To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu" + """ + if not is_torch_npu_available(): + return unittest.skip(reason="test requires PyTorch NPU")(test_case) + + return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case) + + +def require_torch_xpu(test_case): + """ + Decorator marking a test that requires XPU (in PyTorch). + + These tests are skipped when XPU backend is not available. XPU backend might be available either via stock + PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version + must match match current PyTorch version. + """ + return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case) + + +def require_non_xpu(test_case): + """ + Decorator marking a test that should be skipped for XPU. + """ + return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case) + + +def require_torch_multi_xpu(test_case): + """ + Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without + multiple XPUs. + + To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu" + """ + if not is_torch_xpu_available(): + return unittest.skip(reason="test requires PyTorch XPU")(test_case) + + return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) + + +if is_torch_available(): + # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode + import torch + + if "TRANSFORMERS_TEST_BACKEND" in os.environ: + backend = os.environ["TRANSFORMERS_TEST_BACKEND"] + try: + _ = importlib.import_module(backend) + except ModuleNotFoundError as e: + raise ModuleNotFoundError( + f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its" + f" traceback):\n{e}" + ) from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ: + torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"] + if torch_device == "cuda" and not torch.cuda.is_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment." + ) + if torch_device == "xpu" and not is_torch_xpu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment." + ) + if torch_device == "npu" and not is_torch_npu_available(): + raise ValueError( + f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment." + ) + + try: + # try creating device to see if provided device is valid + _ = torch.device(torch_device) + except RuntimeError as e: + raise RuntimeError( + f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}" + ) from e + elif torch.cuda.is_available(): + torch_device = "cuda" + elif _run_third_party_device_tests and is_torch_npu_available(): + torch_device = "npu" + elif _run_third_party_device_tests and is_torch_xpu_available(): + torch_device = "xpu" + else: + torch_device = "cpu" +else: + torch_device = None + +if is_tf_available(): + import tensorflow as tf + +if is_flax_available(): + import jax + + jax_device = jax.default_backend() +else: + jax_device = None + + +def require_torchdynamo(test_case): + """Decorator marking a test that requires TorchDynamo""" + return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case) + + +def require_torchao(test_case): + """Decorator marking a test that requires torchao""" + return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case) + + +def require_torch_tensorrt_fx(test_case): + """Decorator marking a test that requires Torch-TensorRT FX""" + return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case) + + +def require_torch_gpu(test_case): + """Decorator marking a test that requires CUDA and PyTorch.""" + return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case) + + +def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case): + """ + Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled. + """ + if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available(): + return test_case + return require_torch_gpu(test_case) + + +def require_torch_accelerator(test_case): + """Decorator marking a test that requires an accessible accelerator and PyTorch.""" + return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")( + test_case + ) + + +def require_torch_fp16(test_case): + """Decorator marking a test that requires a device that supports fp16""" + return unittest.skipUnless( + is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support" + )(test_case) + + +def require_fp8(test_case): + """Decorator marking a test that requires supports for fp8""" + return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")( + test_case + ) + + +def require_torch_bf16(test_case): + """Decorator marking a test that requires a device that supports bf16""" + return unittest.skipUnless( + is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support" + )(test_case) + + +def require_torch_bf16_gpu(test_case): + """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0""" + return unittest.skipUnless( + is_torch_bf16_gpu_available(), + "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0", + )(test_case) + + +def require_torch_bf16_cpu(test_case): + """Decorator marking a test that requires torch>=1.10, using CPU.""" + return unittest.skipUnless( + is_torch_bf16_cpu_available(), + "test requires torch>=1.10, using CPU", + )(test_case) + + +def require_deterministic_for_xpu(test_case): + if is_torch_xpu_available(): + return unittest.skipUnless(is_torch_deterministic(), "test requires torch to use deterministic algorithms")( + test_case + ) + else: + return test_case + + +def require_torch_tf32(test_case): + """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7.""" + return unittest.skipUnless( + is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7" + )(test_case) + + +def require_detectron2(test_case): + """Decorator marking a test that requires detectron2.""" + return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case) + + +def require_faiss(test_case): + """Decorator marking a test that requires faiss.""" + return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case) + + +def require_optuna(test_case): + """ + Decorator marking a test that requires optuna. + + These tests are skipped when optuna isn't installed. + + """ + return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case) + + +def require_ray(test_case): + """ + Decorator marking a test that requires Ray/tune. + + These tests are skipped when Ray/tune isn't installed. + + """ + return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case) + + +def require_sigopt(test_case): + """ + Decorator marking a test that requires SigOpt. + + These tests are skipped when SigOpt isn't installed. + + """ + return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case) + + +def require_wandb(test_case): + """ + Decorator marking a test that requires wandb. + + These tests are skipped when wandb isn't installed. + + """ + return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) + + +def require_clearml(test_case): + """ + Decorator marking a test requires clearml. + + These tests are skipped when clearml isn't installed. + + """ + return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case) + + +def require_soundfile(test_case): + """ + Decorator marking a test that requires soundfile + + These tests are skipped when soundfile isn't installed. + + """ + return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case) + + +def require_deepspeed(test_case): + """ + Decorator marking a test that requires deepspeed + """ + return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case) + + +def require_apex(test_case): + """ + Decorator marking a test that requires apex + """ + return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case) + + +def require_aqlm(test_case): + """ + Decorator marking a test that requires aqlm + """ + return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case) + + +def require_eetq(test_case): + """ + Decorator marking a test that requires eetq + """ + eetq_available = is_eetq_available() + if eetq_available: + try: + import eetq # noqa: F401 + except ImportError as exc: + if "shard_checkpoint" in str(exc): + # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed + # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34. + # TODO: Remove once eetq releases a fix and this release is used in CI + eetq_available = False + return unittest.skipUnless(eetq_available, "test requires eetq")(test_case) + + +def require_av(test_case): + """ + Decorator marking a test that requires av + """ + return unittest.skipUnless(is_av_available(), "test requires av")(test_case) + + +def require_bitsandbytes(test_case): + """ + Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed. + """ + if is_bitsandbytes_available() and is_torch_available(): + try: + import pytest + + return pytest.mark.bitsandbytes(test_case) + except ImportError: + return test_case + else: + return unittest.skip(reason="test requires bitsandbytes and torch")(test_case) + + +def require_optimum(test_case): + """ + Decorator for optimum dependency + """ + return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case) + + +def require_tensorboard(test_case): + """ + Decorator for `tensorboard` dependency + """ + return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard") + + +def require_gptq(test_case): + """ + Decorator for auto_gptq dependency + """ + return unittest.skipUnless( + is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq" + )(test_case) + + +def require_auto_awq(test_case): + """ + Decorator for auto_awq dependency + """ + return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case) + + +def require_optimum_quanto(test_case): + """ + Decorator for quanto dependency + """ + return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case) + + +def require_compressed_tensors(test_case): + """ + Decorator for compressed_tensors dependency + """ + return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case) + + +def require_fbgemm_gpu(test_case): + """ + Decorator for fbgemm_gpu dependency + """ + return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case) + + +def require_phonemizer(test_case): + """ + Decorator marking a test that requires phonemizer + """ + return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case) + + +def require_pyctcdecode(test_case): + """ + Decorator marking a test that requires pyctcdecode + """ + return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case) + + +def require_librosa(test_case): + """ + Decorator marking a test that requires librosa + """ + return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) + + +def require_liger_kernel(test_case): + """ + Decorator marking a test that requires liger_kernel + """ + return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case) + + +def require_essentia(test_case): + """ + Decorator marking a test that requires essentia + """ + return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case) + + +def require_pretty_midi(test_case): + """ + Decorator marking a test that requires pretty_midi + """ + return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case) + + +def cmd_exists(cmd): + return shutil.which(cmd) is not None + + +def require_usr_bin_time(test_case): + """ + Decorator marking a test that requires `/usr/bin/time` + """ + return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case) + + +def require_sudachi(test_case): + """ + Decorator marking a test that requires sudachi + """ + return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case) + + +def require_sudachi_projection(test_case): + """ + Decorator marking a test that requires sudachi_projection + """ + return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")( + test_case + ) + + +def require_jumanpp(test_case): + """ + Decorator marking a test that requires jumanpp + """ + return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case) + + +def require_cython(test_case): + """ + Decorator marking a test that requires jumanpp + """ + return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case) + + +def require_tiktoken(test_case): + """ + Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed. + """ + return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case) + + +def get_gpu_count(): + """ + Return the number of available gpus (regardless of whether torch, tf or jax is used) + """ + if is_torch_available(): + import torch + + return torch.cuda.device_count() + elif is_tf_available(): + import tensorflow as tf + + return len(tf.config.list_physical_devices("GPU")) + elif is_flax_available(): + import jax + + return jax.device_count() + else: + return 0 + + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return os.path.join(tests_dir, append_path) + else: + return tests_dir + + +# +# Helper functions for dealing with testing text outputs +# The original code came from: +# https://github.com/fastai/fastai/blob/master/tests/utils/text.py + + +# When any function contains print() calls that get overwritten, like progress bars, +# a special care needs to be applied, since under pytest -s captured output (capsys +# or contextlib.redirect_stdout) contains any temporary printed strings, followed by +# \r's. This helper function ensures that the buffer will contain the same output +# with and without -s in pytest, by turning: +# foo bar\r tar mar\r final message +# into: +# final message +# it can handle a single string or a multiline buffer +def apply_print_resets(buf): + return re.sub(r"^.*\r", "", buf, 0, re.M) + + +def assert_screenout(out, what): + out_pr = apply_print_resets(out).lower() + match_str = out_pr.find(what.lower()) + assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" + + +class CaptureStd: + """ + Context manager to capture: + + - stdout: replay it, clean it up and make it available via `obj.out` + - stderr: replay it and make it available via `obj.err` + + Args: + out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not. + err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not. + replay (`bool`, *optional*, defaults to `True`): Whether to replay or not. + By default each captured stream gets replayed back on context's exit, so that one can see what the test was + doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to + disable this feature. + + Examples: + + ```python + # to capture stdout only with auto-replay + with CaptureStdout() as cs: + print("Secret message") + assert "message" in cs.out + + # to capture stderr only with auto-replay + import sys + + with CaptureStderr() as cs: + print("Warning: ", file=sys.stderr) + assert "Warning" in cs.err + + # to capture both streams with auto-replay + with CaptureStd() as cs: + print("Secret message") + print("Warning: ", file=sys.stderr) + assert "message" in cs.out + assert "Warning" in cs.err + + # to capture just one of the streams, and not the other, with auto-replay + with CaptureStd(err=False) as cs: + print("Secret message") + assert "message" in cs.out + # but best use the stream-specific subclasses + + # to capture without auto-replay + with CaptureStd(replay=False) as cs: + print("Secret message") + assert "message" in cs.out + ```""" + + def __init__(self, out=True, err=True, replay=True): + self.replay = replay + + if out: + self.out_buf = StringIO() + self.out = "error: CaptureStd context is unfinished yet, called too early" + else: + self.out_buf = None + self.out = "not capturing stdout" + + if err: + self.err_buf = StringIO() + self.err = "error: CaptureStd context is unfinished yet, called too early" + else: + self.err_buf = None + self.err = "not capturing stderr" + + def __enter__(self): + if self.out_buf: + self.out_old = sys.stdout + sys.stdout = self.out_buf + + if self.err_buf: + self.err_old = sys.stderr + sys.stderr = self.err_buf + + return self + + def __exit__(self, *exc): + if self.out_buf: + sys.stdout = self.out_old + captured = self.out_buf.getvalue() + if self.replay: + sys.stdout.write(captured) + self.out = apply_print_resets(captured) + + if self.err_buf: + sys.stderr = self.err_old + captured = self.err_buf.getvalue() + if self.replay: + sys.stderr.write(captured) + self.err = captured + + def __repr__(self): + msg = "" + if self.out_buf: + msg += f"stdout: {self.out}\n" + if self.err_buf: + msg += f"stderr: {self.err}\n" + return msg + + +# in tests it's the best to capture only the stream that's wanted, otherwise +# it's easy to miss things, so unless you need to capture both streams, use the +# subclasses below (less typing). Or alternatively, configure `CaptureStd` to +# disable the stream you don't need to test. + + +class CaptureStdout(CaptureStd): + """Same as CaptureStd but captures only stdout""" + + def __init__(self, replay=True): + super().__init__(err=False, replay=replay) + + +class CaptureStderr(CaptureStd): + """Same as CaptureStd but captures only stderr""" + + def __init__(self, replay=True): + super().__init__(out=False, replay=replay) + + +class CaptureLogger: + """ + Context manager to capture `logging` streams + + Args: + logger: 'logging` logger object + + Returns: + The captured output is available via `self.out` + + Example: + + ```python + >>> from transformers import logging + >>> from transformers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + + def __init__(self, logger): + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + return f"captured: {self.out}\n" + + +@contextlib.contextmanager +def LoggingLevel(level): + """ + This is a context manager to temporarily change transformers modules logging level to the desired value and have it + restored to the original setting at the end of the scope. + + Example: + + ```python + with LoggingLevel(logging.INFO): + AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times + ``` + """ + orig_level = transformers_logging.get_verbosity() + try: + transformers_logging.set_verbosity(level) + yield + finally: + transformers_logging.set_verbosity(orig_level) + + +class TemporaryHubRepo: + """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to + `tempfile.TemporaryDirectory` and can be used as a context manager. For example: + + with TemporaryHubRepo(token=self._token) as temp_repo: + ... + + Upon exiting the context, the repository and everything contained in it are removed. + + Example: + + ```python + with TemporaryHubRepo(token=self._token) as temp_repo: + model.push_to_hub(tmp_repo.repo_id, token=self._token) + ``` + """ + + def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None: + self.token = token + with tempfile.TemporaryDirectory() as tmp_dir: + repo_id = Path(tmp_dir).name + if namespace is not None: + repo_id = f"{namespace}/{repo_id}" + self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token) + + def __enter__(self): + return self.repo_url + + def __exit__(self, exc, value, tb): + delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True) + + +@contextlib.contextmanager +# adapted from https://stackoverflow.com/a/64789046/9201239 +def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: + """ + Temporary add given path to `sys.path`. + + Usage : + + ```python + with ExtendSysPath("/path/to/dir"): + mymodule = importlib.import_module("mymodule") + ``` + """ + + path = os.fspath(path) + try: + sys.path.insert(0, path) + yield + finally: + sys.path.remove(path) + + +class TestCasePlus(unittest.TestCase): + """ + This class extends *unittest.TestCase* with additional features. + + Feature 1: A set of fully resolved important file and dir path accessors. + + In tests often we need to know where things are relative to the current test file, and it's not trivial since the + test could be invoked from more than one directory or could reside in sub-directories with different depths. This + class solves this problem by sorting out all the basic paths and provides easy accessors to them: + + - `pathlib` objects (all fully resolved): + + - `test_file_path` - the current test file path (=`__file__`) + - `test_file_dir` - the directory containing the current test file + - `tests_dir` - the directory of the `tests` test suite + - `examples_dir` - the directory of the `examples` test suite + - `repo_root_dir` - the directory of the repository + - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides) + + - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects: + + - `test_file_path_str` + - `test_file_dir_str` + - `tests_dir_str` + - `examples_dir_str` + - `repo_root_dir_str` + - `src_dir_str` + + Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test. + + 1. Create a unique temporary dir: + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir() + ``` + + `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the + test. + + + 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't + empty it after the test. + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir("./xxx") + ``` + + This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests + didn't leave any data in there. + + 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the + following behavior: + + `before=True`: the temporary dir will always be cleared at the beginning of the test. + + `before=False`: if the temporary dir already existed, any existing files will remain there. + + `after=True`: the temporary dir will always be deleted at the end of the test. + + `after=False`: the temporary dir will always be left intact at the end of the test. + + Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are + allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem + will get nuked. i.e. please always pass paths that start with `./` + + Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested + otherwise. + + Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This + is useful for invoking external programs from the test suite - e.g. distributed training. + + + ```python + def test_whatever(self): + env = self.get_env() + ```""" + + def setUp(self): + # get_auto_remove_tmp_dir feature: + self.teardown_tmp_dirs = [] + + # figure out the resolved paths for repo_root, tests, examples, etc. + self._test_file_path = inspect.getfile(self.__class__) + path = Path(self._test_file_path).resolve() + self._test_file_dir = path.parents[0] + for up in [1, 2, 3]: + tmp_dir = path.parents[up] + if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir(): + break + if tmp_dir: + self._repo_root_dir = tmp_dir + else: + raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}") + self._tests_dir = self._repo_root_dir / "tests" + self._examples_dir = self._repo_root_dir / "examples" + self._src_dir = self._repo_root_dir / "src" + + @property + def test_file_path(self): + return self._test_file_path + + @property + def test_file_path_str(self): + return str(self._test_file_path) + + @property + def test_file_dir(self): + return self._test_file_dir + + @property + def test_file_dir_str(self): + return str(self._test_file_dir) + + @property + def tests_dir(self): + return self._tests_dir + + @property + def tests_dir_str(self): + return str(self._tests_dir) + + @property + def examples_dir(self): + return self._examples_dir + + @property + def examples_dir_str(self): + return str(self._examples_dir) + + @property + def repo_root_dir(self): + return self._repo_root_dir + + @property + def repo_root_dir_str(self): + return str(self._repo_root_dir) + + @property + def src_dir(self): + return self._src_dir + + @property + def src_dir_str(self): + return str(self._src_dir) + + def get_env(self): + """ + Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's + invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training. + + It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally + the preset `PYTHONPATH` if any (all full resolved paths). + + """ + env = os.environ.copy() + paths = [self.src_dir_str] + if "/examples" in self.test_file_dir_str: + paths.append(self.examples_dir_str) + else: + paths.append(self.tests_dir_str) + paths.append(env.get("PYTHONPATH", "")) + + env["PYTHONPATH"] = ":".join(paths) + return env + + def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): + """ + Args: + tmp_dir (`string`, *optional*): + if `None`: + + - a unique temporary path will be created + - sets `before=True` if `before` is `None` + - sets `after=True` if `after` is `None` + else: + + - `tmp_dir` will be created + - sets `before=True` if `before` is `None` + - sets `after=False` if `after` is `None` + before (`bool`, *optional*): + If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the + `tmp_dir` already exists, any existing files will remain there. + after (`bool`, *optional*): + If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents + intact at the end of the test. + + Returns: + tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir + """ + if tmp_dir is not None: + # defining the most likely desired behavior for when a custom path is provided. + # this most likely indicates the debug mode where we want an easily locatable dir that: + # 1. gets cleared out before the test (if it already exists) + # 2. is left intact after the test + if before is None: + before = True + if after is None: + after = False + + # using provided path + path = Path(tmp_dir).resolve() + + # to avoid nuking parts of the filesystem, only relative paths are allowed + if not tmp_dir.startswith("./"): + raise ValueError( + f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`" + ) + + # ensure the dir is empty to start with + if before is True and path.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + path.mkdir(parents=True, exist_ok=True) + + else: + # defining the most likely desired behavior for when a unique tmp path is auto generated + # (not a debug mode), here we require a unique tmp dir that: + # 1. is empty before the test (it will be empty in this situation anyway) + # 2. gets fully removed after the test + if before is None: + before = True + if after is None: + after = True + + # using unique tmp dir (always empty, regardless of `before`) + tmp_dir = tempfile.mkdtemp() + + if after is True: + # register for deletion + self.teardown_tmp_dirs.append(tmp_dir) + + return tmp_dir + + def python_one_liner_max_rss(self, one_liner_str): + """ + Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the + program. + + Args: + one_liner_str (`string`): + a python one liner code that gets passed to `python -c` + + Returns: + max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run. + + Requirements: + this helper needs `/usr/bin/time` to be installed (`apt install time`) + + Example: + + ``` + one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")' + max_rss = self.python_one_liner_max_rss(one_liner_str) + ``` + """ + + if not cmd_exists("/usr/bin/time"): + raise ValueError("/usr/bin/time is required, install with `apt install time`") + + cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'") + with CaptureStd() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + # returned data is in KB so convert to bytes + max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024 + return max_rss + + def tearDown(self): + # get_auto_remove_tmp_dir feature: remove registered temp dirs + for path in self.teardown_tmp_dirs: + shutil.rmtree(path, ignore_errors=True) + self.teardown_tmp_dirs = [] + if is_accelerate_available(): + AcceleratorState._reset_state() + PartialState._reset_state() + + # delete all the env variables having `ACCELERATE` in them + for k in list(os.environ.keys()): + if "ACCELERATE" in k: + del os.environ[k] + + +def mockenv(**kwargs): + """ + this is a convenience wrapper, that allows this :: + + @mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): + run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False) + + """ + return mock.patch.dict(os.environ, kwargs) + + +# from https://stackoverflow.com/a/34333710/9201239 +@contextlib.contextmanager +def mockenv_context(*remove, **update): + """ + Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv + + The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations. + + Args: + remove: Environment variables to remove. + update: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + [env.pop(k, None) for k in remove] + yield + finally: + env.update(update_after) + [env.pop(k) for k in remove_after] + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, id): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal + changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` + plugins and interfere. + + """ + from _pytest.config import create_terminal_writer + + if not len(id): + id = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dir = f"reports/{id}" + Path(dir).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dir}/{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + + # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it + # takes > 10 minutes (as this part doesn't generate any output on the terminal). + # (also, it seems there is no useful information in this report, and we rarely need to read it) + # with open(report_files["passes"], "w") as f: + # tr._tw = create_terminal_writer(config, f) + # tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +# --- distributed testing functions --- # + +# adapted from https://stackoverflow.com/a/59041913/9201239 +import asyncio # noqa + + +class _RunOutput: + def __init__(self, returncode, stdout, stderr): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + # XXX: the timeout doesn't seem to make any difference here + await asyncio.wait( + [ + _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")), + _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) + ) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + # check that the subprocess actually did run and produced some output, should the test rely on + # the remote side to do the testing + if not result.stdout and not result.stderr: + raise RuntimeError(f"'{cmd_str}' produced no output.") + + return result + + +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 + if `-n 1` or `pytest-xdist` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.M) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. + + Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same + port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta + + +def nested_simplify(obj, decimals=3): + """ + Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test + within tests. + """ + import numpy as np + + if isinstance(obj, list): + return [nested_simplify(item, decimals) for item in obj] + if isinstance(obj, tuple): + return tuple([nested_simplify(item, decimals) for item in obj]) + elif isinstance(obj, np.ndarray): + return nested_simplify(obj.tolist()) + elif isinstance(obj, Mapping): + return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} + elif isinstance(obj, (str, int, np.int64)): + return obj + elif obj is None: + return obj + elif is_torch_available() and isinstance(obj, torch.Tensor): + return nested_simplify(obj.tolist(), decimals) + elif is_tf_available() and tf.is_tensor(obj): + return nested_simplify(obj.numpy().tolist()) + elif isinstance(obj, float): + return round(obj, decimals) + elif isinstance(obj, (np.int32, np.float32, np.float16)): + return nested_simplify(obj.item(), decimals) + else: + raise Exception(f"Not supported: {type(obj)}") + + +def check_json_file_has_correct_format(file_path): + with open(file_path, "r") as f: + lines = f.readlines() + if len(lines) == 1: + # length can only be 1 if dict is empty + assert lines[0] == "{}" + else: + # otherwise make sure json has correct format (at least 3 lines) + assert len(lines) >= 3 + # each key one line, ident should be 2, min length is 3 + assert lines[0].strip() == "{" + for line in lines[1:-1]: + left_indent = len(lines[1]) - len(lines[1].lstrip()) + assert left_indent == 2 + assert lines[-1].strip() == "}" + + +def to_2tuple(x): + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# These utils relate to ensuring the right error message is received when running scripts +class SubprocessCallException(Exception): + pass + + +def run_command(command: List[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occured while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +class RequestCounter: + """ + Helper class that will count all requests made online. + + Might not be robust if urllib3 changes its logging format but should be good enough for us. + + Usage: + ```py + with RequestCounter() as counter: + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + assert counter["GET"] == 0 + assert counter["HEAD"] == 1 + assert counter.total_calls == 1 + ``` + """ + + def __enter__(self): + self._counter = defaultdict(int) + self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug) + self.mock = self.patcher.start() + return self + + def __exit__(self, *args, **kwargs) -> None: + for call in self.mock.call_args_list: + log = call.args[0] % call.args[1:] + for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): + if method in log: + self._counter[method] += 1 + break + self.patcher.stop() + + def __getitem__(self, key: str) -> int: + return self._counter[key] + + @property + def total_calls(self) -> int: + return sum(self._counter.values()) + + +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + description (`str`, *optional*): + A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, + etc.) + """ + + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') + + +def run_test_using_subprocess(func): + """ + To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory + issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`). + """ + import pytest + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if os.getenv("_INSIDE_SUB_PROCESS", None) == "1": + func(*args, **kwargs) + else: + test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1]) + try: + import copy + + env = copy.deepcopy(os.environ) + env["_INSIDE_SUB_PROCESS"] = "1" + # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the + # full information can be passed to the parent pytest process. + # See: https://docs.pytest.org/en/stable/explanation/ci.html + env["CI"] = "true" + + # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments + if "pytestconfig" in kwargs: + command = list(kwargs["pytestconfig"].invocation_params.args) + for idx, x in enumerate(command): + if x in kwargs["pytestconfig"].args: + test = test.split("::")[1:] + command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test) + command = [f"{sys.executable}", "-m", "pytest"] + command + command = [x for x in command if x not in ["--no-summary"]] + # Otherwise, simply run the test with no option at all + else: + command = [f"{sys.executable}", "-m", "pytest", f"{test}"] + + subprocess.run(command, env=env, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + exception_message = e.stdout.decode() + lines = exception_message.split("\n") + # Add a first line with more informative information instead of just `= test session starts =`. + # This makes the `short test summary info` section more useful. + if "= test session starts =" in lines[0]: + text = "" + for line in lines[1:]: + if line.startswith("FAILED "): + text = line[len("FAILED ") :] + text = "".join(text.split(" - ")[1:]) + elif line.startswith("=") and line.endswith("=") and " failed in " in line: + break + elif len(text) > 0: + text += f"\n{line}" + text = "(subprocess) " + text + lines = [text] + lines + exception_message = "\n".join(lines) + raise pytest.fail(exception_message, pytrace=False) + + return wrapper + + +""" +The following contains utils to run the documentation tests without having to overwrite any files. + +The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is +made as a print would otherwise fail the corresonding line. + +To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules +""" + + +def preprocess_string(string, skip_cuda_tests): + """Prepare a docstring or a `.md` file to be run by doctest. + + The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of + its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a + cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for + `string`. + """ + codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:.*?\n)*?.*?```)" + codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string) + is_cuda_found = False + for i, codeblock in enumerate(codeblocks): + if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock: + codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock) + if ( + (">>>" in codeblock or "..." in codeblock) + and re.search(r"cuda|to\(0\)|device=0", codeblock) + and skip_cuda_tests + ): + is_cuda_found = True + break + + modified_string = "" + if not is_cuda_found: + modified_string = "".join(codeblocks) + + return modified_string + + +class HfDocTestParser(doctest.DocTestParser): + """ + Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This + means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also + added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line. + + Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough. + """ + + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + # fmt: off + _EXAMPLE_RE = re.compile(r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) >>> .*) # PS1 line + (?:\n [ ]* \.\.\. .*)*) # PS2 lines + \n? + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*>>>) # Not a line starting with PS1 + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:\n|$) # Match a new line or end of string + )*) + ''', re.MULTILINE | re.VERBOSE + ) + # fmt: on + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False)) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + + def parse(self, string, name=""): + """ + Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before + calling `super().parse` + """ + string = preprocess_string(string, self.skip_cuda_tests) + return super().parse(string, name) + + +class HfDoctestModule(Module): + """ + Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering + tests. + """ + + def collect(self) -> Iterable[DoctestItem]: + class MockAwareDocTestFinder(doctest.DocTestFinder): + """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug. + + https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532 + """ + + def _find_lineno(self, obj, source_lines): + """Doctest code does not take into account `@property`, this + is a hackish way to fix it. https://bugs.python.org/issue17446 + + Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be + reported upstream. #8796 + """ + if isinstance(obj, property): + obj = getattr(obj, "fget", obj) + + if hasattr(obj, "__wrapped__"): + # Get the main obj in case of it being wrapped + obj = inspect.unwrap(obj) + + # Type ignored because this is a private function. + return super()._find_lineno( # type:ignore[misc] + obj, + source_lines, + ) + + def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: + if _is_mocked(obj): + return + with _patch_unwrap_mock_aware(): + # Type ignored because this is a private function. + super()._find( # type:ignore[misc] + tests, obj, name, module, source_lines, globs, seen + ) + + if self.path.name == "conftest.py": + module = self.config.pluginmanager._importconftest( + self.path, + self.config.getoption("importmode"), + rootpath=self.config.rootpath, + ) + else: + try: + module = import_path( + self.path, + root=self.config.rootpath, + mode=self.config.getoption("importmode"), + ) + except ImportError: + if self.config.getvalue("doctest_ignore_import_errors"): + skip("unable to import module %r" % self.path) + else: + raise + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + finder = MockAwareDocTestFinder(parser=HfDocTestParser()) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + optionflags = get_optionflags(self) + runner = _get_runner( + verbose=False, + optionflags=optionflags, + checker=_get_checker(), + continue_on_failure=_get_continue_on_failure(self.config), + ) + for test in finder.find(module, module.__name__): + if test.examples: # skip empty doctests and cuda + yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) + + +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against `None` + # instead at user level. + if fn is None: + return None + return fn(*args, **kwargs) + + +if is_torch_available(): + # Mappings from device names to callable functions to support device agnostic + # testing. + BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} + BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "default": None} + BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "default": lambda: 1} +else: + BACKEND_MANUAL_SEED = {"default": None} + BACKEND_EMPTY_CACHE = {"default": None} + BACKEND_DEVICE_COUNT = {"default": lambda: 0} + + +def backend_manual_seed(device: str, seed: int): + return _device_agnostic_dispatch(device, BACKEND_MANUAL_SEED, seed) + + +def backend_empty_cache(device: str): + return _device_agnostic_dispatch(device, BACKEND_EMPTY_CACHE) + + +def backend_device_count(device: str): + return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) + + +if is_torch_available(): + # If `TRANSFORMERS_TEST_DEVICE_SPEC` is enabled we need to import extra entries + # into device to function mappings. + if "TRANSFORMERS_TEST_DEVICE_SPEC" in os.environ: + device_spec_path = os.environ["TRANSFORMERS_TEST_DEVICE_SPEC"] + if not Path(device_spec_path).is_file(): + raise ValueError( + f"Specified path to device spec file is not a file or not found. Received '{device_spec_path}" + ) + + # Try to strip extension for later import – also verifies we are importing a + # python file. + try: + import_name = device_spec_path[: device_spec_path.index(".py")] + except ValueError as e: + raise ValueError(f"Provided device spec file was not a Python file! Received '{device_spec_path}") from e + + device_spec_module = importlib.import_module(import_name) + + # Imported file must contain `DEVICE_NAME`. If it doesn't, terminate early. + try: + device_name = device_spec_module.DEVICE_NAME + except AttributeError as e: + raise AttributeError("Device spec file did not contain `DEVICE_NAME`") from e + + if "TRANSFORMERS_TEST_DEVICE" in os.environ and torch_device != device_name: + msg = f"Mismatch between environment variable `TRANSFORMERS_TEST_DEVICE` '{torch_device}' and device found in spec '{device_name}'\n" + msg += "Either unset `TRANSFORMERS_TEST_DEVICE` or ensure it matches device spec name." + raise ValueError(msg) + + torch_device = device_name + + def update_mapping_from_spec(device_fn_dict: Dict[str, Callable], attribute_name: str): + try: + # Try to import the function directly + spec_fn = getattr(device_spec_module, attribute_name) + device_fn_dict[torch_device] = spec_fn + except AttributeError as e: + # If the function doesn't exist, and there is no default, throw an error + if "default" not in device_fn_dict: + raise AttributeError( + f"`{attribute_name}` not found in '{device_spec_path}' and no default fallback function found." + ) from e + + # Add one entry here for each `BACKEND_*` dictionary. + update_mapping_from_spec(BACKEND_MANUAL_SEED, "MANUAL_SEED_FN") + update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") + update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") + + +def compare_pipeline_output_to_hub_spec(output, hub_spec): + missing_keys = [] + unexpected_keys = [] + all_field_names = {field.name for field in fields(hub_spec)} + matching_keys = sorted([key for key in output.keys() if key in all_field_names]) + + # Fields with a MISSING default are required and must be in the output + for field in fields(hub_spec): + if field.default is MISSING and field.name not in output: + missing_keys.append(field.name) + + # All output keys must match either a required or optional field in the Hub spec + for output_key in output: + if output_key not in all_field_names: + unexpected_keys.append(output_key) + + if missing_keys or unexpected_keys: + error = ["Pipeline output does not match Hub spec!"] + if matching_keys: + error.append(f"Matching keys: {matching_keys}") + if missing_keys: + error.append(f"Missing required keys in pipeline output: {missing_keys}") + if unexpected_keys: + error.append(f"Keys in pipeline output that are not in Hub spec: {unexpected_keys}") + raise KeyError("\n".join(error)) + + +@require_torch +def cleanup(device: str, gc_collect=False): + if gc_collect: + gc.collect() + backend_empty_cache(device) \ No newline at end of file diff --git a/gptqmodel/integration/transformers/utils/import_utils.py b/gptqmodel/integration/transformers/utils/import_utils.py new file mode 100644 index 000000000..506ebafa4 --- /dev/null +++ b/gptqmodel/integration/transformers/utils/import_utils.py @@ -0,0 +1,2218 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Import utilities: Utilities related to imports and our lazy inits. +""" + +import importlib.machinery +import importlib.metadata +import importlib.util +import json +import os +import shutil +import subprocess +import sys +import warnings +from collections import OrderedDict +from functools import lru_cache +from itertools import chain +from types import ModuleType +from typing import Any, Dict, FrozenSet, Optional, Set, Tuple, Union + +from packaging import version + +from . import logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +# TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check if the package spec exists and grab its version to avoid importing a local directory + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method: Only for "torch" and versions containing "dev" + if pkg_name == "torch": + try: + package = importlib.import_module(pkg_name) + temp_version = getattr(package, "__version__", "N/A") + # Check if the version contains "dev" + if "dev" in temp_version: + package_version = temp_version + package_exists = True + else: + package_exists = False + except ImportError: + # If the package can't be imported, it's not available + package_exists = False + else: + # For packages other than "torch", don't attempt the fallback and set as not available + package_exists = False + logger.debug(f"Detected {pkg_name} version: {package_version}") + if return_version: + return package_exists, package_version + else: + return package_exists + + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() +USE_JAX = os.environ.get("USE_FLAX", "AUTO").upper() + +# Try to run a native pytorch job in an environment with TorchXLA installed by setting this value to 0. +USE_TORCH_XLA = os.environ.get("USE_TORCH_XLA", "1").upper() + +FORCE_TF_AVAILABLE = os.environ.get("FORCE_TF_AVAILABLE", "AUTO").upper() + +# `transformers` requires `torch>=1.11` but this variable is exposed publicly, and we can't simply remove it. +# This is the version of torch required to run torch.fx features and torch.onnx with dictionary inputs. +TORCH_FX_REQUIRED_VERSION = version.parse("1.10") + +ACCELERATE_MIN_VERSION = "0.26.0" +FSDP_MIN_VERSION = "1.12.0" +GGUF_MIN_VERSION = "0.10.0" +XLA_FSDPV2_MIN_VERSION = "2.2.0" +HQQ_MIN_VERSION = "0.2.1" + + +_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) +_apex_available = _is_package_available("apex") +_aqlm_available = _is_package_available("aqlm") +_av_available = importlib.util.find_spec("av") is not None +_bitsandbytes_available = _is_package_available("bitsandbytes") +_eetq_available = _is_package_available("eetq") +_fbgemm_gpu_available = _is_package_available("fbgemm_gpu") +_galore_torch_available = _is_package_available("galore_torch") +_lomo_available = _is_package_available("lomo_optim") +_grokadamw_available = _is_package_available("grokadamw") +_schedulefree_available = _is_package_available("schedulefree") +# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. +_bs4_available = importlib.util.find_spec("bs4") is not None +_coloredlogs_available = _is_package_available("coloredlogs") +# `importlib.metadata.util` doesn't work with `opencv-python-headless`. +_cv2_available = importlib.util.find_spec("cv2") is not None +_datasets_available = _is_package_available("datasets") +_detectron2_available = _is_package_available("detectron2") +# We need to check both `faiss` and `faiss-cpu`. +_faiss_available = importlib.util.find_spec("faiss") is not None +try: + _faiss_version = importlib.metadata.version("faiss") + logger.debug(f"Successfully imported faiss version {_faiss_version}") +except importlib.metadata.PackageNotFoundError: + try: + _faiss_version = importlib.metadata.version("faiss-cpu") + logger.debug(f"Successfully imported faiss version {_faiss_version}") + except importlib.metadata.PackageNotFoundError: + _faiss_available = False +_ftfy_available = _is_package_available("ftfy") +_g2p_en_available = _is_package_available("g2p_en") +_ipex_available, _ipex_version = _is_package_available("intel_extension_for_pytorch", return_version=True) +_jieba_available = _is_package_available("jieba") +_jinja_available = _is_package_available("jinja2") +_kenlm_available = _is_package_available("kenlm") +_keras_nlp_available = _is_package_available("keras_nlp") +_levenshtein_available = _is_package_available("Levenshtein") +_librosa_available = _is_package_available("librosa") +_natten_available = _is_package_available("natten") +_nltk_available = _is_package_available("nltk") +_onnx_available = _is_package_available("onnx") +_openai_available = _is_package_available("openai") +_optimum_available = _is_package_available("optimum") +_auto_gptq_available = _is_package_available("auto_gptq") +_gptqmodel_available = _is_package_available("gptqmodel") +# `importlib.metadata.version` doesn't work with `awq` +_auto_awq_available = importlib.util.find_spec("awq") is not None +_quanto_available = _is_package_available("quanto") +_is_optimum_quanto_available = False +try: + importlib.metadata.version("optimum_quanto") + _is_optimum_quanto_available = True +except importlib.metadata.PackageNotFoundError: + _is_optimum_quanto_available = False +# For compressed_tensors, only check spec to allow compressed_tensors-nightly package +_compressed_tensors_available = importlib.util.find_spec("compressed_tensors") is not None +_pandas_available = _is_package_available("pandas") +_peft_available = _is_package_available("peft") +_phonemizer_available = _is_package_available("phonemizer") +_uroman_available = _is_package_available("uroman") +_psutil_available = _is_package_available("psutil") +_py3nvml_available = _is_package_available("py3nvml") +_pyctcdecode_available = _is_package_available("pyctcdecode") +_pygments_available = _is_package_available("pygments") +_pytesseract_available = _is_package_available("pytesseract") +_pytest_available = _is_package_available("pytest") +_pytorch_quantization_available = _is_package_available("pytorch_quantization") +_rjieba_available = _is_package_available("rjieba") +_sacremoses_available = _is_package_available("sacremoses") +_safetensors_available = _is_package_available("safetensors") +_scipy_available = _is_package_available("scipy") +_sentencepiece_available = _is_package_available("sentencepiece") +_is_seqio_available = _is_package_available("seqio") +_is_gguf_available, _gguf_version = _is_package_available("gguf", return_version=True) +_sklearn_available = importlib.util.find_spec("sklearn") is not None +if _sklearn_available: + try: + importlib.metadata.version("scikit-learn") + except importlib.metadata.PackageNotFoundError: + _sklearn_available = False +_smdistributed_available = importlib.util.find_spec("smdistributed") is not None +_soundfile_available = _is_package_available("soundfile") +_spacy_available = _is_package_available("spacy") +_sudachipy_available, _sudachipy_version = _is_package_available("sudachipy", return_version=True) +_tensorflow_probability_available = _is_package_available("tensorflow_probability") +_tensorflow_text_available = _is_package_available("tensorflow_text") +_tf2onnx_available = _is_package_available("tf2onnx") +_timm_available = _is_package_available("timm") +_tokenizers_available = _is_package_available("tokenizers") +_torchaudio_available = _is_package_available("torchaudio") +_torchao_available = _is_package_available("torchao") +_torchdistx_available = _is_package_available("torchdistx") +_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True) +_mlx_available = _is_package_available("mlx") +_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True) +_tiktoken_available = _is_package_available("tiktoken") +_blobfile_available = _is_package_available("blobfile") +_liger_kernel_available = _is_package_available("liger_kernel") + + +_torch_version = "N/A" +_torch_available = False +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available, _torch_version = _is_package_available("torch", return_version=True) +else: + logger.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +_tf_version = "N/A" +_tf_available = False +if FORCE_TF_AVAILABLE in ENV_VARS_TRUE_VALUES: + _tf_available = True +else: + if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + # Note: _is_package_available("tensorflow") fails for tensorflow-cpu. Please test any changes to the line below + # with tensorflow-cpu to make sure it still works! + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "tf-nightly-rocm", + "intel-tensorflow", + "intel-tensorflow-avx512", + "tensorflow-rocm", + "tensorflow-macos", + "tensorflow-aarch64", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib.metadata.version(pkg) + break + except importlib.metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if version.parse(_tf_version) < version.parse("2"): + logger.info( + f"TensorFlow found but with version {_tf_version}. Transformers requires version 2 minimum." + ) + _tf_available = False + else: + logger.info("Disabling Tensorflow because USE_TORCH is set") + + +_essentia_available = importlib.util.find_spec("essentia") is not None +try: + _essentia_version = importlib.metadata.version("essentia") + logger.debug(f"Successfully imported essentia version {_essentia_version}") +except importlib.metadata.PackageNotFoundError: + _essentia_version = False + + +_pretty_midi_available = importlib.util.find_spec("pretty_midi") is not None +try: + _pretty_midi_version = importlib.metadata.version("pretty_midi") + logger.debug(f"Successfully imported pretty_midi version {_pretty_midi_version}") +except importlib.metadata.PackageNotFoundError: + _pretty_midi_available = False + + +ccl_version = "N/A" +_is_ccl_available = ( + importlib.util.find_spec("torch_ccl") is not None + or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None +) +try: + ccl_version = importlib.metadata.version("oneccl_bind_pt") + logger.debug(f"Detected oneccl_bind_pt version {ccl_version}") +except importlib.metadata.PackageNotFoundError: + _is_ccl_available = False + + +_flax_available = False +if USE_JAX in ENV_VARS_TRUE_AND_AUTO_VALUES: + _flax_available, _flax_version = _is_package_available("flax", return_version=True) + if _flax_available: + _jax_available, _jax_version = _is_package_available("jax", return_version=True) + if _jax_available: + logger.info(f"JAX version {_jax_version}, Flax version {_flax_version} available.") + else: + _flax_available = _jax_available = False + _jax_version = _flax_version = "N/A" + + +_torch_fx_available = False +if _torch_available: + torch_version = version.parse(_torch_version) + _torch_fx_available = (torch_version.major, torch_version.minor) >= ( + TORCH_FX_REQUIRED_VERSION.major, + TORCH_FX_REQUIRED_VERSION.minor, + ) + + +_torch_xla_available = False +if USE_TORCH_XLA in ENV_VARS_TRUE_VALUES: + _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla", return_version=True) + if _torch_xla_available: + logger.info(f"Torch XLA version {_torch_xla_version} available.") + + +def is_kenlm_available(): + return _kenlm_available + + +def is_cv2_available(): + return _cv2_available + + +def is_torch_available(): + return _torch_available + + +def is_accelerate_available(min_version: str = ACCELERATE_MIN_VERSION): + return _accelerate_available and version.parse(_accelerate_version) >= version.parse(min_version) + + +def is_torch_deterministic(): + """ + Check whether pytorch uses deterministic algorithms by looking if torch.set_deterministic_debug_mode() is set to 1 or 2" + """ + import torch + + if torch.get_deterministic_debug_mode() == 0: + return False + else: + return True + + +def is_hqq_available(min_version: str = HQQ_MIN_VERSION): + return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version) + + +def is_pygments_available(): + return _pygments_available + + +def get_torch_version(): + return _torch_version + + +def is_torch_sdpa_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": + return False + + # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: + # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 + # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 + # NOTE: MLU is OK with non-contiguous inputs. + if is_torch_mlu_available(): + return version.parse(_torch_version) >= version.parse("2.1.0") + # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577 + return version.parse(_torch_version) >= version.parse("2.1.1") + + +def is_torch_flex_attn_available(): + if not is_torch_available(): + return False + elif _torch_version == "N/A": + return False + + # TODO check if some bugs cause push backs on the exact version + # NOTE: We require torch>=2.5.0 as it is the first release + return version.parse(_torch_version) >= version.parse("2.5.0") + + +def is_torchvision_available(): + return _torchvision_available + + +def is_torchvision_v2_available(): + if not is_torchvision_available(): + return False + + # NOTE: We require torchvision>=0.15 as v2 transforms are available from this version: https://pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use + return version.parse(_torchvision_version) >= version.parse("0.15") + + +def is_galore_torch_available(): + return _galore_torch_available + + +def is_lomo_available(): + return _lomo_available + + +def is_grokadamw_available(): + return _grokadamw_available + + +def is_schedulefree_available(): + return _schedulefree_available + + +def is_pyctcdecode_available(): + return _pyctcdecode_available + + +def is_librosa_available(): + return _librosa_available + + +def is_essentia_available(): + return _essentia_available + + +def is_pretty_midi_available(): + return _pretty_midi_available + + +def is_torch_cuda_available(): + if is_torch_available(): + import torch + + return torch.cuda.is_available() + else: + return False + + +def is_mamba_ssm_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + else: + return _is_package_available("mamba_ssm") + return False + + +def is_mamba_2_ssm_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + else: + if _is_package_available("mamba_ssm"): + import mamba_ssm + + if version.parse(mamba_ssm.__version__) >= version.parse("2.0.4"): + return True + return False + + +def is_causal_conv1d_available(): + if is_torch_available(): + import torch + + if not torch.cuda.is_available(): + return False + return _is_package_available("causal_conv1d") + return False + + +def is_mambapy_available(): + if is_torch_available(): + return _is_package_available("mambapy") + return False + + +def is_torch_mps_available(min_version: Optional[str] = None): + if is_torch_available(): + import torch + + if hasattr(torch.backends, "mps"): + backend_available = torch.backends.mps.is_available() and torch.backends.mps.is_built() + if min_version is not None: + flag = version.parse(_torch_version) >= version.parse(min_version) + backend_available = backend_available and flag + return backend_available + return False + + +def is_torch_bf16_gpu_available(): + if not is_torch_available(): + return False + + import torch + + return torch.cuda.is_available() and torch.cuda.is_bf16_supported() + + +def is_torch_bf16_cpu_available(): + if not is_torch_available(): + return False + + import torch + + try: + # multiple levels of AttributeError depending on the pytorch version so do them all in one check + _ = torch.cpu.amp.autocast + except AttributeError: + return False + + return True + + +def is_torch_bf16_available(): + # the original bf16 check was for gpu only, but later a cpu/bf16 combo has emerged so this util + # has become ambiguous and therefore deprecated + warnings.warn( + "The util is_torch_bf16_available is deprecated, please use is_torch_bf16_gpu_available " + "or is_torch_bf16_cpu_available instead according to whether it's used with cpu or gpu", + FutureWarning, + ) + return is_torch_bf16_gpu_available() + + +@lru_cache() +def is_torch_fp16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + try: + x = torch.zeros(2, 2, dtype=torch.float16).to(device) + _ = x @ x + + # At this moment, let's be strict of the check: check if `LayerNorm` is also supported on device, because many + # models use this layer. + batch, sentence_length, embedding_dim = 3, 4, 5 + embedding = torch.randn(batch, sentence_length, embedding_dim, dtype=torch.float16, device=device) + layer_norm = torch.nn.LayerNorm(embedding_dim, dtype=torch.float16, device=device) + _ = layer_norm(embedding) + + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + +@lru_cache() +def is_torch_bf16_available_on_device(device): + if not is_torch_available(): + return False + + import torch + + if device == "cuda": + return is_torch_bf16_gpu_available() + + try: + x = torch.zeros(2, 2, dtype=torch.bfloat16).to(device) + _ = x @ x + except: # noqa: E722 + # TODO: more precise exception matching, if possible. + # most backends should return `RuntimeError` however this is not guaranteed. + return False + + return True + + +def is_torch_tf32_available(): + if not is_torch_available(): + return False + + import torch + + if not torch.cuda.is_available() or torch.version.cuda is None: + return False + if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: + return False + if int(torch.version.cuda.split(".")[0]) < 11: + return False + if version.parse(version.parse(torch.__version__).base_version) < version.parse("1.7"): + return False + + return True + + +def is_torch_fx_available(): + return _torch_fx_available + + +def is_peft_available(): + return _peft_available + + +def is_bs4_available(): + return _bs4_available + + +def is_tf_available(): + return _tf_available + + +def is_coloredlogs_available(): + return _coloredlogs_available + + +def is_tf2onnx_available(): + return _tf2onnx_available + + +def is_onnx_available(): + return _onnx_available + + +def is_openai_available(): + return _openai_available + + +def is_flax_available(): + return _flax_available + + +def is_ftfy_available(): + return _ftfy_available + + +def is_g2p_en_available(): + return _g2p_en_available + + +@lru_cache() +def is_torch_tpu_available(check_device=True): + "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" + warnings.warn( + "`is_torch_tpu_available` is deprecated and will be removed in 4.41.0. " + "Please use the `is_torch_xla_available` instead.", + FutureWarning, + ) + + if not _torch_available: + return False + if importlib.util.find_spec("torch_xla") is not None: + if check_device: + # We need to check if `xla_device` can be found, will raise a RuntimeError if not + try: + import torch_xla.core.xla_model as xm + + _ = xm.xla_device() + return True + except RuntimeError: + return False + return True + return False + + +@lru_cache +def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): + """ + Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set + the USE_TORCH_XLA to false. + """ + assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." + + if not _torch_xla_available: + return False + + import torch_xla + + if check_is_gpu: + return torch_xla.runtime.device_type() in ["GPU", "CUDA"] + elif check_is_tpu: + return torch_xla.runtime.device_type() == "TPU" + + return True + + +@lru_cache() +def is_torch_neuroncore_available(check_device=True): + if importlib.util.find_spec("torch_neuronx") is not None: + return is_torch_xla_available() + return False + + +@lru_cache() +def is_torch_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if not _torch_available or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + +@lru_cache() +def is_torch_mlu_available(check_device=False): + """ + Checks if `mlu` is available via an `cndev-based` check which won't trigger the drivers and leave mlu + uninitialized. + """ + if not _torch_available or importlib.util.find_spec("torch_mlu") is None: + return False + + import torch + import torch_mlu # noqa: F401 + + pytorch_cndev_based_mlu_check_previous_value = os.environ.get("PYTORCH_CNDEV_BASED_MLU_CHECK") + try: + os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = str(1) + available = torch.mlu.is_available() + finally: + if pytorch_cndev_based_mlu_check_previous_value: + os.environ["PYTORCH_CNDEV_BASED_MLU_CHECK"] = pytorch_cndev_based_mlu_check_previous_value + else: + os.environ.pop("PYTORCH_CNDEV_BASED_MLU_CHECK", None) + + return available + + +@lru_cache() +def is_torch_musa_available(check_device=False): + "Checks if `torch_musa` is installed and potentially if a MUSA is in the environment" + if not _torch_available or importlib.util.find_spec("torch_musa") is None: + return False + + import torch + import torch_musa # noqa: F401 + + torch_musa_min_version = "0.33.0" + if _accelerate_available and version.parse(_accelerate_version) < version.parse(torch_musa_min_version): + return False + + if check_device: + try: + # Will raise a RuntimeError if no MUSA is found + _ = torch.musa.device_count() + return torch.musa.is_available() + except RuntimeError: + return False + return hasattr(torch, "musa") and torch.musa.is_available() + + +def is_torchdynamo_available(): + if not is_torch_available(): + return False + + return version.parse(_torch_version) >= version.parse("2.0.0") + + +def is_torch_compile_available(): + if not is_torch_available(): + return False + + import torch + + # We don't do any version check here to support nighlies marked as 1.14. Ultimately needs to check version against + # 2.0 but let's do it later. + return hasattr(torch, "compile") + + +def is_torchdynamo_compiling(): + if not is_torch_available(): + return False + + # Importing torch._dynamo causes issues with PyTorch profiler (https://github.com/pytorch/pytorch/issues/130622) + # hence rather relying on `torch.compiler.is_compiling()` when possible (torch>=2.3) + try: + import torch + + return torch.compiler.is_compiling() + except Exception: + try: + import torch._dynamo as dynamo # noqa: F401 + + return dynamo.is_compiling() + except Exception: + return False + + +def is_torch_tensorrt_fx_available(): + if importlib.util.find_spec("torch_tensorrt") is None: + return False + return importlib.util.find_spec("torch_tensorrt.fx") is not None + + +def is_datasets_available(): + return _datasets_available + + +def is_detectron2_available(): + return _detectron2_available + + +def is_rjieba_available(): + return _rjieba_available + + +def is_psutil_available(): + return _psutil_available + + +def is_py3nvml_available(): + return _py3nvml_available + + +def is_sacremoses_available(): + return _sacremoses_available + + +def is_apex_available(): + return _apex_available + + +def is_aqlm_available(): + return _aqlm_available + + +def is_av_available(): + return _av_available + + +def is_ninja_available(): + r""" + Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the + [ninja](https://ninja-build.org/) build system is available on the system, `False` otherwise. + """ + try: + subprocess.check_output("ninja --version".split()) + except Exception: + return False + else: + return True + + +def is_ipex_available(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + if not is_torch_available() or not _ipex_available: + return False + + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + logger.warning( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True + + +@lru_cache +def is_torch_xpu_available(check_device=False): + """ + Checks if XPU acceleration is available either via `intel_extension_for_pytorch` or + via stock PyTorch (>=2.4) and potentially if a XPU is in the environment + """ + if not is_torch_available(): + return False + + torch_version = version.parse(_torch_version) + if is_ipex_available(): + import intel_extension_for_pytorch # noqa: F401 + elif torch_version.major < 2 or (torch_version.major == 2 and torch_version.minor < 4): + return False + + import torch + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache() +def is_bitsandbytes_available(): + if not is_torch_available() or not _bitsandbytes_available: + return False + + import torch + + # `bitsandbytes` versions older than 0.43.1 eagerly require CUDA at import time, + # so those versions of the library are practically only available when CUDA is too. + if version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.1"): + return torch.cuda.is_available() + + # Newer versions of `bitsandbytes` can be imported on systems without CUDA. + return True + + +def is_bitsandbytes_multi_backend_available() -> bool: + if not is_bitsandbytes_available(): + return False + + import bitsandbytes as bnb + + return "multi_backend" in getattr(bnb, "features", set()) + + +def is_flash_attn_2_available(): + if not is_torch_available(): + return False + + if not _is_package_available("flash_attn"): + return False + + # Let's add an extra check to see if cuda is available + import torch + + if not (torch.cuda.is_available() or is_torch_mlu_available()): + return False + + if torch.version.cuda: + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") + elif torch.version.hip: + # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4") + elif is_torch_mlu_available(): + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.3.3") + else: + return False + + +@lru_cache() +def is_flash_attn_greater_or_equal_2_10(): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0") + + +@lru_cache() +def is_flash_attn_greater_or_equal(library_version: str): + if not _is_package_available("flash_attn"): + return False + + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) + + +@lru_cache() +def is_torch_greater_or_equal(library_version: str): + if not _is_package_available("torch"): + return False + + return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) + + +def is_torchdistx_available(): + return _torchdistx_available + + +def is_faiss_available(): + return _faiss_available + + +def is_scipy_available(): + return _scipy_available + + +def is_sklearn_available(): + return _sklearn_available + + +def is_sentencepiece_available(): + return _sentencepiece_available + + +def is_seqio_available(): + return _is_seqio_available + + +def is_gguf_available(min_version: str = GGUF_MIN_VERSION): + return _is_gguf_available and version.parse(_gguf_version) >= version.parse(min_version) + + +def is_protobuf_available(): + if importlib.util.find_spec("google") is None: + return False + return importlib.util.find_spec("google.protobuf") is not None + + +def is_fsdp_available(min_version: str = FSDP_MIN_VERSION): + return is_torch_available() and version.parse(_torch_version) >= version.parse(min_version) + + +def is_optimum_available(): + return _optimum_available + + +def is_auto_awq_available(): + return _auto_awq_available + + +def is_optimum_quanto_available(): + # `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto` + return _is_optimum_quanto_available + + +def is_compressed_tensors_available(): + return _compressed_tensors_available + + +def is_auto_gptq_available(): + return _auto_gptq_available + + +def is_gptqmodel_available(): + return _gptqmodel_available + + +def is_eetq_available(): + return _eetq_available + + +def is_fbgemm_gpu_available(): + return _fbgemm_gpu_available + + +def is_levenshtein_available(): + return _levenshtein_available + + +def is_optimum_neuron_available(): + return _optimum_available and _is_package_available("optimum.neuron") + + +def is_safetensors_available(): + return _safetensors_available + + +def is_tokenizers_available(): + return _tokenizers_available + + +@lru_cache +def is_vision_available(): + _pil_available = importlib.util.find_spec("PIL") is not None + if _pil_available: + try: + package_version = importlib.metadata.version("Pillow") + except importlib.metadata.PackageNotFoundError: + try: + package_version = importlib.metadata.version("Pillow-SIMD") + except importlib.metadata.PackageNotFoundError: + return False + logger.debug(f"Detected PIL version {package_version}") + return _pil_available + + +def is_pytesseract_available(): + return _pytesseract_available + + +def is_pytest_available(): + return _pytest_available + + +def is_spacy_available(): + return _spacy_available + + +def is_tensorflow_text_available(): + return is_tf_available() and _tensorflow_text_available + + +def is_keras_nlp_available(): + return is_tensorflow_text_available() and _keras_nlp_available + + +def is_in_notebook(): + try: + # Test adapted from tqdm.autonotebook: https://github.com/tqdm/tqdm/blob/master/tqdm/autonotebook.py + get_ipython = sys.modules["IPython"].get_ipython + if "IPKernelApp" not in get_ipython().config: + raise ImportError("console") + if "VSCODE_PID" in os.environ: + raise ImportError("vscode") + if "DATABRICKS_RUNTIME_VERSION" in os.environ and os.environ["DATABRICKS_RUNTIME_VERSION"] < "11.0": + # Databricks Runtime 11.0 and above uses IPython kernel by default so it should be compatible with Jupyter notebook + # https://docs.microsoft.com/en-us/azure/databricks/notebooks/ipython-kernel + raise ImportError("databricks") + + return importlib.util.find_spec("IPython") is not None + except (AttributeError, ImportError, KeyError): + return False + + +def is_pytorch_quantization_available(): + return _pytorch_quantization_available + + +def is_tensorflow_probability_available(): + return _tensorflow_probability_available + + +def is_pandas_available(): + return _pandas_available + + +def is_sagemaker_dp_enabled(): + # Get the sagemaker specific env variable. + sagemaker_params = os.getenv("SM_FRAMEWORK_PARAMS", "{}") + try: + # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". + sagemaker_params = json.loads(sagemaker_params) + if not sagemaker_params.get("sagemaker_distributed_dataparallel_enabled", False): + return False + except json.JSONDecodeError: + return False + # Lastly, check if the `smdistributed` module is present. + return _smdistributed_available + + +def is_sagemaker_mp_enabled(): + # Get the sagemaker specific mp parameters from smp_options variable. + smp_options = os.getenv("SM_HP_MP_PARAMETERS", "{}") + try: + # Parse it and check the field "partitions" is included, it is required for model parallel. + smp_options = json.loads(smp_options) + if "partitions" not in smp_options: + return False + except json.JSONDecodeError: + return False + + # Get the sagemaker specific framework parameters from mpi_options variable. + mpi_options = os.getenv("SM_FRAMEWORK_PARAMS", "{}") + try: + # Parse it and check the field "sagemaker_distributed_dataparallel_enabled". + mpi_options = json.loads(mpi_options) + if not mpi_options.get("sagemaker_mpi_enabled", False): + return False + except json.JSONDecodeError: + return False + # Lastly, check if the `smdistributed` module is present. + return _smdistributed_available + + +def is_training_run_on_sagemaker(): + return "SAGEMAKER_JOB_NAME" in os.environ + + +def is_soundfile_availble(): + return _soundfile_available + + +def is_timm_available(): + return _timm_available + + +def is_natten_available(): + return _natten_available + + +def is_nltk_available(): + return _nltk_available + + +def is_torchaudio_available(): + return _torchaudio_available + + +def is_torchao_available(): + return _torchao_available + + +def is_speech_available(): + # For now this depends on torchaudio but the exact dependency might evolve in the future. + return _torchaudio_available + + +def is_phonemizer_available(): + return _phonemizer_available + + +def is_uroman_available(): + return _uroman_available + + +def torch_only_method(fn): + def wrapper(*args, **kwargs): + if not _torch_available: + raise ImportError( + "You need to install pytorch to use this method or class, " + "or activate it with environment variables USE_TORCH=1 and USE_TF=0." + ) + else: + return fn(*args, **kwargs) + + return wrapper + + +def is_ccl_available(): + return _is_ccl_available + + +def is_sudachi_available(): + return _sudachipy_available + + +def get_sudachi_version(): + return _sudachipy_version + + +def is_sudachi_projection_available(): + if not is_sudachi_available(): + return False + + # NOTE: We require sudachipy>=0.6.8 to use projection option in sudachi_kwargs for the constructor of BertJapaneseTokenizer. + # - `projection` option is not supported in sudachipy<0.6.8, see https://github.com/WorksApplications/sudachi.rs/issues/230 + return version.parse(_sudachipy_version) >= version.parse("0.6.8") + + +def is_jumanpp_available(): + return (importlib.util.find_spec("rhoknp") is not None) and (shutil.which("jumanpp") is not None) + + +def is_cython_available(): + return importlib.util.find_spec("pyximport") is not None + + +def is_jieba_available(): + return _jieba_available + + +def is_jinja_available(): + return _jinja_available + + +def is_mlx_available(): + return _mlx_available + + +def is_tiktoken_available(): + return _tiktoken_available and _blobfile_available + + +def is_liger_kernel_available(): + if not _liger_kernel_available: + return False + + return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") + + +# docstyle-ignore +AV_IMPORT_ERROR = """ +{0} requires the PyAv library but it was not found in your environment. You can install it with: +``` +pip install av +``` +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +CV2_IMPORT_ERROR = """ +{0} requires the OpenCV library but it was not found in your environment. You can install it with: +``` +pip install opencv-python +``` +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +DATASETS_IMPORT_ERROR = """ +{0} requires the 🤗 Datasets library but it was not found in your environment. You can install it with: +``` +pip install datasets +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install datasets +``` +then restarting your kernel. + +Note that if you have a local folder named `datasets` or a local python file named `datasets.py` in your current +working directory, python may try to import this instead of the 🤗 Datasets library. You should rename this folder or +that python file if that's the case. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +TOKENIZERS_IMPORT_ERROR = """ +{0} requires the 🤗 Tokenizers library but it was not found in your environment. You can install it with: +``` +pip install tokenizers +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install tokenizers +``` +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +SENTENCEPIECE_IMPORT_ERROR = """ +{0} requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +PROTOBUF_IMPORT_ERROR = """ +{0} requires the protobuf library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/protocolbuffers/protobuf/tree/master/python#installation and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +FAISS_IMPORT_ERROR = """ +{0} requires the faiss library but it was not found in your environment. Checkout the instructions on the +installation page of its repo: https://github.com/facebookresearch/faiss/blob/master/INSTALL.md and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +PYTORCH_IMPORT_ERROR = """ +{0} requires the PyTorch library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +TORCHVISION_IMPORT_ERROR = """ +{0} requires the Torchvision library but it was not found in your environment. Checkout the instructions on the +installation page: https://pytorch.org/get-started/locally/ and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PYTORCH_IMPORT_ERROR_WITH_TF = """ +{0} requires the PyTorch library but it was not found in your environment. +However, we were able to find a TensorFlow installation. TensorFlow classes begin +with "TF", but are otherwise identically named to our PyTorch classes. This +means that the TF equivalent of the class you tried to import would be "TF{0}". +If you want to use TensorFlow, please use TF classes instead! + +If you really do want to use PyTorch please go to +https://pytorch.org/get-started/locally/ and follow the instructions that +match your environment. +""" + +# docstyle-ignore +TF_IMPORT_ERROR_WITH_PYTORCH = """ +{0} requires the TensorFlow library but it was not found in your environment. +However, we were able to find a PyTorch installation. PyTorch classes do not begin +with "TF", but are otherwise identically named to our TF classes. +If you want to use PyTorch, please use those classes instead! + +If you really do want to use TensorFlow, please follow the instructions on the +installation page https://www.tensorflow.org/install that match your environment. +""" + +# docstyle-ignore +BS4_IMPORT_ERROR = """ +{0} requires the Beautiful Soup library but it was not found in your environment. You can install it with pip: +`pip install beautifulsoup4`. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +SKLEARN_IMPORT_ERROR = """ +{0} requires the scikit-learn library but it was not found in your environment. You can install it with: +``` +pip install -U scikit-learn +``` +In a notebook or a colab, you can install it by executing a cell with +``` +!pip install -U scikit-learn +``` +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +TENSORFLOW_IMPORT_ERROR = """ +{0} requires the TensorFlow library but it was not found in your environment. Checkout the instructions on the +installation page: https://www.tensorflow.org/install and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +DETECTRON2_IMPORT_ERROR = """ +{0} requires the detectron2 library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +FLAX_IMPORT_ERROR = """ +{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the +installation page: https://github.com/google/flax and follow the ones that match your environment. +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +FTFY_IMPORT_ERROR = """ +{0} requires the ftfy library but it was not found in your environment. Checkout the instructions on the +installation section: https://github.com/rspeer/python-ftfy/tree/master#installing and follow the ones +that match your environment. Please note that you may need to restart your runtime after installation. +""" + +LEVENSHTEIN_IMPORT_ERROR = """ +{0} requires the python-Levenshtein library but it was not found in your environment. You can install it with pip: `pip +install python-Levenshtein`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +G2P_EN_IMPORT_ERROR = """ +{0} requires the g2p-en library but it was not found in your environment. You can install it with pip: +`pip install g2p-en`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PYTORCH_QUANTIZATION_IMPORT_ERROR = """ +{0} requires the pytorch-quantization library but it was not found in your environment. You can install it with pip: +`pip install pytorch-quantization --extra-index-url https://pypi.ngc.nvidia.com` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TENSORFLOW_PROBABILITY_IMPORT_ERROR = """ +{0} requires the tensorflow_probability library but it was not found in your environment. You can install it with pip as +explained here: https://github.com/tensorflow/probability. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TENSORFLOW_TEXT_IMPORT_ERROR = """ +{0} requires the tensorflow_text library but it was not found in your environment. You can install it with pip as +explained here: https://www.tensorflow.org/text/guide/tf_text_intro. +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TORCHAUDIO_IMPORT_ERROR = """ +{0} requires the torchaudio library but it was not found in your environment. Please install it and restart your +runtime. +""" + +# docstyle-ignore +PANDAS_IMPORT_ERROR = """ +{0} requires the pandas library but it was not found in your environment. You can install it with pip as +explained here: https://pandas.pydata.org/pandas-docs/stable/getting_started/install.html. +Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +PHONEMIZER_IMPORT_ERROR = """ +{0} requires the phonemizer library but it was not found in your environment. You can install it with pip: +`pip install phonemizer`. Please note that you may need to restart your runtime after installation. +""" +# docstyle-ignore +UROMAN_IMPORT_ERROR = """ +{0} requires the uroman library but it was not found in your environment. You can install it with pip: +`pip install uroman`. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +SACREMOSES_IMPORT_ERROR = """ +{0} requires the sacremoses library but it was not found in your environment. You can install it with pip: +`pip install sacremoses`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +SCIPY_IMPORT_ERROR = """ +{0} requires the scipy library but it was not found in your environment. You can install it with pip: +`pip install scipy`. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +SPEECH_IMPORT_ERROR = """ +{0} requires the torchaudio library but it was not found in your environment. You can install it with pip: +`pip install torchaudio`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +TIMM_IMPORT_ERROR = """ +{0} requires the timm library but it was not found in your environment. You can install it with pip: +`pip install timm`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +NATTEN_IMPORT_ERROR = """ +{0} requires the natten library but it was not found in your environment. You can install it by referring to: +shi-labs.com/natten . You can also install it with pip (may take longer to build): +`pip install natten`. Please note that you may need to restart your runtime after installation. +""" + +NUMEXPR_IMPORT_ERROR = """ +{0} requires the numexpr library but it was not found in your environment. You can install it by referring to: +https://numexpr.readthedocs.io/en/latest/index.html. +""" + + +# docstyle-ignore +NLTK_IMPORT_ERROR = """ +{0} requires the NLTK library but it was not found in your environment. You can install it by referring to: +https://www.nltk.org/install.html. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +VISION_IMPORT_ERROR = """ +{0} requires the PIL library but it was not found in your environment. You can install it with pip: +`pip install pillow`. Please note that you may need to restart your runtime after installation. +""" + + +# docstyle-ignore +PYTESSERACT_IMPORT_ERROR = """ +{0} requires the PyTesseract library but it was not found in your environment. You can install it with pip: +`pip install pytesseract`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PYCTCDECODE_IMPORT_ERROR = """ +{0} requires the pyctcdecode library but it was not found in your environment. You can install it with pip: +`pip install pyctcdecode`. Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +ACCELERATE_IMPORT_ERROR = """ +{0} requires the accelerate library >= {ACCELERATE_MIN_VERSION} it was not found in your environment. +You can install or update it with pip: `pip install --upgrade accelerate`. Please note that you may need to restart your +runtime after installation. +""" + +# docstyle-ignore +CCL_IMPORT_ERROR = """ +{0} requires the torch ccl library but it was not found in your environment. You can install it with pip: +`pip install oneccl_bind_pt -f https://developer.intel.com/ipex-whl-stable` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +ESSENTIA_IMPORT_ERROR = """ +{0} requires essentia library. But that was not found in your environment. You can install them with pip: +`pip install essentia==2.1b6.dev1034` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +LIBROSA_IMPORT_ERROR = """ +{0} requires thes librosa library. But that was not found in your environment. You can install them with pip: +`pip install librosa` +Please note that you may need to restart your runtime after installation. +""" + +# docstyle-ignore +PRETTY_MIDI_IMPORT_ERROR = """ +{0} requires thes pretty_midi library. But that was not found in your environment. You can install them with pip: +`pip install pretty_midi` +Please note that you may need to restart your runtime after installation. +""" + + +CYTHON_IMPORT_ERROR = """ +{0} requires the Cython library but it was not found in your environment. You can install it with pip: `pip install +Cython`. Please note that you may need to restart your runtime after installation. +""" + +JIEBA_IMPORT_ERROR = """ +{0} requires the jieba library but it was not found in your environment. You can install it with pip: `pip install +jieba`. Please note that you may need to restart your runtime after installation. +""" + +PEFT_IMPORT_ERROR = """ +{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install +peft`. Please note that you may need to restart your runtime after installation. +""" + +JINJA_IMPORT_ERROR = """ +{0} requires the jinja library but it was not found in your environment. You can install it with pip: `pip install +jinja2`. Please note that you may need to restart your runtime after installation. +""" + +BACKENDS_MAPPING = OrderedDict( + [ + ("av", (is_av_available, AV_IMPORT_ERROR)), + ("bs4", (is_bs4_available, BS4_IMPORT_ERROR)), + ("cv2", (is_cv2_available, CV2_IMPORT_ERROR)), + ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)), + ("detectron2", (is_detectron2_available, DETECTRON2_IMPORT_ERROR)), + ("essentia", (is_essentia_available, ESSENTIA_IMPORT_ERROR)), + ("faiss", (is_faiss_available, FAISS_IMPORT_ERROR)), + ("flax", (is_flax_available, FLAX_IMPORT_ERROR)), + ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), + ("g2p_en", (is_g2p_en_available, G2P_EN_IMPORT_ERROR)), + ("pandas", (is_pandas_available, PANDAS_IMPORT_ERROR)), + ("phonemizer", (is_phonemizer_available, PHONEMIZER_IMPORT_ERROR)), + ("uroman", (is_uroman_available, UROMAN_IMPORT_ERROR)), + ("pretty_midi", (is_pretty_midi_available, PRETTY_MIDI_IMPORT_ERROR)), + ("levenshtein", (is_levenshtein_available, LEVENSHTEIN_IMPORT_ERROR)), + ("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)), + ("protobuf", (is_protobuf_available, PROTOBUF_IMPORT_ERROR)), + ("pyctcdecode", (is_pyctcdecode_available, PYCTCDECODE_IMPORT_ERROR)), + ("pytesseract", (is_pytesseract_available, PYTESSERACT_IMPORT_ERROR)), + ("sacremoses", (is_sacremoses_available, SACREMOSES_IMPORT_ERROR)), + ("pytorch_quantization", (is_pytorch_quantization_available, PYTORCH_QUANTIZATION_IMPORT_ERROR)), + ("sentencepiece", (is_sentencepiece_available, SENTENCEPIECE_IMPORT_ERROR)), + ("sklearn", (is_sklearn_available, SKLEARN_IMPORT_ERROR)), + ("speech", (is_speech_available, SPEECH_IMPORT_ERROR)), + ("tensorflow_probability", (is_tensorflow_probability_available, TENSORFLOW_PROBABILITY_IMPORT_ERROR)), + ("tf", (is_tf_available, TENSORFLOW_IMPORT_ERROR)), + ("tensorflow_text", (is_tensorflow_text_available, TENSORFLOW_TEXT_IMPORT_ERROR)), + ("timm", (is_timm_available, TIMM_IMPORT_ERROR)), + ("torchaudio", (is_torchaudio_available, TORCHAUDIO_IMPORT_ERROR)), + ("natten", (is_natten_available, NATTEN_IMPORT_ERROR)), + ("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)), + ("tokenizers", (is_tokenizers_available, TOKENIZERS_IMPORT_ERROR)), + ("torch", (is_torch_available, PYTORCH_IMPORT_ERROR)), + ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), + ("vision", (is_vision_available, VISION_IMPORT_ERROR)), + ("scipy", (is_scipy_available, SCIPY_IMPORT_ERROR)), + ("accelerate", (is_accelerate_available, ACCELERATE_IMPORT_ERROR)), + ("oneccl_bind_pt", (is_ccl_available, CCL_IMPORT_ERROR)), + ("cython", (is_cython_available, CYTHON_IMPORT_ERROR)), + ("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)), + ("peft", (is_peft_available, PEFT_IMPORT_ERROR)), + ("jinja", (is_jinja_available, JINJA_IMPORT_ERROR)), + ] +) + + +def requires_backends(obj, backends): + if not isinstance(backends, (list, tuple)): + backends = [backends] + + name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__ + + # Raise an error for users who might not realize that classes without "TF" are torch-only + if "torch" in backends and "tf" not in backends and not is_torch_available() and is_tf_available(): + raise ImportError(PYTORCH_IMPORT_ERROR_WITH_TF.format(name)) + + # Raise the inverse error for PyTorch users trying to load TF classes + if "tf" in backends and "torch" not in backends and is_torch_available() and not is_tf_available(): + raise ImportError(TF_IMPORT_ERROR_WITH_PYTORCH.format(name)) + + checks = (BACKENDS_MAPPING[backend] for backend in backends) + failed = [msg.format(name) for available, msg in checks if not available()] + if failed: + raise ImportError("".join(failed)) + + +class DummyObject(type): + """ + Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by + `requires_backend` each time a user tries to access any method of that class. + """ + + def __getattribute__(cls, key): + if key.startswith("_") and key != "_from_config": + return super().__getattribute__(key) + requires_backends(cls, cls._backends) + + +def is_torch_fx_proxy(x): + if is_torch_fx_available(): + import torch.fx + + return isinstance(x, torch.fx.Proxy) + return False + + +BACKENDS_T = FrozenSet[str] +IMPORT_STRUCTURE_T = Dict[BACKENDS_T, Dict[str, Set[str]]] + + +class _LazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__( + self, + name: str, + module_file: str, + import_structure: IMPORT_STRUCTURE_T, + module_spec: importlib.machinery.ModuleSpec = None, + extra_objects: Dict[str, object] = None, + ): + super().__init__(name) + + self._object_missing_backend = {} + if any(isinstance(key, frozenset) for key in import_structure.keys()): + self._modules = set() + self._class_to_module = {} + self.__all__ = [] + + _import_structure = {} + + for backends, module in import_structure.items(): + missing_backends = [] + for backend in backends: + if backend not in BACKENDS_MAPPING: + raise ValueError( + f"Error: the following backend: '{backend}' was specified around object {module} but isn't specified in the backends mapping." + ) + callable, error = BACKENDS_MAPPING[backend] + if not callable(): + missing_backends.append(backend) + self._modules = self._modules.union(set(module.keys())) + + for key, values in module.items(): + if len(missing_backends): + self._object_missing_backend[key] = missing_backends + + for value in values: + self._class_to_module[value] = key + if len(missing_backends): + self._object_missing_backend[value] = missing_backends + _import_structure.setdefault(key, []).extend(values) + + # Needed for autocompletion in an IDE + self.__all__.extend(list(module.keys()) + list(chain(*module.values()))) + + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = _import_structure + + # This can be removed once every exportable object has a `export()` export. + else: + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) + self.__file__ = module_file + self.__spec__ = module_spec + self.__path__ = [os.path.dirname(module_file)] + self._objects = {} if extra_objects is None else extra_objects + self._name = name + self._import_structure = import_structure + + # Needed for autocompletion in an IDE + def __dir__(self): + result = super().__dir__() + # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether + # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. + for attr in self.__all__: + if attr not in result: + result.append(attr) + return result + + def __getattr__(self, name: str) -> Any: + if name in self._objects: + return self._objects[name] + if name in self._object_missing_backend.keys(): + missing_backends = self._object_missing_backend[name] + + class Placeholder(metaclass=DummyObject): + _backends = missing_backends + + def __init__(self, *args, **kwargs): + requires_backends(self, missing_backends) + + Placeholder.__name__ = name + Placeholder.__module__ = self.__spec__ + + value = Placeholder + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + elif name in self._modules: + value = self._get_module(name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str): + try: + return importlib.import_module("." + module_name, self.__name__) + except Exception as e: + raise RuntimeError( + f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" + f" traceback):\n{e}" + ) from e + + def __reduce__(self): + return (self.__class__, (self._name, self.__file__, self._import_structure)) + + +class OptionalDependencyNotAvailable(BaseException): + """Internally used error class for signalling an optional dependency was not found.""" + + +def direct_transformers_import(path: str, file="__init__.py") -> ModuleType: + """Imports transformers directly + + Args: + path (`str`): The path to the source file + file (`str`, *optional*): The file to join with the path. Defaults to "__init__.py". + + Returns: + `ModuleType`: The resulting imported module + """ + name = "transformers" + location = os.path.join(path, file) + spec = importlib.util.spec_from_file_location(name, location, submodule_search_locations=[path]) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module = sys.modules[name] + return module + + +def export(*, backends=()): + """ + This decorator enables two things: + - Attaching a `__backends` tuple to an object to see what are the necessary backends for it + to execute correctly without instantiating it + - The '@export' string is used to dynamically import objects + """ + for backend in backends: + if backend not in BACKENDS_MAPPING: + raise ValueError(f"Backend should be defined in the BACKENDS_MAPPING. Offending backend: {backend}") + + if not isinstance(backends, tuple): + raise ValueError("Backends should be a tuple.") + + def inner_fn(fun): + fun.__backends = backends + return fun + + return inner_fn + + +BASE_FILE_REQUIREMENTS = { + lambda e: "modeling_tf_" in e: ("tf",), + lambda e: "modeling_flax_" in e: ("flax",), + lambda e: "modeling_" in e: ("torch",), + lambda e: e.startswith("tokenization_") and e.endswith("_fast"): ("tokenizers",), +} + + +def fetch__all__(file_content): + """ + Returns the content of the __all__ variable in the file content. + Returns None if not defined, otherwise returns a list of strings. + """ + + if "__all__" not in file_content: + return [] + + lines = file_content.splitlines() + for index, line in enumerate(lines): + if line.startswith("__all__"): + start_index = index + + lines = lines[start_index:] + + if not lines[0].startswith("__all__"): + raise ValueError( + "fetch__all__ accepts a list of lines, with the first line being the __all__ variable declaration" + ) + + # __all__ is defined on a single line + if lines[0].endswith("]"): + return [obj.strip("\"' ") for obj in lines[0].split("=")[1].strip(" []").split(",")] + + # __all__ is defined on multiple lines + else: + _all = [] + for __all__line_index in range(1, len(lines)): + if lines[__all__line_index].strip() == "]": + return _all + else: + _all.append(lines[__all__line_index].strip("\"', ")) + + return _all + + +@lru_cache() +def create_import_structure_from_path(module_path): + """ + This method takes the path to a file/a folder and returns the import structure. + If a file is given, it will return the import structure of the parent folder. + + Import structures are designed to be digestible by `_LazyModule` objects. They are + created from the __all__ definitions in each files as well as the `@export` decorators + above methods and objects. + + The import structure allows explicit display of the required backends for a given object. + These backends are specified in two ways: + + 1. Through their `@export`, if they are exported with that decorator. This `@export` decorator + accepts a `backend` tuple kwarg mentioning which backends are required to run this object. + + 2. If an object is defined in a file with "default" backends, it will have, at a minimum, this + backend specified. The default backends are defined according to the filename: + + - If a file is named like `modeling_*.py`, it will have a `torch` backend + - If a file is named like `modeling_tf_*.py`, it will have a `tf` backend + - If a file is named like `modeling_flax_*.py`, it will have a `flax` backend + - If a file is named like `tokenization_*_fast.py`, it will have a `tokenizers` backend + + Backends serve the purpose of displaying a clear error message to the user in case the backends are not installed. + Should an object be imported without its required backends being in the environment, any attempt to use the + object will raise an error mentioning which backend(s) should be added to the environment in order to use + that object. + + Here's an example of an input import structure at the src.transformers.models level: + + { + 'albert': { + frozenset(): { + 'configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'} + }, + frozenset({'tokenizers'}): { + 'tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + }, + 'align': { + frozenset(): { + 'configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'processing_align': {'AlignProcessor'} + }, + }, + 'altclip': { + frozenset(): { + 'configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'processing_altclip': {'AltCLIPProcessor'}, + } + } + } + """ + import_structure = {} + if os.path.isdir(module_path): + directory = module_path + adjacent_modules = [] + + for f in os.listdir(module_path): + if f != "__pycache__" and os.path.isdir(os.path.join(module_path, f)): + import_structure[f] = create_import_structure_from_path(os.path.join(module_path, f)) + + elif not os.path.isdir(os.path.join(directory, f)): + adjacent_modules.append(f) + + else: + directory = os.path.dirname(module_path) + adjacent_modules = [f for f in os.listdir(directory) if not os.path.isdir(os.path.join(directory, f))] + + # We're only taking a look at files different from __init__.py + # We could theoretically export things directly from the __init__.py + # files, but this is not supported at this time. + if "__init__.py" in adjacent_modules: + adjacent_modules.remove("__init__.py") + + # Modular files should not be imported + def find_substring(substring, list_): + return any(substring in x for x in list_) + + if find_substring("modular_", adjacent_modules) and find_substring("modeling_", adjacent_modules): + adjacent_modules = [module for module in adjacent_modules if "modular_" not in module] + + module_requirements = {} + for module_name in adjacent_modules: + # Only modules ending in `.py` are accepted here. + if not module_name.endswith(".py"): + continue + + with open(os.path.join(directory, module_name), encoding="utf-8") as f: + file_content = f.read() + + # Remove the .py suffix + module_name = module_name[:-3] + + previous_line = "" + previous_index = 0 + + # Some files have some requirements by default. + # For example, any file named `modeling_tf_xxx.py` + # should have TensorFlow as a required backend. + base_requirements = () + for string_check, requirements in BASE_FILE_REQUIREMENTS.items(): + if string_check(module_name): + base_requirements = requirements + break + + # Objects that have a `@export` assigned to them will get exported + # with the backends specified in the decorator as well as the file backends. + exported_objects = set() + if "@export" in file_content: + lines = file_content.split("\n") + for index, line in enumerate(lines): + # This allows exporting items with other decorators. We'll take a look + # at the line that follows at the same indentation level. + if line.startswith((" ", "\t", "@", ")")) and not line.startswith("@export"): + continue + + # Skipping line enables putting whatever we want between the + # export() call and the actual class/method definition. + # This is what enables having # Copied from statements, docs, etc. + skip_line = False + + if "@export" in previous_line: + skip_line = False + + # Backends are defined on the same line as export + if "backends" in previous_line: + backends_string = previous_line.split("backends=")[1].split("(")[1].split(")")[0] + backends = tuple(sorted([b.strip("'\",") for b in backends_string.split(", ") if b])) + + # Backends are defined in the lines following export, for example such as: + # @export( + # backends=( + # "sentencepiece", + # "torch", + # "tf", + # ) + # ) + # + # or + # + # @export( + # backends=( + # "sentencepiece", "tf" + # ) + # ) + elif "backends" in lines[previous_index + 1]: + backends = [] + for backend_line in lines[previous_index:index]: + if "backends" in backend_line: + backend_line = backend_line.split("=")[1] + if '"' in backend_line or "'" in backend_line: + if ", " in backend_line: + backends.extend(backend.strip("()\"', ") for backend in backend_line.split(", ")) + else: + backends.append(backend_line.strip("()\"', ")) + + # If the line is only a ')', then we reached the end of the backends and we break. + if backend_line.strip() == ")": + break + backends = tuple(backends) + + # No backends are registered for export + else: + backends = () + + backends = frozenset(backends + base_requirements) + if backends not in module_requirements: + module_requirements[backends] = {} + if module_name not in module_requirements[backends]: + module_requirements[backends][module_name] = set() + + if not line.startswith("class") and not line.startswith("def"): + skip_line = True + else: + start_index = 6 if line.startswith("class") else 4 + object_name = line[start_index:].split("(")[0].strip(":") + module_requirements[backends][module_name].add(object_name) + exported_objects.add(object_name) + + if not skip_line: + previous_line = line + previous_index = index + + # All objects that are in __all__ should be exported by default. + # These objects are exported with the file backends. + if "__all__" in file_content: + for _all_object in fetch__all__(file_content): + if _all_object not in exported_objects: + backends = frozenset(base_requirements) + if backends not in module_requirements: + module_requirements[backends] = {} + if module_name not in module_requirements[backends]: + module_requirements[backends][module_name] = set() + + module_requirements[backends][module_name].add(_all_object) + + import_structure = {**module_requirements, **import_structure} + return import_structure + + +def spread_import_structure(nested_import_structure): + """ + This method takes as input an unordered import structure and brings the required backends at the top-level, + aggregating modules and objects under their required backends. + + Here's an example of an input import structure at the src.transformers.models level: + + { + 'albert': { + frozenset(): { + 'configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'} + }, + frozenset({'tokenizers'}): { + 'tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + }, + 'align': { + frozenset(): { + 'configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'processing_align': {'AlignProcessor'} + }, + }, + 'altclip': { + frozenset(): { + 'configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'processing_altclip': {'AltCLIPProcessor'}, + } + } + } + + Here's an example of an output import structure at the src.transformers.models level: + + { + frozenset({'tokenizers'}): { + 'albert.tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + frozenset(): { + 'albert.configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'}, + 'align.processing_align': {'AlignProcessor'}, + 'align.configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'altclip.configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'altclip.processing_altclip': {'AltCLIPProcessor'} + } + } + + """ + + def propagate_frozenset(unordered_import_structure): + tuple_first_import_structure = {} + for _key, _value in unordered_import_structure.items(): + if not isinstance(_value, dict): + tuple_first_import_structure[_key] = _value + + elif any(isinstance(v, frozenset) for v in _value.keys()): + # Here we want to switch around key and v + for k, v in _value.items(): + if isinstance(k, frozenset): + if k not in tuple_first_import_structure: + tuple_first_import_structure[k] = {} + tuple_first_import_structure[k][_key] = v + + else: + tuple_first_import_structure[_key] = propagate_frozenset(_value) + + return tuple_first_import_structure + + def flatten_dict(_dict, previous_key=None): + items = [] + for _key, _value in _dict.items(): + _key = f"{previous_key}.{_key}" if previous_key is not None else _key + if isinstance(_value, dict): + items.extend(flatten_dict(_value, _key).items()) + else: + items.append((_key, _value)) + return dict(items) + + # The tuples contain the necessary backends. We want these first, so we propagate them up the + # import structure. + ordered_import_structure = nested_import_structure + + # 6 is a number that gives us sufficient depth to go through all files and foreseeable folder depths + # while not taking too long to parse. + for i in range(6): + ordered_import_structure = propagate_frozenset(ordered_import_structure) + + # We then flatten the dict so that it references a module path. + flattened_import_structure = {} + for key, value in ordered_import_structure.copy().items(): + if isinstance(key, str): + del ordered_import_structure[key] + else: + flattened_import_structure[key] = flatten_dict(value) + + return flattened_import_structure + + +def define_import_structure(module_path: str) -> IMPORT_STRUCTURE_T: + """ + This method takes a module_path as input and creates an import structure digestible by a _LazyModule. + + Here's an example of an output import structure at the src.transformers.models level: + + { + frozenset({'tokenizers'}): { + 'albert.tokenization_albert_fast': {'AlbertTokenizerFast'} + }, + frozenset(): { + 'albert.configuration_albert': {'AlbertConfig', 'AlbertOnnxConfig'}, + 'align.processing_align': {'AlignProcessor'}, + 'align.configuration_align': {'AlignConfig', 'AlignTextConfig', 'AlignVisionConfig'}, + 'altclip.configuration_altclip': {'AltCLIPConfig', 'AltCLIPTextConfig', 'AltCLIPVisionConfig'}, + 'altclip.processing_altclip': {'AltCLIPProcessor'} + } + } + + The import structure is a dict defined with frozensets as keys, and dicts of strings to sets of objects. + """ + import_structure = create_import_structure_from_path(module_path) + return spread_import_structure(import_structure) \ No newline at end of file diff --git a/gptqmodel/integration/transformers/utils/quantization_config.py b/gptqmodel/integration/transformers/utils/quantization_config.py new file mode 100644 index 000000000..51ec9a294 --- /dev/null +++ b/gptqmodel/integration/transformers/utils/quantization_config.py @@ -0,0 +1,1378 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import importlib.metadata +import json +import os +from dataclasses import dataclass +from enum import Enum +from inspect import Parameter, signature +from typing import Any, Dict, List, Optional, Union + +from packaging import version + +from ..utils import ( + is_auto_awq_available, + is_gptqmodel_available, + is_hqq_available, + is_torch_available, + is_torchao_available, + logging, +) +from .import_utils import is_auto_gptq_available + + +if is_torch_available(): + import torch + +logger = logging.get_logger(__name__) + + +class QuantizationMethod(str, Enum): + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + QUANTO = "quanto" + EETQ = "eetq" + HQQ = "hqq" + COMPRESSED_TENSORS = "compressed-tensors" + FBGEMM_FP8 = "fbgemm_fp8" + TORCHAO = "torchao" + BITNET = "bitnet" + + +class AWQLinearVersion(str, Enum): + GEMM = "gemm" + GEMV = "gemv" + EXLLAMA = "exllama" + IPEX = "ipex" + + @staticmethod + def from_str(version: str): + version = version.lower() + if version == "gemm": + return AWQLinearVersion.GEMM + elif version == "gemv": + return AWQLinearVersion.GEMV + elif version == "exllama": + return AWQLinearVersion.EXLLAMA + elif version == "ipex": + return AWQLinearVersion.IPEX + else: + raise ValueError(f"Unknown AWQLinearVersion {version}") + + +class AwqBackendPackingMethod(str, Enum): + AUTOAWQ = "autoawq" + LLMAWQ = "llm-awq" + + +@dataclass +class QuantizationConfigMixin: + """ + Mixin class for quantization config + """ + + quant_method: QuantizationMethod + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`QuantizationConfigMixin`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + if return_unused_kwargs: + return config, kwargs + else: + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `QuantizationConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return copy.deepcopy(self.__dict__) + + def __iter__(self): + """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin""" + for attr, value in copy.deepcopy(self.__dict__).items(): + yield attr, value + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes, + returning all the unused kwargs. + + Args: + kwargs (`Dict[str, Any]`): + Dictionary of attributes to tentatively update this class. + + Returns: + `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # Remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + +@dataclass +class HqqConfig(QuantizationConfigMixin): + """ + This is wrapper around hqq's BaseQuantizeConfig. + + Args: + nbits (`int`, *optional*, defaults to 4): + Number of bits. Supported values are (8, 4, 3, 2, 1). + group_size (`int`, *optional*, defaults to 64): + Group-size value. Supported values are any value that is divisble by weight.shape[axis]). + view_as_float (`bool`, *optional*, defaults to `False`): + View the quantized weight as float (used in distributed training) if set to `True`. + axis (`Optional[int]`, *optional*): + Axis along which grouping is performed. Supported values are 0 or 1. + dynamic_config (dict, *optional*): + Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config. + If set, each layer specified by its id will use its dedicated quantization configuration. + skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`): + List of `nn.Linear` layers to skip. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + nbits: int = 4, + group_size: int = 64, + view_as_float: bool = False, + axis: Optional[int] = None, + dynamic_config: Optional[dict] = None, + skip_modules: List[str] = ["lm_head"], + **kwargs, + ): + if is_hqq_available(): + from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig + + for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]: + if deprecated_key in kwargs: + logger.info( + deprecated_key + " is deprecated. This parameter will be ignored in quantization settings." + ) + + if axis is None: + axis = 1 + logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.") + + if axis not in [0, 1]: + raise ValueError("Invalid axis value. Only 0 and 1 are allowed.") + + if dynamic_config is not None: + self.quant_config = {} + for key in dynamic_config: + self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key]) + else: + self.quant_config = HQQBaseQuantizeConfig( + **{ + "nbits": nbits, + "group_size": group_size, + "view_as_float": view_as_float, + "axis": axis, + } + ) + + self.quant_method = QuantizationMethod.HQQ + self.skip_modules = skip_modules + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + pass + + @classmethod + def from_dict(cls, config: Dict[str, Any]): + """ + Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py + """ + instance = cls() + instance.quant_config = config["quant_config"] + instance.skip_modules = config["skip_modules"] + return instance + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + return { + "quant_config": self.quant_config, + "quant_method": self.quant_method, + "skip_modules": self.skip_modules, + } + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = HqqConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +@dataclass +class BitsAndBytesConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `bitsandbytes`. + + This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. + + Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, + then more arguments will be added to this class. + + Args: + load_in_8bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 8-bit quantization with LLM.int8(). + load_in_4bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from + `bitsandbytes`. + llm_int8_threshold (`float`, *optional*, defaults to 6.0): + This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix + Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value + that is above this threshold will be considered an outlier and the operation on those values will be done + in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but + there are some exceptional systematic outliers that are very differently distributed for large models. + These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of + magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, + but a lower threshold might be needed for more unstable models (small models, fine-tuning). + llm_int8_skip_modules (`List[str]`, *optional*): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as + Jukebox that has several heads in different places and not necessarily at the last position. For example + for `CausalLM` models, the last `lm_head` is kept in its original `dtype`. + llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): + This flag is used for advanced use cases and users that are aware of this feature. If you want to split + your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use + this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 + operations will not be run on CPU. + llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`): + This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not + have to be converted back and forth for the backward pass. + bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`): + This sets the computational type which might be different than the input type. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + bnb_4bit_quant_type (`str`, *optional*, defaults to `"fp4"`): + This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types + which are specified by `fp4` or `nf4`. + bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): + This flag is used for nested quantization where the quantization constants from the first quantization are + quantized again. + bnb_4bit_quant_storage (`torch.dtype` or str, *optional*, defaults to `torch.uint8`): + This sets the storage type to pack the quanitzed 4-bit prarams. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + load_in_8bit=False, + load_in_4bit=False, + llm_int8_threshold=6.0, + llm_int8_skip_modules=None, + llm_int8_enable_fp32_cpu_offload=False, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=None, + bnb_4bit_quant_type="fp4", + bnb_4bit_use_double_quant=False, + bnb_4bit_quant_storage=None, + **kwargs, + ): + self.quant_method = QuantizationMethod.BITS_AND_BYTES + + if load_in_4bit and load_in_8bit: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + + self._load_in_8bit = load_in_8bit + self._load_in_4bit = load_in_4bit + self.llm_int8_threshold = llm_int8_threshold + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + + if bnb_4bit_compute_dtype is None: + self.bnb_4bit_compute_dtype = torch.float32 + elif isinstance(bnb_4bit_compute_dtype, str): + self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype) + elif isinstance(bnb_4bit_compute_dtype, torch.dtype): + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + else: + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if bnb_4bit_quant_storage is None: + self.bnb_4bit_quant_storage = torch.uint8 + elif isinstance(bnb_4bit_quant_storage, str): + if bnb_4bit_quant_storage not in ["float16", "float32", "int8", "uint8", "float64", "bfloat16"]: + raise ValueError( + "`bnb_4bit_quant_storage` must be a valid string (one of 'float16', 'float32', 'int8', 'uint8', 'float64', 'bfloat16') " + ) + self.bnb_4bit_quant_storage = getattr(torch, bnb_4bit_quant_storage) + elif isinstance(bnb_4bit_quant_storage, torch.dtype): + self.bnb_4bit_quant_storage = bnb_4bit_quant_storage + else: + raise ValueError("bnb_4bit_quant_storage must be a string or a torch.dtype") + + if kwargs: + logger.warning(f"Unused kwargs: {list(kwargs.keys())}. These kwargs are not used in {self.__class__}.") + + self.post_init() + + @property + def load_in_4bit(self): + return self._load_in_4bit + + @load_in_4bit.setter + def load_in_4bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_4bit must be a boolean") + + if self.load_in_8bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_4bit = value + + @property + def load_in_8bit(self): + return self._load_in_8bit + + @load_in_8bit.setter + def load_in_8bit(self, value: bool): + if not isinstance(value, bool): + raise TypeError("load_in_8bit must be a boolean") + + if self.load_in_4bit and value: + raise ValueError("load_in_4bit and load_in_8bit are both True, but only one can be used at the same time") + self._load_in_8bit = value + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.load_in_4bit, bool): + raise TypeError("load_in_4bit must be a boolean") + + if not isinstance(self.load_in_8bit, bool): + raise TypeError("load_in_8bit must be a boolean") + + if not isinstance(self.llm_int8_threshold, float): + raise TypeError("llm_int8_threshold must be a float") + + if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + raise TypeError("llm_int8_skip_modules must be a list of strings") + if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): + raise TypeError("llm_int8_enable_fp32_cpu_offload must be a boolean") + + if not isinstance(self.llm_int8_has_fp16_weight, bool): + raise TypeError("llm_int8_has_fp16_weight must be a boolean") + + if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise TypeError("bnb_4bit_compute_dtype must be torch.dtype") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise TypeError("bnb_4bit_quant_type must be a string") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise TypeError("bnb_4bit_use_double_quant must be a boolean") + + if self.load_in_4bit and not version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse( + "0.39.0" + ): + raise ValueError( + "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version" + ) + + def is_quantizable(self): + r""" + Returns `True` if the model is quantizable, `False` otherwise. + """ + return self.load_in_8bit or self.load_in_4bit + + def quantization_method(self): + r""" + This method returns the quantization method used for the model. If the model is not quantizable, it returns + `None`. + """ + if self.load_in_8bit: + return "llm_int8" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4": + return "fp4" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4": + return "nf4" + else: + return None + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] + output["bnb_4bit_quant_storage"] = str(output["bnb_4bit_quant_storage"]).split(".")[1] + output["load_in_4bit"] = self.load_in_4bit + output["load_in_8bit"] = self.load_in_8bit + + return output + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = BitsAndBytesConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +class ExllamaVersion(int, Enum): + ONE = 1 + TWO = 2 + + +@dataclass +class GPTQConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `optimum` api for gptq quantization relying on auto_gptq backend. + + Args: + bits (`int`): + The number of bits to quantize to, supported numbers are (2, 3, 4, 8). + tokenizer (`str` or `PreTrainedTokenizerBase`, *optional*): + The tokenizer used to process the dataset. You can pass either: + - A custom tokenizer object. + - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co. + - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved + using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`. + dataset (`Union[List[str]]`, *optional*): + The dataset used for quantization. You can provide your own dataset in a list of string or just use the + original datasets used in GPTQ paper ['wikitext2','c4','c4-new'] + group_size (`int`, *optional*, defaults to 128): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. + damp_percent (`float`, *optional*, defaults to 0.1): + The percent of the average Hessian diagonal to use for dampening. Recommended value is 0.1. + desc_act (`bool`, *optional*, defaults to `False`): + Whether to quantize columns in order of decreasing activation size. Setting it to False can significantly + speed up inference but the perplexity may become slightly worse. Also known as act-order. + sym (`bool`, *optional*, defaults to `True`): + Whether to use symetric quantization. + true_sequential (`bool`, *optional*, defaults to `True`): + Whether to perform sequential quantization even within a single Transformer block. Instead of quantizing + the entire block at once, we perform layer-wise quantization. As a result, each layer undergoes + quantization using inputs that have passed through the previously quantized layers. + checkpoint_format (`str`, *optional*, defaults to `"gptq"`): + GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only. + meta (`Dict[str, any]`, *optional*): + Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta. + i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"] + backend (`str`, *optional*): + Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only + valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py + use_cuda_fp16 (`bool`, *optional*, defaults to `False`): + Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16. Auto-gptq only. + model_seqlen (`int`, *optional*): + The maximum sequence length that the model can take. + block_name_to_quantize (`str`, *optional*): + The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers) + module_name_preceding_first_block (`List[str]`, *optional*): + The layers that are preceding the first Transformer block. + batch_size (`int`, *optional*, defaults to 1): + The batch size used when processing the dataset + pad_token_id (`int`, *optional*): + The pad token id. Needed to prepare the dataset when `batch_size` > 1. + use_exllama (`bool`, *optional*): + Whether to use exllama backend. Defaults to `True` if unset. Only works with `bits` = 4. + max_input_length (`int`, *optional*): + The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input + length. It is specific to the exllama backend with act-order. + exllama_config (`Dict[str, Any]`, *optional*): + The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults + to `{"version": 1}` if unset. + cache_block_outputs (`bool`, *optional*, defaults to `True`): + Whether to cache block outputs to reuse as inputs for the succeeding block. + modules_in_block_to_quantize (`List[List[str]]`, *optional*): + List of list of module names to quantize in the specified block. This argument is useful to exclude certain linear modules from being quantized. + The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially. If not set, we will quantize all linear layers. + Example: `modules_in_block_to_quantize =[["self_attn.k_proj", "self_attn.v_proj", "self_attn.q_proj"], ["self_attn.o_proj"]]`. + In this example, we will first quantize the q,k,v layers simultaneously since they are independent. + Then, we will quantize `self_attn.o_proj` layer with the q,k,v layers quantized. This way, we will get + better results since it reflects the real input `self_attn.o_proj` will get when the model is quantized. + """ + + def __init__( + self, + bits: int, + tokenizer: Any = None, + dataset: Optional[Union[List[str], str]] = None, + group_size: int = 128, + damp_percent: float = 0.1, + desc_act: bool = False, + sym: bool = True, + true_sequential: bool = True, + checkpoint_format: str = "gptq", + meta: Optional[Dict[str, any]] = None, + backend: Optional[str] = None, + use_cuda_fp16: bool = False, + model_seqlen: Optional[int] = None, + block_name_to_quantize: Optional[str] = None, + module_name_preceding_first_block: Optional[List[str]] = None, + batch_size: int = 1, + pad_token_id: Optional[int] = None, + use_exllama: Optional[bool] = None, + max_input_length: Optional[int] = None, + exllama_config: Optional[Dict[str, Any]] = None, + cache_block_outputs: bool = True, + modules_in_block_to_quantize: Optional[List[List[str]]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.GPTQ + self.bits = bits + self.tokenizer = tokenizer + self.dataset = dataset + self.group_size = group_size + self.damp_percent = damp_percent + self.desc_act = desc_act + self.sym = sym + self.true_sequential = true_sequential + self.checkpoint_format = checkpoint_format.lower() + self.meta = meta + self.backend = backend.lower() if isinstance(backend, str) else backend + self.use_cuda_fp16 = use_cuda_fp16 + self.model_seqlen = model_seqlen + self.block_name_to_quantize = block_name_to_quantize + self.module_name_preceding_first_block = module_name_preceding_first_block + self.batch_size = batch_size + self.pad_token_id = pad_token_id + self.use_exllama = use_exllama + self.max_input_length = max_input_length + self.exllama_config = exllama_config + self.disable_exllama = kwargs.pop("disable_exllama", None) + self.cache_block_outputs = cache_block_outputs + self.modules_in_block_to_quantize = modules_in_block_to_quantize + self.post_init() + + def get_loading_attributes(self): + attibutes_dict = copy.deepcopy(self.__dict__) + loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"] + loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} + return loading_attibutes_dict + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + if self.bits not in [2, 3, 4, 8]: + raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}") + if self.group_size != -1 and self.group_size <= 0: + raise ValueError("group_size must be greater than 0 or equal to -1") + if not (0 < self.damp_percent < 1): + raise ValueError("damp_percent must between 0 and 1.") + if self.dataset is not None: + if isinstance(self.dataset, str): + if self.dataset in ["ptb", "ptb-new"]: + raise ValueError( + f"""{self.dataset} dataset was deprecated. You can only choose between + ['wikitext2','c4','c4-new']""" + ) + if self.dataset not in ["wikitext2", "c4", "c4-new"]: + raise ValueError( + f"""You have entered a string value for dataset. You can only choose between + ['wikitext2','c4','c4-new'], but we found {self.dataset}""" + ) + elif not isinstance(self.dataset, list): + raise ValueError( + f"""dataset needs to be either a list of string or a value in + ['wikitext2','c4','c4-new'], but we found {self.dataset}""" + ) + + # make sure backend is back/forward compatible with both gptqmodel (full) and auto-gptq (partial) + if is_gptqmodel_available(): + # convert auto-gptq control into gptqmodel backend + if self.backend is None: + self.backend = "auto_trainable" if not self.use_exllama else "auto" + else: + # convert gptqmodel backend `auto_trainable` into auto-gptq control + if self.backend == "auto_trainable": + self.use_exllama = False + + # auto-gptq specific kernel control logic + if self.disable_exllama is None and self.use_exllama is None: + # New default behaviour + self.use_exllama = True + elif self.disable_exllama is not None and self.use_exllama is None: + # Follow pattern of old config + logger.warning( + "Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`." + "The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file." + ) + self.use_exllama = not self.disable_exllama + self.disable_exllama = None + elif self.disable_exllama is not None and self.use_exllama is not None: + # Only happens if user explicitly passes in both arguments + raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`") + + if self.exllama_config is None: + self.exllama_config = {"version": ExllamaVersion.ONE} + else: + if "version" not in self.exllama_config: + raise ValueError("`exllama_config` needs to have a `version` key.") + elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: + exllama_version = self.exllama_config["version"] + raise ValueError( + f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}" + ) + + if self.bits == 4 and self.use_exllama: + if self.exllama_config["version"] == ExllamaVersion.ONE: + logger.info( + "You have activated exllama backend. Note that you can get better inference " + "speed using exllamav2 kernel by setting `exllama_config`." + ) + elif self.exllama_config["version"] == ExllamaVersion.TWO: + if is_auto_gptq_available(): + optimum_version = version.parse(importlib.metadata.version("optimum")) + autogptq_version = version.parse(importlib.metadata.version("auto_gptq")) + if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"): + raise ValueError( + f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}" + ) + if self.modules_in_block_to_quantize is not None: + optimum_version = version.parse(importlib.metadata.version("optimum")) + if optimum_version < version.parse("1.15.0"): + raise ValueError( + "You current version of `optimum` does not support `modules_in_block_to_quantize` quantization argument, please upgrade `optimum` package to a version superior than 1.15.0 ." + ) + + def to_dict(self): + config_dict = super().to_dict() + config_dict.pop("disable_exllama", None) + return config_dict + + def to_dict_optimum(self): + """ + Get compatible dict for optimum gptq config + """ + quant_dict = self.to_dict() + # make it compatible with optimum config + quant_dict["disable_exllama"] = not self.use_exllama + return quant_dict + + @classmethod + def from_dict_optimum(cls, config_dict): + """ + Get compatible class with optimum gptq config dict + """ + + if "disable_exllama" in config_dict: + config_dict["use_exllama"] = not config_dict["disable_exllama"] + # switch to None to not trigger the warning + config_dict["disable_exllama"] = None + + config = cls(**config_dict) + return config + + +@dataclass +class AwqConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `auto-awq` library awq quantization relying on auto_awq backend. + + Args: + bits (`int`, *optional*, defaults to 4): + The number of bits to quantize to. + group_size (`int`, *optional*, defaults to 128): + The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization. + zero_point (`bool`, *optional*, defaults to `True`): + Whether to use zero point quantization. + version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`): + The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise, + GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels. + backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`): + The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users + that quantize their own models using `llm-awq` library. + do_fuse (`bool`, *optional*, defaults to `False`): + Whether to fuse attention and mlp layers together for faster inference + fuse_max_seq_len (`int`, *optional*): + The Maximum sequence length to generate when using fusing. + modules_to_fuse (`dict`, *optional*, default to `None`): + Overwrite the natively supported fusing scheme with the one specified by the users. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models. + exllama_config (`Dict[str, Any]`, *optional*): + You can specify the version of the exllama kernel through the `version` key, the maximum sequence + length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key. + Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset. + """ + + def __init__( + self, + bits: int = 4, + group_size: int = 128, + zero_point: bool = True, + version: AWQLinearVersion = AWQLinearVersion.GEMM, + backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ, + do_fuse: Optional[bool] = None, + fuse_max_seq_len: Optional[int] = None, + modules_to_fuse: Optional[dict] = None, + modules_to_not_convert: Optional[List] = None, + exllama_config: Optional[Dict[str, int]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.AWQ + + self.bits = bits + self.group_size = group_size + self.zero_point = zero_point + self.version = version + self.backend = backend + self.fuse_max_seq_len = fuse_max_seq_len + self.modules_to_not_convert = modules_to_not_convert + self.exllama_config = exllama_config + + self.modules_to_fuse = modules_to_fuse + if do_fuse is None: + self.do_fuse = modules_to_fuse is not None and len(modules_to_fuse) > 0 + else: + self.do_fuse = do_fuse + self.fuse_max_seq_len = fuse_max_seq_len + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]: + raise ValueError( + f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}" + ) + + self.version = AWQLinearVersion.from_str(self.version) + if self.version not in [ + AWQLinearVersion.GEMM, + AWQLinearVersion.GEMV, + AWQLinearVersion.EXLLAMA, + AWQLinearVersion.IPEX, + ]: + raise ValueError( + f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA, AWQLinearVersion.IPEX] - not recognized version {self.version}" + ) + + if self.backend == AwqBackendPackingMethod.LLMAWQ: + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + if major < 8: + raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0") + + if self.do_fuse and self.fuse_max_seq_len is None: + raise ValueError( + "You cannot enable fused modules without specifying a `fuse_max_seq_len`, make sure to pass a valid `fuse_max_seq_len` for your usecase" + ) + + if self.do_fuse: + awq_version_supports_fusing = False + MIN_AWQ_VERSION = "0.1.7" + if is_auto_awq_available(): + awq_version_supports_fusing = version.parse(importlib.metadata.version("autoawq")) >= version.parse( + MIN_AWQ_VERSION + ) + + if not awq_version_supports_fusing: + raise ValueError( + f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." + ) + + if self.modules_to_not_convert is not None: + awq_version_supports_non_conversion = False + MIN_AWQ_VERSION = "0.1.8" + if is_auto_awq_available(): + awq_version_supports_non_conversion = version.parse( + importlib.metadata.version("autoawq") + ) >= version.parse(MIN_AWQ_VERSION) + + if not awq_version_supports_non_conversion: + raise ValueError( + f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." + ) + + if self.do_fuse and self.modules_to_fuse is not None: + required_keys = [ + "hidden_size", + "num_attention_heads", + "num_key_value_heads", + "mlp", + "attention", + "layernorm", + "use_alibi", + ] + if not all(key in self.modules_to_fuse for key in required_keys): + raise ValueError( + f"Required fields are missing in the fusing mapping, required fields are {required_keys}" + ) + + if self.version == AWQLinearVersion.EXLLAMA: + awq_version_supports_exllama = False + MIN_AWQ_VERSION = "0.2.0" + if is_auto_awq_available(): + awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse( + MIN_AWQ_VERSION + ) + + if not awq_version_supports_exllama: + raise ValueError( + f"You current version of `autoawq` does not support exllama backend, " + f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}." + ) + + if self.exllama_config is None: + self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8} + else: + if "version" not in self.exllama_config: + raise ValueError("`exllama_config` needs to have a `version` key.") + elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]: + exllama_version = self.exllama_config["version"] + raise ValueError( + f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}" + ) + + def get_loading_attributes(self): + attibutes_dict = copy.deepcopy(self.__dict__) + loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len", "exllama_config"] + loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} + return loading_attibutes_dict + + +@dataclass +class AqlmConfig(QuantizationConfigMixin): + """ + This is a wrapper class about `aqlm` parameters. + + Args: + in_group_size (`int`, *optional*, defaults to 8): + The group size along the input dimension. + out_group_size (`int`, *optional*, defaults to 1): + The group size along the output dimension. It's recommended to always use 1. + num_codebooks (`int`, *optional*, defaults to 1): + Number of codebooks for the Additive Quantization procedure. + nbits_per_codebook (`int`, *optional*, defaults to 16): + Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook. + linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*): + List of full paths of `nn.Linear` weight parameters that shall not be quantized. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + in_group_size: int = 8, + out_group_size: int = 1, + num_codebooks: int = 1, + nbits_per_codebook: int = 16, + linear_weights_not_to_quantize: Optional[List[str]] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.AQLM + self.in_group_size = in_group_size + self.out_group_size = out_group_size + self.num_codebooks = num_codebooks + self.nbits_per_codebook = nbits_per_codebook + self.linear_weights_not_to_quantize = linear_weights_not_to_quantize + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.in_group_size, int): + raise TypeError("in_group_size must be a float") + if not isinstance(self.out_group_size, int): + raise TypeError("out_group_size must be a float") + if not isinstance(self.num_codebooks, int): + raise TypeError("num_codebooks must be a float") + if not isinstance(self.nbits_per_codebook, int): + raise TypeError("nbits_per_codebook must be a float") + + if self.linear_weights_not_to_quantize is not None and not isinstance( + self.linear_weights_not_to_quantize, list + ): + raise ValueError("linear_weights_not_to_quantize must be a list of strings") + + if self.linear_weights_not_to_quantize is None: + self.linear_weights_not_to_quantize = [] + + +@dataclass +class QuantoConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `quanto`. + + Args: + weights (`str`, *optional*, defaults to `"int8"`): + The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2") + activations (`str`, *optional*): + The target dtype for the activations after quantization. Supported values are (None,"int8","float8") + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers). + """ + + def __init__( + self, + weights="int8", + activations=None, + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.QUANTO + self.weights = weights + self.activations = activations + self.modules_to_not_convert = modules_to_not_convert + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + accepted_weights = ["float8", "int8", "int4", "int2"] + accepted_activations = [None, "int8", "float8"] + if self.weights not in accepted_weights: + raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") + if self.activations not in accepted_activations: + raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}") + + +@dataclass +class EetqConfig(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `eetq`. + + Args: + weights (`str`, *optional*, defaults to `"int8"`): + The target dtype for the weights. Supported value is only "int8" + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + weights: str = "int8", + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.EETQ + self.weights = weights + self.modules_to_not_convert = modules_to_not_convert + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + accepted_weights = ["int8"] + if self.weights not in accepted_weights: + raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}") + + +class CompressedTensorsConfig(QuantizationConfigMixin): + """ + This is a wrapper class that handles compressed-tensors quantization config options. + It is a wrapper around `compressed_tensors.QuantizationConfig` + Args: + config_groups (`typing.Dict[str, typing.Union[ForwardRef('QuantizationScheme'), typing.List[str]]]`, *optional*): + dictionary mapping group name to a quantization scheme definition + format (`str`, *optional*, defaults to `"dense"`): + format the model is represented as + quantization_status (`QuantizationStatus`, *optional*, defaults to `"initialized"`): + status of model in the quantization lifecycle, ie 'initialized', 'calibration', 'frozen' + kv_cache_scheme (`typing.Union[QuantizationArgs, NoneType]`, *optional*): + specifies quantization of the kv cache. If None, kv cache is not quantized. + global_compression_ratio (`typing.Union[float, NoneType]`, *optional*): + 0-1 float percentage of model compression + ignore (`typing.Union[typing.List[str], NoneType]`, *optional*): + layer names or types to not quantize, supports regex prefixed by 're:' + sparsity_config (`typing.Dict[str, typing.Any]`, *optional*): + configuration for sparsity compression + quant_method (`str`, *optional*, defaults to `"compressed-tensors"`): + do not override, should be compressed-tensors + """ + + def __init__( + self, + config_groups: Dict[str, Union["QuantizationScheme", List[str]]] = None, # noqa: F821 + format: str = "dense", + quantization_status: "QuantizationStatus" = "initialized", # noqa: F821 + kv_cache_scheme: Optional["QuantizationArgs"] = None, # noqa: F821 + global_compression_ratio: Optional[float] = None, + ignore: Optional[List[str]] = None, + sparsity_config: Dict[str, Any] = None, + quant_method: str = "compressed-tensors", + **kwargs, + ): + from compressed_tensors import QuantizationConfig + from compressed_tensors.config import SparsityCompressionConfig + + self.quantization_config = None + self.sparsity_config = None + + # parse from dict to load nested QuantizationScheme objects + if config_groups or kv_cache_scheme: + self.quantization_config = QuantizationConfig.parse_obj( + { + "config_groups": config_groups, + "quant_method": quant_method, + "format": format, + "quantization_status": quantization_status, + "kv_cache_scheme": kv_cache_scheme, + "global_compression_ratio": global_compression_ratio, + "ignore": ignore, + **kwargs, + } + ) + + if sparsity_config: + self.sparsity_config = SparsityCompressionConfig.load_from_registry( + sparsity_config.get("format"), **sparsity_config + ) + + super().__init__(quant_method=QuantizationMethod.COMPRESSED_TENSORS) + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs): + """ + Instantiates a [`CompressedTensorsConfig`] from a Python dictionary of parameters. + Optionally unwraps any args from the nested quantization_config + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`,*optional*, defaults to `False`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`QuantizationConfigMixin`]: The configuration object instantiated from those parameters. + """ + + if "quantization_config" in config_dict: + config_dict = dict( + sparsity_config=config_dict.get("sparsity_config"), + **config_dict["quantization_config"], + ) + + return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def to_dict(self) -> Dict[str, Any]: + """ + Quantization config to be added to config.json + + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + quantization_config = {} + if self.quantization_config is not None: + quantization_config = self.quantization_config.dict() + else: + quantization_config["quant_method"] = QuantizationMethod.COMPRESSED_TENSORS + + if self.sparsity_config is not None: + quantization_config["sparsity_config"] = self.sparsity_config.dict() + else: + quantization_config["sparsity_config"] = {} + + return quantization_config + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = CompressedTensorsConfig().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if value != default_config_dict[key]: + serializable_config_dict[key] = value + + return serializable_config_dict + + +@dataclass +class FbgemmFp8Config(QuantizationConfigMixin): + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using fbgemm fp8 quantization. + + Args: + activation_scale_ub (`float`, *optional*, defaults to 1200.0): + The activation scale upper bound. This is used when quantizing the input activation. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + """ + + def __init__( + self, + activation_scale_ub: float = 1200.0, + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.FBGEMM_FP8 + self.activation_scale_ub = activation_scale_ub + self.modules_to_not_convert = modules_to_not_convert + + def get_loading_attributes(self): + attibutes_dict = copy.deepcopy(self.__dict__) + loading_attibutes = ["activation_scale_ub"] + loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} + return loading_attibutes_dict + + +@dataclass +class TorchAoConfig(QuantizationConfigMixin): + """This is a config class for torchao quantization/sparsity techniques. + + Args: + quant_type (`str`): + The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`. + modules_to_not_convert (`list`, *optional*, default to `None`): + The list of modules to not quantize, useful for quantizing models that explicitly require to have + some modules left in their original precision. + kwargs (`Dict[str, Any]`, *optional*): + The keyword arguments for the chosen type of quantization, for example, int4_weight_only quantization supports two keyword arguments + `group_size` and `inner_k_tiles` currently. More API examples and documentation of arguments can be found in + https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques + + Example: + + ```python + quantization_config = TorchAoConfig("int4_weight_only", group_size=32) + # int4_weight_only quant is only working with *torch.bfloat16* dtype right now + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) + ``` + """ + + def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): + self.quant_method = QuantizationMethod.TORCHAO + self.quant_type = quant_type + self.modules_to_not_convert = modules_to_not_convert + # when we load from serailized config, "quant_type_kwargs" will be the key + if "quant_type_kwargs" in kwargs: + self.quant_type_kwargs = kwargs["quant_type_kwargs"] + else: + self.quant_type_kwargs = kwargs + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if is_torchao_available(): + if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"): + raise ValueError("Requires torchao 0.4.0 version and above") + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() + if self.quant_type not in _STR_TO_METHOD.keys(): + raise ValueError( + f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer." + ) + + method = _STR_TO_METHOD[self.quant_type] + sig = signature(method) + all_kwargs = [ + param.name + for param in sig.parameters.values() + if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD] + ] + for k in self.quant_type_kwargs: + if k not in all_kwargs: + raise ValueError( + f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}" + ) + + def _get_torchao_quant_type_to_method(self): + if is_torchao_available(): + from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + ) + + return { + "int4_weight_only": int4_weight_only, + "int8_weight_only": int8_weight_only, + "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, + } + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) + + def get_apply_tensor_subclass(self): + _STR_TO_METHOD = self._get_torchao_quant_type_to_method() + return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs) + + def __repr__(self): + config_dict = self.to_dict() + return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + + +@dataclass +class BitNetConfig(QuantizationConfigMixin): + def __init__( + self, + modules_to_not_convert: Optional[List] = None, + **kwargs, + ): + self.quant_method = QuantizationMethod.BITNET + self.modules_to_not_convert = modules_to_not_convert + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct + """ + pass \ No newline at end of file From 3e7086311b2e1c5e31f44e44d9abf47a0dcf0e8f Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 15:07:53 +0800 Subject: [PATCH 07/30] fix transformers import --- .../transformers/quantizers/quantizer_gptq.py | 9 +++++---- gptqmodel/integration/transformers/utils/import_utils.py | 2 +- .../transformers/utils/quantization_config.py | 6 ++---- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py b/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py index 51a5f7df4..2b2f35d7f 100644 --- a/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py +++ b/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py @@ -16,14 +16,15 @@ from packaging import version -from .base import HfQuantizer +from transformers.quantizers.base import HfQuantizer +from gptqmodel.integration.transformers.utils.import_utils import is_gptqmodel_available if TYPE_CHECKING: - from ..modeling_utils import PreTrainedModel + from transformers.modeling_utils import PreTrainedModel -from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging -from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin +from transformers.utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging +from transformers.utils.quantization_config import GPTQConfig, QuantizationConfigMixin if is_torch_available(): diff --git a/gptqmodel/integration/transformers/utils/import_utils.py b/gptqmodel/integration/transformers/utils/import_utils.py index 506ebafa4..1eba515a7 100644 --- a/gptqmodel/integration/transformers/utils/import_utils.py +++ b/gptqmodel/integration/transformers/utils/import_utils.py @@ -32,7 +32,7 @@ from packaging import version -from . import logging +from transformers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/gptqmodel/integration/transformers/utils/quantization_config.py b/gptqmodel/integration/transformers/utils/quantization_config.py index 51ec9a294..b1e667978 100644 --- a/gptqmodel/integration/transformers/utils/quantization_config.py +++ b/gptqmodel/integration/transformers/utils/quantization_config.py @@ -25,16 +25,14 @@ from packaging import version -from ..utils import ( +from transformers.utils import ( is_auto_awq_available, - is_gptqmodel_available, is_hqq_available, is_torch_available, is_torchao_available, logging, ) -from .import_utils import is_auto_gptq_available - +from .import_utils import is_auto_gptq_available, is_gptqmodel_available if is_torch_available(): import torch From 091b594836fd36bd7ed756493e297009bb58308c Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 16:13:54 +0800 Subject: [PATCH 08/30] add patch --- gptqmodel/integration/peft/tuners/adalora/model.py | 2 +- gptqmodel/integration/peft/{tuners => }/utils/__init__.py | 0 gptqmodel/integration/peft/{tuners => }/utils/other.py | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename gptqmodel/integration/peft/{tuners => }/utils/__init__.py (100%) rename gptqmodel/integration/peft/{tuners => }/utils/other.py (100%) diff --git a/gptqmodel/integration/peft/tuners/adalora/model.py b/gptqmodel/integration/peft/tuners/adalora/model.py index 6cfe78158..5dc0d24b1 100644 --- a/gptqmodel/integration/peft/tuners/adalora/model.py +++ b/gptqmodel/integration/peft/tuners/adalora/model.py @@ -30,7 +30,7 @@ ) from peft.utils.integrations import gather_params_ctx -from ..utils import get_gptqmodel_quant_linear +from gptqmodel.integration.peft.utils import get_gptqmodel_quant_linear from ...import_utils import is_gptqmodel_available diff --git a/gptqmodel/integration/peft/tuners/utils/__init__.py b/gptqmodel/integration/peft/utils/__init__.py similarity index 100% rename from gptqmodel/integration/peft/tuners/utils/__init__.py rename to gptqmodel/integration/peft/utils/__init__.py diff --git a/gptqmodel/integration/peft/tuners/utils/other.py b/gptqmodel/integration/peft/utils/other.py similarity index 100% rename from gptqmodel/integration/peft/tuners/utils/other.py rename to gptqmodel/integration/peft/utils/other.py From f97211992d7fce09ae3ff00983d86b9402630cb7 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 16:20:56 +0800 Subject: [PATCH 09/30] add patch --- gptqmodel/integration/integration.py | 50 ++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 gptqmodel/integration/integration.py diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py new file mode 100644 index 000000000..39ab4eeeb --- /dev/null +++ b/gptqmodel/integration/integration.py @@ -0,0 +1,50 @@ +from optimum.gptq import quantizer +from .optimum.gptq import quantizer as patched_quantizer +from optimum.utils import import_utils +from .optimum.utils import import_utils as patched_import_utils +from optimum.utils import testing_utils +from .optimum.utils import testing_utils as patched_testing_utils + +from peft import import_utils as import_utils +from peft.tuners.adalora.model import AdaLoraModel +from .peft.tuners.adalora.model import AdaLoraModel as patched_AdaLoraModel +from peft.tuners.lora import gptq +from .peft.tuners.lora import gptq as patched_gptq +from peft.tuners.lora import model +from .peft.tuners.lora import model as patched_model +from .peft import import_utils as patched_import_utils + +from peft.utils import other +from .peft.utils import other as patched_other + +from transformers.quantizers.quantizer_gptq import GptqHfQuantizer +from .transformers.quantizers.quantizer_gptq import GptqHfQuantizer as patched_GptqHfQuantizer + + +def monkey_patch_peft(): + import_utils.is_gptqmodel_available = patched_import_utils.is_gptqmodel_available + + AdaLoraModel._create_and_replace = patched_AdaLoraModel._create_and_replace + + gptq.dispatch_gptq = patched_gptq.dispatch_gptq + + model.LoraModel = patched_model.LoraModel + + other.get_auto_gptq_quant_linear = patched_other.get_auto_gptq_quant_linear + other.get_gptqmodel_quant_linear = patched_other.get_gptqmodel_quant_linear + + +def monkey_patch_optimum(): + quantizer.is_gptqmodel_available = patched_quantizer.is_gptqmodel_available + quantizer.has_device_more_than_cpu = patched_quantizer.has_device_more_than_cpu + quantizer.GPTQQuantizer = patched_quantizer.GPTQQuantizer + + import_utils._gptqmodel_available = patched_import_utils._gptqmodel_available + import_utils.is_gptqmodel_available = patched_import_utils.is_gptqmodel_available + testing_utils.require_gptq = patched_testing_utils.require_gptq + + +def monkey_patch_transformers(): + GptqHfQuantizer.required_packages = patched_GptqHfQuantizer.required_packages + GptqHfQuantizer.validate_environment = patched_GptqHfQuantizer.validate_environment + GptqHfQuantizer.update_torch_dtype = patched_GptqHfQuantizer.update_torch_dtype From 395912c9596e58be5af4de70237404a25ffa33f8 Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 16:27:43 +0800 Subject: [PATCH 10/30] fix patch --- gptqmodel/integration/__init__.py | 1 - gptqmodel/integration/optimum/utils/testing_utils.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) delete mode 100644 gptqmodel/integration/__init__.py diff --git a/gptqmodel/integration/__init__.py b/gptqmodel/integration/__init__.py deleted file mode 100644 index 88638ca39..000000000 --- a/gptqmodel/integration/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .optimum import monkey_patch_gptqmodel_into_transformers diff --git a/gptqmodel/integration/optimum/utils/testing_utils.py b/gptqmodel/integration/optimum/utils/testing_utils.py index 3137d453b..97eba86bb 100644 --- a/gptqmodel/integration/optimum/utils/testing_utils.py +++ b/gptqmodel/integration/optimum/utils/testing_utils.py @@ -25,16 +25,15 @@ import torch -from . import ( +from optimum.utils import ( is_accelerate_available, is_auto_gptq_available, - is_datasets_available, is_diffusers_available, - is_gptqmodel_available, is_sentence_transformers_available, is_timm_available, ) +from gptqmodel.integration.optimum.utils.import_utils import is_datasets_available, is_gptqmodel_available # Used to test the hub USER = "__DUMMY_OPTIMUM_USER__" From 3514852fae21bae25d41c589b6165e1c4b1eff9c Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 16:36:29 +0800 Subject: [PATCH 11/30] fix patch imports --- gptqmodel/integration/integration.py | 14 ++++++------ .../integration/peft/tuners/lora/gptq.py | 6 +++-- .../integration/peft/tuners/lora/model.py | 22 +++++++++---------- 3 files changed, 22 insertions(+), 20 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 39ab4eeeb..99c0f35c2 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -1,18 +1,18 @@ from optimum.gptq import quantizer from .optimum.gptq import quantizer as patched_quantizer -from optimum.utils import import_utils -from .optimum.utils import import_utils as patched_import_utils +from optimum.utils import import_utils as optimum_import_utils +from .optimum.utils import import_utils as patched_optimum_import_utils from optimum.utils import testing_utils from .optimum.utils import testing_utils as patched_testing_utils -from peft import import_utils as import_utils +from peft import import_utils as peft_import_utils +from .peft import import_utils as patched_peft_import_utils from peft.tuners.adalora.model import AdaLoraModel from .peft.tuners.adalora.model import AdaLoraModel as patched_AdaLoraModel from peft.tuners.lora import gptq from .peft.tuners.lora import gptq as patched_gptq from peft.tuners.lora import model from .peft.tuners.lora import model as patched_model -from .peft import import_utils as patched_import_utils from peft.utils import other from .peft.utils import other as patched_other @@ -22,7 +22,7 @@ def monkey_patch_peft(): - import_utils.is_gptqmodel_available = patched_import_utils.is_gptqmodel_available + peft_import_utils.is_gptqmodel_available = patched_peft_import_utils.is_gptqmodel_available AdaLoraModel._create_and_replace = patched_AdaLoraModel._create_and_replace @@ -39,8 +39,8 @@ def monkey_patch_optimum(): quantizer.has_device_more_than_cpu = patched_quantizer.has_device_more_than_cpu quantizer.GPTQQuantizer = patched_quantizer.GPTQQuantizer - import_utils._gptqmodel_available = patched_import_utils._gptqmodel_available - import_utils.is_gptqmodel_available = patched_import_utils.is_gptqmodel_available + optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available + optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available testing_utils.require_gptq = patched_testing_utils.require_gptq diff --git a/gptqmodel/integration/peft/tuners/lora/gptq.py b/gptqmodel/integration/peft/tuners/lora/gptq.py index d33bbb2e5..bf9d3f972 100644 --- a/gptqmodel/integration/peft/tuners/lora/gptq.py +++ b/gptqmodel/integration/peft/tuners/lora/gptq.py @@ -18,8 +18,10 @@ from peft.tuners.lora.layer import LoraLayer from peft.tuners.tuners_utils import BaseTunerLayer -from peft.utils import get_auto_gptq_quant_linear, get_gptqmodel_quant_linear -from peft.import_utils import is_gptqmodel_available +from peft.utils import get_auto_gptq_quant_linear + +from gptqmodel.integration.peft.import_utils import is_gptqmodel_available +from gptqmodel.integration.peft.utils import get_gptqmodel_quant_linear class QuantLinear(torch.nn.Module, LoraLayer): diff --git a/gptqmodel/integration/peft/tuners/lora/model.py b/gptqmodel/integration/peft/tuners/lora/model.py index 847f276ec..de0598385 100644 --- a/gptqmodel/integration/peft/tuners/lora/model.py +++ b/gptqmodel/integration/peft/tuners/lora/model.py @@ -45,15 +45,15 @@ from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties from peft.utils.other import get_pattern_key -from .aqlm import dispatch_aqlm -from .awq import dispatch_awq -from .config import LoraConfig -from .eetq import dispatch_eetq -from .gptq import dispatch_gptq -from .hqq import dispatch_hqq -from .layer import Conv2d, LoraLayer, dispatch_default -from .torchao import dispatch_torchao -from .tp_layer import dispatch_megatron +from peft.tuners.lora.aqlm import dispatch_aqlm +from peft.tuners.lora.awq import dispatch_awq +from peft.tuners.lora.config import LoraConfig +from peft.tuners.lora.eetq import dispatch_eetq +from peft.tuners.lora.gptq import dispatch_gptq +from peft.tuners.lora.hqq import dispatch_hqq +from peft.tuners.lora.layer import Conv2d, LoraLayer, dispatch_default +from peft.tuners.lora.torchao import dispatch_torchao +from peft.tuners.lora.tp_layer import dispatch_megatron def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names): @@ -327,12 +327,12 @@ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs): # avoid eager bnb import if is_bnb_available(): - from .bnb import dispatch_bnb_8bit + from peft.tuners.lora.bnb import dispatch_bnb_8bit dispatchers.append(dispatch_bnb_8bit) if is_bnb_4bit_available(): - from .bnb import dispatch_bnb_4bit + from peft.tuners.lora.bnb import dispatch_bnb_4bit dispatchers.append(dispatch_bnb_4bit) From fd37bb3ad81be63a352a60bc7fb872e8271b61cb Mon Sep 17 00:00:00 2001 From: CSY Date: Tue, 10 Dec 2024 16:45:10 +0800 Subject: [PATCH 12/30] add prefix for imports & move to other dirs --- gptqmodel/integration/integration.py | 55 +++++++++---------- gptqmodel/integration/src/optimum/__init__.py | 0 .../integration/src/optimum/gptq/__init__.py | 0 .../{ => src}/optimum/gptq/quantizer.py | 2 +- .../integration/src/optimum/utils/__init__.py | 0 .../{ => src}/optimum/utils/import_utils.py | 0 .../{ => src}/optimum/utils/testing_utils.py | 2 +- gptqmodel/integration/src/peft/__init__.py | 0 .../{ => src}/peft/import_utils.py | 0 .../integration/src/peft/tuners/__init__.py | 0 .../src/peft/tuners/adalora/__init__.py | 0 .../{ => src}/peft/tuners/adalora/model.py | 2 +- .../src/peft/tuners/lora/__init__.py | 0 .../{ => src}/peft/tuners/lora/gptq.py | 4 +- .../{ => src}/peft/tuners/lora/model.py | 0 .../{ => src}/peft/utils/__init__.py | 0 .../integration/{ => src}/peft/utils/other.py | 2 +- .../integration/src/transformers/__init__.py | 0 .../src/transformers/quantizers/__init__.py | 0 .../transformers/quantizers/quantizer_gptq.py | 2 +- .../{ => src}/transformers/testing_utils.py | 0 .../src/transformers/utils/__init__.py | 0 .../transformers/utils/import_utils.py | 0 .../transformers/utils/quantization_config.py | 0 24 files changed, 32 insertions(+), 37 deletions(-) create mode 100644 gptqmodel/integration/src/optimum/__init__.py create mode 100644 gptqmodel/integration/src/optimum/gptq/__init__.py rename gptqmodel/integration/{ => src}/optimum/gptq/quantizer.py (99%) create mode 100644 gptqmodel/integration/src/optimum/utils/__init__.py rename gptqmodel/integration/{ => src}/optimum/utils/import_utils.py (100%) rename gptqmodel/integration/{ => src}/optimum/utils/testing_utils.py (98%) create mode 100644 gptqmodel/integration/src/peft/__init__.py rename gptqmodel/integration/{ => src}/peft/import_utils.py (100%) create mode 100644 gptqmodel/integration/src/peft/tuners/__init__.py create mode 100644 gptqmodel/integration/src/peft/tuners/adalora/__init__.py rename gptqmodel/integration/{ => src}/peft/tuners/adalora/model.py (99%) create mode 100644 gptqmodel/integration/src/peft/tuners/lora/__init__.py rename gptqmodel/integration/{ => src}/peft/tuners/lora/gptq.py (96%) rename gptqmodel/integration/{ => src}/peft/tuners/lora/model.py (100%) rename gptqmodel/integration/{ => src}/peft/utils/__init__.py (100%) rename gptqmodel/integration/{ => src}/peft/utils/other.py (99%) create mode 100644 gptqmodel/integration/src/transformers/__init__.py create mode 100644 gptqmodel/integration/src/transformers/quantizers/__init__.py rename gptqmodel/integration/{ => src}/transformers/quantizers/quantizer_gptq.py (98%) rename gptqmodel/integration/{ => src}/transformers/testing_utils.py (100%) create mode 100644 gptqmodel/integration/src/transformers/utils/__init__.py rename gptqmodel/integration/{ => src}/transformers/utils/import_utils.py (100%) rename gptqmodel/integration/{ => src}/transformers/utils/quantization_config.py (100%) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 99c0f35c2..b0be1c5b8 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -1,50 +1,45 @@ -from optimum.gptq import quantizer -from .optimum.gptq import quantizer as patched_quantizer -from optimum.utils import import_utils as optimum_import_utils -from .optimum.utils import import_utils as patched_optimum_import_utils -from optimum.utils import testing_utils -from .optimum.utils import testing_utils as patched_testing_utils +from optimum.gptq import quantizer as optimum_quantizer +from .src.optimum.gptq import quantizer as patched_optimum_quantizer +from optimum.utils import testing_utils as optimum_testing_utils , import_utils as optimum_import_utils +from .src.optimum.utils import testing_utils as patched_optimum_testing_utils, import_utils as patched_optimum_import_utils from peft import import_utils as peft_import_utils -from .peft import import_utils as patched_peft_import_utils -from peft.tuners.adalora.model import AdaLoraModel -from .peft.tuners.adalora.model import AdaLoraModel as patched_AdaLoraModel -from peft.tuners.lora import gptq -from .peft.tuners.lora import gptq as patched_gptq -from peft.tuners.lora import model -from .peft.tuners.lora import model as patched_model +from .src.peft import import_utils as patched_peft_import_utils +from peft.tuners.adalora.model import AdaLoraModel as peft_AdaLoraModel +from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel +from peft.tuners.lora import gptq as peft_gptq, model as peft_model +from .src.peft.tuners.lora import gptq as patched_peft_gptq, model as patched_peft_model +from peft.utils import other as peft_other +from .src.peft.utils import other as patched_peft_other -from peft.utils import other -from .peft.utils import other as patched_other - -from transformers.quantizers.quantizer_gptq import GptqHfQuantizer -from .transformers.quantizers.quantizer_gptq import GptqHfQuantizer as patched_GptqHfQuantizer +from transformers.quantizers.quantizer_gptq import GptqHfQuantizer as transformers_GptqHfQuantizer +from .src.transformers.quantizers.quantizer_gptq import GptqHfQuantizer as patched_transformers_GptqHfQuantizer def monkey_patch_peft(): peft_import_utils.is_gptqmodel_available = patched_peft_import_utils.is_gptqmodel_available - AdaLoraModel._create_and_replace = patched_AdaLoraModel._create_and_replace + peft_AdaLoraModel._create_and_replace = patched_peft_AdaLoraModel._create_and_replace - gptq.dispatch_gptq = patched_gptq.dispatch_gptq + peft_gptq.dispatch_gptq = patched_peft_gptq.dispatch_gptq - model.LoraModel = patched_model.LoraModel + peft_model.LoraModel = patched_peft_model.LoraModel - other.get_auto_gptq_quant_linear = patched_other.get_auto_gptq_quant_linear - other.get_gptqmodel_quant_linear = patched_other.get_gptqmodel_quant_linear + peft_other.get_auto_gptq_quant_linear = patched_peft_other.get_auto_gptq_quant_linear + peft_other.get_gptqmodel_quant_linear = patched_peft_other.get_gptqmodel_quant_linear def monkey_patch_optimum(): - quantizer.is_gptqmodel_available = patched_quantizer.is_gptqmodel_available - quantizer.has_device_more_than_cpu = patched_quantizer.has_device_more_than_cpu - quantizer.GPTQQuantizer = patched_quantizer.GPTQQuantizer + optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available + optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu + optimum_quantizer.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available - testing_utils.require_gptq = patched_testing_utils.require_gptq + optimum_testing_utils.require_gptq = patched_optimum_testing_utils.require_gptq def monkey_patch_transformers(): - GptqHfQuantizer.required_packages = patched_GptqHfQuantizer.required_packages - GptqHfQuantizer.validate_environment = patched_GptqHfQuantizer.validate_environment - GptqHfQuantizer.update_torch_dtype = patched_GptqHfQuantizer.update_torch_dtype + transformers_GptqHfQuantizer.required_packages = patched_transformers_GptqHfQuantizer.required_packages + transformers_GptqHfQuantizer.validate_environment = patched_transformers_GptqHfQuantizer.validate_environment + transformers_GptqHfQuantizer.update_torch_dtype = patched_transformers_GptqHfQuantizer.update_torch_dtype diff --git a/gptqmodel/integration/src/optimum/__init__.py b/gptqmodel/integration/src/optimum/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/src/optimum/gptq/__init__.py b/gptqmodel/integration/src/optimum/gptq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/optimum/gptq/quantizer.py b/gptqmodel/integration/src/optimum/gptq/quantizer.py similarity index 99% rename from gptqmodel/integration/optimum/gptq/quantizer.py rename to gptqmodel/integration/src/optimum/gptq/quantizer.py index 96b4897dd..d11c7a180 100644 --- a/gptqmodel/integration/optimum/gptq/quantizer.py +++ b/gptqmodel/integration/src/optimum/gptq/quantizer.py @@ -35,7 +35,7 @@ from optimum.gptq.utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen from optimum.version import __version__ as optimum_version -from gptqmodel.integration.optimum.utils.import_utils import is_gptqmodel_available +from gptqmodel.integration.src.optimum.utils.import_utils import is_gptqmodel_available if is_accelerate_available(): from accelerate import ( diff --git a/gptqmodel/integration/src/optimum/utils/__init__.py b/gptqmodel/integration/src/optimum/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/optimum/utils/import_utils.py b/gptqmodel/integration/src/optimum/utils/import_utils.py similarity index 100% rename from gptqmodel/integration/optimum/utils/import_utils.py rename to gptqmodel/integration/src/optimum/utils/import_utils.py diff --git a/gptqmodel/integration/optimum/utils/testing_utils.py b/gptqmodel/integration/src/optimum/utils/testing_utils.py similarity index 98% rename from gptqmodel/integration/optimum/utils/testing_utils.py rename to gptqmodel/integration/src/optimum/utils/testing_utils.py index 97eba86bb..9f140d4ab 100644 --- a/gptqmodel/integration/optimum/utils/testing_utils.py +++ b/gptqmodel/integration/src/optimum/utils/testing_utils.py @@ -33,7 +33,7 @@ is_timm_available, ) -from gptqmodel.integration.optimum.utils.import_utils import is_datasets_available, is_gptqmodel_available +from gptqmodel.integration.src.optimum.utils.import_utils import is_datasets_available, is_gptqmodel_available # Used to test the hub USER = "__DUMMY_OPTIMUM_USER__" diff --git a/gptqmodel/integration/src/peft/__init__.py b/gptqmodel/integration/src/peft/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/peft/import_utils.py b/gptqmodel/integration/src/peft/import_utils.py similarity index 100% rename from gptqmodel/integration/peft/import_utils.py rename to gptqmodel/integration/src/peft/import_utils.py diff --git a/gptqmodel/integration/src/peft/tuners/__init__.py b/gptqmodel/integration/src/peft/tuners/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/src/peft/tuners/adalora/__init__.py b/gptqmodel/integration/src/peft/tuners/adalora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/peft/tuners/adalora/model.py b/gptqmodel/integration/src/peft/tuners/adalora/model.py similarity index 99% rename from gptqmodel/integration/peft/tuners/adalora/model.py rename to gptqmodel/integration/src/peft/tuners/adalora/model.py index 5dc0d24b1..e0bf8dd6c 100644 --- a/gptqmodel/integration/peft/tuners/adalora/model.py +++ b/gptqmodel/integration/src/peft/tuners/adalora/model.py @@ -30,7 +30,7 @@ ) from peft.utils.integrations import gather_params_ctx -from gptqmodel.integration.peft.utils import get_gptqmodel_quant_linear +from gptqmodel.integration.src.peft.utils import get_gptqmodel_quant_linear from ...import_utils import is_gptqmodel_available diff --git a/gptqmodel/integration/src/peft/tuners/lora/__init__.py b/gptqmodel/integration/src/peft/tuners/lora/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/peft/tuners/lora/gptq.py b/gptqmodel/integration/src/peft/tuners/lora/gptq.py similarity index 96% rename from gptqmodel/integration/peft/tuners/lora/gptq.py rename to gptqmodel/integration/src/peft/tuners/lora/gptq.py index bf9d3f972..7ef124974 100644 --- a/gptqmodel/integration/peft/tuners/lora/gptq.py +++ b/gptqmodel/integration/src/peft/tuners/lora/gptq.py @@ -20,8 +20,8 @@ from peft.tuners.tuners_utils import BaseTunerLayer from peft.utils import get_auto_gptq_quant_linear -from gptqmodel.integration.peft.import_utils import is_gptqmodel_available -from gptqmodel.integration.peft.utils import get_gptqmodel_quant_linear +from gptqmodel.integration.src.peft.import_utils import is_gptqmodel_available +from gptqmodel.integration.src.peft.utils import get_gptqmodel_quant_linear class QuantLinear(torch.nn.Module, LoraLayer): diff --git a/gptqmodel/integration/peft/tuners/lora/model.py b/gptqmodel/integration/src/peft/tuners/lora/model.py similarity index 100% rename from gptqmodel/integration/peft/tuners/lora/model.py rename to gptqmodel/integration/src/peft/tuners/lora/model.py diff --git a/gptqmodel/integration/peft/utils/__init__.py b/gptqmodel/integration/src/peft/utils/__init__.py similarity index 100% rename from gptqmodel/integration/peft/utils/__init__.py rename to gptqmodel/integration/src/peft/utils/__init__.py diff --git a/gptqmodel/integration/peft/utils/other.py b/gptqmodel/integration/src/peft/utils/other.py similarity index 99% rename from gptqmodel/integration/peft/utils/other.py rename to gptqmodel/integration/src/peft/utils/other.py index f37b09d4d..201e22105 100644 --- a/gptqmodel/integration/peft/utils/other.py +++ b/gptqmodel/integration/src/peft/utils/other.py @@ -32,7 +32,7 @@ from peft.utils.constants import * from safetensors.torch import storage_ptr, storage_size -from gptqmodel.integration.peft.import_utils import is_auto_gptq_available, is_torch_tpu_available, is_gptqmodel_available +from gptqmodel.integration.src.peft.import_utils import is_auto_gptq_available, is_torch_tpu_available, is_gptqmodel_available mlu_available = False if version.parse(accelerate.__version__) >= version.parse("0.29.0"): diff --git a/gptqmodel/integration/src/transformers/__init__.py b/gptqmodel/integration/src/transformers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/src/transformers/quantizers/__init__.py b/gptqmodel/integration/src/transformers/quantizers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py similarity index 98% rename from gptqmodel/integration/transformers/quantizers/quantizer_gptq.py rename to gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py index 2b2f35d7f..c87261508 100644 --- a/gptqmodel/integration/transformers/quantizers/quantizer_gptq.py +++ b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py @@ -18,7 +18,7 @@ from transformers.quantizers.base import HfQuantizer -from gptqmodel.integration.transformers.utils.import_utils import is_gptqmodel_available +from gptqmodel.integration.src.transformers.utils.import_utils import is_gptqmodel_available if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel diff --git a/gptqmodel/integration/transformers/testing_utils.py b/gptqmodel/integration/src/transformers/testing_utils.py similarity index 100% rename from gptqmodel/integration/transformers/testing_utils.py rename to gptqmodel/integration/src/transformers/testing_utils.py diff --git a/gptqmodel/integration/src/transformers/utils/__init__.py b/gptqmodel/integration/src/transformers/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/integration/transformers/utils/import_utils.py b/gptqmodel/integration/src/transformers/utils/import_utils.py similarity index 100% rename from gptqmodel/integration/transformers/utils/import_utils.py rename to gptqmodel/integration/src/transformers/utils/import_utils.py diff --git a/gptqmodel/integration/transformers/utils/quantization_config.py b/gptqmodel/integration/src/transformers/utils/quantization_config.py similarity index 100% rename from gptqmodel/integration/transformers/utils/quantization_config.py rename to gptqmodel/integration/src/transformers/utils/quantization_config.py From 46a4223b99fc6f0634f06e6ac0e8081e4bf779da Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 10:11:13 +0800 Subject: [PATCH 13/30] fix GPTQQuantizer patch --- gptqmodel/integration/integration.py | 3 ++- .../integration/src/transformers/quantizers/quantizer_gptq.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index b0be1c5b8..d19339e3c 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -32,7 +32,8 @@ def monkey_patch_peft(): def monkey_patch_optimum(): optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu - optimum_quantizer.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer + optimum_quantizer.GPTQQuantizer.quantize_model = patched_optimum_quantizer.GPTQQuantizer.quantize_model + optimum_quantizer.GPTQQuantizer.__init__ = patched_optimum_quantizer.GPTQQuantizer.__init__ optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available diff --git a/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py index c87261508..eb8fcbc74 100644 --- a/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py +++ b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py @@ -75,7 +75,7 @@ def validate_environment(self, *args, **kwargs): ) elif is_gptqmodel_available() and ( version.parse(importlib.metadata.version("gptqmodel")) <= version.parse("1.3.1") - or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.99") + or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.3") ): raise ImportError("The gptqmodel version should be >= 1.3.2, optimum version should >= 1.24.0") From 257bea262b5b718109bf7edbcc2042e6beb1781b Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 11:06:07 +0800 Subject: [PATCH 14/30] fix patch error --- gptqmodel/integration/integration.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index d19339e3c..02f57c507 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -1,3 +1,6 @@ +import importlib + +import optimum.gptq as optimum_gptq from optimum.gptq import quantizer as optimum_quantizer from .src.optimum.gptq import quantizer as patched_optimum_quantizer from optimum.utils import testing_utils as optimum_testing_utils , import_utils as optimum_import_utils @@ -30,10 +33,17 @@ def monkey_patch_peft(): def monkey_patch_optimum(): + optimum_gptq.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer + # optimum_quantizer.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer + # importlib.reload(optimum_quantizer) optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu - optimum_quantizer.GPTQQuantizer.quantize_model = patched_optimum_quantizer.GPTQQuantizer.quantize_model - optimum_quantizer.GPTQQuantizer.__init__ = patched_optimum_quantizer.GPTQQuantizer.__init__ + # optimum_quantizer.GPTQQuantizer.quantize_model = patched_optimum_quantizer.GPTQQuantizer.quantize_model + # optimum_quantizer.GPTQQuantizer.__init__ = patched_optimum_quantizer.GPTQQuantizer.__init__ + # optimum_quantizer.GPTQQuantizer.pack_model = patched_optimum_quantizer.GPTQQuantizer.pack_model + # optimum_quantizer.GPTQQuantizer.select_quant_linear = patched_optimum_quantizer.GPTQQuantizer.select_quant_linear + # optimum_quantizer.GPTQQuantizer._replace_by_quant_layers = patched_optimum_quantizer.GPTQQuantizer._replace_by_quant_layers + # optimum_quantizer.GPTQQuantizer.post_init_model = patched_optimum_quantizer.GPTQQuantizer.post_init_model optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available @@ -41,6 +51,7 @@ def monkey_patch_optimum(): def monkey_patch_transformers(): + transformers_GptqHfQuantizer._process_model_after_weight_loading = patched_transformers_GptqHfQuantizer._process_model_after_weight_loading transformers_GptqHfQuantizer.required_packages = patched_transformers_GptqHfQuantizer.required_packages transformers_GptqHfQuantizer.validate_environment = patched_transformers_GptqHfQuantizer.validate_environment transformers_GptqHfQuantizer.update_torch_dtype = patched_transformers_GptqHfQuantizer.update_torch_dtype From d4005e18a0811ef61c59aba0444cc5924f24dec3 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 13:37:50 +0800 Subject: [PATCH 15/30] remove unused --- gptqmodel/integration/integration.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 02f57c507..690c56cc4 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -1,5 +1,3 @@ -import importlib - import optimum.gptq as optimum_gptq from optimum.gptq import quantizer as optimum_quantizer from .src.optimum.gptq import quantizer as patched_optimum_quantizer @@ -34,16 +32,8 @@ def monkey_patch_peft(): def monkey_patch_optimum(): optimum_gptq.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer - # optimum_quantizer.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer - # importlib.reload(optimum_quantizer) optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu - # optimum_quantizer.GPTQQuantizer.quantize_model = patched_optimum_quantizer.GPTQQuantizer.quantize_model - # optimum_quantizer.GPTQQuantizer.__init__ = patched_optimum_quantizer.GPTQQuantizer.__init__ - # optimum_quantizer.GPTQQuantizer.pack_model = patched_optimum_quantizer.GPTQQuantizer.pack_model - # optimum_quantizer.GPTQQuantizer.select_quant_linear = patched_optimum_quantizer.GPTQQuantizer.select_quant_linear - # optimum_quantizer.GPTQQuantizer._replace_by_quant_layers = patched_optimum_quantizer.GPTQQuantizer._replace_by_quant_layers - # optimum_quantizer.GPTQQuantizer.post_init_model = patched_optimum_quantizer.GPTQQuantizer.post_init_model optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available From 41fdbca3374ad5150470661f90e342208e8a31e5 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 13:41:32 +0800 Subject: [PATCH 16/30] check if lib is installed --- gptqmodel/integration/integration.py | 64 ++++++++++++++++++---------- 1 file changed, 42 insertions(+), 22 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 690c56cc4..9fbd4bd72 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -1,23 +1,40 @@ -import optimum.gptq as optimum_gptq -from optimum.gptq import quantizer as optimum_quantizer -from .src.optimum.gptq import quantizer as patched_optimum_quantizer -from optimum.utils import testing_utils as optimum_testing_utils , import_utils as optimum_import_utils -from .src.optimum.utils import testing_utils as patched_optimum_testing_utils, import_utils as patched_optimum_import_utils - -from peft import import_utils as peft_import_utils -from .src.peft import import_utils as patched_peft_import_utils -from peft.tuners.adalora.model import AdaLoraModel as peft_AdaLoraModel -from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel -from peft.tuners.lora import gptq as peft_gptq, model as peft_model -from .src.peft.tuners.lora import gptq as patched_peft_gptq, model as patched_peft_model -from peft.utils import other as peft_other -from .src.peft.utils import other as patched_peft_other - -from transformers.quantizers.quantizer_gptq import GptqHfQuantizer as transformers_GptqHfQuantizer -from .src.transformers.quantizers.quantizer_gptq import GptqHfQuantizer as patched_transformers_GptqHfQuantizer - - -def monkey_patch_peft(): +HAS_OPTIMUM = True +try: + import optimum.gptq as optimum_gptq + from optimum.gptq import quantizer as optimum_quantizer + from .src.optimum.gptq import quantizer as patched_optimum_quantizer + from optimum.utils import testing_utils as optimum_testing_utils, import_utils as optimum_import_utils + from .src.optimum.utils import testing_utils as patched_optimum_testing_utils, import_utils as patched_optimum_import_utils +except BaseException: + HAS_OPTIMUM = False + +HAS_PEFT = True +try: + from peft import import_utils as peft_import_utils + from .src.peft import import_utils as patched_peft_import_utils + from peft.tuners.adalora.model import AdaLoraModel as peft_AdaLoraModel + from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel + from peft.tuners.lora import gptq as peft_gptq, model as peft_model + from .src.peft.tuners.lora import gptq as patched_peft_gptq, model as patched_peft_model + from peft.utils import other as peft_other + from .src.peft.utils import other as patched_peft_other +except BaseException: + HAS_PEFT = False + +from transformers.quantizers.quantizer_gptq import GptqHfQuantizer as transformers_GptqHfQuantizer # noqa: E402 +from .src.transformers.quantizers.quantizer_gptq import GptqHfQuantizer as patched_transformers_GptqHfQuantizer # noqa: E402 + + +def monkey_patch_transformers(): + _patch_peft() + _patch_optimum() + _patch_transformers() + + +def _patch_peft(): + if not HAS_PEFT: + return + peft_import_utils.is_gptqmodel_available = patched_peft_import_utils.is_gptqmodel_available peft_AdaLoraModel._create_and_replace = patched_peft_AdaLoraModel._create_and_replace @@ -30,7 +47,10 @@ def monkey_patch_peft(): peft_other.get_gptqmodel_quant_linear = patched_peft_other.get_gptqmodel_quant_linear -def monkey_patch_optimum(): +def _patch_optimum(): + if not HAS_OPTIMUM: + return + optimum_gptq.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu @@ -40,7 +60,7 @@ def monkey_patch_optimum(): optimum_testing_utils.require_gptq = patched_optimum_testing_utils.require_gptq -def monkey_patch_transformers(): +def _patch_transformers(): transformers_GptqHfQuantizer._process_model_after_weight_loading = patched_transformers_GptqHfQuantizer._process_model_after_weight_loading transformers_GptqHfQuantizer.required_packages = patched_transformers_GptqHfQuantizer.required_packages transformers_GptqHfQuantizer.validate_environment = patched_transformers_GptqHfQuantizer.validate_environment From 7e115e6d878a1bf1668e31bf23f3ca2504424577 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 15:07:32 +0800 Subject: [PATCH 17/30] replace all for transformers --- gptqmodel/integration/integration.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 9fbd4bd72..fb74845c6 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -21,8 +21,8 @@ except BaseException: HAS_PEFT = False -from transformers.quantizers.quantizer_gptq import GptqHfQuantizer as transformers_GptqHfQuantizer # noqa: E402 -from .src.transformers.quantizers.quantizer_gptq import GptqHfQuantizer as patched_transformers_GptqHfQuantizer # noqa: E402 +from transformers.quantizers import quantizer_gptq as transformers_quantizer_gptq #import GptqHfQuantizer as transformers_GptqHfQuantizer # noqa: E402 +from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # import GptqHfQuantizer as patched_transformers_GptqHfQuantizer # noqa: E402 def monkey_patch_transformers(): @@ -61,7 +61,8 @@ def _patch_optimum(): def _patch_transformers(): - transformers_GptqHfQuantizer._process_model_after_weight_loading = patched_transformers_GptqHfQuantizer._process_model_after_weight_loading - transformers_GptqHfQuantizer.required_packages = patched_transformers_GptqHfQuantizer.required_packages - transformers_GptqHfQuantizer.validate_environment = patched_transformers_GptqHfQuantizer.validate_environment - transformers_GptqHfQuantizer.update_torch_dtype = patched_transformers_GptqHfQuantizer.update_torch_dtype + transformers_quantizer_gptq.GptqHfQuantizer = patched_transformers_quantizer_gptq.GptqHfQuantizer + # transformers_GptqHfQuantizer._process_model_after_weight_loading = patched_transformers_GptqHfQuantizer._process_model_after_weight_loading + # transformers_GptqHfQuantizer.required_packages = patched_transformers_GptqHfQuantizer.required_packages + # transformers_GptqHfQuantizer.validate_environment = patched_transformers_GptqHfQuantizer.validate_environment + # transformers_GptqHfQuantizer.update_torch_dtype = patched_transformers_GptqHfQuantizer.update_torch_dtype From d924bcdb7cffd9d132fc5ee4746925cedd99c5ad Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 15:34:27 +0800 Subject: [PATCH 18/30] add ExllamaVersion patch --- gptqmodel/integration/integration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index fb74845c6..ae918ab97 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -54,6 +54,7 @@ def _patch_optimum(): optimum_gptq.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu + optimum_quantizer.ExllamaVersion = patched_optimum_quantizer.ExllamaVersion optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available From 39f6ca85ac7a6a4a37e46a0be80d6d5e6a8ee471 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 15:40:27 +0800 Subject: [PATCH 19/30] add missing patch --- gptqmodel/integration/integration.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index ae918ab97..fa0008f2d 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -21,8 +21,12 @@ except BaseException: HAS_PEFT = False -from transformers.quantizers import quantizer_gptq as transformers_quantizer_gptq #import GptqHfQuantizer as transformers_GptqHfQuantizer # noqa: E402 -from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # import GptqHfQuantizer as patched_transformers_GptqHfQuantizer # noqa: E402 +from transformers.quantizers import quantizer_gptq as transformers_quantizer_gptq # noqa: E402 +from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # noqa: E402 +from transformers.utils import import_utils as transformers_import_utils # noqa: E402 +from .src.transformers.utils import import_utils as patched_transformers_import_utils # noqa: E402 +from transformers.utils import quantization_config as transformers_quantization_config # noqa: E402 +from .src.transformers.utils import quantization_config as patched_transformers_quantization_config # noqa: E402 def monkey_patch_transformers(): @@ -63,7 +67,8 @@ def _patch_optimum(): def _patch_transformers(): transformers_quantizer_gptq.GptqHfQuantizer = patched_transformers_quantizer_gptq.GptqHfQuantizer - # transformers_GptqHfQuantizer._process_model_after_weight_loading = patched_transformers_GptqHfQuantizer._process_model_after_weight_loading - # transformers_GptqHfQuantizer.required_packages = patched_transformers_GptqHfQuantizer.required_packages - # transformers_GptqHfQuantizer.validate_environment = patched_transformers_GptqHfQuantizer.validate_environment - # transformers_GptqHfQuantizer.update_torch_dtype = patched_transformers_GptqHfQuantizer.update_torch_dtype + + transformers_import_utils._gptqmodel_available = patched_transformers_import_utils._gptqmodel_available + transformers_import_utils.is_gptqmodel_available = patched_transformers_import_utils.is_gptqmodel_available + + transformers_quantization_config.AWQLinearVersion = patched_transformers_quantization_config.AWQLinearVersion From 3c5e007da171896a685f5c8a4813b34132cf11f0 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 15:45:56 +0800 Subject: [PATCH 20/30] patch transformers_testing_utils --- gptqmodel/integration/integration.py | 5 ++++- gptqmodel/integration/src/transformers/testing_utils.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index fa0008f2d..f861690d2 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -27,7 +27,8 @@ from .src.transformers.utils import import_utils as patched_transformers_import_utils # noqa: E402 from transformers.utils import quantization_config as transformers_quantization_config # noqa: E402 from .src.transformers.utils import quantization_config as patched_transformers_quantization_config # noqa: E402 - +import transformers.testing_utils as transformers_testing_utils +from .src.transformers import testing_utils as patched_transformers_testing_utils def monkey_patch_transformers(): _patch_peft() @@ -72,3 +73,5 @@ def _patch_transformers(): transformers_import_utils.is_gptqmodel_available = patched_transformers_import_utils.is_gptqmodel_available transformers_quantization_config.AWQLinearVersion = patched_transformers_quantization_config.AWQLinearVersion + + transformers_testing_utils.require_gptq = patched_transformers_testing_utils.require_gptq diff --git a/gptqmodel/integration/src/transformers/testing_utils.py b/gptqmodel/integration/src/transformers/testing_utils.py index bc6f98b9e..46fff122c 100644 --- a/gptqmodel/integration/src/transformers/testing_utils.py +++ b/gptqmodel/integration/src/transformers/testing_utils.py @@ -46,7 +46,7 @@ from transformers import logging as transformers_logging -from .integrations import ( +from transformers.integrations import ( is_clearml_available, is_optuna_available, is_ray_available, @@ -54,8 +54,8 @@ is_tensorboard_available, is_wandb_available, ) -from .integrations.deepspeed import is_deepspeed_available -from .utils import ( +from transformers.integrations.deepspeed import is_deepspeed_available +from transformers.utils import ( ACCELERATE_MIN_VERSION, GGUF_MIN_VERSION, is_accelerate_available, @@ -82,7 +82,6 @@ is_g2p_en_available, is_galore_torch_available, is_gguf_available, - is_gptqmodel_available, is_grokadamw_available, is_ipex_available, is_jieba_available, @@ -145,6 +144,7 @@ strtobool, ) +from .utils.import_utils import is_gptqmodel_available if is_accelerate_available(): from accelerate.state import AcceleratorState, PartialState From 1d9766a95afc36cd92ad34987aaa352eefb481a5 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 15:49:36 +0800 Subject: [PATCH 21/30] patch GPTQConfig --- gptqmodel/integration/integration.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index f861690d2..7fc84ef83 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -30,6 +30,8 @@ import transformers.testing_utils as transformers_testing_utils from .src.transformers import testing_utils as patched_transformers_testing_utils + + def monkey_patch_transformers(): _patch_peft() _patch_optimum() @@ -72,6 +74,6 @@ def _patch_transformers(): transformers_import_utils._gptqmodel_available = patched_transformers_import_utils._gptqmodel_available transformers_import_utils.is_gptqmodel_available = patched_transformers_import_utils.is_gptqmodel_available - transformers_quantization_config.AWQLinearVersion = patched_transformers_quantization_config.AWQLinearVersion + transformers_quantization_config.GPTQConfig = patched_transformers_quantization_config.GPTQConfig transformers_testing_utils.require_gptq = patched_transformers_testing_utils.require_gptq From 3a27d0485d44ae93b561cff994c1e0c058d0a369 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 15:55:26 +0800 Subject: [PATCH 22/30] update init.py --- .../src/transformers/utils/__init__.py | 316 ++++++++++++++++++ 1 file changed, 316 insertions(+) diff --git a/gptqmodel/integration/src/transformers/utils/__init__.py b/gptqmodel/integration/src/transformers/utils/__init__.py index e69de29bb..d126ce47b 100644 --- a/gptqmodel/integration/src/transformers/utils/__init__.py +++ b/gptqmodel/integration/src/transformers/utils/__init__.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import lru_cache +from typing import FrozenSet + +from huggingface_hub import get_full_repo_name # for backward compatibility +from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility +from packaging import version + +from transformers import __version__ +from transformers.utils.backbone_utils import BackboneConfigMixin, BackboneMixin +from transformers.utils.chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema +from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD +from transformers.utils.doc import ( + add_code_sample_docstrings, + add_end_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + copy_func, + replace_return_docstrings, +) +from transformers.utils.generic import ( + ContextManagers, + ExplicitEnum, + LossKwargs, + ModelOutput, + PaddingStrategy, + TensorType, + add_model_info_to_auto_map, + add_model_info_to_custom_pipelines, + cached_property, + can_return_loss, + expand_dims, + filter_out_non_signature_kwargs, + find_labels, + flatten_dict, + infer_framework, + is_jax_tensor, + is_numpy_array, + is_tensor, + is_tf_symbolic_tensor, + is_tf_tensor, + is_torch_device, + is_torch_dtype, + is_torch_tensor, + reshape, + squeeze, + strtobool, + tensor_size, + to_numpy, + to_py_obj, + torch_float, + torch_int, + transpose, + working_or_temp_dir, +) +from transformers.utils.hub import ( + CLOUDFRONT_DISTRIB_PREFIX, + HF_MODULES_CACHE, + HUGGINGFACE_CO_PREFIX, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + PYTORCH_PRETRAINED_BERT_CACHE, + PYTORCH_TRANSFORMERS_CACHE, + S3_BUCKET_PREFIX, + TRANSFORMERS_CACHE, + TRANSFORMERS_DYNAMIC_MODULE_NAME, + EntryNotFoundError, + PushInProgress, + PushToHubMixin, + RepositoryNotFoundError, + RevisionNotFoundError, + cached_file, + default_cache_path, + define_sagemaker_information, + download_url, + extract_commit_hash, + get_cached_models, + get_file_from_repo, + has_file, + http_user_agent, + is_offline_mode, + is_remote_url, + move_cache, + send_example_telemetry, + try_to_load_from_cache, +) +from .import_utils import ( + ACCELERATE_MIN_VERSION, + ENV_VARS_TRUE_AND_AUTO_VALUES, + ENV_VARS_TRUE_VALUES, + GGUF_MIN_VERSION, + TORCH_FX_REQUIRED_VERSION, + USE_JAX, + USE_TF, + USE_TORCH, + XLA_FSDPV2_MIN_VERSION, + DummyObject, + OptionalDependencyNotAvailable, + _LazyModule, + ccl_version, + direct_transformers_import, + get_torch_version, + is_accelerate_available, + is_apex_available, + is_aqlm_available, + is_auto_awq_available, + is_auto_gptq_available, + is_av_available, + is_bitsandbytes_available, + is_bitsandbytes_multi_backend_available, + is_bs4_available, + is_coloredlogs_available, + is_compressed_tensors_available, + is_cv2_available, + is_cython_available, + is_datasets_available, + is_detectron2_available, + is_eetq_available, + is_essentia_available, + is_faiss_available, + is_fbgemm_gpu_available, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal, + is_flash_attn_greater_or_equal_2_10, + is_flax_available, + is_fsdp_available, + is_ftfy_available, + is_g2p_en_available, + is_galore_torch_available, + is_gguf_available, + is_gptqmodel_available, + is_grokadamw_available, + is_hqq_available, + is_in_notebook, + is_ipex_available, + is_jieba_available, + is_jinja_available, + is_jumanpp_available, + is_kenlm_available, + is_keras_nlp_available, + is_levenshtein_available, + is_librosa_available, + is_liger_kernel_available, + is_lomo_available, + is_mlx_available, + is_natten_available, + is_ninja_available, + is_nltk_available, + is_onnx_available, + is_openai_available, + is_optimum_available, + is_optimum_quanto_available, + is_pandas_available, + is_peft_available, + is_phonemizer_available, + is_pretty_midi_available, + is_protobuf_available, + is_psutil_available, + is_py3nvml_available, + is_pyctcdecode_available, + is_pytesseract_available, + is_pytest_available, + is_pytorch_quantization_available, + is_rjieba_available, + is_sacremoses_available, + is_safetensors_available, + is_sagemaker_dp_enabled, + is_sagemaker_mp_enabled, + is_schedulefree_available, + is_scipy_available, + is_sentencepiece_available, + is_seqio_available, + is_sklearn_available, + is_soundfile_availble, + is_spacy_available, + is_speech_available, + is_sudachi_available, + is_sudachi_projection_available, + is_tensorflow_probability_available, + is_tensorflow_text_available, + is_tf2onnx_available, + is_tf_available, + is_tiktoken_available, + is_timm_available, + is_tokenizers_available, + is_torch_available, + is_torch_bf16_available, + is_torch_bf16_available_on_device, + is_torch_bf16_cpu_available, + is_torch_bf16_gpu_available, + is_torch_compile_available, + is_torch_cuda_available, + is_torch_deterministic, + is_torch_flex_attn_available, + is_torch_fp16_available_on_device, + is_torch_fx_available, + is_torch_fx_proxy, + is_torch_greater_or_equal, + is_torch_mlu_available, + is_torch_mps_available, + is_torch_musa_available, + is_torch_neuroncore_available, + is_torch_npu_available, + is_torch_sdpa_available, + is_torch_tensorrt_fx_available, + is_torch_tf32_available, + is_torch_tpu_available, + is_torch_xla_available, + is_torch_xpu_available, + is_torchao_available, + is_torchaudio_available, + is_torchdistx_available, + is_torchdynamo_available, + is_torchdynamo_compiling, + is_torchvision_available, + is_torchvision_v2_available, + is_training_run_on_sagemaker, + is_uroman_available, + is_vision_available, + requires_backends, + torch_only_method, +) +from transformers.utils.peft_utils import ( + ADAPTER_CONFIG_NAME, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + check_peft_version, + find_adapter_config_file, +) + + +WEIGHTS_NAME = "pytorch_model.bin" +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" +TF2_WEIGHTS_NAME = "tf_model.h5" +TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" +TF_WEIGHTS_NAME = "model.ckpt" +FLAX_WEIGHTS_NAME = "flax_model.msgpack" +FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" +SAFE_WEIGHTS_NAME = "model.safetensors" +SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" +CONFIG_NAME = "config.json" +FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" +IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME +PROCESSOR_NAME = "processor_config.json" +CHAT_TEMPLATE_NAME = "chat_template.json" +GENERATION_CONFIG_NAME = "generation_config.json" +MODEL_CARD_NAME = "modelcard.json" + +SENTENCEPIECE_UNDERLINE = "▁" +SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility + +MULTIPLE_CHOICE_DUMMY_INPUTS = [ + [[0, 1, 0, 1], [1, 0, 0, 1]] + ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too. +DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] +DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] + + +def check_min_version(min_version): + if version.parse(__version__) < version.parse(min_version): + if "dev" in min_version: + error_message = ( + "This example requires a source install from HuggingFace Transformers (see " + "`https://huggingface.co/docs/transformers/installation#install-from-source`)," + ) + else: + error_message = f"This example requires a minimum version of {min_version}," + error_message += f" but the version found is {__version__}.\n" + raise ImportError( + error_message + + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other " + "versions of HuggingFace Transformers." + ) + + +@lru_cache() +def get_available_devices() -> FrozenSet[str]: + """ + Returns a frozenset of devices available for the current PyTorch installation. + """ + devices = {"cpu"} # `cpu` is always supported as a device in PyTorch + + if is_torch_cuda_available(): + devices.add("cuda") + + if is_torch_mps_available(): + devices.add("mps") + + if is_torch_xpu_available(): + devices.add("xpu") + + if is_torch_npu_available(): + devices.add("npu") + + if is_torch_mlu_available(): + devices.add("mlu") + + if is_torch_musa_available(): + devices.add("musa") + + return frozenset(devices) \ No newline at end of file From 9ebf031a5980a9334452d1a8ae09053de8f2ccf5 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 16:13:37 +0800 Subject: [PATCH 23/30] delete unused --- gptqmodel/integration/src/transformers/__init__.py | 0 gptqmodel/integration/src/transformers/quantizers/__init__.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 gptqmodel/integration/src/transformers/__init__.py delete mode 100644 gptqmodel/integration/src/transformers/quantizers/__init__.py diff --git a/gptqmodel/integration/src/transformers/__init__.py b/gptqmodel/integration/src/transformers/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/gptqmodel/integration/src/transformers/quantizers/__init__.py b/gptqmodel/integration/src/transformers/quantizers/__init__.py deleted file mode 100644 index e69de29bb..000000000 From f4123728476efb8a4a8e24568c4ca6c091fb4257 Mon Sep 17 00:00:00 2001 From: CSY Date: Wed, 11 Dec 2024 16:14:18 +0800 Subject: [PATCH 24/30] add another patch --- gptqmodel/integration/integration.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 7fc84ef83..3737de1fc 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -77,3 +77,19 @@ def _patch_transformers(): transformers_quantization_config.GPTQConfig = patched_transformers_quantization_config.GPTQConfig transformers_testing_utils.require_gptq = patched_transformers_testing_utils.require_gptq + + # if 'transformers.quantizers.quantizer_gptq' in sys.modules: + # del sys.modules['transformers.quantizers.quantizer_gptq'] + # sys.modules['transformers.quantizers.quantizer_gptq'] = patched_transformers_quantizer_gptq + # + # if 'transformers.utils.import_utils' in sys.modules: + # del sys.modules['transformers.utils.import_utils'] + # sys.modules['transformers.utils.import_utils'] = patched_transformers_import_utils + # + # if 'transformers.utils.quantization_config' in sys.modules: + # del sys.modules['transformers.utils.quantization_config'] + # sys.modules['transformers.utils.quantization_config'] = patched_transformers_quantization_config + # + # if 'transformers.testing_utils' in sys.modules: + # del sys.modules['transformers.testing_utils'] + # sys.modules['transformers.testing_utils'] = patched_transformers_testing_utils \ No newline at end of file From 53f6fc6f11901974e55256d376ee86d514d09aee Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 12 Dec 2024 21:26:27 +0800 Subject: [PATCH 25/30] fix patch --- gptqmodel/integration/integration.py | 9 +- .../src/transformers/utils/__init__.py | 316 ------------------ 2 files changed, 6 insertions(+), 319 deletions(-) delete mode 100644 gptqmodel/integration/src/transformers/utils/__init__.py diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 3737de1fc..74cdbe2c9 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -32,7 +32,7 @@ -def monkey_patch_transformers(): +def patch_hf(): _patch_peft() _patch_optimum() _patch_transformers() @@ -69,12 +69,15 @@ def _patch_optimum(): def _patch_transformers(): - transformers_quantizer_gptq.GptqHfQuantizer = patched_transformers_quantizer_gptq.GptqHfQuantizer + transformers_quantizer_gptq.GptqHfQuantizer.required_packages = patched_transformers_quantizer_gptq.GptqHfQuantizer.required_packages + transformers_quantizer_gptq.GptqHfQuantizer.validate_environment = patched_transformers_quantizer_gptq.GptqHfQuantizer.validate_environment + transformers_quantizer_gptq.GptqHfQuantizer._process_model_before_weight_loading = patched_transformers_quantizer_gptq.GptqHfQuantizer._process_model_before_weight_loading transformers_import_utils._gptqmodel_available = patched_transformers_import_utils._gptqmodel_available transformers_import_utils.is_gptqmodel_available = patched_transformers_import_utils.is_gptqmodel_available - transformers_quantization_config.GPTQConfig = patched_transformers_quantization_config.GPTQConfig + transformers_quantization_config.GPTQConfig.__init__ = patched_transformers_quantization_config.GPTQConfig.__init__ + transformers_quantization_config.GPTQConfig.post_init = patched_transformers_quantization_config.GPTQConfig.post_init transformers_testing_utils.require_gptq = patched_transformers_testing_utils.require_gptq diff --git a/gptqmodel/integration/src/transformers/utils/__init__.py b/gptqmodel/integration/src/transformers/utils/__init__.py deleted file mode 100644 index d126ce47b..000000000 --- a/gptqmodel/integration/src/transformers/utils/__init__.py +++ /dev/null @@ -1,316 +0,0 @@ -#!/usr/bin/env python -# coding=utf-8 - -# Copyright 2021 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from functools import lru_cache -from typing import FrozenSet - -from huggingface_hub import get_full_repo_name # for backward compatibility -from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility -from packaging import version - -from transformers import __version__ -from transformers.utils.backbone_utils import BackboneConfigMixin, BackboneMixin -from transformers.utils.chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema -from transformers.utils.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD -from transformers.utils.doc import ( - add_code_sample_docstrings, - add_end_docstrings, - add_start_docstrings, - add_start_docstrings_to_model_forward, - copy_func, - replace_return_docstrings, -) -from transformers.utils.generic import ( - ContextManagers, - ExplicitEnum, - LossKwargs, - ModelOutput, - PaddingStrategy, - TensorType, - add_model_info_to_auto_map, - add_model_info_to_custom_pipelines, - cached_property, - can_return_loss, - expand_dims, - filter_out_non_signature_kwargs, - find_labels, - flatten_dict, - infer_framework, - is_jax_tensor, - is_numpy_array, - is_tensor, - is_tf_symbolic_tensor, - is_tf_tensor, - is_torch_device, - is_torch_dtype, - is_torch_tensor, - reshape, - squeeze, - strtobool, - tensor_size, - to_numpy, - to_py_obj, - torch_float, - torch_int, - transpose, - working_or_temp_dir, -) -from transformers.utils.hub import ( - CLOUDFRONT_DISTRIB_PREFIX, - HF_MODULES_CACHE, - HUGGINGFACE_CO_PREFIX, - HUGGINGFACE_CO_RESOLVE_ENDPOINT, - PYTORCH_PRETRAINED_BERT_CACHE, - PYTORCH_TRANSFORMERS_CACHE, - S3_BUCKET_PREFIX, - TRANSFORMERS_CACHE, - TRANSFORMERS_DYNAMIC_MODULE_NAME, - EntryNotFoundError, - PushInProgress, - PushToHubMixin, - RepositoryNotFoundError, - RevisionNotFoundError, - cached_file, - default_cache_path, - define_sagemaker_information, - download_url, - extract_commit_hash, - get_cached_models, - get_file_from_repo, - has_file, - http_user_agent, - is_offline_mode, - is_remote_url, - move_cache, - send_example_telemetry, - try_to_load_from_cache, -) -from .import_utils import ( - ACCELERATE_MIN_VERSION, - ENV_VARS_TRUE_AND_AUTO_VALUES, - ENV_VARS_TRUE_VALUES, - GGUF_MIN_VERSION, - TORCH_FX_REQUIRED_VERSION, - USE_JAX, - USE_TF, - USE_TORCH, - XLA_FSDPV2_MIN_VERSION, - DummyObject, - OptionalDependencyNotAvailable, - _LazyModule, - ccl_version, - direct_transformers_import, - get_torch_version, - is_accelerate_available, - is_apex_available, - is_aqlm_available, - is_auto_awq_available, - is_auto_gptq_available, - is_av_available, - is_bitsandbytes_available, - is_bitsandbytes_multi_backend_available, - is_bs4_available, - is_coloredlogs_available, - is_compressed_tensors_available, - is_cv2_available, - is_cython_available, - is_datasets_available, - is_detectron2_available, - is_eetq_available, - is_essentia_available, - is_faiss_available, - is_fbgemm_gpu_available, - is_flash_attn_2_available, - is_flash_attn_greater_or_equal, - is_flash_attn_greater_or_equal_2_10, - is_flax_available, - is_fsdp_available, - is_ftfy_available, - is_g2p_en_available, - is_galore_torch_available, - is_gguf_available, - is_gptqmodel_available, - is_grokadamw_available, - is_hqq_available, - is_in_notebook, - is_ipex_available, - is_jieba_available, - is_jinja_available, - is_jumanpp_available, - is_kenlm_available, - is_keras_nlp_available, - is_levenshtein_available, - is_librosa_available, - is_liger_kernel_available, - is_lomo_available, - is_mlx_available, - is_natten_available, - is_ninja_available, - is_nltk_available, - is_onnx_available, - is_openai_available, - is_optimum_available, - is_optimum_quanto_available, - is_pandas_available, - is_peft_available, - is_phonemizer_available, - is_pretty_midi_available, - is_protobuf_available, - is_psutil_available, - is_py3nvml_available, - is_pyctcdecode_available, - is_pytesseract_available, - is_pytest_available, - is_pytorch_quantization_available, - is_rjieba_available, - is_sacremoses_available, - is_safetensors_available, - is_sagemaker_dp_enabled, - is_sagemaker_mp_enabled, - is_schedulefree_available, - is_scipy_available, - is_sentencepiece_available, - is_seqio_available, - is_sklearn_available, - is_soundfile_availble, - is_spacy_available, - is_speech_available, - is_sudachi_available, - is_sudachi_projection_available, - is_tensorflow_probability_available, - is_tensorflow_text_available, - is_tf2onnx_available, - is_tf_available, - is_tiktoken_available, - is_timm_available, - is_tokenizers_available, - is_torch_available, - is_torch_bf16_available, - is_torch_bf16_available_on_device, - is_torch_bf16_cpu_available, - is_torch_bf16_gpu_available, - is_torch_compile_available, - is_torch_cuda_available, - is_torch_deterministic, - is_torch_flex_attn_available, - is_torch_fp16_available_on_device, - is_torch_fx_available, - is_torch_fx_proxy, - is_torch_greater_or_equal, - is_torch_mlu_available, - is_torch_mps_available, - is_torch_musa_available, - is_torch_neuroncore_available, - is_torch_npu_available, - is_torch_sdpa_available, - is_torch_tensorrt_fx_available, - is_torch_tf32_available, - is_torch_tpu_available, - is_torch_xla_available, - is_torch_xpu_available, - is_torchao_available, - is_torchaudio_available, - is_torchdistx_available, - is_torchdynamo_available, - is_torchdynamo_compiling, - is_torchvision_available, - is_torchvision_v2_available, - is_training_run_on_sagemaker, - is_uroman_available, - is_vision_available, - requires_backends, - torch_only_method, -) -from transformers.utils.peft_utils import ( - ADAPTER_CONFIG_NAME, - ADAPTER_SAFE_WEIGHTS_NAME, - ADAPTER_WEIGHTS_NAME, - check_peft_version, - find_adapter_config_file, -) - - -WEIGHTS_NAME = "pytorch_model.bin" -WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" -TF2_WEIGHTS_NAME = "tf_model.h5" -TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json" -TF_WEIGHTS_NAME = "model.ckpt" -FLAX_WEIGHTS_NAME = "flax_model.msgpack" -FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json" -SAFE_WEIGHTS_NAME = "model.safetensors" -SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json" -CONFIG_NAME = "config.json" -FEATURE_EXTRACTOR_NAME = "preprocessor_config.json" -IMAGE_PROCESSOR_NAME = FEATURE_EXTRACTOR_NAME -PROCESSOR_NAME = "processor_config.json" -CHAT_TEMPLATE_NAME = "chat_template.json" -GENERATION_CONFIG_NAME = "generation_config.json" -MODEL_CARD_NAME = "modelcard.json" - -SENTENCEPIECE_UNDERLINE = "▁" -SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility - -MULTIPLE_CHOICE_DUMMY_INPUTS = [ - [[0, 1, 0, 1], [1, 0, 0, 1]] - ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too. -DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]] -DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]] - - -def check_min_version(min_version): - if version.parse(__version__) < version.parse(min_version): - if "dev" in min_version: - error_message = ( - "This example requires a source install from HuggingFace Transformers (see " - "`https://huggingface.co/docs/transformers/installation#install-from-source`)," - ) - else: - error_message = f"This example requires a minimum version of {min_version}," - error_message += f" but the version found is {__version__}.\n" - raise ImportError( - error_message - + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other " - "versions of HuggingFace Transformers." - ) - - -@lru_cache() -def get_available_devices() -> FrozenSet[str]: - """ - Returns a frozenset of devices available for the current PyTorch installation. - """ - devices = {"cpu"} # `cpu` is always supported as a device in PyTorch - - if is_torch_cuda_available(): - devices.add("cuda") - - if is_torch_mps_available(): - devices.add("mps") - - if is_torch_xpu_available(): - devices.add("xpu") - - if is_torch_npu_available(): - devices.add("npu") - - if is_torch_mlu_available(): - devices.add("mlu") - - if is_torch_musa_available(): - devices.add("musa") - - return frozenset(devices) \ No newline at end of file From 2882e0e90d7dae6e2c37aadf4b19ce41b07bb373 Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 12 Dec 2024 21:30:28 +0800 Subject: [PATCH 26/30] fix ruff --- gptqmodel/integration/integration.py | 30 +++++++++++++++++----------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py index 74cdbe2c9..81d199982 100644 --- a/gptqmodel/integration/integration.py +++ b/gptqmodel/integration/integration.py @@ -2,34 +2,40 @@ try: import optimum.gptq as optimum_gptq from optimum.gptq import quantizer as optimum_quantizer + from optimum.utils import import_utils as optimum_import_utils + from optimum.utils import testing_utils as optimum_testing_utils + from .src.optimum.gptq import quantizer as patched_optimum_quantizer - from optimum.utils import testing_utils as optimum_testing_utils, import_utils as optimum_import_utils - from .src.optimum.utils import testing_utils as patched_optimum_testing_utils, import_utils as patched_optimum_import_utils + from .src.optimum.utils import import_utils as patched_optimum_import_utils + from .src.optimum.utils import testing_utils as patched_optimum_testing_utils except BaseException: HAS_OPTIMUM = False HAS_PEFT = True try: from peft import import_utils as peft_import_utils - from .src.peft import import_utils as patched_peft_import_utils from peft.tuners.adalora.model import AdaLoraModel as peft_AdaLoraModel - from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel - from peft.tuners.lora import gptq as peft_gptq, model as peft_model - from .src.peft.tuners.lora import gptq as patched_peft_gptq, model as patched_peft_model + from peft.tuners.lora import gptq as peft_gptq + from peft.tuners.lora import model as peft_model from peft.utils import other as peft_other + + from .src.peft import import_utils as patched_peft_import_utils + from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel + from .src.peft.tuners.lora import gptq as patched_peft_gptq + from .src.peft.tuners.lora import model as patched_peft_model from .src.peft.utils import other as patched_peft_other except BaseException: HAS_PEFT = False +import transformers.testing_utils as transformers_testing_utils # noqa: E402 from transformers.quantizers import quantizer_gptq as transformers_quantizer_gptq # noqa: E402 -from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # noqa: E402 from transformers.utils import import_utils as transformers_import_utils # noqa: E402 -from .src.transformers.utils import import_utils as patched_transformers_import_utils # noqa: E402 from transformers.utils import quantization_config as transformers_quantization_config # noqa: E402 -from .src.transformers.utils import quantization_config as patched_transformers_quantization_config # noqa: E402 -import transformers.testing_utils as transformers_testing_utils -from .src.transformers import testing_utils as patched_transformers_testing_utils +from .src.transformers import testing_utils as patched_transformers_testing_utils # noqa: E402 +from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # noqa: E402 +from .src.transformers.utils import import_utils as patched_transformers_import_utils # noqa: E402 +from .src.transformers.utils import quantization_config as patched_transformers_quantization_config # noqa: E402 def patch_hf(): @@ -95,4 +101,4 @@ def _patch_transformers(): # # if 'transformers.testing_utils' in sys.modules: # del sys.modules['transformers.testing_utils'] - # sys.modules['transformers.testing_utils'] = patched_transformers_testing_utils \ No newline at end of file + # sys.modules['transformers.testing_utils'] = patched_transformers_testing_utils From dd8e5e066b8f1472c45690be9933cc0b60188c3e Mon Sep 17 00:00:00 2001 From: CSY-ModelCloud Date: Thu, 12 Dec 2024 21:30:42 +0800 Subject: [PATCH 27/30] fix ruff --- gptqmodel/models/definitions/__init__.py | 2 +- gptqmodel/models/definitions/qwen2_vl.py | 4 +++- tests/models/test_qwen2_vl.py | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index 8a3023b75..79edf001e 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -36,9 +36,9 @@ from .qwen import QwenGPTQ from .qwen2 import Qwen2GPTQ from .qwen2_moe import Qwen2MoeGPTQ +from .qwen2_vl import Qwen2VLGPTQ from .rw import RWGPTQ from .stablelmepoch import StableLMEpochGPTQ from .starcoder2 import Starcoder2GPTQ from .xverse import XverseGPTQ from .yi import YiGPTQ -from .qwen2_vl import Qwen2VLGPTQ diff --git a/gptqmodel/models/definitions/qwen2_vl.py b/gptqmodel/models/definitions/qwen2_vl.py index 475474617..85dc6f251 100644 --- a/gptqmodel/models/definitions/qwen2_vl.py +++ b/gptqmodel/models/definitions/qwen2_vl.py @@ -1,6 +1,8 @@ -from ..base import BaseGPTQModel from transformers import AutoModelForVision2Seq +from ..base import BaseGPTQModel + + class Qwen2VLGPTQ(BaseGPTQModel): loader = AutoModelForVision2Seq diff --git a/tests/models/test_qwen2_vl.py b/tests/models/test_qwen2_vl.py index 91b3bd2c3..b14158042 100644 --- a/tests/models/test_qwen2_vl.py +++ b/tests/models/test_qwen2_vl.py @@ -1,5 +1,6 @@ from model_test import ModelTest + class TestQwen2_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2-VL-2B-Instruct" QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 @@ -10,4 +11,4 @@ class TestQwen2_VL(ModelTest): BATCH_SIZE = 6 def test_qwen2_vl(self): - self.quant_lm_eval() \ No newline at end of file + self.quant_lm_eval() From 69d57cb673f3e5ed5319c7bfdf591b2db893c38a Mon Sep 17 00:00:00 2001 From: CSY Date: Fri, 13 Dec 2024 09:55:10 +0800 Subject: [PATCH 28/30] check cuda when there's only cuda device --- gptqmodel/nn_modules/qlinear/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index a3003863b..6e48a13e2 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -26,7 +26,7 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat if err: raise err - if DEVICE.CUDA in self.SUPPORTS_DEVICES: + if len(self.SUPPORTS_DEVICES) == 1 and DEVICE.CUDA in self.SUPPORTS_DEVICES: check_cuda() @classmethod From e61b1084142aeafb5329e95fccb7239bc0027b6c Mon Sep 17 00:00:00 2001 From: CSY Date: Fri, 13 Dec 2024 12:14:12 +0800 Subject: [PATCH 29/30] exclude triton --- gptqmodel/nn_modules/qlinear/tritonv2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index c0bb94796..033aa4d4e 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -101,6 +101,8 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None, outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[ bool, Optional[Exception]]: + if trainable: + return False, ValueError("exclude triton from trainable") if not TRITON_AVAILABLE: return False, ValueError(TRITON_INSTALL_HINT) return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable) From 59cb8d2315a17676a0b890370a05c3c0f8743e31 Mon Sep 17 00:00:00 2001 From: CSY Date: Fri, 13 Dec 2024 12:17:14 +0800 Subject: [PATCH 30/30] set SUPPORTS_TRAINING to exclude --- gptqmodel/nn_modules/qlinear/tritonv2.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 033aa4d4e..6e617eb59 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -36,7 +36,7 @@ class TritonV2QuantLinear(BaseQuantLinear, TritonModuleMixin): SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = True + SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = True SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [32] SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [32] @@ -101,8 +101,6 @@ def __init__(self, bits: int, group_size: int, desc_act: bool, sym: bool, infeat def validate(cls, bits: int, group_size: int, desc_act: bool, sym: bool, infeatures:int=None, outfeatures:int=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None) -> Tuple[ bool, Optional[Exception]]: - if trainable: - return False, ValueError("exclude triton from trainable") if not TRITON_AVAILABLE: return False, ValueError(TRITON_INSTALL_HINT) return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, dynamic=dynamic, device=device, trainable=trainable)