diff --git a/gptqmodel/integration/integration.py b/gptqmodel/integration/integration.py
new file mode 100644
index 00000000..81d19998
--- /dev/null
+++ b/gptqmodel/integration/integration.py
@@ -0,0 +1,104 @@
+HAS_OPTIMUM = True
+try:
+ import optimum.gptq as optimum_gptq
+ from optimum.gptq import quantizer as optimum_quantizer
+ from optimum.utils import import_utils as optimum_import_utils
+ from optimum.utils import testing_utils as optimum_testing_utils
+
+ from .src.optimum.gptq import quantizer as patched_optimum_quantizer
+ from .src.optimum.utils import import_utils as patched_optimum_import_utils
+ from .src.optimum.utils import testing_utils as patched_optimum_testing_utils
+except BaseException:
+ HAS_OPTIMUM = False
+
+HAS_PEFT = True
+try:
+ from peft import import_utils as peft_import_utils
+ from peft.tuners.adalora.model import AdaLoraModel as peft_AdaLoraModel
+ from peft.tuners.lora import gptq as peft_gptq
+ from peft.tuners.lora import model as peft_model
+ from peft.utils import other as peft_other
+
+ from .src.peft import import_utils as patched_peft_import_utils
+ from .src.peft.tuners.adalora.model import AdaLoraModel as patched_peft_AdaLoraModel
+ from .src.peft.tuners.lora import gptq as patched_peft_gptq
+ from .src.peft.tuners.lora import model as patched_peft_model
+ from .src.peft.utils import other as patched_peft_other
+except BaseException:
+ HAS_PEFT = False
+
+import transformers.testing_utils as transformers_testing_utils # noqa: E402
+from transformers.quantizers import quantizer_gptq as transformers_quantizer_gptq # noqa: E402
+from transformers.utils import import_utils as transformers_import_utils # noqa: E402
+from transformers.utils import quantization_config as transformers_quantization_config # noqa: E402
+
+from .src.transformers import testing_utils as patched_transformers_testing_utils # noqa: E402
+from .src.transformers.quantizers import quantizer_gptq as patched_transformers_quantizer_gptq # noqa: E402
+from .src.transformers.utils import import_utils as patched_transformers_import_utils # noqa: E402
+from .src.transformers.utils import quantization_config as patched_transformers_quantization_config # noqa: E402
+
+
+def patch_hf():
+ _patch_peft()
+ _patch_optimum()
+ _patch_transformers()
+
+
+def _patch_peft():
+ if not HAS_PEFT:
+ return
+
+ peft_import_utils.is_gptqmodel_available = patched_peft_import_utils.is_gptqmodel_available
+
+ peft_AdaLoraModel._create_and_replace = patched_peft_AdaLoraModel._create_and_replace
+
+ peft_gptq.dispatch_gptq = patched_peft_gptq.dispatch_gptq
+
+ peft_model.LoraModel = patched_peft_model.LoraModel
+
+ peft_other.get_auto_gptq_quant_linear = patched_peft_other.get_auto_gptq_quant_linear
+ peft_other.get_gptqmodel_quant_linear = patched_peft_other.get_gptqmodel_quant_linear
+
+
+def _patch_optimum():
+ if not HAS_OPTIMUM:
+ return
+
+ optimum_gptq.GPTQQuantizer = patched_optimum_quantizer.GPTQQuantizer
+ optimum_quantizer.is_gptqmodel_available = patched_optimum_quantizer.is_gptqmodel_available
+ optimum_quantizer.has_device_more_than_cpu = patched_optimum_quantizer.has_device_more_than_cpu
+ optimum_quantizer.ExllamaVersion = patched_optimum_quantizer.ExllamaVersion
+
+ optimum_import_utils._gptqmodel_available = patched_optimum_import_utils._gptqmodel_available
+ optimum_import_utils.is_gptqmodel_available = patched_optimum_import_utils.is_gptqmodel_available
+ optimum_testing_utils.require_gptq = patched_optimum_testing_utils.require_gptq
+
+
+def _patch_transformers():
+ transformers_quantizer_gptq.GptqHfQuantizer.required_packages = patched_transformers_quantizer_gptq.GptqHfQuantizer.required_packages
+ transformers_quantizer_gptq.GptqHfQuantizer.validate_environment = patched_transformers_quantizer_gptq.GptqHfQuantizer.validate_environment
+ transformers_quantizer_gptq.GptqHfQuantizer._process_model_before_weight_loading = patched_transformers_quantizer_gptq.GptqHfQuantizer._process_model_before_weight_loading
+
+ transformers_import_utils._gptqmodel_available = patched_transformers_import_utils._gptqmodel_available
+ transformers_import_utils.is_gptqmodel_available = patched_transformers_import_utils.is_gptqmodel_available
+
+ transformers_quantization_config.GPTQConfig.__init__ = patched_transformers_quantization_config.GPTQConfig.__init__
+ transformers_quantization_config.GPTQConfig.post_init = patched_transformers_quantization_config.GPTQConfig.post_init
+
+ transformers_testing_utils.require_gptq = patched_transformers_testing_utils.require_gptq
+
+ # if 'transformers.quantizers.quantizer_gptq' in sys.modules:
+ # del sys.modules['transformers.quantizers.quantizer_gptq']
+ # sys.modules['transformers.quantizers.quantizer_gptq'] = patched_transformers_quantizer_gptq
+ #
+ # if 'transformers.utils.import_utils' in sys.modules:
+ # del sys.modules['transformers.utils.import_utils']
+ # sys.modules['transformers.utils.import_utils'] = patched_transformers_import_utils
+ #
+ # if 'transformers.utils.quantization_config' in sys.modules:
+ # del sys.modules['transformers.utils.quantization_config']
+ # sys.modules['transformers.utils.quantization_config'] = patched_transformers_quantization_config
+ #
+ # if 'transformers.testing_utils' in sys.modules:
+ # del sys.modules['transformers.testing_utils']
+ # sys.modules['transformers.testing_utils'] = patched_transformers_testing_utils
diff --git a/gptqmodel/integration/src/optimum/__init__.py b/gptqmodel/integration/src/optimum/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/optimum/gptq/__init__.py b/gptqmodel/integration/src/optimum/gptq/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/optimum/gptq/quantizer.py b/gptqmodel/integration/src/optimum/gptq/quantizer.py
new file mode 100644
index 00000000..d11c7a18
--- /dev/null
+++ b/gptqmodel/integration/src/optimum/gptq/quantizer.py
@@ -0,0 +1,925 @@
+
+# coding=utf-8
+# Copyright 2023 HuggingFace Inc. team and GPTQ and AutoGPTQ authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import json
+import os
+from enum import Enum
+from logging import getLogger
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from packaging import version
+from torch import nn
+from tqdm.auto import tqdm
+from transformers import AutoTokenizer
+from transformers.pytorch_utils import Conv1D
+from transformers.utils.quantization_config import QuantizationMethod
+
+from optimum.utils import is_accelerate_available, is_auto_gptq_available
+from optimum.utils.modeling_utils import recurse_getattr
+from optimum.gptq.constants import GPTQ_CONFIG
+from optimum.gptq.data import get_dataset, prepare_dataset
+from optimum.gptq.utils import get_block_name_with_pattern, get_device, get_layers, get_preceding_modules, get_seqlen
+from optimum.version import __version__ as optimum_version
+
+from gptqmodel.integration.src.optimum.utils.import_utils import is_gptqmodel_available
+
+if is_accelerate_available():
+ from accelerate import (
+ cpu_offload_with_hook,
+ load_checkpoint_and_dispatch,
+ )
+ from accelerate.hooks import remove_hook_from_module
+
+if is_auto_gptq_available():
+ from auto_gptq import exllama_set_max_input_length
+ from auto_gptq.modeling._utils import autogptq_post_init as gptq_post_init
+ from auto_gptq.quantization import GPTQ
+ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear as hf_select_quant_linear
+ from auto_gptq import __version__ as autogptq_version
+
+if is_gptqmodel_available():
+ from gptqmodel import exllama_set_max_input_length
+ from gptqmodel.quantization import GPTQ
+ from gptqmodel.utils.importer import hf_select_quant_linear
+ from gptqmodel.utils.model import hf_convert_gptq_v1_to_v2_format, hf_convert_gptq_v2_to_v1_format
+ from gptqmodel.utils.model import hf_gptqmodel_post_init as gptq_post_init
+ from gptqmodel.version import __version__ as gptqmodel_version
+
+logger = getLogger(__name__)
+
+
+def has_device_more_than_cpu():
+ return torch.cuda.is_available() or (hasattr(torch, "xpu") and torch.xpu.is_available())
+
+
+class ExllamaVersion(int, Enum):
+ ONE = 1
+ TWO = 2
+
+
+class GPTQQuantizer(object):
+ r"""
+ A simple API for GPTQ Quantization
+ """
+
+ def __init__(
+ self,
+ bits: int,
+ dataset: Optional[Union[List[str], str]] = None,
+ group_size: int = 128,
+ damp_percent: float = 0.1,
+ desc_act: bool = False,
+ sym: bool = True,
+ true_sequential: bool = True,
+ checkpoint_format: str = "gptq",
+ meta: Optional[Dict[str, any]] = None,
+ backend: Optional[str] = None,
+ use_cuda_fp16: bool = False,
+ model_seqlen: Optional[int] = None,
+ block_name_to_quantize: Optional[str] = None,
+ module_name_preceding_first_block: Optional[List[str]] = None,
+ batch_size: int = 1,
+ pad_token_id: Optional[int] = None,
+ disable_exllama: bool = False,
+ exllama_config: Optional[Dict[str, Any]] = None,
+ max_input_length: Optional[int] = None,
+ cache_block_outputs: Optional[bool] = True,
+ modules_in_block_to_quantize: Optional[List[List[str]]] = None,
+ *args,
+ **kwargs,
+ ):
+ """
+ Args:
+ bits (`int`):
+ The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
+ dataset (`Union[List[str], str, Any]`, defaults to `None`):
+ The dataset used for quantization. You can provide your own dataset in a list of string or in a list of tokenized data
+ (e.g. [{ "input_ids": [ 1, 100, 15, ... ],"attention_mask": [ 1, 1, 1, ... ]},...])
+ or just use the original datasets used in GPTQ paper ['wikitext2','c4','c4-new'].
+ group_size (int, defaults to 128):
+ The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
+ damp_percent (`float`, defaults to `0.1`):
+ The percent of the average Hessian diagonal to use for dampening, recommended value is 0.1.
+ desc_act (`bool`, defaults to `False`):
+ Whether to quantize columns in order of decreasing activation size.
+ Setting it to False can significantly speed up inference but the perplexity may become slightly worse.
+ Also known as act-order.
+ sym (`bool`, defaults to `True`):
+ Whether to use symetric quantization.
+ true_sequential (`bool`, defaults to `True`):
+ Whether to perform sequential quantization even within a single Transformer block.
+ Instead of quantizing the entire block at once, we perform layer-wise quantization.
+ As a result, each layer undergoes quantization using inputs that have passed through the previously quantized layers.
+ checkpoint_format (`str`, *optional*, defaults to `gptq`):
+ GPTQ weight format. `gptq`(v1) is supported by both gptqmodel and auto-gptq. `gptq_v2` is gptqmodel only.
+ meta (`Dict[str, any]`, *optional*):
+ Properties, such as tooling:version, that do not directly contributes to quantization or quant inference are stored in meta.
+ i.e. `meta.quantizer`: ["optimum:_version_", "gptqmodel:_version_"]
+ backend (`str`, *optional*):
+ Controls which gptq kernel to be used. Valid values for gptqmodel are `auto`, `auto_trainable` and more. For auto-gptq, only
+ valid value is None and `auto_trainable`. Ref gptqmodel backends: https://github.com/ModelCloud/GPTQModel/blob/main/gptqmodel/utils/backend.py
+ use_cuda_fp16 (`bool`, defaults to `False`):
+ Whether or not to use optimized cuda kernel for fp16 model. Need to have model in fp16.
+ model_seqlen (`Optional[int]`, defaults to `None`):
+ The maximum sequence length that the model can take.
+ block_name_to_quantize (`Optional[str]`, defaults to `None`):
+ The transformers block name to quantize. If None, we will infer the block name using common patterns (e.g. model.layers)
+ module_name_preceding_first_block (`Optional[List[str]]`, defaults to `None`):
+ The layers that are preceding the first Transformer block.
+ batch_size (`int`, defaults to `1`):
+ The batch size of the dataset
+ pad_token_id (`Optional[int]`, defaults to `None`):
+ The pad token id. Needed to prepare the dataset when `batch_size` > 1.
+ disable_exllama (`bool`, defaults to `False`):
+ Whether to use exllama backend. Only works with `bits` = 4.
+ exllama_config (`Dict[str, Any]`, *optional*):
+ The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults to `{"version": 2}` if unset.
+ max_input_length (`Optional[int]`, defaults to `None`):
+ The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
+ It is specific to the exllama backend with act-order.
+ cache_block_outputs (`bool`, defaults to `True`):
+ Whether to cache block outputs to reuse as inputs for the succeeding block. It allows optimization of non-standard models
+ (e.g. ChatGLM) but can require more time.
+ modules_in_block_to_quantize (`Optional[List[List[str]]]`, defaults to `None`):
+ List list of module names to quantize in the block specified. This argument is useful to exclude certain linear modules from being quantized.
+ The block to quantize can be specified by setting `block_name_to_quantize`. We will quantize each list sequentially.
+ If not set, we will quantize all linear layers. Example: `inside_layer_modules=[["self_attention.query_key_value"], ["mlp.dense_h_to_4h"]]`
+ """
+
+ self.bits = bits
+ self.dataset = dataset
+ self.group_size = group_size
+ self.damp_percent = damp_percent
+ self.desc_act = desc_act
+ self.sym = sym
+ self.true_sequential = true_sequential
+ self.checkpoint_format = checkpoint_format.lower()
+ self.meta = meta
+ self.backend = backend.lower() if backend is not None else None
+ self.use_cuda_fp16 = use_cuda_fp16
+ self.model_seqlen = model_seqlen
+ self.block_name_to_quantize = block_name_to_quantize
+ self.module_name_preceding_first_block = module_name_preceding_first_block
+ self.batch_size = batch_size
+ self.pad_token_id = pad_token_id
+ self.disable_exllama = disable_exllama
+ self.exllama_config = exllama_config
+ self.max_input_length = max_input_length
+ self.quant_method = QuantizationMethod.GPTQ
+ self.cache_block_outputs = cache_block_outputs
+ self.modules_in_block_to_quantize = modules_in_block_to_quantize
+
+ self.serialization_keys = [
+ "bits",
+ "dataset",
+ "group_size",
+ "damp_percent",
+ "desc_act",
+ "sym",
+ "true_sequential",
+ "quant_method",
+ "modules_in_block_to_quantize",
+ "checkpoint_format",
+ "meta",
+ ]
+
+ if self.bits not in [2, 3, 4, 8]:
+ raise ValueError("only support quantize to [2,3,4,8] bits.")
+ if self.group_size != -1 and self.group_size <= 0:
+ raise ValueError("group_size must be greater than 0 or equal to -1")
+ if not (0 < self.damp_percent < 1):
+ raise ValueError("damp_percent must between 0 and 1.")
+
+ if self.exllama_config is None:
+ self.exllama_config = {"version": ExllamaVersion.TWO}
+ else:
+ if "version" not in self.exllama_config:
+ raise ValueError("`exllama_config` needs to have a `version` key")
+ elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
+ version = self.exllama_config["version"]
+ raise ValueError(
+ f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {version}"
+ )
+ self.exllama_version = self.exllama_config["version"]
+
+ def select_quant_linear(self, device_map: Union[str, dict]):
+ if is_gptqmodel_available():
+ self.quant_linear = hf_select_quant_linear(
+ bits=self.bits,
+ group_size=self.group_size,
+ desc_act=self.desc_act,
+ sym=self.sym,
+ checkpoint_format=self.checkpoint_format,
+ meta=self.meta,
+ device_map=device_map,
+ backend=self.backend,
+ )
+ else:
+ self.quant_linear = hf_select_quant_linear(
+ use_triton=False,
+ desc_act=self.desc_act,
+ group_size=self.group_size,
+ bits=self.bits,
+ disable_exllama=self.disable_exllama or self.exllama_version != ExllamaVersion.ONE,
+ disable_exllamav2=self.disable_exllama or self.exllama_version != ExllamaVersion.TWO,
+ )
+
+ def to_dict(self):
+ """
+ Returns the args in dict format.
+ """
+ gptq_dict = {}
+ for key in self.serialization_keys:
+ gptq_dict[key] = getattr(self, key)
+
+ if gptq_dict.get("meta") is None:
+ gptq_dict["meta"] = {}
+
+ meta = gptq_dict["meta"]
+ # store both optimum:version and gptq_lib:version into quantize_config.meta.quantizer
+ if meta.get("quantizer") is None:
+ meta["quantizer"] = [f"optimum:{optimum_version}"]
+
+ if is_gptqmodel_available():
+ meta["quantizer"].append(f"gptqmodel:{gptqmodel_version}")
+ elif is_auto_gptq_available():
+ meta["quantizer"].append(f"auto_gptq:{autogptq_version}")
+
+ return gptq_dict
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict[str, Any]):
+ """
+ Instantiates a `GPTQQuantizer` using config_dict as kwargs
+
+ Args:
+ config_dict (`Dict[str,Any]`):
+ quantization config
+
+ Returns:
+ `GPTQQuantizer`: The quantizer object instantiated from those parameters.
+ """
+ return cls(**config_dict)
+
+ def convert_model(self, model: nn.Module, **kwargs):
+ """
+ Convert the model to a GPTQ model by getting and replacing the layers.
+
+ Args:
+ model (`nn.Module`):
+ Model to be converted
+
+ """
+ if self.block_name_to_quantize is None:
+ self.block_name_to_quantize = get_block_name_with_pattern(model)
+ block_name = self.block_name_to_quantize
+ layers_to_be_replaced = get_layers(model, prefix=block_name)
+ if self.modules_in_block_to_quantize is not None:
+ layers_to_keep = sum(self.modules_in_block_to_quantize, [])
+ for name in list(layers_to_be_replaced.keys()):
+ if not any(name.endswith(layer) for layer in layers_to_keep):
+ logger.info(
+ f"Quantization disabled for {name} (only modules_in_block_to_quantize={self.modules_in_block_to_quantize} are quantized)"
+ )
+ del layers_to_be_replaced[name]
+
+ self.select_quant_linear(device_map=kwargs.get("device_map", None))
+
+ self._replace_by_quant_layers(model, layers_to_be_replaced)
+
+ return model
+
+ def get_no_split_module_classes(self, model):
+ """
+ Get the modules that should not be split across multiple devices.
+ Args:
+ model (`nn.Module`):
+ The input model
+ """
+
+ block_class_name = recurse_getattr(model, self.block_name_to_quantize)[0].__class__.__name__
+ no_split_module_classes = [block_class_name]
+ return no_split_module_classes
+
+ def _replace_by_quant_layers(self, module: nn.Module, names: List[str], name: str = ""):
+ """
+ Replaces linear layers in `module` by `QuantLinear`
+
+ Args:
+ module (`nn.Module`):
+ Module to quantize
+ names (`List[str]`):
+ List of names of the module to quantize
+ name (`str`, defaults to `""`):
+ To keep track of the name of the current module
+ """
+ if isinstance(module, self.quant_linear):
+ return
+ for attr in dir(module):
+ layer = getattr(module, attr)
+ name1 = name + "." + attr if name != "" else attr
+ if name1 in names:
+ device = get_device(layer)
+ delattr(module, attr)
+ if isinstance(layer, nn.Linear):
+ in_features = layer.in_features
+ out_features = layer.out_features
+ elif isinstance(layer, nn.Conv2d):
+ in_features = layer.in_channels
+ out_features = layer.out_channels
+ elif isinstance(layer, Conv1D):
+ in_features = layer.weight.shape[0]
+ out_features = layer.weight.shape[1]
+ bias = layer.bias is not None
+ if is_gptqmodel_available():
+ new_layer = self.quant_linear(
+ self.bits,
+ self.group_size,
+ self.desc_act,
+ self.sym,
+ in_features,
+ out_features,
+ bias,
+ weight_dtype=layer.weight.dtype,
+ )
+ else:
+ if not (self.desc_act) or self.group_size == -1:
+ new_layer = self.quant_linear(
+ self.bits,
+ self.group_size,
+ in_features,
+ out_features,
+ bias,
+ use_cuda_fp16=self.use_cuda_fp16,
+ weight_dtype=layer.weight.dtype,
+ )
+ else:
+ new_layer = self.quant_linear(
+ self.bits,
+ self.group_size,
+ in_features,
+ out_features,
+ bias,
+ weight_dtype=layer.weight.dtype,
+ )
+ new_layer.device = device
+ setattr(module, attr, new_layer.to(device))
+ for name1, child in module.named_children():
+ self._replace_by_quant_layers(child, names, name + "." + name1 if name != "" else name1)
+
+ @torch.no_grad()
+ def quantize_model(self, model: nn.Module, tokenizer: Optional[Any] = None):
+ """
+ Quantizes the model using the dataset
+
+ Args:
+ model (`nn.Module`):
+ The model to quantize
+ tokenizer (Optional[`Any`], defaults to `None`):
+ The tokenizer to use in order to prepare the dataset. You can pass either:
+ - A custom tokenizer object.
+ - A string, the *model id* of a predefined tokenizer hosted inside a model repo on huggingface.co.
+ Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
+ user or organization name, like `dbmdz/bert-base-german-cased`.
+ - A path to a *directory* containing vocabulary files required by the tokenizer, for instance saved
+ using the [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
+ Returns:
+ `nn.Module`: The quantized model
+ """
+
+ if not is_auto_gptq_available() and not is_gptqmodel_available():
+ raise RuntimeError(
+ "gptqmodel or auto-gptq is required in order to perform gptq quantzation: `pip install gptqmodel` or `pip install auto-gptq`. Please notice that auto-gptq will be deprecated in the future."
+ )
+ elif is_gptqmodel_available() and is_auto_gptq_available():
+ logger.warning(
+ "Detected gptqmodel and auto-gptq, will use gptqmodel. The auto_gptq will be deprecated in the future."
+ )
+
+ gptq_supports_cpu = (
+ is_auto_gptq_available()
+ and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
+ ) or is_gptqmodel_available()
+
+ if not gptq_supports_cpu and not torch.cuda.is_available():
+ raise RuntimeError(
+ "No cuda gpu or cpu support using Intel/IPEX found. A gpu or cpu with Intel/IPEX is required for quantization."
+ )
+
+ if not self.sym and not is_gptqmodel_available():
+ raise ValueError(
+ "Asymmetric sym=False quantization is not supported with auto-gptq. Please use gptqmodel: `pip install gptqmodel`"
+ )
+
+ if self.checkpoint_format == "gptq_v2" and not is_gptqmodel_available():
+ raise ValueError(
+ "gptq_v2 format only supported with gptqmodel. Please install gptqmodel: `pip install gptqmodel`"
+ )
+
+ model.eval()
+
+ # gptqmodel internal is gptq_v2 for asym support, gptq(v1) can only support sym=True
+ if is_gptqmodel_available() and self.checkpoint_format != "gptq_v2":
+ self.checkpoint_format = "gptq_v2"
+
+ # For Transformer model
+ has_config = False
+ has_device_map = False
+ if hasattr(model, "config"):
+ has_config = True
+ use_cache = model.config.use_cache
+ model.config.use_cache = False
+
+ # If the model has a device_map, we don't move to model. We have already dispatched the hook that will do the work
+ if hasattr(model, "hf_device_map"):
+ devices = list(model.hf_device_map.values())
+ has_device_map = True
+ if "disk" in devices:
+ raise ValueError("disk offload is not supported with GPTQ quantization")
+ if "cpu" in devices or torch.device("cpu") in devices:
+ if len(model.hf_device_map) > 1:
+ logger.info("Cpu offload is not recommended. There might be some issues with the memory")
+ hook = None
+ for name, device in model.hf_device_map.items():
+ if device == "cpu":
+ module = recurse_getattr(model, name)
+ remove_hook_from_module(module, recurse=True)
+ module, hook = cpu_offload_with_hook(module, prev_module_hook=hook)
+ else:
+ has_device_map = False
+
+ if hasattr(model, "dtype"):
+ self.use_cuda_fp16 = model.dtype == torch.float16
+
+ if self.model_seqlen is None:
+ # We allow a max value of 4028 to avoid passing data with huge length to the model during the calibration step
+ self.model_seqlen = min(4028, get_seqlen(model))
+
+ device = get_device(model)
+
+ # Step 1: Prepare the data
+ if isinstance(self.dataset, list) and not isinstance(self.dataset[0], str):
+ dataset = self.dataset
+ logger.info("GPTQQuantizer dataset appears to be already tokenized. Skipping tokenization.")
+ else:
+ if isinstance(tokenizer, str):
+ try:
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer)
+ except Exception:
+ raise ValueError(
+ f"""We were not able to get the tokenizer using `AutoTokenizer.from_pretrained`
+ with the string that you have passed {tokenizer}. If you have a custom tokenizer, you can pass it as input.
+ For now, we only support quantization for text model. Support for vision, speech and multimodel will come later."""
+ )
+ if self.dataset is None:
+ raise ValueError("You need to pass `dataset` in order to quantize your model")
+ elif isinstance(self.dataset, str):
+ dataset = get_dataset(self.dataset, tokenizer, seqlen=self.model_seqlen, split="train")
+ elif isinstance(self.dataset, list):
+ dataset = [tokenizer(data, return_tensors="pt") for data in self.dataset]
+ else:
+ raise ValueError(
+ f"You need to pass a list of string, a list of tokenized data or a string for `dataset`. Found: {type(self.dataset)}."
+ )
+
+ dataset = prepare_dataset(dataset, pad_token_id=self.pad_token_id, batch_size=self.batch_size)
+
+ # Step 2: get the input of the 1st block
+ # To do that, we need to put the modules preceding the first block on the same device as the first bloc.
+ # Then we run the model and it will stop at the first bloc as we added a prehook that raise an Exception after storing the inputs.
+
+ layer_inputs = []
+ layer_outputs = []
+ layer_input_kwargs = []
+
+ if self.block_name_to_quantize is None:
+ self.block_name_to_quantize = get_block_name_with_pattern(model)
+
+ if self.module_name_preceding_first_block is None:
+ self.module_name_preceding_first_block = get_preceding_modules(model, self.block_name_to_quantize)
+
+ blocks = recurse_getattr(model, self.block_name_to_quantize)
+
+ if not has_device_map:
+ # put modules from module_name_preceding_first_block on cuda or xpu or cpu
+ to_device = 0 if has_device_more_than_cpu() else "cpu"
+ for module_name in self.module_name_preceding_first_block:
+ module = recurse_getattr(model, module_name)
+ if module is None:
+ raise ValueError(f"Module {module_name} was not found in model")
+ module = module.to(to_device)
+ blocks[0] = blocks[0].to(to_device)
+
+ def store_input_hook(_, input, *args):
+ kwargs = args[0]
+ if input is None:
+ if "hidden_states" in kwargs:
+ input = (kwargs["hidden_states"],)
+ else:
+ raise ValueError("No input value found in the foward pass")
+ layer_inputs.append(input)
+ other_kwargs = {}
+ for k, v in kwargs.items(): # make sure other arguments also be captured
+ if k not in ["hidden_states"]:
+ other_kwargs[k] = v
+ layer_input_kwargs.append(other_kwargs)
+ raise ValueError
+
+ if self.cache_block_outputs:
+ handle = blocks[0].register_forward_pre_hook(store_input_hook, with_kwargs=True)
+ for data in dataset:
+ for k, v in data.items():
+ # put the data on gpu, we won't put them back to cpu
+ if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu():
+ data[k] = v.to(0)
+ else:
+ data[k] = v.to(device)
+ try:
+ model(**data)
+ except ValueError:
+ pass
+ handle.remove()
+
+ if not has_device_map:
+ blocks[0].to(device)
+ for module_name in self.module_name_preceding_first_block:
+ module = recurse_getattr(model, module_name)
+ if module is None:
+ raise ValueError(f"Module {module_name} was not found in model")
+
+ torch.cuda.empty_cache()
+ if hasattr(torch, "xpu"):
+ torch.xpu.empty_cache()
+
+ # Step 3: Quantize the blocks
+ quantizers = {}
+ for i, block in enumerate(tqdm(blocks, desc=f"Quantizing {self.block_name_to_quantize} blocks ")):
+ logger.info(f"Start quantizing block {self.block_name_to_quantize} {i + 1}/{len(blocks)}")
+
+ if not self.cache_block_outputs:
+ handle = block.register_forward_pre_hook(store_input_hook, with_kwargs=True)
+ for data in dataset:
+ for k, v in data.items():
+ # put the data on gpu, we won't put them back to cpu
+ if (not has_device_map or device.type == "cpu") and has_device_more_than_cpu():
+ data[k] = v.to(0)
+ else:
+ data[k] = v.to(device)
+ try:
+ model(**data)
+ except ValueError:
+ pass
+ handle.remove()
+
+ # move block to cuda if needed
+ # in case we have offload modules, we need to put them on cuda because of GPTQ object
+ if (not has_device_map or get_device(block) == torch.device("cpu")) and has_device_more_than_cpu():
+ block = block.to(0)
+ layers = get_layers(block)
+ if isinstance(self.modules_in_block_to_quantize, list) and len(self.modules_in_block_to_quantize) > 0:
+ if self.true_sequential:
+ layers_name_list = self.modules_in_block_to_quantize
+ else:
+ layers_name_list = [sum(self.modules_in_block_to_quantize, [])]
+ else:
+ if self.true_sequential:
+ # lazy sequential but works well
+ layers_name_list = [[key] for key in layers.keys()]
+ else:
+ layers_name_list = [list(layers.keys())]
+ logger.info(f"Module to quantize {layers_name_list}")
+ for subset_name_list in tqdm(layers_name_list, leave=False, desc="Quantizing layers inside the block"):
+ subset_layers = {name: layers[name] for name in subset_name_list}
+ gptq = {}
+ handles = []
+ # add hook for each layer in subset_layers
+ for name in subset_layers:
+ gptq[name] = GPTQ(subset_layers[name])
+ gptq[name].quantizer.configure(bits=self.bits, sym=self.sym, perchannel=True)
+
+ def add_batch(name):
+ def tmp(_, input, output):
+ gptq[name].add_batch(input[0].data, output.data)
+
+ return tmp
+
+ # because it adding a hook will replace the old one.
+ handles.append(subset_layers[name].register_forward_hook(add_batch(name)))
+ # update Hessian for each layer in subset_layers thanks to the hook
+ for j in range(len(dataset)):
+ # the args are already on the gpu
+ # don't need to store the output
+ block(*layer_inputs[j], **layer_input_kwargs[j])
+ # remove hook
+ for h in handles:
+ h.remove()
+ for name in subset_name_list:
+ logger.info(f"Quantizing {name} in block {i + 1}/{len(blocks)}...")
+ quant_outputs = gptq[name].fasterquant(
+ percdamp=self.damp_percent, group_size=self.group_size, actorder=self.desc_act
+ )
+ scale, zero, g_idx = quant_outputs[0], quant_outputs[1], quant_outputs[2]
+ quantizers[f"{self.block_name_to_quantize}.{i}.{name}"] = (
+ gptq[name].quantizer,
+ scale,
+ zero,
+ g_idx,
+ )
+ gptq[name].free()
+ del subset_layers
+ # we get the new output from the partial quantized block
+ if self.cache_block_outputs:
+ for j in range(len(dataset)):
+ layer_output = block(*layer_inputs[j], **layer_input_kwargs[j])
+ layer_outputs.append(layer_output)
+
+ # put back to device
+ if not has_device_map:
+ blocks[i] = block.to(device)
+ del layers
+ del layer_inputs
+ layer_inputs, layer_outputs = layer_outputs, []
+ else:
+ del layers
+ del layer_inputs
+ layer_inputs = []
+ torch.cuda.empty_cache()
+ if hasattr(torch, "xpu"):
+ torch.xpu.empty_cache()
+
+ if self.bits == 4:
+ # device not on gpu
+ if device.type != "cuda" or (has_device_map and any(d in devices for d in ["cpu", "disk", "hpu"])):
+ if not self.disable_exllama and not is_gptqmodel_available():
+ logger.warning(
+ "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
+ )
+ self.disable_exllama = True
+ # act order and exllama
+ elif self.desc_act and not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE:
+ logger.warning(
+ "Using Exllama backend with act_order will reorder the weights offline, thus you will not be able to save the model with the right weights."
+ "Setting `disable_exllama=True`. You should only use Exllama backend with act_order for inference. "
+ )
+ self.disable_exllama = True
+ elif not self.disable_exllama and self.exllama_version == ExllamaVersion.TWO:
+ logger.warning(
+ "Using Exllamav2 backend will reorder the weights offline, thus you will not be able to save the model with the right weights."
+ "Setting `disable_exllama=True`. You should only use Exllamav2 backend for inference. "
+ )
+ self.disable_exllama = True
+ # Step 4: Pack the model at the end (Replacing the layers)
+ self.pack_model(model=model, quantizers=quantizers)
+
+ model.is_quantized = True
+ model.quantization_method = QuantizationMethod.GPTQ
+ if has_config:
+ model.config.use_cache = use_cache
+ model.config.quantization_config = self.to_dict()
+
+ # Step 5: Any post-initialization that require device information, for example buffers initialization on device.
+ model = self.post_init_model(model)
+
+ torch.cuda.empty_cache()
+ if hasattr(torch, "xpu"):
+ torch.xpu.empty_cache()
+ return model
+
+ def post_init_model(self, model):
+ """
+ Post-initialization that require device information, for example buffers initialization on device.
+
+ Args:
+ model (`nn.Module`):
+ The input model
+ """
+ if self.bits == 4 and not self.disable_exllama:
+ if get_device(model).type != "cuda" or (
+ hasattr(model, "hf_device_map") and any(d in model.hf_device_map for d in ["cpu", "disk", "hpu"])
+ ):
+ if not self.disable_exllama:
+ logger.warning(
+ "Found modules on cpu/disk. Using Exllama/Exllamav2 backend requires all the modules to be on GPU. Setting `disable_exllama=True`"
+ )
+ self.disable_exllama = True
+
+ class StoreAttr(object):
+ pass
+
+ if is_gptqmodel_available():
+ model, _ = hf_convert_gptq_v1_to_v2_format(model, self.bits, self.quant_linear, self.checkpoint_format, self.meta)
+
+ model.quantize_config = StoreAttr()
+ model.quantize_config.desc_act = self.desc_act
+ model = gptq_post_init(model, use_act_order=self.desc_act)
+ if (
+ self.desc_act
+ and (not self.disable_exllama and self.exllama_version == ExllamaVersion.ONE)
+ and self.max_input_length is not None
+ ):
+ model = exllama_set_max_input_length(model, self.max_input_length)
+ return model
+
+ def pack_model(
+ self,
+ model: nn.Module,
+ quantizers: Dict[str, Tuple],
+ ):
+ """
+ Pack the model by replacing the layers by quantized layers
+
+ Args:
+ model (`nn.Module`):
+ The model to pack
+ quantizers (`Dict[str,Tuple]`):
+ A mapping of the layer name and the data needed to pack the layer
+ """
+ logger.info("Packing model...")
+ layers = get_layers(model)
+ layers = {n: layers[n] for n in quantizers}
+
+ self.select_quant_linear(device_map=model.hf_device_map)
+
+ self._replace_by_quant_layers(model, quantizers)
+ qlayers = get_layers(model, [self.quant_linear])
+ for name in qlayers:
+ logger.info(name)
+ quantizers[name], scale, zero, g_idx = quantizers[name]
+ # so far can only pack layer on CPU
+ layer_device = qlayers[name].device
+ qlayers[name].to("cpu")
+ layers[name], scale, zero, g_idx = layers[name].to("cpu"), scale.to("cpu"), zero.to("cpu"), g_idx.to("cpu")
+ qlayers[name].pack(layers[name], scale, zero, g_idx)
+ qlayers[name].to(layer_device)
+
+ logger.info("Model packed.")
+
+ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True):
+ """
+ Save model state dict and configs
+
+ Args:
+ model (`nn.Module`):
+ Model to be saved. The model can be wrapped or unwraped.
+ save_dir (`str`):
+ Directory to which to save. Will be created if it doesn't exist.
+ max_shard_size (`str`, defaults to `"10GB"`):
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
+
+
+ If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
+ which will be bigger than `max_shard_size`.
+
+
+ safe_serialization (`bool`, defaults to `True`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+
+ """
+
+ # convert gptqmodel internal gptq_v2 format to v1 for max compatibility
+ model, converted = hf_convert_gptq_v2_to_v1_format(model, self.sym, self.bits, self.quant_linear, self.checkpoint_format, self.meta)
+ if converted:
+ self.checkpoint_format = "gptq"
+
+ os.makedirs(save_dir, exist_ok=True)
+ model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
+ with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f:
+ json.dump(self.to_dict(), f, indent=2)
+
+
+def load_quantized_model(
+ model: nn.Module,
+ save_folder: str,
+ quant_config_name: str = GPTQ_CONFIG,
+ state_dict_name: Optional[str] = None,
+ device_map: Optional[str] = None,
+ max_memory: Optional[Dict] = None,
+ no_split_module_classes: Optional[Dict] = None,
+ offload_folder: Optional[str] = None,
+ offload_buffers: Optional[str] = None,
+ offload_state_dict: bool = False,
+ disable_exllama: bool = False,
+ exllama_config: Optional[Dict[str, Any]] = None,
+ max_input_length: Optional[int] = None,
+):
+ """
+ Load quantized weights from the save_folder into the converted model and dispatch the weights according to the device_map.
+
+ Args:
+ model (`nn.Module`):
+ The model can be enpty or not.
+ save_folder (`str`):
+ Directory to which to load the weights.
+ quant_config_name (`str`, defaults to `GPTQ_CONFIG`):
+ Name of the quantization config file
+ state_dict_name (`Optional[str]`, defaults to `None`):
+ Name of the state dict file
+ device_map (`Optional[str]`, defaults to `None`):
+ A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
+ name, once a given module name is inside, every submodule of it will be sent to the same device.
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`.
+ max_memory (`Optional[Dict]`, defaults to `None`):
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
+ and the available CPU RAM if unset.
+ no_split_module_classes (`Optional[Dict]`, defaults to `None`):
+ A list of layer class names that should never be split across device (for instance any layer that has a
+ residual connection).
+ offload_folder (`Optional[str]`, defaults to `None`):
+ If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
+ offload_buffers (`Optional[str]`, defaults to `None`):
+ In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
+ well as the parameters.
+ offload_state_dict (`bool`, defaults to `False`):
+ If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
+ the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
+ picked contains `"disk"` values.
+ disable_exllama (`Optional[bool]`, defaults to `None`):
+ Whether to use exllama backend. Only works with `bits` = 4.
+ exllama_config (`Optional[Dict[str, Any]]`, defaults to `None`):
+ The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults to `{"version": 2}` if unset.
+ max_input_length (`Optional[int]`, defaults to `None`):
+ The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length.
+ It is specific to the exllama backend with act-order.
+
+ Returns:
+ `nn.Module`: The quantized model
+ """
+ if not torch.cuda.is_available() and not is_gptqmodel_available():
+ raise RuntimeError("No GPU found. A GPU is needed to run quantized model by auto_gptq.")
+ if not is_auto_gptq_available() and not is_gptqmodel_available():
+ raise RuntimeError(
+ "gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) is required in order to load quantized weights. Please notice that auto-gptq will be deprecated in the future."
+ )
+ if not is_accelerate_available():
+ raise RuntimeError(
+ "You need to install accelerate in order to load and dispatch weights to"
+ "a quantized model. You can do it with `pip install accelerate`"
+ )
+ if device_map is None:
+ device_map = {"": torch.cuda.current_device()}
+ logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.")
+
+ if exllama_config is None:
+ exllama_config = {"version": ExllamaVersion.TWO}
+ else:
+ if "version" not in exllama_config:
+ raise ValueError("`exllama_config` needs to have a `version` key")
+ elif exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
+ version = exllama_config["version"]
+ raise ValueError(
+ f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {version}"
+ )
+
+ # this branch will check if model is from huggingface
+ try:
+ if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
+ quantize_config_dict = model.config.quantization_config.to_dict()
+ else:
+ with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f:
+ quantize_config_dict = json.load(f)
+ except Exception as err:
+ raise ValueError(
+ f"Failed to load quantization config from {save_folder} (lookup for traceback): {err}\nTip: If the save directory is saved from a transformers.PreTrainedModel, make sure that `config.json` contains a 'quantization_config' key."
+ ) from err
+ quantizer = GPTQQuantizer.from_dict(quantize_config_dict)
+ quantizer.disable_exllama = disable_exllama
+ quantizer.exllama_config = exllama_config
+ quantizer.exllama_version = quantizer.exllama_config["version"]
+ quantizer.max_input_length = max_input_length
+
+ model = quantizer.convert_model(model, device_map=device_map)
+
+ if no_split_module_classes is None:
+ no_split_module_classes = quantizer.get_no_split_module_classes(model)
+
+ model = load_checkpoint_and_dispatch(
+ model,
+ checkpoint=os.path.join(save_folder, state_dict_name) if state_dict_name is not None else save_folder,
+ device_map=device_map,
+ max_memory=max_memory,
+ no_split_module_classes=no_split_module_classes,
+ offload_folder=offload_folder,
+ offload_buffers=offload_buffers,
+ offload_state_dict=offload_state_dict,
+ )
+
+ model = quantizer.post_init_model(model)
+ model.is_quantized = True
+ model.quantization_method = QuantizationMethod.GPTQ
+ model.eval()
+ return model
diff --git a/gptqmodel/integration/src/optimum/utils/__init__.py b/gptqmodel/integration/src/optimum/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/optimum/utils/import_utils.py b/gptqmodel/integration/src/optimum/utils/import_utils.py
new file mode 100644
index 00000000..a29cb063
--- /dev/null
+++ b/gptqmodel/integration/src/optimum/utils/import_utils.py
@@ -0,0 +1,501 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import importlib.util
+import itertools
+import os
+import shutil
+import subprocess
+import sys
+import unittest
+from collections.abc import MutableMapping
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple
+
+import torch
+
+from optimum.utils import (
+ is_accelerate_available,
+ is_auto_gptq_available,
+ is_diffusers_available,
+ is_sentence_transformers_available,
+ is_timm_available,
+)
+
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Import utilities."""
+
+import importlib.util
+import inspect
+import sys
+from collections import OrderedDict
+from contextlib import contextmanager
+from typing import Tuple, Union
+
+import numpy as np
+from packaging import version
+from transformers.utils import is_torch_available
+
+
+def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
+ # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
+ package_exists = importlib.util.find_spec(pkg_name) is not None
+ package_version = "N/A"
+ if package_exists:
+ try:
+ package_version = importlib.metadata.version(pkg_name)
+ package_exists = True
+ except importlib.metadata.PackageNotFoundError:
+ package_exists = False
+ if return_version:
+ return package_exists, package_version
+ else:
+ return package_exists
+
+
+# The package importlib_metadata is in a different place, depending on the python version.
+if sys.version_info < (3, 8):
+ import importlib_metadata
+else:
+ import importlib.metadata as importlib_metadata
+
+
+TORCH_MINIMUM_VERSION = version.parse("1.11.0")
+TRANSFORMERS_MINIMUM_VERSION = version.parse("4.25.0")
+DIFFUSERS_MINIMUM_VERSION = version.parse("0.22.0")
+AUTOGPTQ_MINIMUM_VERSION = version.parse("0.4.99") # Allows 0.5.0.dev0
+GPTQMODEL_MINIMUM_VERSION = version.parse("1.3.99") # Allows 1.4.0.dev0
+
+
+# This is the minimal required version to support some ONNX Runtime features
+ORT_QUANTIZE_MINIMUM_VERSION = version.parse("1.4.0")
+
+
+_onnx_available = _is_package_available("onnx")
+
+# importlib.metadata.version seem to not be robust with the ONNX Runtime extensions (`onnxruntime-gpu`, etc.)
+_onnxruntime_available = importlib.util.find_spec("onnxruntime") is not None
+
+_pydantic_available = _is_package_available("pydantic")
+_accelerate_available = _is_package_available("accelerate")
+_diffusers_available = _is_package_available("diffusers")
+_auto_gptq_available = _is_package_available("auto_gptq")
+_gptqmodel_available = _is_package_available("gptqmodel")
+_timm_available = _is_package_available("timm")
+_sentence_transformers_available = _is_package_available("sentence_transformers")
+_datasets_available = _is_package_available("datasets")
+
+torch_version = None
+if is_torch_available():
+ torch_version = version.parse(importlib_metadata.version("torch"))
+
+_is_torch_onnx_support_available = is_torch_available() and (
+ TORCH_MINIMUM_VERSION.major,
+ TORCH_MINIMUM_VERSION.minor,
+) <= (
+ torch_version.major,
+ torch_version.minor,
+ )
+
+
+_diffusers_version = None
+if _diffusers_available:
+ try:
+ _diffusers_version = importlib_metadata.version("diffusers")
+ except importlib_metadata.PackageNotFoundError:
+ _diffusers_available = False
+
+
+def is_torch_onnx_support_available():
+ return _is_torch_onnx_support_available
+
+
+def is_onnx_available():
+ return _onnx_available
+
+
+def is_onnxruntime_available():
+ try:
+ # Try to import the source file of onnxruntime - if you run the tests from `tests` the function gets
+ # confused since there a folder named `onnxruntime` in `tests`. Therefore, `_onnxruntime_available`
+ # will be set to `True` even if not installed.
+ mod = importlib.import_module("onnxruntime")
+ inspect.getsourcefile(mod)
+ except Exception:
+ return False
+ return _onnxruntime_available
+
+
+def is_pydantic_available():
+ return _pydantic_available
+
+
+def is_accelerate_available():
+ return _accelerate_available
+
+
+def is_diffusers_available():
+ return _diffusers_available
+
+
+def is_timm_available():
+ return _timm_available
+
+
+def is_sentence_transformers_available():
+ return _sentence_transformers_available
+
+
+def is_datasets_available():
+ return _datasets_available
+
+
+def is_auto_gptq_available():
+ if _auto_gptq_available:
+ v = version.parse(importlib_metadata.version("auto_gptq"))
+ if v >= AUTOGPTQ_MINIMUM_VERSION:
+ return True
+ else:
+ raise ImportError(
+ f"Found an incompatible version of auto-gptq. Found version {v}, but only version >= {AUTOGPTQ_MINIMUM_VERSION} are supported"
+ )
+
+
+def is_gptqmodel_available():
+ if _gptqmodel_available:
+ v = version.parse(importlib_metadata.version("gptqmodel"))
+ if v >= GPTQMODEL_MINIMUM_VERSION:
+ return True
+ else:
+ raise ImportError(
+ f"Found an incompatible version of gptqmodel. Found version {v}, but only version >= {GPTQMODEL_MINIMUM_VERSION} are supported"
+ )
+
+
+@contextmanager
+def check_if_pytorch_greater(target_version: str, message: str):
+ r"""
+ A context manager that does nothing except checking if the PyTorch version is greater than `pt_version`
+ """
+ import torch
+
+ if not version.parse(torch.__version__) >= version.parse(target_version):
+ raise ImportError(
+ f"Found an incompatible version of PyTorch. Found version {torch.__version__}, but only {target_version} and above are supported. {message}"
+ )
+ try:
+ yield
+ finally:
+ pass
+
+
+def check_if_transformers_greater(target_version: Union[str, version.Version]) -> bool:
+ """
+ Checks whether the current install of transformers is greater than or equal to the target version.
+
+ Args:
+ target_version (`Union[str, packaging.version.Version]`): version used as the reference for comparison.
+
+ Returns:
+ bool: whether the check is True or not.
+ """
+ import transformers
+
+ if isinstance(target_version, str):
+ target_version = version.parse(target_version)
+
+ return version.parse(transformers.__version__) >= target_version
+
+
+def check_if_diffusers_greater(target_version: str) -> bool:
+ """
+ Checks whether the current install of diffusers is greater than or equal to the target version.
+
+ Args:
+ target_version (str): version used as the reference for comparison.
+
+ Returns:
+ bool: whether the check is True or not.
+ """
+ if not _diffusers_available:
+ return False
+
+ return version.parse(_diffusers_version) >= version.parse(target_version)
+
+
+def check_if_torch_greater(target_version: str) -> bool:
+ """
+ Checks whether the current install of torch is greater than or equal to the target version.
+
+ Args:
+ target_version (str): version used as the reference for comparison.
+
+ Returns:
+ bool: whether the check is True or not.
+ """
+ if not is_torch_available():
+ return False
+
+ return torch_version >= version.parse(target_version)
+
+
+@contextmanager
+def require_numpy_strictly_lower(package_version: str, message: str):
+ if not version.parse(np.__version__) < version.parse(package_version):
+ raise ImportError(
+ f"Found an incompatible version of numpy. Found version {np.__version__}, but expected numpy<{version}. {message}"
+ )
+ try:
+ yield
+ finally:
+ pass
+
+
+DIFFUSERS_IMPORT_ERROR = """
+{0} requires the diffusers library but it was not found in your environment. You can install it with pip: `pip install
+diffusers`. Please note that you may need to restart your runtime after installation.
+"""
+
+TRANSFORMERS_IMPORT_ERROR = """requires the transformers>={0} library but it was not found in your environment. You can install it with pip: `pip install
+-U transformers`. Please note that you may need to restart your runtime after installation.
+"""
+
+DATASETS_IMPORT_ERROR = """
+{0} requires the datasets library but it was not found in your environment. You can install it with pip:
+`pip install datasets`. Please note that you may need to restart your runtime after installation.
+"""
+
+
+BACKENDS_MAPPING = OrderedDict(
+ [
+ ("diffusers", (is_diffusers_available, DIFFUSERS_IMPORT_ERROR)),
+ (
+ "transformers_431",
+ (lambda: check_if_transformers_greater("4.31"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.31")),
+ ),
+ (
+ "transformers_432",
+ (lambda: check_if_transformers_greater("4.32"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.32")),
+ ),
+ (
+ "transformers_434",
+ (lambda: check_if_transformers_greater("4.34"), "{0} " + TRANSFORMERS_IMPORT_ERROR.format("4.34")),
+ ),
+ ("datasets", (is_datasets_available, DATASETS_IMPORT_ERROR)),
+ ]
+)
+
+
+def requires_backends(obj, backends):
+ if not isinstance(backends, (list, tuple)):
+ backends = [backends]
+
+ name = obj.__name__ if hasattr(obj, "__name__") else obj.__class__.__name__
+ checks = (BACKENDS_MAPPING[backend] for backend in backends)
+ failed = [msg.format(name) for available, msg in checks if not available()]
+ if failed:
+ raise ImportError("".join(failed))
+
+
+# Copied from: https://github.com/huggingface/transformers/blob/v4.26.0/src/transformers/utils/import_utils.py#L1041
+class DummyObject(type):
+ """
+ Metaclass for the dummy objects. Any class inheriting from it will return the ImportError generated by
+ `requires_backend` each time a user tries to access any method of that class.
+ """
+
+ def __getattr__(cls, key):
+ if key.startswith("_"):
+ return super().__getattr__(cls, key)
+ requires_backends(cls, cls._backends)
+# Used to test the hub
+USER = "__DUMMY_OPTIMUM_USER__"
+
+
+def flatten_dict(dictionary: Dict):
+ """
+ Flatten a nested dictionaries as a flat dictionary.
+ """
+ items = []
+ for k, v in dictionary.items():
+ new_key = k
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_dict(v).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
+
+
+def require_gptq(test_case):
+ """
+ Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed.
+ """
+ return unittest.skipUnless(is_auto_gptq_available() or is_gptqmodel_available(), "test requires auto-gptq")(
+ test_case
+ )
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
+
+
+def require_ort_rocm(test_case):
+ """Decorator marking a test that requires ROCMExecutionProvider for ONNX Runtime."""
+ import onnxruntime as ort
+
+ providers = ort.get_available_providers()
+
+ return unittest.skipUnless("ROCMExecutionProvider" == providers[0], "test requires ROCMExecutionProvider")(
+ test_case
+ )
+
+
+def require_hf_token(test_case):
+ """
+ Decorator marking a test that requires huggingface hub token.
+ """
+ # is HF_AUTH_TOKEN used instead of HF_TOKEN to avoid huggigface_hub picking it up ?
+ hf_token = os.environ.get("HF_AUTH_TOKEN", None)
+ if hf_token is None:
+ return unittest.skip("test requires hf token as `HF_AUTH_TOKEN` environment variable")(test_case)
+ else:
+ return test_case
+
+
+def require_sigopt_token_and_project(test_case):
+ """
+ Decorator marking a test that requires sigopt API token.
+ """
+ sigopt_api_token = os.environ.get("SIGOPT_API_TOKEN", None)
+ has_sigopt_project = os.environ.get("SIGOPT_PROJECT", None)
+ if sigopt_api_token is None or has_sigopt_project is None:
+ return unittest.skip("test requires an environment variable `SIGOPT_API_TOKEN` and `SIGOPT_PROJECT`")(
+ test_case
+ )
+ else:
+ return test_case
+
+
+def is_ort_training_available():
+ is_ort_train_available = importlib.util.find_spec("onnxruntime.training") is not None
+
+ if importlib.util.find_spec("torch_ort") is not None:
+ try:
+ is_torch_ort_configured = True
+ subprocess.run([sys.executable, "-m", "torch_ort.configure"], shell=False, check=True)
+ except subprocess.CalledProcessError:
+ is_torch_ort_configured = False
+
+ return is_ort_train_available and is_torch_ort_configured
+
+
+def require_ort_training(test_case):
+ """
+ Decorator marking a test that requires onnxruntime-training and torch_ort correctly installed and configured.
+ These tests are skipped otherwise.
+ """
+ return unittest.skipUnless(
+ is_ort_training_available(),
+ "test requires torch_ort correctly installed and configured",
+ )(test_case)
+
+
+def require_diffusers(test_case):
+ return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case)
+
+
+def require_timm(test_case):
+ return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
+
+
+def require_sentence_transformers(test_case):
+ return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case)
+
+
+def require_datasets(test_case):
+ return unittest.skipUnless(is_datasets_available(), "test requires datasets")(test_case)
+
+
+def grid_parameters(
+ parameters: Dict[str, Iterable[Any]],
+ yield_dict: bool = False,
+ add_test_name: bool = True,
+ filter_params_func: Optional[Callable[[Tuple], Tuple]] = None,
+) -> Iterable:
+ """
+ Generates an iterable over the grid of all combinations of parameters.
+
+ Args:
+ `parameters` (`Dict[str, Iterable[Any]]`):
+ Dictionary of multiple values to generate a grid from.
+ `yield_dict` (`bool`, defaults to `False`):
+ If True, a dictionary with all keys, and sampled values will be returned. Otherwise, return sampled values as a list.
+ `add_test_name` (`bool`, defaults to `True`):
+ Whether to add the test name in the yielded list or dictionary.
+ filter_params_func (`Optional[Callable[[Tuple], Tuple]]`, defaults to `None`):
+ A function that can modify or exclude the current set of parameters. The function should take a tuple of the
+ parameters and return the same. If a parameter set is to be excluded, the function should return an empty tuple.
+ """
+ for params in itertools.product(*parameters.values()):
+ if filter_params_func is not None:
+ params = filter_params_func(list(params))
+ if params is None:
+ continue
+
+ test_name = "_".join([str(param) for param in params])
+ if yield_dict is True:
+ res_dict = {}
+ for i, key in enumerate(parameters.keys()):
+ res_dict[key] = params[i]
+ if add_test_name is True:
+ res_dict["test_name"] = test_name
+ yield res_dict
+ else:
+ returned_list = [test_name] + list(params) if add_test_name is True else list(params)
+ yield returned_list
+
+
+def remove_directory(dirpath):
+ """
+ Remove a directory and its content.
+ This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows.
+ Reference: https://github.com/python/cpython/issues/107408
+ """
+ if os.path.exists(dirpath) and os.path.isdir(dirpath):
+ if os.name == "nt":
+ os.system(f"rmdir /S /Q {dirpath}")
+ else:
+ shutil.rmtree(dirpath)
\ No newline at end of file
diff --git a/gptqmodel/integration/src/optimum/utils/testing_utils.py b/gptqmodel/integration/src/optimum/utils/testing_utils.py
new file mode 100644
index 00000000..9f140d4a
--- /dev/null
+++ b/gptqmodel/integration/src/optimum/utils/testing_utils.py
@@ -0,0 +1,205 @@
+# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import importlib.util
+import itertools
+import os
+import shutil
+import subprocess
+import sys
+import unittest
+from collections.abc import MutableMapping
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple
+
+import torch
+
+from optimum.utils import (
+ is_accelerate_available,
+ is_auto_gptq_available,
+ is_diffusers_available,
+ is_sentence_transformers_available,
+ is_timm_available,
+)
+
+from gptqmodel.integration.src.optimum.utils.import_utils import is_datasets_available, is_gptqmodel_available
+
+# Used to test the hub
+USER = "__DUMMY_OPTIMUM_USER__"
+
+
+def flatten_dict(dictionary: Dict):
+ """
+ Flatten a nested dictionaries as a flat dictionary.
+ """
+ items = []
+ for k, v in dictionary.items():
+ new_key = k
+ if isinstance(v, MutableMapping):
+ items.extend(flatten_dict(v).items())
+ else:
+ items.append((new_key, v))
+ return dict(items)
+
+
+def require_accelerate(test_case):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
+
+
+def require_gptq(test_case):
+ """
+ Decorator marking a test that requires gptqmodel or auto-gptq. These tests are skipped when gptqmodel and auto-gptq are not installed.
+ """
+ return unittest.skipUnless(is_auto_gptq_available() or is_gptqmodel_available(), "test requires auto-gptq")(
+ test_case
+ )
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
+
+
+def require_ort_rocm(test_case):
+ """Decorator marking a test that requires ROCMExecutionProvider for ONNX Runtime."""
+ import onnxruntime as ort
+
+ providers = ort.get_available_providers()
+
+ return unittest.skipUnless("ROCMExecutionProvider" == providers[0], "test requires ROCMExecutionProvider")(
+ test_case
+ )
+
+
+def require_hf_token(test_case):
+ """
+ Decorator marking a test that requires huggingface hub token.
+ """
+ # is HF_AUTH_TOKEN used instead of HF_TOKEN to avoid huggigface_hub picking it up ?
+ hf_token = os.environ.get("HF_AUTH_TOKEN", None)
+ if hf_token is None:
+ return unittest.skip("test requires hf token as `HF_AUTH_TOKEN` environment variable")(test_case)
+ else:
+ return test_case
+
+
+def require_sigopt_token_and_project(test_case):
+ """
+ Decorator marking a test that requires sigopt API token.
+ """
+ sigopt_api_token = os.environ.get("SIGOPT_API_TOKEN", None)
+ has_sigopt_project = os.environ.get("SIGOPT_PROJECT", None)
+ if sigopt_api_token is None or has_sigopt_project is None:
+ return unittest.skip("test requires an environment variable `SIGOPT_API_TOKEN` and `SIGOPT_PROJECT`")(
+ test_case
+ )
+ else:
+ return test_case
+
+
+def is_ort_training_available():
+ is_ort_train_available = importlib.util.find_spec("onnxruntime.training") is not None
+
+ if importlib.util.find_spec("torch_ort") is not None:
+ try:
+ is_torch_ort_configured = True
+ subprocess.run([sys.executable, "-m", "torch_ort.configure"], shell=False, check=True)
+ except subprocess.CalledProcessError:
+ is_torch_ort_configured = False
+
+ return is_ort_train_available and is_torch_ort_configured
+
+
+def require_ort_training(test_case):
+ """
+ Decorator marking a test that requires onnxruntime-training and torch_ort correctly installed and configured.
+ These tests are skipped otherwise.
+ """
+ return unittest.skipUnless(
+ is_ort_training_available(),
+ "test requires torch_ort correctly installed and configured",
+ )(test_case)
+
+
+def require_diffusers(test_case):
+ return unittest.skipUnless(is_diffusers_available(), "test requires diffusers")(test_case)
+
+
+def require_timm(test_case):
+ return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
+
+
+def require_sentence_transformers(test_case):
+ return unittest.skipUnless(is_sentence_transformers_available(), "test requires sentence-transformers")(test_case)
+
+
+def require_datasets(test_case):
+ return unittest.skipUnless(is_datasets_available(), "test requires datasets")(test_case)
+
+
+def grid_parameters(
+ parameters: Dict[str, Iterable[Any]],
+ yield_dict: bool = False,
+ add_test_name: bool = True,
+ filter_params_func: Optional[Callable[[Tuple], Tuple]] = None,
+) -> Iterable:
+ """
+ Generates an iterable over the grid of all combinations of parameters.
+
+ Args:
+ `parameters` (`Dict[str, Iterable[Any]]`):
+ Dictionary of multiple values to generate a grid from.
+ `yield_dict` (`bool`, defaults to `False`):
+ If True, a dictionary with all keys, and sampled values will be returned. Otherwise, return sampled values as a list.
+ `add_test_name` (`bool`, defaults to `True`):
+ Whether to add the test name in the yielded list or dictionary.
+ filter_params_func (`Optional[Callable[[Tuple], Tuple]]`, defaults to `None`):
+ A function that can modify or exclude the current set of parameters. The function should take a tuple of the
+ parameters and return the same. If a parameter set is to be excluded, the function should return an empty tuple.
+ """
+ for params in itertools.product(*parameters.values()):
+ if filter_params_func is not None:
+ params = filter_params_func(list(params))
+ if params is None:
+ continue
+
+ test_name = "_".join([str(param) for param in params])
+ if yield_dict is True:
+ res_dict = {}
+ for i, key in enumerate(parameters.keys()):
+ res_dict[key] = params[i]
+ if add_test_name is True:
+ res_dict["test_name"] = test_name
+ yield res_dict
+ else:
+ returned_list = [test_name] + list(params) if add_test_name is True else list(params)
+ yield returned_list
+
+
+def remove_directory(dirpath):
+ """
+ Remove a directory and its content.
+ This is a cross-platform solution to remove a directory and its content that avoids the use of `shutil.rmtree` on Windows.
+ Reference: https://github.com/python/cpython/issues/107408
+ """
+ if os.path.exists(dirpath) and os.path.isdir(dirpath):
+ if os.name == "nt":
+ os.system(f"rmdir /S /Q {dirpath}")
+ else:
+ shutil.rmtree(dirpath)
\ No newline at end of file
diff --git a/gptqmodel/integration/src/peft/__init__.py b/gptqmodel/integration/src/peft/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/peft/import_utils.py b/gptqmodel/integration/src/peft/import_utils.py
new file mode 100644
index 00000000..2314c377
--- /dev/null
+++ b/gptqmodel/integration/src/peft/import_utils.py
@@ -0,0 +1,131 @@
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+import importlib.metadata as importlib_metadata
+from functools import lru_cache
+
+import packaging.version
+
+
+@lru_cache
+def is_bnb_available() -> bool:
+ return importlib.util.find_spec("bitsandbytes") is not None
+
+
+@lru_cache
+def is_bnb_4bit_available() -> bool:
+ if not is_bnb_available():
+ return False
+
+ import bitsandbytes as bnb
+
+ return hasattr(bnb.nn, "Linear4bit")
+
+
+@lru_cache
+def is_auto_gptq_available():
+ if importlib.util.find_spec("auto_gptq") is not None:
+ AUTOGPTQ_MINIMUM_VERSION = packaging.version.parse("0.5.0")
+ version_autogptq = packaging.version.parse(importlib_metadata.version("auto_gptq"))
+ if AUTOGPTQ_MINIMUM_VERSION <= version_autogptq:
+ return True
+ else:
+ raise ImportError(
+ f"Found an incompatible version of auto-gptq. Found version {version_autogptq}, "
+ f"but only versions above {AUTOGPTQ_MINIMUM_VERSION} are supported"
+ )
+
+
+@lru_cache
+def is_gptqmodel_available():
+ if importlib.util.find_spec("gptqmodel") is not None:
+ GPTQMODEL_MINIMUM_VERSION = packaging.version.parse("1.3.0")
+ version_gptqmodel = packaging.version.parse(importlib_metadata.version("gptqmodel"))
+ if GPTQMODEL_MINIMUM_VERSION <= version_gptqmodel:
+ return True
+ else:
+ raise ImportError(
+ f"Found an incompatible version of gptqmodel. Found version {version_gptqmodel}, "
+ f"but only versions above {GPTQMODEL_MINIMUM_VERSION} are supported"
+ )
+
+
+@lru_cache
+def is_optimum_available() -> bool:
+ return importlib.util.find_spec("optimum") is not None
+
+
+@lru_cache
+def is_torch_tpu_available(check_device=True):
+ "Checks if `torch_xla` is installed and potentially if a TPU is in the environment"
+ if importlib.util.find_spec("torch_xla") is not None:
+ if check_device:
+ # We need to check if `xla_device` can be found, will raise a RuntimeError if not
+ try:
+ import torch_xla.core.xla_model as xm
+
+ _ = xm.xla_device()
+ return True
+ except RuntimeError:
+ return False
+ return True
+ return False
+
+
+@lru_cache
+def is_aqlm_available():
+ return importlib.util.find_spec("aqlm") is not None
+
+
+@lru_cache
+def is_auto_awq_available():
+ return importlib.util.find_spec("awq") is not None
+
+
+@lru_cache
+def is_eetq_available():
+ if importlib.util.find_spec("eetq") is None:
+ return False
+
+ is_available = True
+ try:
+ from eetq import EetqLinear # noqa: F401
+ except ImportError as exc:
+ if "shard_checkpoint" in str(exc):
+ # eetq is currently broken with newer transformers versions because it tries to import shard_checkpoint
+ # see https://github.com/NetEase-FuXi/EETQ/issues/34
+ # TODO: Remove once eetq releasees a fix and this release is used in CI
+ is_available = False
+ return is_available
+
+
+@lru_cache
+def is_hqq_available():
+ return importlib.util.find_spec("hqq") is not None
+
+
+@lru_cache
+def is_torchao_available():
+ if importlib.util.find_spec("torchao") is None:
+ return False
+
+ TORCHAO_MINIMUM_VERSION = packaging.version.parse("0.4.0")
+ torchao_version = packaging.version.parse(importlib_metadata.version("torchao"))
+
+ if torchao_version < TORCHAO_MINIMUM_VERSION:
+ raise ImportError(
+ f"Found an incompatible version of torchao. Found version {torchao_version}, "
+ f"but only versions above {TORCHAO_MINIMUM_VERSION} are supported"
+ )
+ return True
\ No newline at end of file
diff --git a/gptqmodel/integration/src/peft/tuners/__init__.py b/gptqmodel/integration/src/peft/tuners/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/peft/tuners/adalora/__init__.py b/gptqmodel/integration/src/peft/tuners/adalora/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/peft/tuners/adalora/model.py b/gptqmodel/integration/src/peft/tuners/adalora/model.py
new file mode 100644
index 00000000..e0bf8dd6
--- /dev/null
+++ b/gptqmodel/integration/src/peft/tuners/adalora/model.py
@@ -0,0 +1,364 @@
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+
+import torch
+from peft.tuners.adalora import RankAllocator, AdaLoraLayer, SVDQuantLinear, SVDLinear
+from transformers.pytorch_utils import Conv1D
+
+from peft.import_utils import is_bnb_4bit_available, is_bnb_available
+from peft.tuners.lora import LoraConfig, LoraModel
+from peft.tuners.tuners_utils import BaseTunerLayer
+from peft.utils import (
+ TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
+ _freeze_adapter,
+ _get_submodules,
+ get_auto_gptq_quant_linear,
+ get_quantization_config,
+)
+from peft.utils.integrations import gather_params_ctx
+
+from gptqmodel.integration.src.peft.utils import get_gptqmodel_quant_linear
+from ...import_utils import is_gptqmodel_available
+
+
+class AdaLoraModel(LoraModel):
+ """
+ Creates AdaLoRA (Adaptive LoRA) model from a pretrained transformers model. Paper:
+ https://openreview.net/forum?id=lq62uWRJjiY
+
+ Args:
+ model ([`transformers.PreTrainedModel`]): The model to be adapted.
+ config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
+ adapter_name (`str`): The name of the adapter, defaults to `"default"`.
+ low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
+ Create empty adapter weights on meta device. Useful to speed up the loading process.
+
+ Returns:
+ `torch.nn.Module`: The AdaLora model.
+
+ Example::
+
+ >>> from transformers import AutoModelForSeq2SeqLM >>> from peft import LoraConfig, AdaLoraModel, AdaLoraConfig
+ >>> config = AdaLoraConfig(
+ peft_type="ADALORA", task_type="SEQ_2_SEQ_LM", init_r=12, lora_alpha=32, target_modules=["q", "v"],
+ lora_dropout=0.01,
+ )
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") >>> model = AdaLoraModel(model, config, "default")
+
+ **Attributes**:
+ - **model** ([`transformers.PreTrainedModel`]) -- The model to be adapted.
+ - **peft_config** ([`AdaLoraConfig`]): The configuration of the AdaLora model.
+ """
+
+ # Note: don't redefine prefix here, it should be inherited from LoraModel
+
+ def __init__(self, model, config, adapter_name):
+ super().__init__(model, config, adapter_name)
+
+ traininable_mode_counter = 0
+ for config in self.peft_config.values():
+ if not config.inference_mode:
+ traininable_mode_counter += 1
+
+ if traininable_mode_counter > 1:
+ raise ValueError(
+ "AdaLoraModel supports only 1 trainable adapter. "
+ "When using multiple adapters, set inference_mode to True for all adapters except the one you want to train."
+ )
+
+ if self.peft_config[adapter_name].inference_mode:
+ _freeze_adapter(self.model, adapter_name)
+ else:
+ self.trainable_adapter_name = adapter_name
+ self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name)
+
+ def _check_new_adapter_config(self, config: LoraConfig) -> None:
+ """
+ A helper method to check the config when a new adapter is being added.
+
+ Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
+
+ """
+ super()._check_new_adapter_config(config)
+
+ traininable_mode_counter = 0
+ for config_ in self.peft_config.values():
+ if not config_.inference_mode:
+ traininable_mode_counter += 1
+
+ if traininable_mode_counter > 1:
+ raise ValueError(
+ f"{self.__class__.__name__} supports only 1 trainable adapter. "
+ "When using multiple adapters, set inference_mode to True for all adapters except the one "
+ "you want to train."
+ )
+
+ def _create_and_replace(
+ self,
+ lora_config,
+ adapter_name,
+ target,
+ target_name,
+ parent,
+ current_key,
+ ):
+ kwargs = {
+ "r": lora_config.init_r,
+ "lora_alpha": lora_config.lora_alpha,
+ "lora_dropout": lora_config.lora_dropout,
+ "fan_in_fan_out": lora_config.fan_in_fan_out,
+ "init_lora_weights": lora_config.init_lora_weights,
+ "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
+ "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
+ }
+ if (kwargs["loaded_in_8bit"] or kwargs["loaded_in_4bit"]) and not is_bnb_available():
+ raise ImportError(
+ "To use AdaLora with 8-bit quantization, please install the `bitsandbytes` package. "
+ "You can install it with `pip install bitsandbytes`."
+ )
+
+ quantization_config = get_quantization_config(self.model, method="gptq")
+ if quantization_config is not None:
+ kwargs["gptq_quantization_config"] = quantization_config
+
+ # If it is not an AdaLoraLayer, create a new module, else update it with new adapters
+ if not isinstance(target, AdaLoraLayer):
+ new_module = self._create_new_module(lora_config, adapter_name, target, self.model.hf_device_map, **kwargs)
+ if adapter_name not in self.active_adapters:
+ # adding an additional adapter: it is not automatically trainable
+ new_module.requires_grad_(False)
+ self._replace_module(parent, target_name, new_module, target)
+ else:
+ target.update_layer(
+ adapter_name,
+ lora_config.init_r,
+ lora_config.lora_alpha,
+ lora_config.lora_dropout,
+ lora_config.init_lora_weights,
+ )
+
+ @staticmethod
+ def _create_new_module(lora_config, adapter_name, target, device_map, **kwargs):
+ # avoid eager bnb import
+ if is_bnb_available():
+ import bitsandbytes as bnb
+
+ from peft.tuners.adalora.bnb import SVDLinear8bitLt
+ if is_bnb_4bit_available():
+ from peft.tuners.adalora.bnb import SVDLinear4bit
+
+ gptq_quantization_config = kwargs.get("gptq_quantization_config", None)
+
+ if is_gptqmodel_available():
+ QuantLinear = get_gptqmodel_quant_linear(gptq_quantization_config, device_map)
+ else:
+ QuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config)
+
+ loaded_in_8bit = kwargs.pop("loaded_in_8bit", False)
+ loaded_in_4bit = kwargs.pop("loaded_in_4bit", False)
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ if loaded_in_8bit and isinstance(target_base_layer, bnb.nn.Linear8bitLt):
+ kwargs.update(
+ {
+ "has_fp16_weights": target_base_layer.state.has_fp16_weights,
+ "memory_efficient_backward": target_base_layer.state.memory_efficient_backward,
+ "threshold": target_base_layer.state.threshold,
+ "index": target_base_layer.index,
+ }
+ )
+ new_module = SVDLinear8bitLt(target, adapter_name, **kwargs)
+ elif loaded_in_4bit and is_bnb_4bit_available() and isinstance(target_base_layer, bnb.nn.Linear4bit):
+ fourbit_kwargs = kwargs.copy()
+ fourbit_kwargs.update(
+ {
+ "compute_dtype": target_base_layer.compute_dtype,
+ "compress_statistics": target_base_layer.weight.compress_statistics,
+ "quant_type": target_base_layer.weight.quant_type,
+ }
+ )
+ new_module = SVDLinear4bit(target, adapter_name, **fourbit_kwargs)
+ elif QuantLinear is not None and isinstance(target, QuantLinear):
+ new_module = SVDQuantLinear(target, adapter_name, **kwargs)
+ else:
+ if isinstance(target_base_layer, torch.nn.Linear):
+ if kwargs["fan_in_fan_out"]:
+ warnings.warn(
+ "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
+ "Setting fan_in_fan_out to False."
+ )
+ kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
+ elif isinstance(target_base_layer, Conv1D):
+ if not kwargs["fan_in_fan_out"]:
+ warnings.warn(
+ "fan_in_fan_out is set to False but the target module is `Conv1D`. "
+ "Setting fan_in_fan_out to True."
+ )
+ kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
+ else:
+ raise ValueError(
+ f"Target module {target} is not supported. "
+ f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
+ )
+ new_module = SVDLinear(target, adapter_name, **kwargs)
+
+ return new_module
+
+ @staticmethod
+ def _prepare_adapter_config(peft_config, model_config):
+ if peft_config.target_modules is None:
+ if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING:
+ raise ValueError("Please specify `target_modules` in `peft_config`")
+ peft_config.target_modules = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[
+ model_config["model_type"]
+ ]
+ return peft_config
+
+ def __getattr__(self, name: str):
+ """Forward missing attributes to the wrapped module."""
+ try:
+ return super().__getattr__(name) # defer to nn.Module's logic
+ except AttributeError:
+ if name == "model": # see #1892: prevent infinite recursion if class is not initialized
+ raise
+ return getattr(self.model, name)
+
+ def forward(self, *args, **kwargs):
+ outputs = self.model.forward(*args, **kwargs)
+
+ if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor):
+ # Calculate the orthogonal regularization
+ orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight
+
+ if orth_reg_weight <= 0:
+ raise ValueError("orth_reg_weight should be greater than 0. ")
+
+ regu_loss = 0
+ num_param = 0
+ for n, p in self.model.named_parameters():
+ if ("lora_A" in n or "lora_B" in n) and self.trainable_adapter_name in n:
+ if p.shape == torch.Size([0]):
+ with gather_params_ctx(p, fwd_module=self):
+ para_cov = p @ p.T if "lora_A" in n else p.T @ p
+ else:
+ para_cov = p @ p.T if "lora_A" in n else p.T @ p
+ I = torch.eye(*para_cov.size(), out=torch.empty_like(para_cov)) # noqa: E741
+ I.requires_grad = False
+ num_param += 1
+ regu_loss += torch.norm(para_cov - I, p="fro")
+ if num_param > 0:
+ regu_loss = regu_loss / num_param
+ else:
+ regu_loss = 0
+ outputs.loss += orth_reg_weight * regu_loss
+ return outputs
+
+ def resize_modules_by_rank_pattern(self, rank_pattern, adapter_name):
+ lora_config = self.peft_config[adapter_name]
+ for name, rank_idx in rank_pattern.items():
+ if isinstance(rank_idx, list):
+ rank = sum(rank_idx)
+ elif isinstance(rank_idx, torch.Tensor):
+ rank_idx = rank_idx.view(-1)
+ rank = rank_idx.sum().item()
+ else:
+ raise ValueError("Unexpected type of rank_idx")
+ key = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
+ _, target, _ = _get_submodules(self.model, key)
+ lora_E_weights = target.lora_E[adapter_name][rank_idx]
+ lora_A_weights = target.lora_A[adapter_name][rank_idx]
+ lora_B_weights = target.lora_B[adapter_name][:, rank_idx]
+ ranknum = target.ranknum[adapter_name]
+ target.update_layer(
+ adapter_name,
+ rank,
+ lora_config.lora_alpha,
+ lora_config.lora_dropout,
+ lora_config.init_lora_weights,
+ )
+ with torch.no_grad():
+ if rank > 0:
+ target.lora_E[adapter_name].copy_(lora_E_weights)
+ target.lora_A[adapter_name].copy_(lora_A_weights)
+ target.lora_B[adapter_name].copy_(lora_B_weights)
+ # The scaling is exactly as the previous
+ target.ranknum[adapter_name].copy_(ranknum)
+
+ def resize_state_dict_by_rank_pattern(self, rank_pattern, state_dict, adapter_name):
+ for name, rank_idx in rank_pattern.items():
+ rank = sum(rank_idx)
+ prefix = ".".join(name.split(".")[0:-2]) if adapter_name in name else ".".join(name.split(".")[0:-1])
+ for layer in ["lora_E", "lora_A", "lora_B"]:
+ key = f"base_model.model.{prefix}.{layer}.{adapter_name}"
+ if layer != "lora_B":
+ state_dict[key] = (
+ state_dict[key][rank_idx] if rank != state_dict[key].shape[0] else state_dict[key]
+ )
+ else:
+ state_dict[key] = (
+ state_dict[key][:, rank_idx] if rank != state_dict[key].shape[1] else state_dict[key]
+ )
+ return state_dict
+
+ def update_and_allocate(self, global_step):
+ """
+ This method updates Adalora budget and mask.
+
+ This should be called in every training step after `loss.backward()` and before `zero_grad()`.
+
+ `tinit`, `tfinal` and `deltaT` are handled with in the method.
+
+ Args:
+ global_step (`int`): The current training step, it is used to calculate adalora budget.
+
+ Example:
+
+ ```python
+ >>> loss = model(**input).loss
+ >>> loss.backward()
+ >>> optimizer.step()
+ >>> model.base_model.update_and_allocate(i_step)
+ >>> optimizer.zero_grad()
+ ```
+ """
+ lora_config = self.peft_config[self.trainable_adapter_name]
+ # Update the importance score and allocate the budget
+ if global_step < lora_config.total_step - lora_config.tfinal:
+ _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step)
+ if rank_pattern:
+ lora_config.rank_pattern = rank_pattern
+ # Finalize the budget allocation
+ elif global_step == lora_config.total_step - lora_config.tfinal:
+ _, rank_pattern = self.rankallocator.update_and_allocate(self.model, global_step, force_mask=True)
+ # for some reason, this freezes the trainable parameters and nothing gets updates
+ # self.resize_modules_by_rank_pattern(rank_pattern, self.trainable_adapter_name)
+ lora_config.rank_pattern = rank_pattern
+ self.rankallocator.reset_ipt()
+ # Currently using inefficient way to mask the unimportant weights using the rank pattern
+ # due to problem mentioned above
+ elif global_step > lora_config.total_step - lora_config.tfinal:
+ self.rankallocator.mask_using_rank_pattern(self.model, lora_config.rank_pattern)
+ # Pass the function and do forward propagation
+ else:
+ return None
+
+ def add_weighted_adapter(self, *args, **kwargs):
+ """This method is not supported for AdaLoRA, use LoRA instead."""
+ raise TypeError(f"{self.__class__.__name__} does not support add_weighted_adapter method.")
\ No newline at end of file
diff --git a/gptqmodel/integration/src/peft/tuners/lora/__init__.py b/gptqmodel/integration/src/peft/tuners/lora/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/gptqmodel/integration/src/peft/tuners/lora/gptq.py b/gptqmodel/integration/src/peft/tuners/lora/gptq.py
new file mode 100644
index 00000000..7ef12497
--- /dev/null
+++ b/gptqmodel/integration/src/peft/tuners/lora/gptq.py
@@ -0,0 +1,124 @@
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Optional
+
+import torch
+
+from peft.tuners.lora.layer import LoraLayer
+from peft.tuners.tuners_utils import BaseTunerLayer
+from peft.utils import get_auto_gptq_quant_linear
+
+from gptqmodel.integration.src.peft.import_utils import is_gptqmodel_available
+from gptqmodel.integration.src.peft.utils import get_gptqmodel_quant_linear
+
+
+class QuantLinear(torch.nn.Module, LoraLayer):
+ def __init__(
+ self,
+ base_layer,
+ adapter_name: str,
+ r: int = 0,
+ lora_alpha: int = 1,
+ lora_dropout: float = 0.0,
+ init_lora_weights: bool = True,
+ use_rslora: bool = False,
+ use_dora: bool = False,
+ lora_bias: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ LoraLayer.__init__(self, base_layer)
+
+ if use_dora:
+ raise ValueError(f"{self.__class__.__name__} does not support DoRA yet, please set it to False")
+
+ # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter
+ # for backwards compatibility
+ self.quant_linear_module = base_layer
+ self._active_adapter = adapter_name
+ self.update_layer(
+ adapter_name,
+ r,
+ lora_alpha=lora_alpha,
+ lora_dropout=lora_dropout,
+ init_lora_weights=init_lora_weights,
+ use_rslora=use_rslora,
+ use_dora=use_dora,
+ lora_bias=lora_bias,
+ )
+
+ def forward(self, x: torch.Tensor):
+ # note: logic differs from default Linear because merging is not supported
+ result = self.quant_linear_module(x)
+
+ if self.disable_adapters:
+ return result
+
+ for active_adapter in self.active_adapters:
+ if active_adapter not in self.lora_A.keys():
+ continue
+ lora_A = self.lora_A[active_adapter]
+ lora_B = self.lora_B[active_adapter]
+ dropout = self.lora_dropout[active_adapter]
+ scaling = self.scaling[active_adapter]
+
+ requires_conversion = not torch.is_autocast_enabled()
+ if requires_conversion:
+ expected_dtype = result.dtype
+ x = x.to(lora_A.weight.dtype)
+
+ output = lora_B(lora_A(dropout(x)))
+ if requires_conversion:
+ output = output.to(expected_dtype)
+ output = output * scaling
+ result += output
+ return result
+
+ def __repr__(self) -> str:
+ rep = super().__repr__()
+ return "lora." + rep
+
+ # TODO: Check if it is better as suggested by users https://github.com/PanQiWei/AutoGPTQ/pull/102
+ # def reset_lora_parameters(self, adapter_name):
+ # if adapter_name in self.lora_A.keys():
+ # torch.nn.init.xavier_uniform_(self.lora_A[adapter_name].weight)
+ # torch.nn.init.zeros_(self.lora_B[adapter_name].weight)
+
+
+def dispatch_gptq(
+ target: torch.nn.Module,
+ adapter_name: str,
+ **kwargs: Any,
+) -> Optional[torch.nn.Module]:
+ new_module = None
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ cfg = kwargs.get("gptq_quantization_config", None)
+
+ if is_gptqmodel_available():
+ device_map = kwargs.get("device_map", None)
+ quant_linear = get_gptqmodel_quant_linear(cfg, device_map=device_map)
+ else:
+ quant_linear = get_auto_gptq_quant_linear(cfg)
+
+ if quant_linear is not None and isinstance(target_base_layer, quant_linear):
+ new_module = QuantLinear(target, adapter_name, **kwargs)
+ target.qweight = target_base_layer.qweight
+
+ return new_module
\ No newline at end of file
diff --git a/gptqmodel/integration/src/peft/tuners/lora/model.py b/gptqmodel/integration/src/peft/tuners/lora/model.py
new file mode 100644
index 00000000..de059838
--- /dev/null
+++ b/gptqmodel/integration/src/peft/tuners/lora/model.py
@@ -0,0 +1,939 @@
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import math
+import operator
+import warnings
+from contextlib import contextmanager
+from dataclasses import asdict, replace
+from enum import Enum
+from functools import partial, reduce
+from typing import Literal, Optional
+
+import torch
+from torch import nn
+from tqdm import tqdm
+
+from peft.import_utils import is_bnb_4bit_available, is_bnb_available
+from peft.tuners.tuners_utils import (
+ BaseTuner,
+ BaseTunerLayer,
+ check_target_module_exists,
+ onload_layer,
+ replicate_layers,
+)
+from peft.utils import (
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
+ ModulesToSaveWrapper,
+ _freeze_adapter,
+ _get_submodules,
+ get_peft_model_state_dict,
+ get_quantization_config,
+)
+from peft.utils.merge_utils import dare_linear, dare_ties, magnitude_prune, task_arithmetic, ties
+from peft.utils.other import get_pattern_key
+
+from peft.tuners.lora.aqlm import dispatch_aqlm
+from peft.tuners.lora.awq import dispatch_awq
+from peft.tuners.lora.config import LoraConfig
+from peft.tuners.lora.eetq import dispatch_eetq
+from peft.tuners.lora.gptq import dispatch_gptq
+from peft.tuners.lora.hqq import dispatch_hqq
+from peft.tuners.lora.layer import Conv2d, LoraLayer, dispatch_default
+from peft.tuners.lora.torchao import dispatch_torchao
+from peft.tuners.lora.tp_layer import dispatch_megatron
+
+
+def _adapter_names_pre_forward_hook(target, args, kwargs, adapter_names):
+ # pre-forward hook to inject the adapter_names argument when using mixed adapter batches inference
+ kwargs["adapter_names"] = adapter_names
+ return args, kwargs
+
+
+class LoraModel(BaseTuner):
+ """
+ Creates Low Rank Adapter (LoRA) model from a pretrained transformers model.
+
+ The method is described in detail in https://arxiv.org/abs/2106.09685.
+
+ Args:
+ model ([`torch.nn.Module`]): The model to be adapted.
+ config ([`LoraConfig`]): The configuration of the Lora model.
+ adapter_name (`str`): The name of the adapter, defaults to `"default"`.
+ low_cpu_mem_usage (`bool`, `optional`, defaults to `False`):
+ Create empty adapter weights on meta device. Useful to speed up the loading process.
+
+ Returns:
+ `torch.nn.Module`: The Lora model.
+
+ Example:
+
+ ```py
+ >>> from transformers import AutoModelForSeq2SeqLM
+ >>> from peft import LoraModel, LoraConfig
+
+ >>> config = LoraConfig(
+ ... task_type="SEQ_2_SEQ_LM",
+ ... r=8,
+ ... lora_alpha=32,
+ ... target_modules=["q", "v"],
+ ... lora_dropout=0.01,
+ ... )
+
+ >>> model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
+ >>> lora_model = LoraModel(model, config, "default")
+ ```
+
+ ```py
+ >>> import torch
+ >>> import transformers
+ >>> from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
+
+ >>> rank = ...
+ >>> target_modules = ["q_proj", "k_proj", "v_proj", "out_proj", "fc_in", "fc_out", "wte"]
+ >>> config = LoraConfig(
+ ... r=4, lora_alpha=16, target_modules=target_modules, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
+ ... )
+ >>> quantization_config = transformers.BitsAndBytesConfig(load_in_8bit=True)
+
+ >>> tokenizer = transformers.AutoTokenizer.from_pretrained(
+ ... "kakaobrain/kogpt",
+ ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
+ ... bos_token="[BOS]",
+ ... eos_token="[EOS]",
+ ... unk_token="[UNK]",
+ ... pad_token="[PAD]",
+ ... mask_token="[MASK]",
+ ... )
+ >>> model = transformers.GPTJForCausalLM.from_pretrained(
+ ... "kakaobrain/kogpt",
+ ... revision="KoGPT6B-ryan1.5b-float16", # or float32 version: revision=KoGPT6B-ryan1.5b
+ ... pad_token_id=tokenizer.eos_token_id,
+ ... use_cache=False,
+ ... device_map={"": rank},
+ ... torch_dtype=torch.float16,
+ ... quantization_config=quantization_config,
+ ... )
+ >>> model = prepare_model_for_kbit_training(model)
+ >>> lora_model = get_peft_model(model, config)
+ ```
+
+ **Attributes**:
+ - **model** ([`~transformers.PreTrainedModel`]) -- The model to be adapted.
+ - **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
+ """
+
+ prefix: str = "lora_"
+
+ def __init__(self, model, config, adapter_name, low_cpu_mem_usage: bool = False) -> None:
+ super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
+
+ def _check_new_adapter_config(self, config: LoraConfig) -> None:
+ """
+ A helper method to check the config when a new adapter is being added.
+
+ Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
+
+ """
+ # TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check
+ # does not fully correspond to the error message.
+ if (len(self.peft_config) > 1) and (config.bias != "none"):
+ raise ValueError(
+ f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, "
+ "set bias to 'none' for all adapters."
+ )
+
+ @staticmethod
+ def _check_target_module_exists(lora_config, key):
+ return check_target_module_exists(lora_config, key)
+
+ def _prepare_model(self, peft_config: LoraConfig, model: nn.Module):
+ r"""
+ A private method to modify the model structure before adapter is applied.
+
+ Args:
+ peft_config (`PeftConfig`):
+ The prepared adapter config.
+ model (`nn.Module`):
+ The model that is going to be adapted.
+ """
+ if peft_config.layer_replication:
+ replicate_layers(model, peft_config.layer_replication)
+
+ def _create_and_replace(
+ self,
+ lora_config,
+ adapter_name,
+ target,
+ target_name,
+ parent,
+ current_key,
+ ):
+ if current_key is None:
+ raise ValueError("Current Key shouldn't be `None`")
+
+ # Regexp matching - Find key which matches current target_name in patterns provided
+ r_key = get_pattern_key(lora_config.rank_pattern.keys(), current_key)
+ alpha_key = get_pattern_key(lora_config.alpha_pattern.keys(), current_key)
+ r = lora_config.rank_pattern.get(r_key, lora_config.r)
+ alpha = lora_config.alpha_pattern.get(alpha_key, lora_config.lora_alpha)
+
+ kwargs = {
+ "r": r,
+ "lora_alpha": alpha,
+ "lora_dropout": lora_config.lora_dropout,
+ "fan_in_fan_out": lora_config.fan_in_fan_out,
+ "init_lora_weights": lora_config.init_lora_weights,
+ "use_rslora": lora_config.use_rslora,
+ "use_dora": lora_config.use_dora,
+ "ephemeral_gpu_offload": lora_config.runtime_config.ephemeral_gpu_offload,
+ "lora_bias": lora_config.lora_bias,
+ "loaded_in_8bit": getattr(self.model, "is_loaded_in_8bit", False),
+ "loaded_in_4bit": getattr(self.model, "is_loaded_in_4bit", False),
+ }
+ # for torchao merging, we need the get_apply_tensor_subclass from the quantization config
+ try:
+ kwargs["get_apply_tensor_subclass"] = operator.attrgetter(
+ "hf_quantizer.quantization_config.get_apply_tensor_subclass"
+ )(self.model)
+ except AttributeError:
+ pass
+
+ quant_methods = ["gptq", "aqlm", "awq"]
+ for quant_method in quant_methods:
+ quantization_config = get_quantization_config(self.model, method=quant_method)
+ if quantization_config is not None:
+ kwargs[f"{quant_method}_quantization_config"] = quantization_config
+
+ # note: AdaLoraLayer is a subclass of LoraLayer, we need to exclude it
+ from peft.tuners.adalora import AdaLoraLayer
+
+ if isinstance(target, LoraLayer) and not isinstance(target, AdaLoraLayer):
+ target.update_layer(
+ adapter_name,
+ r,
+ lora_alpha=alpha,
+ lora_dropout=lora_config.lora_dropout,
+ init_lora_weights=lora_config.init_lora_weights,
+ use_rslora=lora_config.use_rslora,
+ use_dora=lora_config.use_dora,
+ lora_bias=lora_config.lora_bias,
+ )
+ else:
+ new_module = self._create_new_module(lora_config, adapter_name, target, device_map=self.model.hf_device_map, **kwargs)
+ if adapter_name not in self.active_adapters:
+ # adding an additional adapter: it is not automatically trainable
+ new_module.requires_grad_(False)
+ self._replace_module(parent, target_name, new_module, target)
+
+ def _replace_module(self, parent, child_name, new_module, child):
+ setattr(parent, child_name, new_module)
+ # It's not necessary to set requires_grad here, as that is handled by
+ # _mark_only_adapters_as_trainable
+
+ # child layer wraps the original module, unpack it
+ if hasattr(child, "base_layer"):
+ child = child.base_layer
+
+ if not hasattr(new_module, "base_layer"):
+ if hasattr(new_module, "W_q"): # HQQ
+ new_module.W_q = child.W_q
+ else:
+ new_module.weight = child.weight
+ if hasattr(child, "bias"):
+ new_module.bias = child.bias
+
+ if getattr(child, "state", None) is not None:
+ if hasattr(new_module, "base_layer"):
+ new_module.base_layer.state = child.state
+ else:
+ new_module.state = child.state
+ new_module.to(child.weight.device)
+
+ meta = torch.device("meta")
+ # dispatch to correct device
+ for name, module in new_module.named_modules():
+ if (self.prefix in name) or ("ranknum" in name):
+ weight = (
+ child.qweight
+ if hasattr(child, "qweight")
+ else child.W_q
+ if hasattr(child, "W_q")
+ else child.weight
+ if hasattr(child, "weight")
+ else next(child.parameters())
+ )
+ if not any(p.device == meta for p in module.parameters()):
+ module.to(weight.device)
+
+ def _mark_only_adapters_as_trainable(self, model: nn.Module) -> None:
+ for n, p in model.named_parameters():
+ if self.prefix not in n:
+ p.requires_grad = False
+
+ for active_adapter in self.active_adapters:
+ bias = self.peft_config[active_adapter].bias
+ if bias == "none":
+ continue
+
+ if bias == "all":
+ for n, p in model.named_parameters():
+ if "bias" in n:
+ p.requires_grad = True
+ elif bias == "lora_only":
+ for m in model.modules():
+ if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
+ m.bias.requires_grad = True
+ else:
+ raise NotImplementedError(f"Requested bias: {bias}, is not implemented.")
+
+ @staticmethod
+ def _create_new_module(lora_config, adapter_name, target, **kwargs):
+ # Collect dispatcher functions to decide what backend to use for the replaced LoRA layer. The order matters,
+ # because the first match is always used. Therefore, the default layers should be checked last.
+ dispatchers = []
+
+ if lora_config._custom_modules:
+ # Experimental custom LoRA module support. Allows users to pass a custom mapping for unsupported layer
+ # types by impelementing their own LoRA layers.
+ def dynamic_dispatch_func(target, adapter_name, lora_config, **kwargs):
+ new_module = None
+
+ if isinstance(target, BaseTunerLayer):
+ target_base_layer = target.get_base_layer()
+ else:
+ target_base_layer = target
+
+ for key, custom_cls in lora_config._custom_modules.items():
+ if isinstance(target_base_layer, key):
+ new_module = custom_cls(target, adapter_name, **kwargs)
+ break
+
+ return new_module
+
+ dispatchers.append(dynamic_dispatch_func)
+
+ # avoid eager bnb import
+ if is_bnb_available():
+ from peft.tuners.lora.bnb import dispatch_bnb_8bit
+
+ dispatchers.append(dispatch_bnb_8bit)
+
+ if is_bnb_4bit_available():
+ from peft.tuners.lora.bnb import dispatch_bnb_4bit
+
+ dispatchers.append(dispatch_bnb_4bit)
+
+ dispatchers.extend(
+ [
+ dispatch_eetq,
+ dispatch_aqlm,
+ dispatch_awq,
+ dispatch_gptq,
+ dispatch_hqq,
+ dispatch_torchao,
+ dispatch_megatron,
+ dispatch_default,
+ ]
+ )
+
+ new_module = None
+ for dispatcher in dispatchers:
+ new_module = dispatcher(target, adapter_name, lora_config=lora_config, **kwargs)
+ if new_module is not None: # first match wins
+ break
+
+ if new_module is None:
+ # no module could be matched
+ raise ValueError(
+ f"Target module {target} is not supported. Currently, only the following modules are supported: "
+ "`torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `torch.nn.Conv3d`, "
+ "`transformers.pytorch_utils.Conv1D`."
+ )
+
+ return new_module
+
+ def __getattr__(self, name: str):
+ """Forward missing attributes to the wrapped module."""
+ try:
+ return super().__getattr__(name) # defer to nn.Module's logic
+ except AttributeError:
+ if name == "model": # see #1892: prevent infinite recursion if class is not initialized
+ raise
+ return getattr(self.model, name)
+
+ def get_peft_config_as_dict(self, inference: bool = False):
+ config_dict = {}
+ for key, value in self.peft_config.items():
+ config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
+ if inference:
+ config["inference_mode"] = True
+ config_dict[key] = config
+ return config
+
+ def _set_adapter_layers(self, enabled: bool = True) -> None:
+ for module in self.model.modules():
+ if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)):
+ module.enable_adapters(enabled)
+
+ def enable_adapter_layers(self) -> None:
+ """Enable all adapters.
+
+ Call this if you have previously disabled all adapters and want to re-enable them.
+ """
+ self._set_adapter_layers(enabled=True)
+
+ def disable_adapter_layers(self) -> None:
+ """Disable all adapters.
+
+ When disabling all adapters, the model output corresponds to the output of the base model.
+ """
+ for active_adapter in self.active_adapters:
+ val = self.peft_config[active_adapter].bias
+ if val != "none":
+ msg = (
+ f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
+ "output as the the base model would without adaption."
+ )
+ warnings.warn(msg)
+ self._set_adapter_layers(enabled=False)
+
+ def set_adapter(self, adapter_name: str | list[str]) -> None:
+ """Set the active adapter(s).
+
+ Additionally, this function will set the specified adapters to trainable (i.e., requires_grad=True). If this is
+ not desired, use the following code.
+
+ ```py
+ >>> for name, param in model_peft.named_parameters():
+ ... if ...: # some check on name (ex. if 'lora' in name)
+ ... param.requires_grad = False
+ ```
+
+ Args:
+ adapter_name (`str` or `list[str]`): Name of the adapter(s) to be activated.
+ """
+ for module in self.model.modules():
+ if isinstance(module, LoraLayer):
+ if module.merged:
+ warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
+ module.unmerge()
+ module.set_adapter(adapter_name)
+ self.active_adapter = adapter_name
+
+ @contextmanager
+ def _enable_peft_forward_hooks(self, *args, **kwargs):
+ # If adapter_names is passed as an argument, we inject it into the forward arguments.
+ adapter_names = kwargs.pop("adapter_names", None)
+ if adapter_names is None:
+ # nothing to do
+ yield
+ return
+
+ if self.training:
+ raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")
+
+ # Check that users only passed actually existing adapters.
+ # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
+ # to check that there is at least one layer with the given name, or else something like typos can easily slip.
+ expected_adapters = set()
+ for layer in self.modules():
+ if isinstance(layer, LoraLayer):
+ expected_adapters |= layer.lora_A.keys()
+ expected_adapters |= layer.lora_embedding_A.keys()
+ unique_adapters = {name for name in adapter_names if name != "__base__"}
+ unexpected_adapters = unique_adapters - expected_adapters
+ if unexpected_adapters:
+ raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")
+
+ hook_handles = []
+ for module in self.modules():
+ if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
+ pre_forward = partial(_adapter_names_pre_forward_hook, adapter_names=adapter_names)
+ handle = module.register_forward_pre_hook(pre_forward, with_kwargs=True)
+ hook_handles.append(handle)
+
+ yield
+
+ for handle in hook_handles:
+ handle.remove()
+
+ def _check_merge_allowed(self):
+ """Verify that the configuration supports merging.
+
+ Currently gptq quantization and replicated layers do not support merging.
+ """
+ super()._check_merge_allowed()
+ if getattr(self.model, "quantization_method", None) == "gptq":
+ raise ValueError("Cannot merge LORA layers when the model is gptq quantized")
+ if self.peft_config.get("layer_replication"):
+ raise ValueError("Cannot merge LORA layers when base model layers are replicated")
+
+ @staticmethod
+ def _prepare_adapter_config(peft_config, model_config):
+ if peft_config.target_modules is None:
+ if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
+ raise ValueError("Please specify `target_modules` in `peft_config`")
+ peft_config.target_modules = set(
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
+ )
+ return peft_config
+
+ def _unload_and_optionally_merge(
+ self,
+ merge=True,
+ progressbar: bool = False,
+ safe_merge: bool = False,
+ adapter_names: Optional[list[str]] = None,
+ ):
+ if merge:
+ self._check_merge_allowed()
+
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
+ desc = "Unloading " + ("and merging " if merge else "") + "model"
+ for key in tqdm(key_list, disable=not progressbar, desc=desc):
+ try:
+ parent, target, target_name = _get_submodules(self.model, key)
+ except AttributeError:
+ continue
+ with onload_layer(target):
+ if hasattr(target, "base_layer"):
+ if merge:
+ target.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ self._replace_module(parent, target_name, target.get_base_layer(), target)
+ elif isinstance(target, ModulesToSaveWrapper):
+ # save any additional trainable modules part of `modules_to_save`
+ new_module = target.modules_to_save[target.active_adapter]
+ if hasattr(new_module, "base_layer"):
+ # check if the module is itself a tuner layer
+ if merge:
+ new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
+ new_module = new_module.get_base_layer()
+ setattr(parent, target_name, new_module)
+
+ return self.model
+
+ def _check_add_weighted_adapter(
+ self, adapters: list[str], combination_type: str, svd_rank: int | None
+ ) -> tuple[str, int, str]:
+ """
+ Helper function to check if the arguments to add_weighted_adapter are valid and compatible with the underlying
+ model.
+ """
+ for adapter in adapters:
+ if adapter not in list(self.peft_config.keys()):
+ raise ValueError(f"Adapter {adapter} does not exist")
+
+ # If more than one of the adapters targets the same module with modules_to_save, raise an error, as these
+ # modules cannot be merged. First, find the ModulesToSaveWrapper instances in the model, then check if they
+ # have modules for the adapters to be merged.
+ modules_to_save_wrappers = [module for module in self.modules() if isinstance(module, ModulesToSaveWrapper)]
+ problematic_wrappers = [
+ wrapper
+ for wrapper in modules_to_save_wrappers
+ if sum(adapter in wrapper.modules_to_save for adapter in adapters) > 1
+ ]
+ if problematic_wrappers:
+ raise ValueError(
+ "Cannot add weighted adapters if they target the same module with modules_to_save, but found "
+ f"{len(problematic_wrappers)} such instance(s)."
+ )
+
+ # if there is only one adapter, we can only use linear merging
+ combination_type = "linear" if len(adapters) == 1 else combination_type
+
+ adapters_ranks = [self.peft_config[adapter].r for adapter in adapters]
+ if combination_type in ("linear", "ties", "dare_ties", "dare_linear", "magnitude_prune"):
+ # all adapters ranks should be same, new rank is just this value
+ if len(set(adapters_ranks)) != 1:
+ raise ValueError(
+ "All adapters must have the same r value when using combination_type linear, ties, dare_ties or "
+ "dare_linear."
+ )
+ new_rank = adapters_ranks[0]
+ elif combination_type == "cat":
+ # adapters ranks may be different, new rank is sum of all ranks
+ # be careful, because output adapter rank may be really big if mixing a lot of adapters
+ new_rank = sum(adapters_ranks)
+ elif combination_type.endswith("svd"):
+ # new rank is the max of all ranks of the adapters if not provided
+ new_rank = svd_rank or max(adapters_ranks)
+ else:
+ raise ValueError(f"Invalid combination_type: {combination_type}")
+
+ target_module_types = [type(self.peft_config[adapter].target_modules) for adapter in adapters]
+ if not target_module_types:
+ raise ValueError(f"Found no adapter matching the names in {adapters}")
+ if len(set(target_module_types)) > 1:
+ raise ValueError(
+ "all adapter configs should follow the same target modules type. "
+ "Combining adapters with `target_modules` type being a mix of list/set and string is not supported."
+ )
+
+ if target_module_types[0] is str:
+ new_target_modules = "|".join(f"({self.peft_config[adapter].target_modules})" for adapter in adapters)
+ elif target_module_types[0] is set:
+ new_target_modules = reduce(
+ operator.or_, (self.peft_config[adapter].target_modules for adapter in adapters)
+ )
+ else:
+ raise TypeError(f"Invalid type {target_module_types[0]} found in target_modules")
+
+ return combination_type, new_rank, new_target_modules
+
+ def add_weighted_adapter(
+ self,
+ adapters: list[str],
+ weights: list[float],
+ adapter_name: str,
+ combination_type: str = "svd",
+ svd_rank: int | None = None,
+ svd_clamp: int | None = None,
+ svd_full_matrices: bool = True,
+ svd_driver: str | None = None,
+ density: float | None = None,
+ majority_sign_method: Literal["total", "frequency"] = "total",
+ ) -> None:
+ """
+ This method adds a new adapter by merging the given adapters with the given weights.
+
+ When using the `cat` combination_type you should be aware that rank of the resulting adapter will be equal to
+ the sum of all adapters ranks. So it's possible that the mixed adapter may become too big and result in OOM
+ errors.
+
+ Args:
+ adapters (`list`):
+ List of adapter names to be merged.
+ weights (`list`):
+ List of weights for each adapter.
+ adapter_name (`str`):
+ Name of the new adapter.
+ combination_type (`str`):
+ The merging type can be one of [`svd`, `linear`, `cat`, `ties`, `ties_svd`, `dare_ties`, `dare_linear`,
+ `dare_ties_svd`, `dare_linear_svd`, `magnitude_prune`, `magnitude_prune_svd`]. When using the `cat`
+ combination_type, the rank of the resulting adapter is equal to the sum of all adapters ranks (the
+ mixed adapter may be too big and result in OOM errors).
+ svd_rank (`int`, *optional*):
+ Rank of output adapter for svd. If None provided, will use max rank of merging adapters.
+ svd_clamp (`float`, *optional*):
+ A quantile threshold for clamping SVD decomposition output. If None is provided, do not perform
+ clamping. Defaults to None.
+ svd_full_matrices (`bool`, *optional*):
+ Controls whether to compute the full or reduced SVD, and consequently, the shape of the returned
+ tensors U and Vh. Defaults to True.
+ svd_driver (`str`, *optional*):
+ Name of the cuSOLVER method to be used. This keyword argument only works when merging on CUDA. Can be
+ one of [None, `gesvd`, `gesvdj`, `gesvda`]. For more info please refer to `torch.linalg.svd`
+ documentation. Defaults to None.
+ density (`float`, *optional*):
+ Value between 0 and 1. 0 means all values are pruned and 1 means no values are pruned. Should be used
+ with [`ties`, `ties_svd`, `dare_ties`, `dare_linear`, `dare_ties_svd`, `dare_linear_svd`,
+ `magnintude_prune`, `magnitude_prune_svd`]
+ majority_sign_method (`str`):
+ The method, should be one of ["total", "frequency"], to use to get the magnitude of the sign values.
+ Should be used with [`ties`, `ties_svd`, `dare_ties`, `dare_ties_svd`]
+ """
+
+ if adapter_name in list(self.peft_config.keys()):
+ return
+
+ combination_type, new_rank, new_target_modules = self._check_add_weighted_adapter(
+ adapters=adapters,
+ combination_type=combination_type,
+ svd_rank=svd_rank,
+ )
+
+ self.peft_config[adapter_name] = replace(
+ self.peft_config[adapters[0]],
+ r=new_rank,
+ lora_alpha=new_rank,
+ target_modules=new_target_modules,
+ )
+ self.inject_adapter(self.model, adapter_name)
+
+ # Do we really need that?
+ _freeze_adapter(self.model, adapter_name)
+
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
+ for key in key_list:
+ _, target, _ = _get_submodules(self.model, key)
+ if isinstance(target, LoraLayer):
+ if adapter_name in target.lora_A:
+ target_lora_A = target.lora_A[adapter_name].weight
+ target_lora_B = target.lora_B[adapter_name].weight
+ elif adapter_name in target.lora_embedding_A:
+ target_lora_A = target.lora_embedding_A[adapter_name]
+ target_lora_B = target.lora_embedding_B[adapter_name]
+ else:
+ continue
+
+ target_lora_A.data = target_lora_A.data * 0.0
+ target_lora_B.data = target_lora_B.data * 0.0
+ if combination_type == "cat":
+ loras_A, loras_B = [], []
+ for adapter, weight in zip(adapters, weights):
+ if adapter in target.lora_A:
+ current_adapter_lora_A = target.lora_A[adapter].weight
+ current_adapter_lora_B = target.lora_B[adapter].weight
+ elif adapter in target.lora_embedding_A:
+ current_adapter_lora_A = target.lora_embedding_A[adapter]
+ current_adapter_lora_B = target.lora_embedding_B[adapter]
+ else:
+ continue
+ loras_A.append(current_adapter_lora_A.data * weight * target.scaling[adapter])
+ loras_B.append(current_adapter_lora_B.data)
+
+ if len(loras_A) == 0:
+ raise ValueError("No matching LoRAs found. Please raise an issue on GitHub.")
+ loras_A = torch.cat(loras_A, dim=0)
+ loras_B = torch.cat(loras_B, dim=1)
+ target_lora_A.data[: loras_A.shape[0], :] = loras_A
+ target_lora_B.data[:, : loras_B.shape[1]] = loras_B
+ elif combination_type in [
+ "svd",
+ "ties_svd",
+ "dare_linear_svd",
+ "dare_ties_svd",
+ "magnitude_prune_svd",
+ ]:
+ target_lora_A.data, target_lora_B.data = self._svd_generalized_task_arithmetic_weighted_adapter(
+ combination_type,
+ adapters,
+ weights,
+ new_rank,
+ target,
+ target_lora_A,
+ target_lora_B,
+ density,
+ majority_sign_method,
+ svd_clamp,
+ full_matrices=svd_full_matrices,
+ driver=svd_driver,
+ )
+ elif combination_type in ["linear", "ties", "dare_linear", "dare_ties", "magnitude_prune"]:
+ target_lora_A.data, target_lora_B.data = self._generalized_task_arithmetic_weighted_adapter(
+ combination_type, adapters, weights, target, density, majority_sign_method
+ )
+
+ def _svd_generalized_task_arithmetic_weighted_adapter(
+ self,
+ combination_type,
+ adapters,
+ weights,
+ new_rank,
+ target,
+ target_lora_A,
+ target_lora_B,
+ density,
+ majority_sign_method,
+ clamp=None,
+ full_matrices=True,
+ driver=None,
+ ):
+ valid_adapters = []
+ valid_weights = []
+ is_embedding = any(adapter in target.lora_embedding_A for adapter in adapters)
+ for adapter, weight in zip(adapters, weights):
+ if adapter in target.lora_A or adapter in target.lora_embedding_A:
+ valid_adapters.append(adapter)
+ valid_weights.append(weight * target.scaling[adapter])
+
+ # if no valid adapter, nothing to do
+ if len(valid_adapters) == 0:
+ raise ValueError("No matching LoRAs found. Please raise an issue on Github.")
+ delta_weight = [target.get_delta_weight(adapter) for adapter in valid_adapters]
+ valid_weights = torch.tensor(valid_weights).to(delta_weight[0].device)
+ if combination_type == "svd":
+ delta_weight = task_arithmetic(delta_weight, valid_weights)
+ elif combination_type == "ties_svd":
+ delta_weight = ties(delta_weight, valid_weights, density, majority_sign_method)
+ elif combination_type == "dare_linear_svd":
+ delta_weight = dare_linear(delta_weight, valid_weights, density)
+ elif combination_type == "dare_ties_svd":
+ delta_weight = dare_ties(delta_weight, valid_weights, density, majority_sign_method)
+ elif combination_type == "magnitude_prune_svd":
+ delta_weight = magnitude_prune(delta_weight, valid_weights, density)
+ else:
+ raise ValueError(f"Invalid value passed to combination type: {combination_type}")
+
+ conv2d = isinstance(target, Conv2d)
+ if conv2d:
+ conv2d_1x1 = target.weight.size()[2:4] == (1, 1)
+ if not conv2d_1x1:
+ delta_weight = delta_weight.flatten(start_dim=1)
+ else:
+ delta_weight = delta_weight.squeeze()
+ if (hasattr(target, "fan_in_fan_out") and target.fan_in_fan_out) or is_embedding:
+ delta_weight = delta_weight.T
+
+ # based on https://github.com/kohya-ss/sd-scripts/blob/main/networks/svd_merge_lora.py#L114-L131
+ U, S, Vh = torch.linalg.svd(delta_weight, full_matrices=full_matrices, driver=driver)
+ U = U[:, :new_rank]
+ S = S[:new_rank]
+ U = U @ torch.diag(S)
+ Vh = Vh[:new_rank, :]
+ if clamp is not None:
+ dist = torch.cat([U.flatten(), Vh.flatten()])
+ hi_val = torch.quantile(dist, clamp)
+ low_val = -hi_val
+ U = U.clamp(low_val, hi_val)
+ Vh = Vh.clamp(low_val, hi_val)
+ if conv2d:
+ U = U.reshape(target_lora_B.data.shape)
+ Vh = Vh.reshape(target_lora_A.data.shape)
+ return Vh, U
+
+ def _generalized_task_arithmetic_weighted_adapter(
+ self,
+ combination_type,
+ adapters,
+ weights,
+ target,
+ density,
+ majority_sign_method,
+ ):
+ # account weights for LoRA A and B layers.
+ valid_weights = []
+ lora_A_deltas = []
+ lora_B_deltas = []
+ for adapter, weight in zip(adapters, weights):
+ if adapter in target.lora_A:
+ current_adapter_lora_A = target.lora_A[adapter].weight
+ current_adapter_lora_B = target.lora_B[adapter].weight
+ elif adapter in target.lora_embedding_A:
+ current_adapter_lora_A = target.lora_embedding_A[adapter]
+ current_adapter_lora_B = target.lora_embedding_B[adapter]
+ else:
+ continue
+ valid_weights.append(math.sqrt(weight * target.scaling[adapter]))
+ lora_A_deltas.append(current_adapter_lora_A.data)
+ lora_B_deltas.append(current_adapter_lora_B.data)
+ valid_weights = torch.tensor(valid_weights).to(lora_A_deltas[0].device)
+ lora_deltas = [lora_A_deltas, lora_B_deltas]
+ dtype = lora_A_deltas[0].dtype
+ for i, task_tensors in enumerate(lora_deltas):
+ if combination_type == "linear":
+ lora_deltas[i] = task_arithmetic(task_tensors, valid_weights)
+ elif combination_type == "ties":
+ lora_deltas[i] = ties(task_tensors, valid_weights, density, majority_sign_method)
+ elif combination_type == "dare_linear":
+ lora_deltas[i] = dare_linear(task_tensors, valid_weights, density)
+ elif combination_type == "dare_ties":
+ lora_deltas[i] = dare_ties(task_tensors, valid_weights, density, majority_sign_method)
+ elif combination_type == "magnitude_prune":
+ lora_deltas[i] = magnitude_prune(task_tensors, valid_weights, density)
+ else:
+ raise ValueError("Invalid combination type")
+ lora_deltas = [delta.to(dtype) for delta in lora_deltas]
+ return lora_deltas
+
+ def delete_adapter(self, adapter_name: str) -> None:
+ """
+ Deletes an existing adapter.
+
+ Args:
+ adapter_name (str): Name of the adapter to be deleted.
+ """
+ if adapter_name not in list(self.peft_config.keys()):
+ raise ValueError(f"Adapter {adapter_name} does not exist")
+ del self.peft_config[adapter_name]
+
+ key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
+ new_adapter = None
+ for key in key_list:
+ _, target, _ = _get_submodules(self.model, key)
+ if isinstance(target, LoraLayer):
+ target.delete_adapter(adapter_name)
+ if new_adapter is None:
+ new_adapter = target.active_adapters[:]
+
+ self.active_adapter = new_adapter or []
+
+ def merge_and_unload(
+ self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None
+ ) -> torch.nn.Module:
+ r"""
+ This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
+ as a standalone model.
+
+ Args:
+ progressbar (`bool`):
+ whether to show a progressbar indicating the unload and merge process
+ safe_merge (`bool`):
+ whether to activate the safe merging check to check if there is any potential Nan in the adapter
+ weights
+ adapter_names (`List[str]`, *optional*):
+ The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults
+ to `None`.
+ Example:
+
+ ```py
+ >>> from transformers import AutoModelForCausalLM
+ >>> from peft import PeftModel
+
+ >>> base_model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-40b")
+ >>> peft_model_id = "smangrul/falcon-40B-int4-peft-lora-sfttrainer-sample"
+ >>> model = PeftModel.from_pretrained(base_model, peft_model_id)
+ >>> merged_model = model.merge_and_unload()
+ ```
+ """
+ return self._unload_and_optionally_merge(
+ progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names
+ )
+
+ def unload(self) -> torch.nn.Module:
+ """
+ Gets back the base model by removing all the lora modules without merging. This gives back the original base
+ model.
+ """
+ return self._unload_and_optionally_merge(merge=False)
+
+ def subtract_mutated_init(self, output_state_dict: dict[str, torch.Tensor], adapter_name: str, kwargs=None):
+ """
+ This function can calculate the updates of the [PiSSA | OLoRA] by comparing the parameters of the [PiSSA |
+ OLoRA] adapter in `output_state_dict` with the initial values of [PiSSA | OLoRA] in `adapter_name`, thus
+ converting [PiSSA | OLoRA] to LoRA.
+ """
+ for name, param in self.model.named_parameters():
+ if (
+ param.data.dtype != torch.float32
+ and param.data.dtype != torch.float16
+ and param.data.dtype != torch.bfloat16
+ ) and adapter_name.startswith("pissa"):
+ warnings.warn(
+ r"Note that Quant(W_res) + AB != Quant(W) + \Delta(AB); "
+ "the converted LoRA, when combined with W or Quant(W), may introduce a certain gap in the fine-tuned model. "
+ "Therefore, we recommend directly using the Quant(W_res) in conjunction with the PiSSA adapter. "
+ )
+ mutated_init_state_dict = get_peft_model_state_dict(
+ self,
+ state_dict=kwargs.get("state_dict", None),
+ adapter_name=adapter_name,
+ )
+ tensors_lora = {}
+ for name in output_state_dict.keys():
+ ## W = W^{res} + A_0 \times B_0,
+ ## W + \Delta W = W^{res} + A \times B,
+ ## \Delta W = A \times B - A_0 \times B_0 = [A | A_0] \times [B | -B_0]^T = A'B'.
+ if "lora_A" in name:
+ tensors_lora[name] = torch.cat(
+ [output_state_dict[name], mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=0
+ )
+ elif "lora_B" in name:
+ tensors_lora[name] = torch.cat(
+ [output_state_dict[name], -mutated_init_state_dict[".".join(name.split(".")[1:])]], dim=1
+ )
+
+ return tensors_lora
\ No newline at end of file
diff --git a/gptqmodel/integration/src/peft/utils/__init__.py b/gptqmodel/integration/src/peft/utils/__init__.py
new file mode 100644
index 00000000..dc96e4d2
--- /dev/null
+++ b/gptqmodel/integration/src/peft/utils/__init__.py
@@ -0,0 +1,53 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all
+
+# coding=utf-8
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
+from .other import (
+ TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
+ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING,
+ TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING,
+ CONFIG_NAME,
+ WEIGHTS_NAME,
+ SAFETENSORS_WEIGHTS_NAME,
+ INCLUDE_LINEAR_LAYERS_SHORTHAND,
+ _set_trainable,
+ bloom_model_postprocess_past_key_value,
+ prepare_model_for_kbit_training,
+ shift_tokens_right,
+ transpose,
+ _get_batch_size,
+ _get_submodules,
+ _set_adapter,
+ _freeze_adapter,
+ ModulesToSaveWrapper,
+ _prepare_prompt_learning_config,
+ _is_valid_match,
+ infer_device,
+ get_auto_gptq_quant_linear,
+ get_gptqmodel_quant_linear,
+ get_quantization_config,
+ id_tensor_storage,
+ cast_mixed_precision_params,
+)
diff --git a/gptqmodel/integration/src/peft/utils/other.py b/gptqmodel/integration/src/peft/utils/other.py
new file mode 100644
index 00000000..201e2210
--- /dev/null
+++ b/gptqmodel/integration/src/peft/utils/other.py
@@ -0,0 +1,741 @@
+
+# Copyright 2023-present the HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import annotations
+
+import copy
+import inspect
+import os
+import re
+import warnings
+from contextlib import nullcontext
+from typing import Any, Optional
+
+import accelerate
+import torch
+from accelerate.hooks import add_hook_to_module, remove_hook_from_module
+from accelerate.utils import is_npu_available, is_xpu_available
+from huggingface_hub import file_exists
+from huggingface_hub.errors import EntryNotFoundError, HFValidationError
+from packaging import version
+from peft.utils.constants import *
+from safetensors.torch import storage_ptr, storage_size
+
+from gptqmodel.integration.src.peft.import_utils import is_auto_gptq_available, is_torch_tpu_available, is_gptqmodel_available
+
+mlu_available = False
+if version.parse(accelerate.__version__) >= version.parse("0.29.0"):
+ from accelerate.utils import is_mlu_available
+
+ mlu_available = is_mlu_available()
+
+
+__all__ = [
+ "CONFIG_NAME",
+ "EMBEDDING_LAYER_NAMES",
+ "SAFETENSORS_WEIGHTS_NAME",
+ "TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING",
+ "TRANSFORMERS_MODELS_TO_LNTUNING_TARGET_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_VERA_TARGET_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_VBLORA_TARGET_MODULES_MAPPING",
+ "TRANSFORMERS_MODELS_TO_FOURIERFT_TARGET_MODULES_MAPPING",
+ "WEIGHTS_NAME",
+ "INCLUDE_LINEAR_LAYERS_SHORTHAND",
+ "bloom_model_postprocess_past_key_value",
+ "starcoder_model_postprocess_past_key_value",
+]
+
+
+# Get current device name based on available devices
+def infer_device() -> str:
+ if torch.cuda.is_available():
+ return "cuda"
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ return "mps"
+ elif mlu_available:
+ return "mlu"
+ elif is_xpu_available():
+ return "xpu"
+ elif is_npu_available():
+ return "npu"
+ return "cpu"
+
+
+def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
+ r"""
+ Note this method only works for `transformers` models.
+
+ This method wraps the entire protocol for preparing a model before running a training. This includes:
+ 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
+ head to fp32
+
+ Args:
+ model (`transformers.PreTrainedModel`):
+ The loaded model from `transformers`
+ use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
+ If True, use gradient checkpointing to save memory at the expense of slower backward pass.
+ gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
+ Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of
+ `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method.
+ Note this is only available in the latest transformers versions (> 4.34.1).
+ """
+ loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
+ is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
+ is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
+ is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"
+ is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao"
+ is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False)
+
+ if gradient_checkpointing_kwargs is None:
+ gradient_checkpointing_kwargs = {}
+
+ for name, param in model.named_parameters():
+ # freeze base model's layers
+ param.requires_grad = False
+
+ if (
+ not is_gptq_quantized
+ and not is_aqlm_quantized
+ and not is_eetq_quantized
+ and not is_hqq_quantized
+ and not is_torchao_quantized
+ ):
+ # cast all non INT8 parameters to fp32
+ for param in model.parameters():
+ if (
+ (param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
+ ) and param.__class__.__name__ != "Params4bit":
+ param.data = param.data.to(torch.float32)
+
+ if (
+ loaded_in_kbit
+ or is_gptq_quantized
+ or is_aqlm_quantized
+ or is_eetq_quantized
+ or is_hqq_quantized
+ or is_torchao_quantized
+ ) and use_gradient_checkpointing:
+ # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
+ if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
+ # For backward compatibility
+ if hasattr(model, "enable_input_require_grads"):
+ model.enable_input_require_grads()
+ else:
+
+ def make_inputs_require_grad(module, input, output):
+ output.requires_grad_(True)
+
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
+
+ # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
+ _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
+ inspect.signature(model.gradient_checkpointing_enable).parameters
+ )
+
+ if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
+ warnings.warn(
+ "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
+ " if you want to use that feature, please upgrade to the latest version of transformers.",
+ FutureWarning,
+ )
+
+ gc_enable_kwargs = (
+ {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
+ )
+
+ # enable gradient checkpointing for memory efficiency
+ model.gradient_checkpointing_enable(**gc_enable_kwargs)
+ return model
+
+
+# copied from transformers.models.bart.modeling_bart
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
+ """
+ Shift input ids one token to the right.
+
+ Args:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
+ pad_token_id (`int`): The id of the `padding` token.
+ decoder_start_token_id (`int`): The id of the `start` token.
+ """
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
+ shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
+ shifted_input_ids[:, 0] = decoder_start_token_id
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
+
+ return shifted_input_ids
+
+
+class ModulesToSaveWrapper(torch.nn.Module):
+ def __init__(self, module_to_save, adapter_name):
+ super().__init__()
+ self.original_module = module_to_save
+ self.modules_to_save = torch.nn.ModuleDict({})
+ self._active_adapter = adapter_name
+ self._disable_adapters = False
+ self.update(adapter_name)
+ self.check_module()
+
+ def check_module(self):
+ """Perform some sanity checks on the module to ensure that it works"""
+ # Try to anticipate some modules that users could try to target that would not work.
+ # Note: It's not possible to check hasattr(module, "forward"), since that returns True for ModuleDict and
+ # ModuleList, even though their forward methods cannot be called
+ forbidden_classes = (torch.nn.ModuleDict, torch.nn.ModuleList, torch.nn.ParameterDict, torch.nn.ParameterList)
+ if isinstance(self.original_module, forbidden_classes):
+ cls_name = self.original_module.__class__
+ raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}")
+
+ # local import to avoid circular import
+ from peft.tuners.tuners_utils import BaseTunerLayer
+
+ if isinstance(self.original_module, BaseTunerLayer):
+ # e.g. applying modules_to_save to a lora layer makes no sense
+ cls_name = self.original_module.__class__
+ raise TypeError(f"modules_to_save cannot be applied to modules of type {cls_name}")
+
+ @property
+ def disable_adapters(self) -> bool:
+ # use a property to ensure that disable_adapters is not set directly, instead use the enable_adapters method
+ return self._disable_adapters
+
+ @property
+ def active_adapter(self) -> str:
+ # use a property to ensure that active_adapter is not set directly, instead use the set_adapter method
+ return self._active_adapter
+
+ def __getattr__(self, name: str):
+ # Note: This whole method may seem overly complex at first but PyTorch messes with __getattr__ in a way that
+ # requires very careful handling to avoid infinite recursion.
+ try:
+ return super().__getattr__(name)
+ except AttributeError:
+ pass
+
+ if "_modules" not in self.__dict__:
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+
+ # Could not find the attribute the PyTorch way. So let's check if it's an attribute on the
+ # original_module/modules_to_save.
+ modules = self.__dict__["_modules"]
+ if self.disable_adapters:
+ module = modules["original_module"]
+ elif self.active_adapter in modules["modules_to_save"]:
+ module = modules["modules_to_save"][self.active_adapter]
+ else:
+ # For some reason, there is no module corresponding to the active adapter; this should normally not be
+ # reached and exists as a failsafe (otherwise, a KeyError would be raised)
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
+ return getattr(module, name)
+
+ def update(self, adapter_name):
+ context_manager = nullcontext()
+ for _, param in self.original_module.named_parameters():
+ num_params = param.numel()
+ # if using DS Zero 3 and the weights are initialized empty
+ if num_params == 0 and hasattr(param, "ds_numel"):
+ import deepspeed
+
+ context_manager = deepspeed.zero.GatheredParameters(self.original_module.parameters(), modifier_rank=0)
+ break
+ with context_manager:
+ self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))
+
+ if hasattr(self.modules_to_save[adapter_name], "_hf_hook"):
+ old_hook = self.modules_to_save[adapter_name]._hf_hook
+ new_hook = self._create_new_hook(old_hook)
+ remove_hook_from_module(self.modules_to_save[adapter_name])
+ add_hook_to_module(self.modules_to_save[adapter_name], new_hook)
+
+ self.original_module.requires_grad_(False)
+ if adapter_name == self.active_adapter:
+ self.modules_to_save[adapter_name].requires_grad_(True)
+
+ def _create_new_hook(self, old_hook):
+ r"""
+ Creates a new hook based on the old hook. Use it only if you know what you are doing !
+ """
+ old_hook_cls = getattr(accelerate.hooks, old_hook.__class__.__name__)
+ old_hook_attr = old_hook.__dict__
+ filtered_old_hook_attr = {}
+ old_hook_init_signature = inspect.signature(old_hook_cls.__init__)
+ for k in old_hook_attr.keys():
+ if k in old_hook_init_signature.parameters:
+ filtered_old_hook_attr[k] = old_hook_attr[k]
+ new_hook = old_hook_cls(**filtered_old_hook_attr)
+ return new_hook
+
+ def _check_forward_args(self, x, *args, **kwargs):
+ """Check if the arguments are compatible with the configs and state of the model"""
+ adapter_names = kwargs.get("adapter_names", None)
+ if adapter_names is None:
+ return
+
+ if len(x) != len(adapter_names):
+ msg = (
+ "Length of `adapter_names` should be the same as the number of inputs, but got "
+ f"{len(adapter_names)} and {len(x)} respectively."
+ )
+ raise ValueError(msg)
+
+ def _mixed_batch_forward(
+ self, input: torch.Tensor, *args: Any, adapter_names: list[str], **kwargs: Any
+ ) -> torch.Tensor:
+ # This is a special method that handles the case when users pass the argument `adapter_names`. This is an
+ # extra argument that allows mixing different adapters in the same batch at inference time.
+
+ SUPPORTED_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d)
+
+ module_names = ", ".join([module.__name__ for module in SUPPORTED_MODULES])
+
+ if not isinstance(self.original_module, SUPPORTED_MODULES):
+ raise TypeError(f"Mixed batching is only supported for the following modules: {module_names}.")
+
+ unique_adapters = set(adapter_names)
+ sub_batch_indices_list = []
+
+ for adapter in unique_adapters:
+ sub_batch_indices_list.append([index for index, item in enumerate(adapter_names) if item == adapter])
+
+ results = [0 for _ in range(len(input))]
+
+ for i, active_adapter in enumerate(unique_adapters):
+ sub_batch = input[sub_batch_indices_list[i]]
+
+ if active_adapter == "__base__":
+ output = self.original_module(sub_batch, *args, **kwargs)
+ else:
+ output = self.modules_to_save[active_adapter](sub_batch, *args, **kwargs)
+
+ for index, j in enumerate(sub_batch_indices_list[i]):
+ results[j] = output[index]
+
+ return torch.stack(results)
+
+ def forward(self, x: torch.Tensor, *args, **kwargs):
+ self._check_forward_args(x, *args, **kwargs)
+ adapter_names = kwargs.pop("adapter_names", None)
+
+ if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
+ return self.original_module(x, *args, **kwargs)
+ if adapter_names is None:
+ return self.modules_to_save[self.active_adapter](x, *args, **kwargs)
+ return self._mixed_batch_forward(x, *args, adapter_names=adapter_names, **kwargs)
+
+ def enable_adapters(self, enabled: bool):
+ """Toggle the enabling and disabling of adapters
+
+ Takes care of setting the requires_grad flag for the adapter weights.
+
+ Args:
+ enabled (bool): True to enable adapters, False to disable adapters
+ """
+ if self._disable_adapters is not enabled:
+ # already in the desired state, do nothing
+ return
+
+ if enabled:
+ self.original_module.requires_grad_(False)
+ self.modules_to_save[self.active_adapter].requires_grad_(True)
+ self._disable_adapters = False
+ else:
+ self.original_module.requires_grad_(True)
+ self.modules_to_save.requires_grad_(False)
+ self._disable_adapters = True
+
+ def set_adapter(self, adapter_name: str):
+ """Set the active adapter
+
+ Additionally, this function will set the specified adapter to trainable (i.e., requires_grad=True). If this is
+ not desired, use the following code.
+
+ ```py
+ >>> for name, param in model_peft.named_parameters():
+ ... if ...: # some check on name (ex. if 'lora' in name)
+ ... param.requires_grad = False
+ ```
+
+ Args:
+ adapter_name (str): The name of the adapter to set as active
+ """
+ if adapter_name not in self.modules_to_save:
+ raise ValueError(f"Adapter {adapter_name} not found in {self.modules_to_save.keys()}")
+
+ self.modules_to_save[self.active_adapter].requires_grad_(False)
+ self.modules_to_save[adapter_name].requires_grad_(True)
+ self._active_adapter = adapter_name
+
+
+def _get_submodules(model, key):
+ parent = model.get_submodule(".".join(key.split(".")[:-1]))
+ target_name = key.split(".")[-1]
+ target = model.get_submodule(key)
+ return parent, target, target_name
+
+
+def _freeze_adapter(model, adapter_name):
+ for n, p in model.named_parameters():
+ if adapter_name in n:
+ p.requires_grad = False
+
+
+def _set_trainable(model, adapter_name):
+ key_list = [key for key, _ in model.named_modules()]
+ for key in key_list:
+ target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
+ if target_module_found:
+ parent, target, target_name = _get_submodules(model, key)
+ if isinstance(target, ModulesToSaveWrapper):
+ target.update(adapter_name)
+ target.set_adapter(target.active_adapter)
+ else:
+ new_module = ModulesToSaveWrapper(target, adapter_name)
+ new_module.set_adapter(adapter_name)
+ setattr(parent, target_name, new_module)
+
+
+def _set_adapter(model, adapter_name):
+ def check_adapter_name(adapter_name):
+ if isinstance(adapter_name, str):
+ return adapter_name
+
+ # adapter_name is a list of str
+ if len(adapter_name) > 1:
+ raise ValueError("Only one adapter can be set at a time for modules_to_save")
+ elif len(adapter_name) == 0:
+ raise ValueError("Please specify at least one adapter to set")
+ adapter_name = adapter_name[0]
+ return adapter_name
+
+ for module in model.modules():
+ if isinstance(module, ModulesToSaveWrapper):
+ # only check the adapter_name if we actually encounter a ModulesToSaveWrapper, otherwise we don't care
+ adapter_name = check_adapter_name(adapter_name)
+
+ # if the adapter is found in this module, set it as the active adapter, else disable the adapters of this
+ # module
+ if adapter_name in module.modules_to_save:
+ module.set_adapter(adapter_name)
+ else:
+ module.enable_adapters(False)
+
+
+def _prepare_prompt_learning_config(peft_config, model_config):
+ if peft_config.num_layers is None:
+ if "num_hidden_layers" in model_config:
+ num_layers = model_config["num_hidden_layers"]
+ elif "num_layers" in model_config:
+ num_layers = model_config["num_layers"]
+ elif "n_layer" in model_config:
+ num_layers = model_config["n_layer"]
+ else:
+ raise ValueError("Please specify `num_layers` in `peft_config`")
+ peft_config.num_layers = num_layers
+
+ if peft_config.token_dim is None:
+ if "hidden_size" in model_config:
+ token_dim = model_config["hidden_size"]
+ elif "n_embd" in model_config:
+ token_dim = model_config["n_embd"]
+ elif "d_model" in model_config:
+ token_dim = model_config["d_model"]
+ else:
+ raise ValueError("Please specify `token_dim` in `peft_config`")
+ peft_config.token_dim = token_dim
+
+ if peft_config.num_attention_heads is None:
+ if "num_attention_heads" in model_config:
+ num_attention_heads = model_config["num_attention_heads"]
+ elif "n_head" in model_config:
+ num_attention_heads = model_config["n_head"]
+ elif "num_heads" in model_config:
+ num_attention_heads = model_config["num_heads"]
+ elif "encoder_attention_heads" in model_config:
+ num_attention_heads = model_config["encoder_attention_heads"]
+ else:
+ raise ValueError("Please specify `num_attention_heads` in `peft_config`")
+ peft_config.num_attention_heads = num_attention_heads
+
+ # For grouped-query attention, see #1901.
+ if peft_config.peft_type == "PREFIX_TUNING" and "num_key_value_heads" in model_config:
+ num_key_value_heads = model_config["num_key_value_heads"]
+ peft_config.token_dim = peft_config.token_dim // peft_config.num_attention_heads * num_key_value_heads
+ peft_config.num_attention_heads = num_key_value_heads
+
+ if getattr(peft_config, "encoder_hidden_size", None) is None:
+ setattr(peft_config, "encoder_hidden_size", peft_config.token_dim)
+
+ return peft_config
+
+
+def fsdp_auto_wrap_policy(model):
+ import functools
+ import os
+
+ from accelerate import FullyShardedDataParallelPlugin
+
+ if hasattr(FullyShardedDataParallelPlugin, "get_module_class_from_name"):
+ get_module_class_from_name = FullyShardedDataParallelPlugin.get_module_class_from_name
+ else:
+ from accelerate.utils.dataclasses import get_module_class_from_name
+ from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
+
+ from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
+
+ default_transformer_cls_names_to_wrap = (
+ ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else ""
+ )
+ transformer_cls_names_to_wrap = os.environ.get(
+ "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap
+ ).split(",")
+ transformer_cls_to_wrap = {PrefixEncoder, PromptEncoder, PromptEmbedding}
+ for layer_class in transformer_cls_names_to_wrap:
+ if len(layer_class) == 0:
+ continue
+ transformer_cls = get_module_class_from_name(model, layer_class)
+ if transformer_cls is None:
+ raise Exception("Could not find the transformer layer class to wrap in the model.")
+ else:
+ transformer_cls_to_wrap.add(transformer_cls)
+
+ def lambda_policy_fn(module):
+ if (
+ len(list(module.named_children())) == 0
+ and getattr(module, "weight", None) is not None
+ and module.weight.requires_grad
+ ):
+ return True
+ return False
+
+ lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
+ transformer_wrap_policy = functools.partial(
+ transformer_auto_wrap_policy,
+ transformer_layer_cls=transformer_cls_to_wrap,
+ )
+
+ auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
+ return auto_wrap_policy
+
+
+def transpose(weight, fan_in_fan_out):
+ if not fan_in_fan_out:
+ return weight
+
+ if isinstance(weight, torch.nn.Parameter):
+ return torch.nn.Parameter(weight.T)
+ return weight.T
+
+
+def _is_valid_match(key: str, target_key: str):
+ """
+ Helper function to match module names target_key and key. Makes sure that either the key is exactly the target_key
+ or the target_key is a submodule of key
+ """
+ if key.endswith(target_key):
+ if len(key) > len(target_key):
+ return key.endswith("." + target_key) # must be a sub module
+ return True
+ return False
+
+
+def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int:
+ """Get the batch size based on either input_ids or input_embeds
+
+ Raises an ValueError if both are None.
+
+ """
+ if (input_ids is None) and (inputs_embeds is None):
+ raise ValueError("You have to provide either input_ids or inputs_embeds")
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ else:
+ batch_size = inputs_embeds.shape[0]
+ return batch_size
+
+
+def get_quantization_config(model: torch.nn.Module, method: str):
+ """
+ Get the quantization config of the related quantization method
+ """
+ if (
+ hasattr(model, "config")
+ and hasattr(model.config, "quantization_config")
+ and (getattr(model, "quantization_method", None) == method)
+ ):
+ return model.config.quantization_config
+ return None
+
+
+def get_auto_gptq_quant_linear(gptq_quantization_config):
+ """
+ Get the right AutoGPTQQuantLinear class based on the quantization config file
+ """
+ if gptq_quantization_config is None:
+ return None
+
+ if is_auto_gptq_available():
+ from auto_gptq.utils.import_utils import dynamically_import_QuantLinear
+ else:
+ return None
+
+ desc_act = gptq_quantization_config.desc_act
+ group_size = gptq_quantization_config.group_size
+ bits = gptq_quantization_config.bits
+ if hasattr(gptq_quantization_config, "use_exllama"):
+ use_exllama = gptq_quantization_config.use_exllama
+ else:
+ use_exllama = not gptq_quantization_config.disable_exllama
+ if hasattr(gptq_quantization_config, "exllama_config"):
+ exllama_version = gptq_quantization_config.exllama_config["version"]
+ else:
+ exllama_version = 1
+
+ QuantLinear = dynamically_import_QuantLinear(
+ use_triton=False,
+ desc_act=desc_act,
+ group_size=group_size,
+ bits=bits,
+ disable_exllama=not (use_exllama and exllama_version == 1),
+ disable_exllamav2=not (use_exllama and exllama_version == 2),
+ )
+
+ return QuantLinear
+
+
+def get_gptqmodel_quant_linear(gptq_quantization_config, device_map=None):
+ """
+ Get the right GPTQQuantLinear class based on the quantization config file
+ """
+ if gptq_quantization_config is None:
+ return None
+
+ if not is_gptqmodel_available():
+ return None
+
+ from gptqmodel.utils.importer import hf_select_quant_linear
+
+ desc_act = gptq_quantization_config.desc_act
+ group_size = gptq_quantization_config.group_size
+ bits = gptq_quantization_config.bits
+ checkpoint_format = gptq_quantization_config.checkpoint_format if hasattr(gptq_quantization_config, "checkpoint_format") else "gptq"
+ sym = gptq_quantization_config.sym
+ meta = gptq_quantization_config.meta if hasattr(gptq_quantization_config, "meta") else None
+
+ QuantLinear = hf_select_quant_linear(bits=bits, group_size=group_size,
+ desc_act=desc_act, sym=sym, device_map=device_map,
+ checkpoint_format=checkpoint_format, meta=meta, backend="auto_trainable")
+
+ return QuantLinear
+
+
+def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
+ """
+ Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
+ example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
+ guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
+ non-overlapping lifetimes may have the same id.
+
+ This method is the exact same copy of
+ https://github.com/huggingface/transformers/blob/main/src/transformers/pytorch_utils.py#L282C1-L300C58 but we added
+ it here manually to avoid import issue with old versions of transformers.
+ """
+ if tensor.device.type == "xla" and is_torch_tpu_available():
+ # NOTE: xla tensors dont have storage
+ # use some other unique id to distinguish.
+ # this is a XLA tensor, it must be created using torch_xla's
+ # device. So the following import is safe:
+ import torch_xla
+
+ unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
+ else:
+ unique_id = storage_ptr(tensor)
+
+ return tensor.device, unique_id, storage_size(tensor)
+
+
+def cast_mixed_precision_params(model, dtype):
+ """
+ Cast all non-trainable parameters of the model to the given `dtype`. The `dtype` can be `torch.float16` or
+ `torch.bfloat16` as per the mixed-precision training you are performing. The trainable parameters are cast to full
+ precision. This is meant to reduce the GPU memory usage when using PEFT methods by using half-precision dtype for
+ non-trainable parameters. Having the trainable parameters in full-precision preserves training stability when using
+ automatic mixed-precision training.
+
+ Args:
+ model (`torch.nn.Module`):
+ The model to cast the non-trainable parameters of.
+ dtype (`torch.dtype`):
+ The dtype to cast the non-trainable parameters to. The `dtype` can be `torch.float16` or
+ `torch.bfloat16` as per the mixed-precision training you are performing.
+ """
+ for p in model.parameters():
+ if not p.requires_grad:
+ p.data = p.to(dtype)
+ else:
+ p.data = p.to(torch.float32)
+
+
+def str_to_bool(value: str) -> int:
+ """
+ Converts a string representation of truth to `True` (1) or `False` (0).
+
+ True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
+ """
+ # same as function as in accelerate.utils, which replaces the deprecated distutils.util.strtobool
+ value = value.lower()
+ if value in ("y", "yes", "t", "true", "on", "1"):
+ return 1
+ elif value in ("n", "no", "f", "false", "off", "0"):
+ return 0
+ else:
+ raise ValueError(f"invalid truth value {value}")
+
+
+def check_file_exists_on_hf_hub(repo_id: str, filename: str, **kwargs) -> Optional[bool]:
+ """Check if a file exists on HF Hub, if check was not successful returns None instead of erroring.
+
+ Respect offline mode if set.
+
+ """
+ exists: Optional[bool] = None
+ if str_to_bool(os.environ.get("HF_HUB_OFFLINE", "0")):
+ # user set offline mode, cannot check
+ return exists
+
+ try:
+ exists = file_exists(repo_id, filename, **kwargs)
+ except (HFValidationError, EntryNotFoundError):
+ # error, exists stays None
+ pass
+ except Exception as e:
+ warnings.warn(
+ f"Unable to fetch remote file due to the following error {e} - silently ignoring the lookup"
+ f" for the file {filename} in {repo_id}."
+ )
+
+ return exists
+
+
+def get_pattern_key(pattern_keys, key_to_match):
+ """Match a substring of key_to_match in pattern keys"""
+ return next(filter(lambda key: re.match(rf".*\.{key}$", key_to_match), pattern_keys), key_to_match)
diff --git a/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py
new file mode 100644
index 00000000..eb8fcbc7
--- /dev/null
+++ b/gptqmodel/integration/src/transformers/quantizers/quantizer_gptq.py
@@ -0,0 +1,111 @@
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import importlib
+from typing import TYPE_CHECKING, Optional
+
+from packaging import version
+
+from transformers.quantizers.base import HfQuantizer
+
+from gptqmodel.integration.src.transformers.utils.import_utils import is_gptqmodel_available
+
+if TYPE_CHECKING:
+ from transformers.modeling_utils import PreTrainedModel
+
+from transformers.utils import is_auto_gptq_available, is_optimum_available, is_torch_available, logging
+from transformers.utils.quantization_config import GPTQConfig, QuantizationConfigMixin
+
+
+if is_torch_available():
+ import torch
+
+logger = logging.get_logger(__name__)
+
+
+class GptqHfQuantizer(HfQuantizer):
+ """
+ Quantizer of the GPTQ method - for GPTQ the quantizer support calibration of the model through
+ `auto_gptq` or `gptqmodel` package. Quantization is done under the hood for users if they load a non-prequantized model.
+ """
+
+ requires_calibration = False
+ required_packages = ["optimum", "auto_gptq", "gptqmodel"]
+ optimum_quantizer = None
+
+ def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
+ super().__init__(quantization_config, **kwargs)
+ from optimum.gptq import GPTQQuantizer
+
+ self.optimum_quantizer = GPTQQuantizer.from_dict(self.quantization_config.to_dict_optimum())
+
+ def validate_environment(self, *args, **kwargs):
+ if not is_optimum_available():
+ raise ImportError("Loading a GPTQ quantized model requires optimum (`pip install optimum`)")
+ if is_auto_gptq_available() and is_gptqmodel_available():
+ logger.warning(
+ "Detected gptqmodel and auto-gptq, will use gptqmodel, auto-gptq will be deprecated in the future."
+ )
+
+ gptq_supports_cpu = (
+ is_auto_gptq_available()
+ and version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
+ ) or is_gptqmodel_available()
+ if not gptq_supports_cpu and not torch.cuda.is_available():
+ raise RuntimeError("GPU is required to quantize or run quantize model.")
+ elif not (is_auto_gptq_available() or is_gptqmodel_available()):
+ raise ImportError(
+ "Loading a GPTQ quantized model requires gptqmodel (`pip install gptqmodel`) or auto-gptq (`pip install auto-gptq`) library. Please notice that auto-gptq will be deprecated in the future."
+ )
+ elif is_auto_gptq_available() and version.parse(importlib.metadata.version("auto_gptq")) < version.parse(
+ "0.4.2"
+ ):
+ raise ImportError(
+ "You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq` or use gptqmodel by `pip install gptqmodel`. Please notice that auto-gptq will be deprecated in the future."
+ )
+ elif is_gptqmodel_available() and (
+ version.parse(importlib.metadata.version("gptqmodel")) <= version.parse("1.3.1")
+ or version.parse(importlib.metadata.version("optimum")) < version.parse("1.23.3")
+ ):
+ raise ImportError("The gptqmodel version should be >= 1.3.2, optimum version should >= 1.24.0")
+
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
+ if torch_dtype is None:
+ torch_dtype = torch.float16
+ elif torch_dtype != torch.float16:
+ logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
+ return torch_dtype
+
+ def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ if model.__class__.main_input_name != "input_ids":
+ raise RuntimeError("We can only quantize pure text model.")
+
+ if self.pre_quantized:
+ model = self.optimum_quantizer.convert_model(model, **kwargs)
+
+ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
+ if self.pre_quantized:
+ model = self.optimum_quantizer.post_init_model(model)
+ else:
+ if self.quantization_config.tokenizer is None:
+ self.quantization_config.tokenizer = model.name_or_path
+
+ self.optimum_quantizer.quantize_model(model, self.quantization_config.tokenizer)
+ model.config.quantization_config = GPTQConfig.from_dict(self.optimum_quantizer.to_dict())
+
+ @property
+ def is_trainable(self, model: Optional["PreTrainedModel"] = None):
+ return True
+
+ def is_serializable(self, safe_serialization=None):
+ return True
\ No newline at end of file
diff --git a/gptqmodel/integration/src/transformers/testing_utils.py b/gptqmodel/integration/src/transformers/testing_utils.py
new file mode 100644
index 00000000..46fff122
--- /dev/null
+++ b/gptqmodel/integration/src/transformers/testing_utils.py
@@ -0,0 +1,2756 @@
+# Copyright 2020 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import contextlib
+import doctest
+import functools
+import gc
+import importlib
+import inspect
+import logging
+import multiprocessing
+import os
+import re
+import shlex
+import shutil
+import subprocess
+import sys
+import tempfile
+import time
+import unittest
+from collections import defaultdict
+from collections.abc import Mapping
+from dataclasses import MISSING, fields
+from functools import wraps
+from io import StringIO
+from pathlib import Path
+from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union
+from unittest import mock
+from unittest.mock import patch
+
+import huggingface_hub.utils
+import urllib3
+from huggingface_hub import delete_repo
+
+from transformers import logging as transformers_logging
+
+from transformers.integrations import (
+ is_clearml_available,
+ is_optuna_available,
+ is_ray_available,
+ is_sigopt_available,
+ is_tensorboard_available,
+ is_wandb_available,
+)
+from transformers.integrations.deepspeed import is_deepspeed_available
+from transformers.utils import (
+ ACCELERATE_MIN_VERSION,
+ GGUF_MIN_VERSION,
+ is_accelerate_available,
+ is_apex_available,
+ is_aqlm_available,
+ is_auto_awq_available,
+ is_auto_gptq_available,
+ is_av_available,
+ is_bitsandbytes_available,
+ is_bitsandbytes_multi_backend_available,
+ is_bs4_available,
+ is_compressed_tensors_available,
+ is_cv2_available,
+ is_cython_available,
+ is_detectron2_available,
+ is_eetq_available,
+ is_essentia_available,
+ is_faiss_available,
+ is_fbgemm_gpu_available,
+ is_flash_attn_2_available,
+ is_flax_available,
+ is_fsdp_available,
+ is_ftfy_available,
+ is_g2p_en_available,
+ is_galore_torch_available,
+ is_gguf_available,
+ is_grokadamw_available,
+ is_ipex_available,
+ is_jieba_available,
+ is_jinja_available,
+ is_jumanpp_available,
+ is_keras_nlp_available,
+ is_levenshtein_available,
+ is_librosa_available,
+ is_liger_kernel_available,
+ is_lomo_available,
+ is_natten_available,
+ is_nltk_available,
+ is_onnx_available,
+ is_optimum_available,
+ is_optimum_quanto_available,
+ is_pandas_available,
+ is_peft_available,
+ is_phonemizer_available,
+ is_pretty_midi_available,
+ is_pyctcdecode_available,
+ is_pytesseract_available,
+ is_pytest_available,
+ is_pytorch_quantization_available,
+ is_rjieba_available,
+ is_sacremoses_available,
+ is_safetensors_available,
+ is_schedulefree_available,
+ is_scipy_available,
+ is_sentencepiece_available,
+ is_seqio_available,
+ is_soundfile_availble,
+ is_spacy_available,
+ is_sudachi_available,
+ is_sudachi_projection_available,
+ is_tensorflow_probability_available,
+ is_tensorflow_text_available,
+ is_tf2onnx_available,
+ is_tf_available,
+ is_tiktoken_available,
+ is_timm_available,
+ is_tokenizers_available,
+ is_torch_available,
+ is_torch_bf16_available_on_device,
+ is_torch_bf16_cpu_available,
+ is_torch_bf16_gpu_available,
+ is_torch_deterministic,
+ is_torch_fp16_available_on_device,
+ is_torch_neuroncore_available,
+ is_torch_npu_available,
+ is_torch_sdpa_available,
+ is_torch_tensorrt_fx_available,
+ is_torch_tf32_available,
+ is_torch_xla_available,
+ is_torch_xpu_available,
+ is_torchao_available,
+ is_torchaudio_available,
+ is_torchdynamo_available,
+ is_torchvision_available,
+ is_vision_available,
+ strtobool,
+)
+
+from .utils.import_utils import is_gptqmodel_available
+
+if is_accelerate_available():
+ from accelerate.state import AcceleratorState, PartialState
+ from accelerate.utils.imports import is_fp8_available
+
+
+if is_pytest_available():
+ from _pytest.doctest import (
+ Module,
+ _get_checker,
+ _get_continue_on_failure,
+ _get_runner,
+ _is_mocked,
+ _patch_unwrap_mock_aware,
+ get_optionflags,
+ )
+ from _pytest.outcomes import skip
+ from _pytest.pathlib import import_path
+ from pytest import DoctestItem
+else:
+ Module = object
+ DoctestItem = object
+
+
+SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
+DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown"
+DUMMY_DIFF_TOKENIZER_IDENTIFIER = "julien-c/dummy-diff-tokenizer"
+# Used to test Auto{Config, Model, Tokenizer} model_type detection.
+
+# Used to test the hub
+USER = "__DUMMY_TRANSFORMERS_USER__"
+ENDPOINT_STAGING = "https://hub-ci.huggingface.co"
+
+# Not critical, only usable on the sandboxed CI instance.
+TOKEN = "hf_94wBhPGp6KrrTH3KDchhKpRxZwd6dmHWLL"
+
+if is_torch_available():
+ import torch
+
+ IS_ROCM_SYSTEM = torch.version.hip is not None
+ IS_CUDA_SYSTEM = torch.version.cuda is not None
+else:
+ IS_ROCM_SYSTEM = False
+ IS_CUDA_SYSTEM = False
+
+
+def parse_flag_from_env(key, default=False):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ # KEY isn't set, default to `default`.
+ _value = default
+ else:
+ # KEY is set, convert it to True or False.
+ try:
+ _value = strtobool(value)
+ except ValueError:
+ # More values are supported, but let's keep the message simple.
+ raise ValueError(f"If set, {key} must be yes or no.")
+ return _value
+
+
+def parse_int_from_env(key, default=None):
+ try:
+ value = os.environ[key]
+ except KeyError:
+ _value = default
+ else:
+ try:
+ _value = int(value)
+ except ValueError:
+ raise ValueError(f"If set, {key} must be a int.")
+ return _value
+
+
+_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
+_run_pt_tf_cross_tests = parse_flag_from_env("RUN_PT_TF_CROSS_TESTS", default=True)
+_run_pt_flax_cross_tests = parse_flag_from_env("RUN_PT_FLAX_CROSS_TESTS", default=True)
+_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
+_run_staging = parse_flag_from_env("HUGGINGFACE_CO_STAGING", default=False)
+_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
+_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True)
+_run_agent_tests = parse_flag_from_env("RUN_AGENT_TESTS", default=False)
+_run_third_party_device_tests = parse_flag_from_env("RUN_THIRD_PARTY_DEVICE_TESTS", default=False)
+
+
+def get_device_count():
+ import torch
+
+ if is_torch_xpu_available():
+ num_devices = torch.xpu.device_count()
+ else:
+ num_devices = torch.cuda.device_count()
+
+ return num_devices
+
+
+def is_pt_tf_cross_test(test_case):
+ """
+ Decorator marking a test as a test that control interactions between PyTorch and TensorFlow.
+
+ PT+TF tests are skipped by default and we can run only them by setting RUN_PT_TF_CROSS_TESTS environment variable
+ to a truthy value and selecting the is_pt_tf_cross_test pytest mark.
+
+ """
+ if not _run_pt_tf_cross_tests or not is_torch_available() or not is_tf_available():
+ return unittest.skip(reason="test is PT+TF test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_pt_tf_cross_test()(test_case)
+
+
+def is_pt_flax_cross_test(test_case):
+ """
+ Decorator marking a test as a test that control interactions between PyTorch and Flax
+
+ PT+FLAX tests are skipped by default and we can run only them by setting RUN_PT_FLAX_CROSS_TESTS environment
+ variable to a truthy value and selecting the is_pt_flax_cross_test pytest mark.
+
+ """
+ if not _run_pt_flax_cross_tests or not is_torch_available() or not is_flax_available():
+ return unittest.skip(reason="test is PT+FLAX test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_pt_flax_cross_test()(test_case)
+
+
+def is_staging_test(test_case):
+ """
+ Decorator marking a test as a staging test.
+
+ Those tests will run using the staging environment of huggingface.co instead of the real model hub.
+ """
+ if not _run_staging:
+ return unittest.skip(reason="test is staging test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_staging_test()(test_case)
+
+
+def is_pipeline_test(test_case):
+ """
+ Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be
+ skipped.
+ """
+ if not _run_pipeline_tests:
+ return unittest.skip(reason="test is pipeline test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_pipeline_test()(test_case)
+
+
+def is_agent_test(test_case):
+ """
+ Decorator marking a test as an agent test. If RUN_TOOL_TESTS is set to a falsy value, those tests will be skipped.
+ """
+ if not _run_agent_tests:
+ return unittest.skip(reason="test is an agent test")(test_case)
+ else:
+ try:
+ import pytest # We don't need a hard dependency on pytest in the main library
+ except ImportError:
+ return test_case
+ else:
+ return pytest.mark.is_agent_test()(test_case)
+
+
+def slow(test_case):
+ """
+ Decorator marking a test as slow.
+
+ Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
+
+ """
+ return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
+
+
+def tooslow(test_case):
+ """
+ Decorator marking a test as too slow.
+
+ Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as
+ these will not be tested by the CI.
+
+ """
+ return unittest.skip(reason="test is too slow")(test_case)
+
+
+def skip_if_not_implemented(test_func):
+ @functools.wraps(test_func)
+ def wrapper(*args, **kwargs):
+ try:
+ return test_func(*args, **kwargs)
+ except NotImplementedError as e:
+ raise unittest.SkipTest(f"Test skipped due to NotImplementedError: {e}")
+
+ return wrapper
+
+
+def apply_skip_if_not_implemented(cls):
+ """
+ Class decorator to apply @skip_if_not_implemented to all test methods.
+ """
+ for attr_name in dir(cls):
+ if attr_name.startswith("test_"):
+ attr = getattr(cls, attr_name)
+ if callable(attr):
+ setattr(cls, attr_name, skip_if_not_implemented(attr))
+ return cls
+
+
+def custom_tokenizers(test_case):
+ """
+ Decorator marking a test for a custom tokenizer.
+
+ Custom tokenizers require additional dependencies, and are skipped by default. Set the RUN_CUSTOM_TOKENIZERS
+ environment variable to a truthy value to run them.
+ """
+ return unittest.skipUnless(_run_custom_tokenizers, "test of custom tokenizers")(test_case)
+
+
+def require_bs4(test_case):
+ """
+ Decorator marking a test that requires BeautifulSoup4. These tests are skipped when BeautifulSoup4 isn't installed.
+ """
+ return unittest.skipUnless(is_bs4_available(), "test requires BeautifulSoup4")(test_case)
+
+
+def require_galore_torch(test_case):
+ """
+ Decorator marking a test that requires GaLore. These tests are skipped when GaLore isn't installed.
+ https://github.com/jiaweizzhao/GaLore
+ """
+ return unittest.skipUnless(is_galore_torch_available(), "test requires GaLore")(test_case)
+
+
+def require_lomo(test_case):
+ """
+ Decorator marking a test that requires LOMO. These tests are skipped when LOMO-optim isn't installed.
+ https://github.com/OpenLMLab/LOMO
+ """
+ return unittest.skipUnless(is_lomo_available(), "test requires LOMO")(test_case)
+
+
+def require_grokadamw(test_case):
+ """
+ Decorator marking a test that requires GrokAdamW. These tests are skipped when GrokAdamW isn't installed.
+ """
+ return unittest.skipUnless(is_grokadamw_available(), "test requires GrokAdamW")(test_case)
+
+
+def require_schedulefree(test_case):
+ """
+ Decorator marking a test that requires schedulefree. These tests are skipped when schedulefree isn't installed.
+ https://github.com/facebookresearch/schedule_free
+ """
+ return unittest.skipUnless(is_schedulefree_available(), "test requires schedulefree")(test_case)
+
+
+def require_cv2(test_case):
+ """
+ Decorator marking a test that requires OpenCV.
+
+ These tests are skipped when OpenCV isn't installed.
+
+ """
+ return unittest.skipUnless(is_cv2_available(), "test requires OpenCV")(test_case)
+
+
+def require_levenshtein(test_case):
+ """
+ Decorator marking a test that requires Levenshtein.
+
+ These tests are skipped when Levenshtein isn't installed.
+
+ """
+ return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case)
+
+
+def require_nltk(test_case):
+ """
+ Decorator marking a test that requires NLTK.
+
+ These tests are skipped when NLTK isn't installed.
+
+ """
+ return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case)
+
+
+def require_accelerate(test_case, min_version: str = ACCELERATE_MIN_VERSION):
+ """
+ Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
+ """
+ return unittest.skipUnless(
+ is_accelerate_available(min_version), f"test requires accelerate version >= {min_version}"
+ )(test_case)
+
+
+def require_gguf(test_case, min_version: str = GGUF_MIN_VERSION):
+ """
+ Decorator marking a test that requires ggguf. These tests are skipped when gguf isn't installed.
+ """
+ return unittest.skipUnless(is_gguf_available(min_version), f"test requires gguf version >= {min_version}")(
+ test_case
+ )
+
+
+def require_fsdp(test_case, min_version: str = "1.12.0"):
+ """
+ Decorator marking a test that requires fsdp. These tests are skipped when fsdp isn't installed.
+ """
+ return unittest.skipUnless(is_fsdp_available(min_version), f"test requires torch version >= {min_version}")(
+ test_case
+ )
+
+
+def require_g2p_en(test_case):
+ """
+ Decorator marking a test that requires g2p_en. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_g2p_en_available(), "test requires g2p_en")(test_case)
+
+
+def require_safetensors(test_case):
+ """
+ Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed.
+ """
+ return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case)
+
+
+def require_rjieba(test_case):
+ """
+ Decorator marking a test that requires rjieba. These tests are skipped when rjieba isn't installed.
+ """
+ return unittest.skipUnless(is_rjieba_available(), "test requires rjieba")(test_case)
+
+
+def require_jieba(test_case):
+ """
+ Decorator marking a test that requires jieba. These tests are skipped when jieba isn't installed.
+ """
+ return unittest.skipUnless(is_jieba_available(), "test requires jieba")(test_case)
+
+
+def require_jinja(test_case):
+ """
+ Decorator marking a test that requires jinja. These tests are skipped when jinja isn't installed.
+ """
+ return unittest.skipUnless(is_jinja_available(), "test requires jinja")(test_case)
+
+
+def require_tf2onnx(test_case):
+ return unittest.skipUnless(is_tf2onnx_available(), "test requires tf2onnx")(test_case)
+
+
+def require_onnx(test_case):
+ return unittest.skipUnless(is_onnx_available(), "test requires ONNX")(test_case)
+
+
+def require_timm(test_case):
+ """
+ Decorator marking a test that requires Timm.
+
+ These tests are skipped when Timm isn't installed.
+
+ """
+ return unittest.skipUnless(is_timm_available(), "test requires Timm")(test_case)
+
+
+def require_natten(test_case):
+ """
+ Decorator marking a test that requires NATTEN.
+
+ These tests are skipped when NATTEN isn't installed.
+
+ """
+ return unittest.skipUnless(is_natten_available(), "test requires natten")(test_case)
+
+
+def require_torch(test_case):
+ """
+ Decorator marking a test that requires PyTorch.
+
+ These tests are skipped when PyTorch isn't installed.
+
+ """
+ return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
+
+
+def require_flash_attn(test_case):
+ """
+ Decorator marking a test that requires Flash Attention.
+
+ These tests are skipped when Flash Attention isn't installed.
+
+ """
+ return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
+
+
+def require_torch_sdpa(test_case):
+ """
+ Decorator marking a test that requires PyTorch's SDPA.
+
+ These tests are skipped when requirements are not met (torch version).
+ """
+ return unittest.skipUnless(is_torch_sdpa_available(), "test requires PyTorch SDPA")(test_case)
+
+
+def require_read_token(fn):
+ """
+ A decorator that loads the HF token for tests that require to load gated models.
+ """
+ token = os.getenv("HF_HUB_READ_TOKEN")
+
+ @wraps(fn)
+ def _inner(*args, **kwargs):
+ if token is not None:
+ with patch("huggingface_hub.utils._headers.get_token", return_value=token):
+ return fn(*args, **kwargs)
+ else: # Allow running locally with the default token env variable
+ return fn(*args, **kwargs)
+
+ return _inner
+
+
+def require_peft(test_case):
+ """
+ Decorator marking a test that requires PEFT.
+
+ These tests are skipped when PEFT isn't installed.
+
+ """
+ return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
+
+
+def require_torchvision(test_case):
+ """
+ Decorator marking a test that requires Torchvision.
+
+ These tests are skipped when Torchvision isn't installed.
+
+ """
+ return unittest.skipUnless(is_torchvision_available(), "test requires Torchvision")(test_case)
+
+
+def require_torch_or_tf(test_case):
+ """
+ Decorator marking a test that requires PyTorch or TensorFlow.
+
+ These tests are skipped when neither PyTorch not TensorFlow is installed.
+
+ """
+ return unittest.skipUnless(is_torch_available() or is_tf_available(), "test requires PyTorch or TensorFlow")(
+ test_case
+ )
+
+
+def require_intel_extension_for_pytorch(test_case):
+ """
+ Decorator marking a test that requires Intel Extension for PyTorch.
+
+ These tests are skipped when Intel Extension for PyTorch isn't installed or it does not match current PyTorch
+ version.
+
+ """
+ return unittest.skipUnless(
+ is_ipex_available(),
+ "test requires Intel Extension for PyTorch to be installed and match current PyTorch version, see"
+ " https://github.com/intel/intel-extension-for-pytorch",
+ )(test_case)
+
+
+def require_tensorflow_probability(test_case):
+ """
+ Decorator marking a test that requires TensorFlow probability.
+
+ These tests are skipped when TensorFlow probability isn't installed.
+
+ """
+ return unittest.skipUnless(is_tensorflow_probability_available(), "test requires TensorFlow probability")(
+ test_case
+ )
+
+
+def require_torchaudio(test_case):
+ """
+ Decorator marking a test that requires torchaudio. These tests are skipped when torchaudio isn't installed.
+ """
+ return unittest.skipUnless(is_torchaudio_available(), "test requires torchaudio")(test_case)
+
+
+def require_tf(test_case):
+ """
+ Decorator marking a test that requires TensorFlow. These tests are skipped when TensorFlow isn't installed.
+ """
+ return unittest.skipUnless(is_tf_available(), "test requires TensorFlow")(test_case)
+
+
+def require_flax(test_case):
+ """
+ Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
+ """
+ return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
+
+
+def require_sentencepiece(test_case):
+ """
+ Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case)
+
+
+def require_sacremoses(test_case):
+ """
+ Decorator marking a test that requires Sacremoses. These tests are skipped when Sacremoses isn't installed.
+ """
+ return unittest.skipUnless(is_sacremoses_available(), "test requires Sacremoses")(test_case)
+
+
+def require_seqio(test_case):
+ """
+ Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_seqio_available(), "test requires Seqio")(test_case)
+
+
+def require_scipy(test_case):
+ """
+ Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed.
+ """
+ return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case)
+
+
+def require_tokenizers(test_case):
+ """
+ Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed.
+ """
+ return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case)
+
+
+def require_tensorflow_text(test_case):
+ """
+ Decorator marking a test that requires tensorflow_text. These tests are skipped when tensroflow_text isn't
+ installed.
+ """
+ return unittest.skipUnless(is_tensorflow_text_available(), "test requires tensorflow_text")(test_case)
+
+
+def require_keras_nlp(test_case):
+ """
+ Decorator marking a test that requires keras_nlp. These tests are skipped when keras_nlp isn't installed.
+ """
+ return unittest.skipUnless(is_keras_nlp_available(), "test requires keras_nlp")(test_case)
+
+
+def require_pandas(test_case):
+ """
+ Decorator marking a test that requires pandas. These tests are skipped when pandas isn't installed.
+ """
+ return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case)
+
+
+def require_pytesseract(test_case):
+ """
+ Decorator marking a test that requires PyTesseract. These tests are skipped when PyTesseract isn't installed.
+ """
+ return unittest.skipUnless(is_pytesseract_available(), "test requires PyTesseract")(test_case)
+
+
+def require_pytorch_quantization(test_case):
+ """
+ Decorator marking a test that requires PyTorch Quantization Toolkit. These tests are skipped when PyTorch
+ Quantization Toolkit isn't installed.
+ """
+ return unittest.skipUnless(is_pytorch_quantization_available(), "test requires PyTorch Quantization Toolkit")(
+ test_case
+ )
+
+
+def require_vision(test_case):
+ """
+ Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't
+ installed.
+ """
+ return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case)
+
+
+def require_ftfy(test_case):
+ """
+ Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed.
+ """
+ return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case)
+
+
+def require_spacy(test_case):
+ """
+ Decorator marking a test that requires SpaCy. These tests are skipped when SpaCy isn't installed.
+ """
+ return unittest.skipUnless(is_spacy_available(), "test requires spacy")(test_case)
+
+
+def require_torch_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple GPUs.
+
+ To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests -k "multi_gpu"
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ device_count = get_device_count()
+
+ return unittest.skipUnless(device_count > 1, "test requires multiple GPUs")(test_case)
+
+
+def require_torch_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires a multi-accelerator (in PyTorch). These tests are skipped on a machine
+ without multiple accelerators. To run *only* the multi_accelerator tests, assuming all test names contain
+ multi_accelerator: $ pytest -sv ./tests -k "multi_accelerator"
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ return unittest.skipUnless(backend_device_count(torch_device) > 1, "test requires multiple accelerators")(
+ test_case
+ )
+
+
+def require_torch_non_multi_gpu(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 GPU setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() < 2, "test requires 0 or 1 GPU")(test_case)
+
+
+def require_torch_non_multi_accelerator(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 accelerator setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ return unittest.skipUnless(backend_device_count(torch_device) < 2, "test requires 0 or 1 accelerator")(test_case)
+
+
+def require_torch_up_to_2_gpus(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 or 2 GPU setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ import torch
+
+ return unittest.skipUnless(torch.cuda.device_count() < 3, "test requires 0 or 1 or 2 GPUs")(test_case)
+
+
+def require_torch_up_to_2_accelerators(test_case):
+ """
+ Decorator marking a test that requires 0 or 1 or 2 accelerator setup (in PyTorch).
+ """
+ if not is_torch_available():
+ return unittest.skip(reason="test requires PyTorch")(test_case)
+
+ return unittest.skipUnless(backend_device_count(torch_device) < 3, "test requires 0 or 1 or 2 accelerators")(
+ test_case
+ )
+
+
+def require_torch_xla(test_case):
+ """
+ Decorator marking a test that requires TorchXLA (in PyTorch).
+ """
+ return unittest.skipUnless(is_torch_xla_available(), "test requires TorchXLA")(test_case)
+
+
+def require_torch_neuroncore(test_case):
+ """
+ Decorator marking a test that requires NeuronCore (in PyTorch).
+ """
+ return unittest.skipUnless(is_torch_neuroncore_available(check_device=False), "test requires PyTorch NeuronCore")(
+ test_case
+ )
+
+
+def require_torch_npu(test_case):
+ """
+ Decorator marking a test that requires NPU (in PyTorch).
+ """
+ return unittest.skipUnless(is_torch_npu_available(), "test requires PyTorch NPU")(test_case)
+
+
+def require_torch_multi_npu(test_case):
+ """
+ Decorator marking a test that requires a multi-NPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple NPUs.
+
+ To run *only* the multi_npu tests, assuming all test names contain multi_npu: $ pytest -sv ./tests -k "multi_npu"
+ """
+ if not is_torch_npu_available():
+ return unittest.skip(reason="test requires PyTorch NPU")(test_case)
+
+ return unittest.skipUnless(torch.npu.device_count() > 1, "test requires multiple NPUs")(test_case)
+
+
+def require_torch_xpu(test_case):
+ """
+ Decorator marking a test that requires XPU (in PyTorch).
+
+ These tests are skipped when XPU backend is not available. XPU backend might be available either via stock
+ PyTorch (>=2.4) or via Intel Extension for PyTorch. In the latter case, if IPEX is installed, its version
+ must match match current PyTorch version.
+ """
+ return unittest.skipUnless(is_torch_xpu_available(), "test requires XPU device")(test_case)
+
+
+def require_non_xpu(test_case):
+ """
+ Decorator marking a test that should be skipped for XPU.
+ """
+ return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case)
+
+
+def require_torch_multi_xpu(test_case):
+ """
+ Decorator marking a test that requires a multi-XPU setup (in PyTorch). These tests are skipped on a machine without
+ multiple XPUs.
+
+ To run *only* the multi_xpu tests, assuming all test names contain multi_xpu: $ pytest -sv ./tests -k "multi_xpu"
+ """
+ if not is_torch_xpu_available():
+ return unittest.skip(reason="test requires PyTorch XPU")(test_case)
+
+ return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case)
+
+
+if is_torch_available():
+ # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
+ import torch
+
+ if "TRANSFORMERS_TEST_BACKEND" in os.environ:
+ backend = os.environ["TRANSFORMERS_TEST_BACKEND"]
+ try:
+ _ = importlib.import_module(backend)
+ except ModuleNotFoundError as e:
+ raise ModuleNotFoundError(
+ f"Failed to import `TRANSFORMERS_TEST_BACKEND` '{backend}'! This should be the name of an installed module. The original error (look up to see its"
+ f" traceback):\n{e}"
+ ) from e
+
+ if "TRANSFORMERS_TEST_DEVICE" in os.environ:
+ torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
+ if torch_device == "cuda" and not torch.cuda.is_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but CUDA is unavailable. Please double-check your testing environment."
+ )
+ if torch_device == "xpu" and not is_torch_xpu_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but XPU is unavailable. Please double-check your testing environment."
+ )
+ if torch_device == "npu" and not is_torch_npu_available():
+ raise ValueError(
+ f"TRANSFORMERS_TEST_DEVICE={torch_device}, but NPU is unavailable. Please double-check your testing environment."
+ )
+
+ try:
+ # try creating device to see if provided device is valid
+ _ = torch.device(torch_device)
+ except RuntimeError as e:
+ raise RuntimeError(
+ f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
+ ) from e
+ elif torch.cuda.is_available():
+ torch_device = "cuda"
+ elif _run_third_party_device_tests and is_torch_npu_available():
+ torch_device = "npu"
+ elif _run_third_party_device_tests and is_torch_xpu_available():
+ torch_device = "xpu"
+ else:
+ torch_device = "cpu"
+else:
+ torch_device = None
+
+if is_tf_available():
+ import tensorflow as tf
+
+if is_flax_available():
+ import jax
+
+ jax_device = jax.default_backend()
+else:
+ jax_device = None
+
+
+def require_torchdynamo(test_case):
+ """Decorator marking a test that requires TorchDynamo"""
+ return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
+
+
+def require_torchao(test_case):
+ """Decorator marking a test that requires torchao"""
+ return unittest.skipUnless(is_torchao_available(), "test requires torchao")(test_case)
+
+
+def require_torch_tensorrt_fx(test_case):
+ """Decorator marking a test that requires Torch-TensorRT FX"""
+ return unittest.skipUnless(is_torch_tensorrt_fx_available(), "test requires Torch-TensorRT FX")(test_case)
+
+
+def require_torch_gpu(test_case):
+ """Decorator marking a test that requires CUDA and PyTorch."""
+ return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
+
+
+def require_torch_gpu_if_bnb_not_multi_backend_enabled(test_case):
+ """
+ Decorator marking a test that requires a GPU if bitsandbytes multi-backend feature is not enabled.
+ """
+ if is_bitsandbytes_available() and is_bitsandbytes_multi_backend_available():
+ return test_case
+ return require_torch_gpu(test_case)
+
+
+def require_torch_accelerator(test_case):
+ """Decorator marking a test that requires an accessible accelerator and PyTorch."""
+ return unittest.skipUnless(torch_device is not None and torch_device != "cpu", "test requires accelerator")(
+ test_case
+ )
+
+
+def require_torch_fp16(test_case):
+ """Decorator marking a test that requires a device that supports fp16"""
+ return unittest.skipUnless(
+ is_torch_fp16_available_on_device(torch_device), "test requires device with fp16 support"
+ )(test_case)
+
+
+def require_fp8(test_case):
+ """Decorator marking a test that requires supports for fp8"""
+ return unittest.skipUnless(is_accelerate_available() and is_fp8_available(), "test requires fp8 support")(
+ test_case
+ )
+
+
+def require_torch_bf16(test_case):
+ """Decorator marking a test that requires a device that supports bf16"""
+ return unittest.skipUnless(
+ is_torch_bf16_available_on_device(torch_device), "test requires device with bf16 support"
+ )(test_case)
+
+
+def require_torch_bf16_gpu(test_case):
+ """Decorator marking a test that requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0"""
+ return unittest.skipUnless(
+ is_torch_bf16_gpu_available(),
+ "test requires torch>=1.10, using Ampere GPU or newer arch with cuda>=11.0",
+ )(test_case)
+
+
+def require_torch_bf16_cpu(test_case):
+ """Decorator marking a test that requires torch>=1.10, using CPU."""
+ return unittest.skipUnless(
+ is_torch_bf16_cpu_available(),
+ "test requires torch>=1.10, using CPU",
+ )(test_case)
+
+
+def require_deterministic_for_xpu(test_case):
+ if is_torch_xpu_available():
+ return unittest.skipUnless(is_torch_deterministic(), "test requires torch to use deterministic algorithms")(
+ test_case
+ )
+ else:
+ return test_case
+
+
+def require_torch_tf32(test_case):
+ """Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
+ return unittest.skipUnless(
+ is_torch_tf32_available(), "test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7"
+ )(test_case)
+
+
+def require_detectron2(test_case):
+ """Decorator marking a test that requires detectron2."""
+ return unittest.skipUnless(is_detectron2_available(), "test requires `detectron2`")(test_case)
+
+
+def require_faiss(test_case):
+ """Decorator marking a test that requires faiss."""
+ return unittest.skipUnless(is_faiss_available(), "test requires `faiss`")(test_case)
+
+
+def require_optuna(test_case):
+ """
+ Decorator marking a test that requires optuna.
+
+ These tests are skipped when optuna isn't installed.
+
+ """
+ return unittest.skipUnless(is_optuna_available(), "test requires optuna")(test_case)
+
+
+def require_ray(test_case):
+ """
+ Decorator marking a test that requires Ray/tune.
+
+ These tests are skipped when Ray/tune isn't installed.
+
+ """
+ return unittest.skipUnless(is_ray_available(), "test requires Ray/tune")(test_case)
+
+
+def require_sigopt(test_case):
+ """
+ Decorator marking a test that requires SigOpt.
+
+ These tests are skipped when SigOpt isn't installed.
+
+ """
+ return unittest.skipUnless(is_sigopt_available(), "test requires SigOpt")(test_case)
+
+
+def require_wandb(test_case):
+ """
+ Decorator marking a test that requires wandb.
+
+ These tests are skipped when wandb isn't installed.
+
+ """
+ return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case)
+
+
+def require_clearml(test_case):
+ """
+ Decorator marking a test requires clearml.
+
+ These tests are skipped when clearml isn't installed.
+
+ """
+ return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)
+
+
+def require_soundfile(test_case):
+ """
+ Decorator marking a test that requires soundfile
+
+ These tests are skipped when soundfile isn't installed.
+
+ """
+ return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")(test_case)
+
+
+def require_deepspeed(test_case):
+ """
+ Decorator marking a test that requires deepspeed
+ """
+ return unittest.skipUnless(is_deepspeed_available(), "test requires deepspeed")(test_case)
+
+
+def require_apex(test_case):
+ """
+ Decorator marking a test that requires apex
+ """
+ return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
+
+
+def require_aqlm(test_case):
+ """
+ Decorator marking a test that requires aqlm
+ """
+ return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
+
+
+def require_eetq(test_case):
+ """
+ Decorator marking a test that requires eetq
+ """
+ eetq_available = is_eetq_available()
+ if eetq_available:
+ try:
+ import eetq # noqa: F401
+ except ImportError as exc:
+ if "shard_checkpoint" in str(exc):
+ # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
+ # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
+ # TODO: Remove once eetq releases a fix and this release is used in CI
+ eetq_available = False
+ return unittest.skipUnless(eetq_available, "test requires eetq")(test_case)
+
+
+def require_av(test_case):
+ """
+ Decorator marking a test that requires av
+ """
+ return unittest.skipUnless(is_av_available(), "test requires av")(test_case)
+
+
+def require_bitsandbytes(test_case):
+ """
+ Decorator marking a test that requires the bitsandbytes library. Will be skipped when the library or its hard dependency torch is not installed.
+ """
+ if is_bitsandbytes_available() and is_torch_available():
+ try:
+ import pytest
+
+ return pytest.mark.bitsandbytes(test_case)
+ except ImportError:
+ return test_case
+ else:
+ return unittest.skip(reason="test requires bitsandbytes and torch")(test_case)
+
+
+def require_optimum(test_case):
+ """
+ Decorator for optimum dependency
+ """
+ return unittest.skipUnless(is_optimum_available(), "test requires optimum")(test_case)
+
+
+def require_tensorboard(test_case):
+ """
+ Decorator for `tensorboard` dependency
+ """
+ return unittest.skipUnless(is_tensorboard_available(), "test requires tensorboard")
+
+
+def require_gptq(test_case):
+ """
+ Decorator for auto_gptq dependency
+ """
+ return unittest.skipUnless(
+ is_gptqmodel_available() or is_auto_gptq_available(), "test requires gptqmodel or auto-gptq"
+ )(test_case)
+
+
+def require_auto_awq(test_case):
+ """
+ Decorator for auto_awq dependency
+ """
+ return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
+
+
+def require_optimum_quanto(test_case):
+ """
+ Decorator for quanto dependency
+ """
+ return unittest.skipUnless(is_optimum_quanto_available(), "test requires optimum-quanto")(test_case)
+
+
+def require_compressed_tensors(test_case):
+ """
+ Decorator for compressed_tensors dependency
+ """
+ return unittest.skipUnless(is_compressed_tensors_available(), "test requires compressed_tensors")(test_case)
+
+
+def require_fbgemm_gpu(test_case):
+ """
+ Decorator for fbgemm_gpu dependency
+ """
+ return unittest.skipUnless(is_fbgemm_gpu_available(), "test requires fbgemm-gpu")(test_case)
+
+
+def require_phonemizer(test_case):
+ """
+ Decorator marking a test that requires phonemizer
+ """
+ return unittest.skipUnless(is_phonemizer_available(), "test requires phonemizer")(test_case)
+
+
+def require_pyctcdecode(test_case):
+ """
+ Decorator marking a test that requires pyctcdecode
+ """
+ return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case)
+
+
+def require_librosa(test_case):
+ """
+ Decorator marking a test that requires librosa
+ """
+ return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case)
+
+
+def require_liger_kernel(test_case):
+ """
+ Decorator marking a test that requires liger_kernel
+ """
+ return unittest.skipUnless(is_liger_kernel_available(), "test requires liger_kernel")(test_case)
+
+
+def require_essentia(test_case):
+ """
+ Decorator marking a test that requires essentia
+ """
+ return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case)
+
+
+def require_pretty_midi(test_case):
+ """
+ Decorator marking a test that requires pretty_midi
+ """
+ return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case)
+
+
+def cmd_exists(cmd):
+ return shutil.which(cmd) is not None
+
+
+def require_usr_bin_time(test_case):
+ """
+ Decorator marking a test that requires `/usr/bin/time`
+ """
+ return unittest.skipUnless(cmd_exists("/usr/bin/time"), "test requires /usr/bin/time")(test_case)
+
+
+def require_sudachi(test_case):
+ """
+ Decorator marking a test that requires sudachi
+ """
+ return unittest.skipUnless(is_sudachi_available(), "test requires sudachi")(test_case)
+
+
+def require_sudachi_projection(test_case):
+ """
+ Decorator marking a test that requires sudachi_projection
+ """
+ return unittest.skipUnless(is_sudachi_projection_available(), "test requires sudachi which supports projection")(
+ test_case
+ )
+
+
+def require_jumanpp(test_case):
+ """
+ Decorator marking a test that requires jumanpp
+ """
+ return unittest.skipUnless(is_jumanpp_available(), "test requires jumanpp")(test_case)
+
+
+def require_cython(test_case):
+ """
+ Decorator marking a test that requires jumanpp
+ """
+ return unittest.skipUnless(is_cython_available(), "test requires cython")(test_case)
+
+
+def require_tiktoken(test_case):
+ """
+ Decorator marking a test that requires TikToken. These tests are skipped when TikToken isn't installed.
+ """
+ return unittest.skipUnless(is_tiktoken_available(), "test requires TikToken")(test_case)
+
+
+def get_gpu_count():
+ """
+ Return the number of available gpus (regardless of whether torch, tf or jax is used)
+ """
+ if is_torch_available():
+ import torch
+
+ return torch.cuda.device_count()
+ elif is_tf_available():
+ import tensorflow as tf
+
+ return len(tf.config.list_physical_devices("GPU"))
+ elif is_flax_available():
+ import jax
+
+ return jax.device_count()
+ else:
+ return 0
+
+
+def get_tests_dir(append_path=None):
+ """
+ Args:
+ append_path: optional path to append to the tests dir path
+
+ Return:
+ The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is
+ joined after the `tests` dir the former is provided.
+
+ """
+ # this function caller's __file__
+ caller__file__ = inspect.stack()[1][1]
+ tests_dir = os.path.abspath(os.path.dirname(caller__file__))
+
+ while not tests_dir.endswith("tests"):
+ tests_dir = os.path.dirname(tests_dir)
+
+ if append_path:
+ return os.path.join(tests_dir, append_path)
+ else:
+ return tests_dir
+
+
+#
+# Helper functions for dealing with testing text outputs
+# The original code came from:
+# https://github.com/fastai/fastai/blob/master/tests/utils/text.py
+
+
+# When any function contains print() calls that get overwritten, like progress bars,
+# a special care needs to be applied, since under pytest -s captured output (capsys
+# or contextlib.redirect_stdout) contains any temporary printed strings, followed by
+# \r's. This helper function ensures that the buffer will contain the same output
+# with and without -s in pytest, by turning:
+# foo bar\r tar mar\r final message
+# into:
+# final message
+# it can handle a single string or a multiline buffer
+def apply_print_resets(buf):
+ return re.sub(r"^.*\r", "", buf, 0, re.M)
+
+
+def assert_screenout(out, what):
+ out_pr = apply_print_resets(out).lower()
+ match_str = out_pr.find(what.lower())
+ assert match_str != -1, f"expecting to find {what} in output: f{out_pr}"
+
+
+class CaptureStd:
+ """
+ Context manager to capture:
+
+ - stdout: replay it, clean it up and make it available via `obj.out`
+ - stderr: replay it and make it available via `obj.err`
+
+ Args:
+ out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not.
+ err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not.
+ replay (`bool`, *optional*, defaults to `True`): Whether to replay or not.
+ By default each captured stream gets replayed back on context's exit, so that one can see what the test was
+ doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to
+ disable this feature.
+
+ Examples:
+
+ ```python
+ # to capture stdout only with auto-replay
+ with CaptureStdout() as cs:
+ print("Secret message")
+ assert "message" in cs.out
+
+ # to capture stderr only with auto-replay
+ import sys
+
+ with CaptureStderr() as cs:
+ print("Warning: ", file=sys.stderr)
+ assert "Warning" in cs.err
+
+ # to capture both streams with auto-replay
+ with CaptureStd() as cs:
+ print("Secret message")
+ print("Warning: ", file=sys.stderr)
+ assert "message" in cs.out
+ assert "Warning" in cs.err
+
+ # to capture just one of the streams, and not the other, with auto-replay
+ with CaptureStd(err=False) as cs:
+ print("Secret message")
+ assert "message" in cs.out
+ # but best use the stream-specific subclasses
+
+ # to capture without auto-replay
+ with CaptureStd(replay=False) as cs:
+ print("Secret message")
+ assert "message" in cs.out
+ ```"""
+
+ def __init__(self, out=True, err=True, replay=True):
+ self.replay = replay
+
+ if out:
+ self.out_buf = StringIO()
+ self.out = "error: CaptureStd context is unfinished yet, called too early"
+ else:
+ self.out_buf = None
+ self.out = "not capturing stdout"
+
+ if err:
+ self.err_buf = StringIO()
+ self.err = "error: CaptureStd context is unfinished yet, called too early"
+ else:
+ self.err_buf = None
+ self.err = "not capturing stderr"
+
+ def __enter__(self):
+ if self.out_buf:
+ self.out_old = sys.stdout
+ sys.stdout = self.out_buf
+
+ if self.err_buf:
+ self.err_old = sys.stderr
+ sys.stderr = self.err_buf
+
+ return self
+
+ def __exit__(self, *exc):
+ if self.out_buf:
+ sys.stdout = self.out_old
+ captured = self.out_buf.getvalue()
+ if self.replay:
+ sys.stdout.write(captured)
+ self.out = apply_print_resets(captured)
+
+ if self.err_buf:
+ sys.stderr = self.err_old
+ captured = self.err_buf.getvalue()
+ if self.replay:
+ sys.stderr.write(captured)
+ self.err = captured
+
+ def __repr__(self):
+ msg = ""
+ if self.out_buf:
+ msg += f"stdout: {self.out}\n"
+ if self.err_buf:
+ msg += f"stderr: {self.err}\n"
+ return msg
+
+
+# in tests it's the best to capture only the stream that's wanted, otherwise
+# it's easy to miss things, so unless you need to capture both streams, use the
+# subclasses below (less typing). Or alternatively, configure `CaptureStd` to
+# disable the stream you don't need to test.
+
+
+class CaptureStdout(CaptureStd):
+ """Same as CaptureStd but captures only stdout"""
+
+ def __init__(self, replay=True):
+ super().__init__(err=False, replay=replay)
+
+
+class CaptureStderr(CaptureStd):
+ """Same as CaptureStd but captures only stderr"""
+
+ def __init__(self, replay=True):
+ super().__init__(out=False, replay=replay)
+
+
+class CaptureLogger:
+ """
+ Context manager to capture `logging` streams
+
+ Args:
+ logger: 'logging` logger object
+
+ Returns:
+ The captured output is available via `self.out`
+
+ Example:
+
+ ```python
+ >>> from transformers import logging
+ >>> from transformers.testing_utils import CaptureLogger
+
+ >>> msg = "Testing 1, 2, 3"
+ >>> logging.set_verbosity_info()
+ >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart")
+ >>> with CaptureLogger(logger) as cl:
+ ... logger.info(msg)
+ >>> assert cl.out, msg + "\n"
+ ```
+ """
+
+ def __init__(self, logger):
+ self.logger = logger
+ self.io = StringIO()
+ self.sh = logging.StreamHandler(self.io)
+ self.out = ""
+
+ def __enter__(self):
+ self.logger.addHandler(self.sh)
+ return self
+
+ def __exit__(self, *exc):
+ self.logger.removeHandler(self.sh)
+ self.out = self.io.getvalue()
+
+ def __repr__(self):
+ return f"captured: {self.out}\n"
+
+
+@contextlib.contextmanager
+def LoggingLevel(level):
+ """
+ This is a context manager to temporarily change transformers modules logging level to the desired value and have it
+ restored to the original setting at the end of the scope.
+
+ Example:
+
+ ```python
+ with LoggingLevel(logging.INFO):
+ AutoModel.from_pretrained("openai-community/gpt2") # calls logger.info() several times
+ ```
+ """
+ orig_level = transformers_logging.get_verbosity()
+ try:
+ transformers_logging.set_verbosity(level)
+ yield
+ finally:
+ transformers_logging.set_verbosity(orig_level)
+
+
+class TemporaryHubRepo:
+ """Create a temporary Hub repository and return its `RepoUrl` object. This is similar to
+ `tempfile.TemporaryDirectory` and can be used as a context manager. For example:
+
+ with TemporaryHubRepo(token=self._token) as temp_repo:
+ ...
+
+ Upon exiting the context, the repository and everything contained in it are removed.
+
+ Example:
+
+ ```python
+ with TemporaryHubRepo(token=self._token) as temp_repo:
+ model.push_to_hub(tmp_repo.repo_id, token=self._token)
+ ```
+ """
+
+ def __init__(self, namespace: Optional[str] = None, token: Optional[str] = None) -> None:
+ self.token = token
+ with tempfile.TemporaryDirectory() as tmp_dir:
+ repo_id = Path(tmp_dir).name
+ if namespace is not None:
+ repo_id = f"{namespace}/{repo_id}"
+ self.repo_url = huggingface_hub.create_repo(repo_id, token=self.token)
+
+ def __enter__(self):
+ return self.repo_url
+
+ def __exit__(self, exc, value, tb):
+ delete_repo(repo_id=self.repo_url.repo_id, token=self.token, missing_ok=True)
+
+
+@contextlib.contextmanager
+# adapted from https://stackoverflow.com/a/64789046/9201239
+def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]:
+ """
+ Temporary add given path to `sys.path`.
+
+ Usage :
+
+ ```python
+ with ExtendSysPath("/path/to/dir"):
+ mymodule = importlib.import_module("mymodule")
+ ```
+ """
+
+ path = os.fspath(path)
+ try:
+ sys.path.insert(0, path)
+ yield
+ finally:
+ sys.path.remove(path)
+
+
+class TestCasePlus(unittest.TestCase):
+ """
+ This class extends *unittest.TestCase* with additional features.
+
+ Feature 1: A set of fully resolved important file and dir path accessors.
+
+ In tests often we need to know where things are relative to the current test file, and it's not trivial since the
+ test could be invoked from more than one directory or could reside in sub-directories with different depths. This
+ class solves this problem by sorting out all the basic paths and provides easy accessors to them:
+
+ - `pathlib` objects (all fully resolved):
+
+ - `test_file_path` - the current test file path (=`__file__`)
+ - `test_file_dir` - the directory containing the current test file
+ - `tests_dir` - the directory of the `tests` test suite
+ - `examples_dir` - the directory of the `examples` test suite
+ - `repo_root_dir` - the directory of the repository
+ - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides)
+
+ - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects:
+
+ - `test_file_path_str`
+ - `test_file_dir_str`
+ - `tests_dir_str`
+ - `examples_dir_str`
+ - `repo_root_dir_str`
+ - `src_dir_str`
+
+ Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test.
+
+ 1. Create a unique temporary dir:
+
+ ```python
+ def test_whatever(self):
+ tmp_dir = self.get_auto_remove_tmp_dir()
+ ```
+
+ `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the
+ test.
+
+
+ 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't
+ empty it after the test.
+
+ ```python
+ def test_whatever(self):
+ tmp_dir = self.get_auto_remove_tmp_dir("./xxx")
+ ```
+
+ This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests
+ didn't leave any data in there.
+
+ 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the
+ following behavior:
+
+ `before=True`: the temporary dir will always be cleared at the beginning of the test.
+
+ `before=False`: if the temporary dir already existed, any existing files will remain there.
+
+ `after=True`: the temporary dir will always be deleted at the end of the test.
+
+ `after=False`: the temporary dir will always be left intact at the end of the test.
+
+ Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are
+ allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem
+ will get nuked. i.e. please always pass paths that start with `./`
+
+ Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested
+ otherwise.
+
+ Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This
+ is useful for invoking external programs from the test suite - e.g. distributed training.
+
+
+ ```python
+ def test_whatever(self):
+ env = self.get_env()
+ ```"""
+
+ def setUp(self):
+ # get_auto_remove_tmp_dir feature:
+ self.teardown_tmp_dirs = []
+
+ # figure out the resolved paths for repo_root, tests, examples, etc.
+ self._test_file_path = inspect.getfile(self.__class__)
+ path = Path(self._test_file_path).resolve()
+ self._test_file_dir = path.parents[0]
+ for up in [1, 2, 3]:
+ tmp_dir = path.parents[up]
+ if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir():
+ break
+ if tmp_dir:
+ self._repo_root_dir = tmp_dir
+ else:
+ raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}")
+ self._tests_dir = self._repo_root_dir / "tests"
+ self._examples_dir = self._repo_root_dir / "examples"
+ self._src_dir = self._repo_root_dir / "src"
+
+ @property
+ def test_file_path(self):
+ return self._test_file_path
+
+ @property
+ def test_file_path_str(self):
+ return str(self._test_file_path)
+
+ @property
+ def test_file_dir(self):
+ return self._test_file_dir
+
+ @property
+ def test_file_dir_str(self):
+ return str(self._test_file_dir)
+
+ @property
+ def tests_dir(self):
+ return self._tests_dir
+
+ @property
+ def tests_dir_str(self):
+ return str(self._tests_dir)
+
+ @property
+ def examples_dir(self):
+ return self._examples_dir
+
+ @property
+ def examples_dir_str(self):
+ return str(self._examples_dir)
+
+ @property
+ def repo_root_dir(self):
+ return self._repo_root_dir
+
+ @property
+ def repo_root_dir_str(self):
+ return str(self._repo_root_dir)
+
+ @property
+ def src_dir(self):
+ return self._src_dir
+
+ @property
+ def src_dir_str(self):
+ return str(self._src_dir)
+
+ def get_env(self):
+ """
+ Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's
+ invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training.
+
+ It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally
+ the preset `PYTHONPATH` if any (all full resolved paths).
+
+ """
+ env = os.environ.copy()
+ paths = [self.src_dir_str]
+ if "/examples" in self.test_file_dir_str:
+ paths.append(self.examples_dir_str)
+ else:
+ paths.append(self.tests_dir_str)
+ paths.append(env.get("PYTHONPATH", ""))
+
+ env["PYTHONPATH"] = ":".join(paths)
+ return env
+
+ def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None):
+ """
+ Args:
+ tmp_dir (`string`, *optional*):
+ if `None`:
+
+ - a unique temporary path will be created
+ - sets `before=True` if `before` is `None`
+ - sets `after=True` if `after` is `None`
+ else:
+
+ - `tmp_dir` will be created
+ - sets `before=True` if `before` is `None`
+ - sets `after=False` if `after` is `None`
+ before (`bool`, *optional*):
+ If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the
+ `tmp_dir` already exists, any existing files will remain there.
+ after (`bool`, *optional*):
+ If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents
+ intact at the end of the test.
+
+ Returns:
+ tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir
+ """
+ if tmp_dir is not None:
+ # defining the most likely desired behavior for when a custom path is provided.
+ # this most likely indicates the debug mode where we want an easily locatable dir that:
+ # 1. gets cleared out before the test (if it already exists)
+ # 2. is left intact after the test
+ if before is None:
+ before = True
+ if after is None:
+ after = False
+
+ # using provided path
+ path = Path(tmp_dir).resolve()
+
+ # to avoid nuking parts of the filesystem, only relative paths are allowed
+ if not tmp_dir.startswith("./"):
+ raise ValueError(
+ f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
+ )
+
+ # ensure the dir is empty to start with
+ if before is True and path.exists():
+ shutil.rmtree(tmp_dir, ignore_errors=True)
+
+ path.mkdir(parents=True, exist_ok=True)
+
+ else:
+ # defining the most likely desired behavior for when a unique tmp path is auto generated
+ # (not a debug mode), here we require a unique tmp dir that:
+ # 1. is empty before the test (it will be empty in this situation anyway)
+ # 2. gets fully removed after the test
+ if before is None:
+ before = True
+ if after is None:
+ after = True
+
+ # using unique tmp dir (always empty, regardless of `before`)
+ tmp_dir = tempfile.mkdtemp()
+
+ if after is True:
+ # register for deletion
+ self.teardown_tmp_dirs.append(tmp_dir)
+
+ return tmp_dir
+
+ def python_one_liner_max_rss(self, one_liner_str):
+ """
+ Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the
+ program.
+
+ Args:
+ one_liner_str (`string`):
+ a python one liner code that gets passed to `python -c`
+
+ Returns:
+ max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run.
+
+ Requirements:
+ this helper needs `/usr/bin/time` to be installed (`apt install time`)
+
+ Example:
+
+ ```
+ one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("google-t5/t5-large")'
+ max_rss = self.python_one_liner_max_rss(one_liner_str)
+ ```
+ """
+
+ if not cmd_exists("/usr/bin/time"):
+ raise ValueError("/usr/bin/time is required, install with `apt install time`")
+
+ cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'")
+ with CaptureStd() as cs:
+ execute_subprocess_async(cmd, env=self.get_env())
+ # returned data is in KB so convert to bytes
+ max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024
+ return max_rss
+
+ def tearDown(self):
+ # get_auto_remove_tmp_dir feature: remove registered temp dirs
+ for path in self.teardown_tmp_dirs:
+ shutil.rmtree(path, ignore_errors=True)
+ self.teardown_tmp_dirs = []
+ if is_accelerate_available():
+ AcceleratorState._reset_state()
+ PartialState._reset_state()
+
+ # delete all the env variables having `ACCELERATE` in them
+ for k in list(os.environ.keys()):
+ if "ACCELERATE" in k:
+ del os.environ[k]
+
+
+def mockenv(**kwargs):
+ """
+ this is a convenience wrapper, that allows this ::
+
+ @mockenv(RUN_SLOW=True, USE_TF=False) def test_something():
+ run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False)
+
+ """
+ return mock.patch.dict(os.environ, kwargs)
+
+
+# from https://stackoverflow.com/a/34333710/9201239
+@contextlib.contextmanager
+def mockenv_context(*remove, **update):
+ """
+ Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv
+
+ The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations.
+
+ Args:
+ remove: Environment variables to remove.
+ update: Dictionary of environment variables and values to add/update.
+ """
+ env = os.environ
+ update = update or {}
+ remove = remove or []
+
+ # List of environment variables being updated or removed.
+ stomped = (set(update.keys()) | set(remove)) & set(env.keys())
+ # Environment variables and values to restore on exit.
+ update_after = {k: env[k] for k in stomped}
+ # Environment variables and values to remove on exit.
+ remove_after = frozenset(k for k in update if k not in env)
+
+ try:
+ env.update(update)
+ [env.pop(k, None) for k in remove]
+ yield
+ finally:
+ env.update(update_after)
+ [env.pop(k) for k in remove_after]
+
+
+# --- pytest conf functions --- #
+
+# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once
+pytest_opt_registered = {}
+
+
+def pytest_addoption_shared(parser):
+ """
+ This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there.
+
+ It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest`
+ option.
+
+ """
+ option = "--make-reports"
+ if option not in pytest_opt_registered:
+ parser.addoption(
+ option,
+ action="store",
+ default=False,
+ help="generate report files. The value of this option is used as a prefix to report names",
+ )
+ pytest_opt_registered[option] = 1
+
+
+def pytest_terminal_summary_main(tr, id):
+ """
+ Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current
+ directory. The report files are prefixed with the test suite name.
+
+ This function emulates --duration and -rA pytest arguments.
+
+ This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined
+ there.
+
+ Args:
+ - tr: `terminalreporter` passed from `conftest.py`
+ - id: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is
+ needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other.
+
+ NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal
+ changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-`
+ plugins and interfere.
+
+ """
+ from _pytest.config import create_terminal_writer
+
+ if not len(id):
+ id = "tests"
+
+ config = tr.config
+ orig_writer = config.get_terminal_writer()
+ orig_tbstyle = config.option.tbstyle
+ orig_reportchars = tr.reportchars
+
+ dir = f"reports/{id}"
+ Path(dir).mkdir(parents=True, exist_ok=True)
+ report_files = {
+ k: f"{dir}/{k}.txt"
+ for k in [
+ "durations",
+ "errors",
+ "failures_long",
+ "failures_short",
+ "failures_line",
+ "passes",
+ "stats",
+ "summary_short",
+ "warnings",
+ ]
+ }
+
+ # custom durations report
+ # note: there is no need to call pytest --durations=XX to get this separate report
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66
+ dlist = []
+ for replist in tr.stats.values():
+ for rep in replist:
+ if hasattr(rep, "duration"):
+ dlist.append(rep)
+ if dlist:
+ dlist.sort(key=lambda x: x.duration, reverse=True)
+ with open(report_files["durations"], "w") as f:
+ durations_min = 0.05 # sec
+ f.write("slowest durations\n")
+ for i, rep in enumerate(dlist):
+ if rep.duration < durations_min:
+ f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted")
+ break
+ f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n")
+
+ def summary_failures_short(tr):
+ # expecting that the reports were --tb=long (default) so we chop them off here to the last frame
+ reports = tr.getreports("failed")
+ if not reports:
+ return
+ tr.write_sep("=", "FAILURES SHORT STACK")
+ for rep in reports:
+ msg = tr._getfailureheadline(rep)
+ tr.write_sep("_", msg, red=True, bold=True)
+ # chop off the optional leading extra frames, leaving only the last one
+ longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S)
+ tr._tw.line(longrepr)
+ # note: not printing out any rep.sections to keep the report short
+
+ # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each
+ # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814
+ # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g.
+ # pytest-instafail does that)
+
+ # report failures with line/short/long styles
+ config.option.tbstyle = "auto" # full tb
+ with open(report_files["failures_long"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ # config.option.tbstyle = "short" # short tb
+ with open(report_files["failures_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ summary_failures_short(tr)
+
+ config.option.tbstyle = "line" # one line per error
+ with open(report_files["failures_line"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_failures()
+
+ with open(report_files["errors"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_errors()
+
+ with open(report_files["warnings"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_warnings() # normal warnings
+ tr.summary_warnings() # final warnings
+
+ tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary())
+
+ # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it
+ # takes > 10 minutes (as this part doesn't generate any output on the terminal).
+ # (also, it seems there is no useful information in this report, and we rarely need to read it)
+ # with open(report_files["passes"], "w") as f:
+ # tr._tw = create_terminal_writer(config, f)
+ # tr.summary_passes()
+
+ with open(report_files["summary_short"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.short_test_summary()
+
+ with open(report_files["stats"], "w") as f:
+ tr._tw = create_terminal_writer(config, f)
+ tr.summary_stats()
+
+ # restore:
+ tr._tw = orig_writer
+ tr.reportchars = orig_reportchars
+ config.option.tbstyle = orig_tbstyle
+
+
+# --- distributed testing functions --- #
+
+# adapted from https://stackoverflow.com/a/59041913/9201239
+import asyncio # noqa
+
+
+class _RunOutput:
+ def __init__(self, returncode, stdout, stderr):
+ self.returncode = returncode
+ self.stdout = stdout
+ self.stderr = stderr
+
+
+async def _read_stream(stream, callback):
+ while True:
+ line = await stream.readline()
+ if line:
+ callback(line)
+ else:
+ break
+
+
+async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput:
+ if echo:
+ print("\nRunning: ", " ".join(cmd))
+
+ p = await asyncio.create_subprocess_exec(
+ cmd[0],
+ *cmd[1:],
+ stdin=stdin,
+ stdout=asyncio.subprocess.PIPE,
+ stderr=asyncio.subprocess.PIPE,
+ env=env,
+ )
+
+ # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe
+ # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait
+ #
+ # If it starts hanging, will need to switch to the following code. The problem is that no data
+ # will be seen until it's done and if it hangs for example there will be no debug info.
+ # out, err = await p.communicate()
+ # return _RunOutput(p.returncode, out, err)
+
+ out = []
+ err = []
+
+ def tee(line, sink, pipe, label=""):
+ line = line.decode("utf-8").rstrip()
+ sink.append(line)
+ if not quiet:
+ print(label, line, file=pipe)
+
+ # XXX: the timeout doesn't seem to make any difference here
+ await asyncio.wait(
+ [
+ _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")),
+ _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")),
+ ],
+ timeout=timeout,
+ )
+ return _RunOutput(await p.wait(), out, err)
+
+
+def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:
+ loop = asyncio.get_event_loop()
+ result = loop.run_until_complete(
+ _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo)
+ )
+
+ cmd_str = " ".join(cmd)
+ if result.returncode > 0:
+ stderr = "\n".join(result.stderr)
+ raise RuntimeError(
+ f"'{cmd_str}' failed with returncode {result.returncode}\n\n"
+ f"The combined stderr from workers follows:\n{stderr}"
+ )
+
+ # check that the subprocess actually did run and produced some output, should the test rely on
+ # the remote side to do the testing
+ if not result.stdout and not result.stderr:
+ raise RuntimeError(f"'{cmd_str}' produced no output.")
+
+ return result
+
+
+def pytest_xdist_worker_id():
+ """
+ Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0
+ if `-n 1` or `pytest-xdist` isn't being used.
+ """
+ worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0")
+ worker = re.sub(r"^gw", "", worker, 0, re.M)
+ return int(worker)
+
+
+def get_torch_dist_unique_port():
+ """
+ Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument.
+
+ Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same
+ port at once.
+ """
+ port = 29500
+ uniq_delta = pytest_xdist_worker_id()
+ return port + uniq_delta
+
+
+def nested_simplify(obj, decimals=3):
+ """
+ Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test
+ within tests.
+ """
+ import numpy as np
+
+ if isinstance(obj, list):
+ return [nested_simplify(item, decimals) for item in obj]
+ if isinstance(obj, tuple):
+ return tuple([nested_simplify(item, decimals) for item in obj])
+ elif isinstance(obj, np.ndarray):
+ return nested_simplify(obj.tolist())
+ elif isinstance(obj, Mapping):
+ return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()}
+ elif isinstance(obj, (str, int, np.int64)):
+ return obj
+ elif obj is None:
+ return obj
+ elif is_torch_available() and isinstance(obj, torch.Tensor):
+ return nested_simplify(obj.tolist(), decimals)
+ elif is_tf_available() and tf.is_tensor(obj):
+ return nested_simplify(obj.numpy().tolist())
+ elif isinstance(obj, float):
+ return round(obj, decimals)
+ elif isinstance(obj, (np.int32, np.float32, np.float16)):
+ return nested_simplify(obj.item(), decimals)
+ else:
+ raise Exception(f"Not supported: {type(obj)}")
+
+
+def check_json_file_has_correct_format(file_path):
+ with open(file_path, "r") as f:
+ lines = f.readlines()
+ if len(lines) == 1:
+ # length can only be 1 if dict is empty
+ assert lines[0] == "{}"
+ else:
+ # otherwise make sure json has correct format (at least 3 lines)
+ assert len(lines) >= 3
+ # each key one line, ident should be 2, min length is 3
+ assert lines[0].strip() == "{"
+ for line in lines[1:-1]:
+ left_indent = len(lines[1]) - len(lines[1].lstrip())
+ assert left_indent == 2
+ assert lines[-1].strip() == "}"
+
+
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return (x, x)
+
+
+# These utils relate to ensuring the right error message is received when running scripts
+class SubprocessCallException(Exception):
+ pass
+
+
+def run_command(command: List[str], return_stdout=False):
+ """
+ Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
+ if an error occured while running `command`
+ """
+ try:
+ output = subprocess.check_output(command, stderr=subprocess.STDOUT)
+ if return_stdout:
+ if hasattr(output, "decode"):
+ output = output.decode("utf-8")
+ return output
+ except subprocess.CalledProcessError as e:
+ raise SubprocessCallException(
+ f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
+ ) from e
+
+
+class RequestCounter:
+ """
+ Helper class that will count all requests made online.
+
+ Might not be robust if urllib3 changes its logging format but should be good enough for us.
+
+ Usage:
+ ```py
+ with RequestCounter() as counter:
+ _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
+ assert counter["GET"] == 0
+ assert counter["HEAD"] == 1
+ assert counter.total_calls == 1
+ ```
+ """
+
+ def __enter__(self):
+ self._counter = defaultdict(int)
+ self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug)
+ self.mock = self.patcher.start()
+ return self
+
+ def __exit__(self, *args, **kwargs) -> None:
+ for call in self.mock.call_args_list:
+ log = call.args[0] % call.args[1:]
+ for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"):
+ if method in log:
+ self._counter[method] += 1
+ break
+ self.patcher.stop()
+
+ def __getitem__(self, key: str) -> int:
+ return self._counter[key]
+
+ @property
+ def total_calls(self) -> int:
+ return sum(self._counter.values())
+
+
+def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
+ """
+ To decorate flaky tests. They will be retried on failures.
+
+ Args:
+ max_attempts (`int`, *optional*, defaults to 5):
+ The maximum number of attempts to retry the flaky test.
+ wait_before_retry (`float`, *optional*):
+ If provided, will wait that number of seconds before retrying the test.
+ description (`str`, *optional*):
+ A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors,
+ etc.)
+ """
+
+ def decorator(test_func_ref):
+ @functools.wraps(test_func_ref)
+ def wrapper(*args, **kwargs):
+ retry_count = 1
+
+ while retry_count < max_attempts:
+ try:
+ return test_func_ref(*args, **kwargs)
+
+ except Exception as err:
+ print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
+ if wait_before_retry is not None:
+ time.sleep(wait_before_retry)
+ retry_count += 1
+
+ return test_func_ref(*args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
+ """
+ To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
+
+ Args:
+ test_case (`unittest.TestCase`):
+ The test that will run `target_func`.
+ target_func (`Callable`):
+ The function implementing the actual testing logic.
+ inputs (`dict`, *optional*, defaults to `None`):
+ The inputs that will be passed to `target_func` through an (input) queue.
+ timeout (`int`, *optional*, defaults to `None`):
+ The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env.
+ variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`.
+ """
+ if timeout is None:
+ timeout = int(os.environ.get("PYTEST_TIMEOUT", 600))
+
+ start_methohd = "spawn"
+ ctx = multiprocessing.get_context(start_methohd)
+
+ input_queue = ctx.Queue(1)
+ output_queue = ctx.JoinableQueue(1)
+
+ # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
+ input_queue.put(inputs, timeout=timeout)
+
+ process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
+ process.start()
+ # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents
+ # the test to exit properly.
+ try:
+ results = output_queue.get(timeout=timeout)
+ output_queue.task_done()
+ except Exception as e:
+ process.terminate()
+ test_case.fail(e)
+ process.join(timeout=timeout)
+
+ if results["error"] is not None:
+ test_case.fail(f'{results["error"]}')
+
+
+def run_test_using_subprocess(func):
+ """
+ To decorate a test to run in a subprocess using the `subprocess` module. This could avoid potential GPU memory
+ issues (GPU OOM or a test that causes many subsequential failing with `CUDA error: device-side assert triggered`).
+ """
+ import pytest
+
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if os.getenv("_INSIDE_SUB_PROCESS", None) == "1":
+ func(*args, **kwargs)
+ else:
+ test = " ".join(os.environ.get("PYTEST_CURRENT_TEST").split(" ")[:-1])
+ try:
+ import copy
+
+ env = copy.deepcopy(os.environ)
+ env["_INSIDE_SUB_PROCESS"] = "1"
+ # This prevents the entries in `short test summary info` given by the subprocess being truncated. so the
+ # full information can be passed to the parent pytest process.
+ # See: https://docs.pytest.org/en/stable/explanation/ci.html
+ env["CI"] = "true"
+
+ # If not subclass of `unitTest.TestCase` and `pytestconfig` is used: try to grab and use the arguments
+ if "pytestconfig" in kwargs:
+ command = list(kwargs["pytestconfig"].invocation_params.args)
+ for idx, x in enumerate(command):
+ if x in kwargs["pytestconfig"].args:
+ test = test.split("::")[1:]
+ command[idx] = "::".join([f"{func.__globals__['__file__']}"] + test)
+ command = [f"{sys.executable}", "-m", "pytest"] + command
+ command = [x for x in command if x not in ["--no-summary"]]
+ # Otherwise, simply run the test with no option at all
+ else:
+ command = [f"{sys.executable}", "-m", "pytest", f"{test}"]
+
+ subprocess.run(command, env=env, check=True, capture_output=True)
+ except subprocess.CalledProcessError as e:
+ exception_message = e.stdout.decode()
+ lines = exception_message.split("\n")
+ # Add a first line with more informative information instead of just `= test session starts =`.
+ # This makes the `short test summary info` section more useful.
+ if "= test session starts =" in lines[0]:
+ text = ""
+ for line in lines[1:]:
+ if line.startswith("FAILED "):
+ text = line[len("FAILED ") :]
+ text = "".join(text.split(" - ")[1:])
+ elif line.startswith("=") and line.endswith("=") and " failed in " in line:
+ break
+ elif len(text) > 0:
+ text += f"\n{line}"
+ text = "(subprocess) " + text
+ lines = [text] + lines
+ exception_message = "\n".join(lines)
+ raise pytest.fail(exception_message, pytrace=False)
+
+ return wrapper
+
+
+"""
+The following contains utils to run the documentation tests without having to overwrite any files.
+
+The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is
+made as a print would otherwise fail the corresonding line.
+
+To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules
+"""
+
+
+def preprocess_string(string, skip_cuda_tests):
+ """Prepare a docstring or a `.md` file to be run by doctest.
+
+ The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of
+ its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a
+ cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for
+ `string`.
+ """
+ codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:.*?\n)*?.*?```)"
+ codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string)
+ is_cuda_found = False
+ for i, codeblock in enumerate(codeblocks):
+ if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock:
+ codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock)
+ if (
+ (">>>" in codeblock or "..." in codeblock)
+ and re.search(r"cuda|to\(0\)|device=0", codeblock)
+ and skip_cuda_tests
+ ):
+ is_cuda_found = True
+ break
+
+ modified_string = ""
+ if not is_cuda_found:
+ modified_string = "".join(codeblocks)
+
+ return modified_string
+
+
+class HfDocTestParser(doctest.DocTestParser):
+ """
+ Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This
+ means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also
+ added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line.
+
+ Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough.
+ """
+
+ # This regular expression is used to find doctest examples in a
+ # string. It defines three groups: `source` is the source code
+ # (including leading indentation and prompts); `indent` is the
+ # indentation of the first (PS1) line of the source code; and
+ # `want` is the expected output (including leading indentation).
+ # fmt: off
+ _EXAMPLE_RE = re.compile(r'''
+ # Source consists of a PS1 line followed by zero or more PS2 lines.
+ (?P