diff --git a/README.md b/README.md index 09846dc61c..06a757ed90 100644 --- a/README.md +++ b/README.md @@ -367,6 +367,8 @@ any GPU memory savings. Please refer issue [[FSDP] FSDP with CPU offload consume ## 🤗 PEFT as a utility library +### Injecting adapters directly into the model + Inject trainable adapters on any `torch` model using `inject_adapter_in_model` method. Note the method will make no further change to the model. ```python @@ -403,6 +405,35 @@ dummy_outputs = model(dummy_inputs) Learn more about the [low level API in the docs](https://huggingface.co/docs/peft/developer_guides/low_level_api). +### Mixing different adapter types + +Ususally, it is not possible to combine different adapter types in the same model, e.g. combining LoRA with AdaLoRA, LoHa, or LoKr. Using a mixed model, this can, however, be achieved: + +```python +from peft import PeftMixedModel + +model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-OPTForCausalLM").eval() +peft_model = PeftMixedModel.from_pretrained(model, , "adapter0") +peft_model.load_adapter(, "adapter1") +peft_model.set_adapter(["adapter0", "adapter1"]) +result = peft_model(**inputs) +``` + +The main intent is to load already trained adapters and use this only for inference. However, it is also possible to create a PEFT model for training by passing `mixed=True` to `get_peft_model`: + +```python +from peft import get_peft_model, LoraConfig, LoKrConfig + +base_model = ... +config0 = LoraConfig(...) +config1 = LoKrConfig(...) +peft_model = get_peft_model(base_model, config0, "adapter0", mixed=True) +peft_model.add_adapter(config1, "adapter1") +peft_model.set_adapter(["adapter0", "adapter1"]) +for batch in dataloader: + ... +``` + ## Contributing If you would like to contribute to PEFT, please check out our [contributing guide](https://huggingface.co/docs/peft/developer_guides/contributing). diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 88bedf31d7..25992b3966 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -34,6 +34,8 @@ title: Working with custom models - local: developer_guides/low_level_api title: PEFT low level API + - local: developer_guides/mixed_models + title: Mixing different adapter types - local: developer_guides/contributing title: Contributing to PEFT - local: developer_guides/troubleshooting diff --git a/docs/source/developer_guides/mixed_models.md b/docs/source/developer_guides/mixed_models.md new file mode 100644 index 0000000000..93414eee04 --- /dev/null +++ b/docs/source/developer_guides/mixed_models.md @@ -0,0 +1,39 @@ + + +# Working with mixed adapter types + +Normally, it is not possible to mix different adapter types in 🤗 PEFT. For example, even though it is possible to create a PEFT model that has two different LoRA adapters (that can have different config options), it is not possible to combine a LoRA adapter with a LoHa adapter. However, by using a mixed model, this works as long as the adapter types are compatible. + +## Loading different adapter types into a PEFT model + +To load different adapter types into a PEFT model, proceed the same as if you were loading two adapters of the same type, but use `PeftMixedModel` instead of `PeftModel`: + +```py +from peft import PeftMixedModel + +base_model = ... # load the base model, e.g. from transformers +# load first adapter, which will be called "default" +peft_model = PeftMixedModel.from_pretrained(base_model, ) +peft_model.load_adapter(, adapter_name="other") +peft_model.set_adapter(["default", "other"]) +``` + +The last line is necessary if you want to activate both adapters, otherwise, only the first adapter would be active. Of course, you can add more different adapters by calling `add_adapter` repeatedly. + +Currently, the main purpose of mixed adapter types is to combine trained adapters for inference. Although it is technically also possible to train a mixed adapter model, this has not been tested and is not recommended. + +## Tips + +- Not all adapter types can be combined. See `peft.tuners.mixed.COMPATIBLE_TUNER_TYPES` for a list of compatible types. An error will be raised if you are trying to combine incompatible adapter types. +- It is possible to mix multiple adapters of the same type. This can be useful to combine adapters with very different configs. +- If you want to combine a lot of different adapters, it is most performant to add the same types of adapters consecutively. E.g., add LoRA1, LoRA2, LoHa1, LoHa2 in this order, instead of LoRA1, LoHa1, LoRA2, LoHa2. The order will make a difference for the outcome in most cases, but since no order is better a priori, it is best to choose the order that is most performant. diff --git a/src/peft/__init__.py b/src/peft/__init__.py index 75ddda498c..2b1883ebd7 100644 --- a/src/peft/__init__.py +++ b/src/peft/__init__.py @@ -35,6 +35,7 @@ get_peft_model, inject_adapter_in_model, ) +from .mixed_model import PeftMixedModel from .peft_model import ( PeftModel, PeftModelForCausalLM, diff --git a/src/peft/mapping.py b/src/peft/mapping.py index 60503fa985..f34bdb51c5 100644 --- a/src/peft/mapping.py +++ b/src/peft/mapping.py @@ -20,6 +20,7 @@ import torch from .config import PeftConfig +from .mixed_model import PeftMixedModel from .peft_model import ( PeftModel, PeftModelForCausalLM, @@ -99,13 +100,21 @@ def get_peft_config(config_dict: Dict[str, Any]) -> PeftConfig: return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) -def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default") -> PeftModel: +def get_peft_model( + model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False +) -> PeftModel | PeftMixedModel: """ Returns a Peft model object from a model and a config. Args: - model ([`transformers.PreTrainedModel`]): Model to be wrapped. - peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. + model ([`transformers.PreTrainedModel`]): + Model to be wrapped. + peft_config ([`PeftConfig`]): + Configuration object containing the parameters of the Peft model. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the adapter to be injected, if not provided, the default adapter name is used ("default"). + mixed (`bool`, `optional`, defaults to `False`): + Whether to allow mixing different (compatible) adapter types. """ model_config = getattr(model, "config", {"model_type": "custom"}) if hasattr(model_config, "to_dict"): @@ -113,8 +122,12 @@ def get_peft_model(model: PreTrainedModel, peft_config: PeftConfig, adapter_name peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) + if mixed: + return PeftMixedModel(model, peft_config, adapter_name=adapter_name) + if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning: return PeftModel(model, peft_config, adapter_name=adapter_name) + if peft_config.is_prompt_learning: peft_config = _prepare_prompt_learning_config(peft_config, model_config) return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config, adapter_name=adapter_name) diff --git a/src/peft/mixed_model.py b/src/peft/mixed_model.py new file mode 100644 index 0000000000..55892851e9 --- /dev/null +++ b/src/peft/mixed_model.py @@ -0,0 +1,394 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import os +from contextlib import contextmanager +from typing import Any, Optional, Union + +import torch +from accelerate.hooks import remove_hook_from_submodules +from torch import nn +from transformers.utils import PushToHubMixin + +from peft.tuners.mixed import COMPATIBLE_TUNER_TYPES + +from .config import PeftConfig +from .peft_model import PeftModel +from .tuners import ( + AdaLoraModel, + IA3Model, + LoHaModel, + LoKrModel, + LoraModel, + MixedModel, +) +from .utils import PeftType, _set_adapter, _set_trainable + + +PEFT_TYPE_TO_MODEL_MAPPING = { + PeftType.LORA: LoraModel, + PeftType.LOHA: LoHaModel, + PeftType.LOKR: LoKrModel, + PeftType.ADALORA: AdaLoraModel, + PeftType.IA3: IA3Model, +} + + +def _prepare_model_for_gradient_checkpointing(model: nn.Module) -> None: + r""" + Prepares the model for gradient checkpointing if necessary + """ + # Note: same as PeftModel._prepare_model_for_gradient_checkpointing + if not getattr(model, "is_gradient_checkpointing", True): + return model + + if not ( + getattr(model, "is_loaded_in_8bit", False) + or getattr(model, "is_loaded_in_4bit", False) + or getattr(model, "is_quantized", False) + ): + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + elif hasattr(model, "get_input_embeddings"): + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + + +def _check_config_compatible(peft_config: PeftConfig) -> None: + if peft_config.peft_type not in COMPATIBLE_TUNER_TYPES: + raise ValueError( + f"The provided `peft_type` '{peft_config.peft_type.value}' is not compatible with the `PeftMixedModel`. " + f"Compatible types are: {COMPATIBLE_TUNER_TYPES}" + ) + + +class PeftMixedModel(PushToHubMixin, torch.nn.Module): + """ + Peft model for mixing different types of adapters. + + This class currently does not support saving and loading. Instead, it is assumed that the adapters are already + trained and loading the model requires a script to be run each time. + + Currently, the main purpose of mixed adapter types is to combine trained adapters for inference. Although it is + technically possible to train a mixed adapter model, this has not been tested and is not recommended. + + Note: This class should usually not be initialized directly. Instead, use `get_peft_model` with the argument + `mixed=True`. + + Below is an example that shows how to load a mixed model with two different types of adapters. + + ```py + >>> from peft import get_peft_model + + >>> base_model = ... # load the base model, e.g. from transformers + >>> peft_model = PeftMixedModel.from_pretrained(base_model, path_to_adapter1, "adapter1").eval() + >>> peft_model.load_adapter(path_to_adapter2, "adapter2") + >>> peft_model.set_adapter(["adapter1", "adapter2"]) # activate both adapters + >>> peft_model(data) # forward pass using both adapters + ``` + + Tips: + + - Not all adapter types can be combined. See `peft.tuners.mixed.COMPATIBLE_TUNER_TYPES` for a list of compatible + types. An error will be raised if you are trying to combine incompatible adapter types. + - It is possible to mix multiple adapters of the same type. This can be useful to combine adapters with very + different configs. + - If you want to combine a lot of different adapters, it is most performant to add the same types of adapters + consecutively. E.g., add LoRA1, LoRA2, LoHa1, LoHa2 in this order, instead of LoRA1, LoHa1, LoRA2, LoHa2. As long + as the adapters are commutative, the order does not matter for the final result. + + Args: + model (`torch.nn.Module`): + The model to be tuned. + config (`PeftConfig`): + The config of the model to be tuned. The adapter type must be compatible. + adapter_name (`str`, `optional`, defaults to `"default"`): + The name of the first adapter. + """ + + def __init__(self, model: nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__() + _check_config_compatible(peft_config) + _prepare_model_for_gradient_checkpointing(model) + self.modules_to_save = None + self.base_model = MixedModel(model, {adapter_name: peft_config}, adapter_name) + self.set_modules_to_save(peft_config, adapter_name) + + self.config = getattr(model, "config", {"model_type": "custom"}) + + # the `pretraining_tp` is set for some models to simulate Tensor Parallelism during inference to avoid + # numerical differences, https://github.com/pytorch/pytorch/issues/76232 - to avoid any unexpected + # behavior we disable that in this line. + if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"): + self.base_model.config.pretraining_tp = 1 + + @property + def peft_config(self) -> dict[str, PeftConfig]: + return self.base_model.peft_config + + @property + def active_adapter(self) -> str: + return self.base_model.active_adapter + + @property + def active_adapters(self) -> list[str]: + return self.base_model.active_adapters + + def get_nb_trainable_parameters(self): + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + # note: same as PeftModel.get_nb_trainable_parameters + trainable_params = 0 + all_param = 0 + for _, param in self.named_parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes + # one needs to multiply the number of parameters by 2 to get + # the correct number of parameters + if param.__class__.__name__ == "Params4bit": + num_params = num_params * 2 + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + def print_trainable_parameters(self): + """ + Prints the number of trainable parameters in the model. + """ + # note: same as PeftModel.print_trainable_parameters + trainable_params, all_param = self.get_nb_trainable_parameters() + + print( + f"trainable params: {trainable_params:,d} || " + f"all params: {all_param:,d} || " + f"trainable%: {100 * trainable_params / all_param:.4f}" + ) + + def forward(self, *args: Any, **kwargs: Any): + """ + Forward pass of the model. + """ + return self.base_model(*args, **kwargs) + + def generate(self, *args: Any, **kwargs: Any): + """ + Generate output. + """ + return self.base_model.generate(*args, **kwargs) + + @contextmanager + def disable_adapter(self): + """ + Disables the adapter module. + """ + try: + self.base_model.disable_adapter_layers() + yield + finally: + self.base_model.enable_adapter_layers() + + def add_adapter(self, adapter_name: str, peft_config: PeftConfig): + _check_config_compatible(peft_config) + + try: + self.peft_config[adapter_name] = peft_config + self.base_model.inject_adapter(self, adapter_name) + except Exception: # somthing went wrong, roll back + if adapter_name in self.peft_config: + del self.peft_config[adapter_name] + raise + + self.set_modules_to_save(peft_config, adapter_name) + + def set_modules_to_save(self, peft_config: PeftConfig, adapter_name: str) -> None: + if (modules_to_save := getattr(peft_config, "modules_to_save", None)) is None: + return + + if self.modules_to_save is None: + self.modules_to_save = set(modules_to_save) + else: + self.modules_to_save.update(modules_to_save) + _set_trainable(self, adapter_name) + + def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + """ + Sets the active adapter(s) for the model. + + Note that the order in which the adapters are applied during the forward pass may not be the same as the order + in which they are passed to this function. Instead, the order during the forward pass is determined by the + order in which the adapters were loaded into the model. The active adapters only determine which adapters are + active during the forward pass, but not the order in which they are applied. + + Args: + adapter_name (`str` or `List[str]`): + The name of the adapter(s) to be activated. + """ + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + mismatched = set(adapter_name) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + self.base_model.set_adapter(adapter_name) + _set_adapter(self, adapter_name) + + def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: + if isinstance(adapter_name, str): + adapter_name = [adapter_name] + + mismatched = set(adapter_name) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + self.base_model.delete_adapter(adapter_name) + + def merge_and_unload(self, *args: Any, **kwargs: Any): + r""" + This method merges the adapter layers into the base model. This is needed if someone wants to use the base + model as a standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + return self.base_model.merge_and_unload(*args, **kwargs) + + def unload(self, *args: Any, **kwargs: Any): + """ + Gets back the base model by removing all the adapter modules without merging. This gives back the original base + model. + """ + return self.base_model.unload(*args, **kwargs) + + @classmethod + def _split_kwargs(cls, kwargs: dict[str, Any]): + return PeftModel._split_kwargs(kwargs) + + def load_adapter(self, model_id: str, adapter_name: str, *args: Any, **kwargs: Any): + output = PeftModel.load_adapter(self, model_id, adapter_name, *args, **kwargs) + # TODO: not quite clear why this is necessary but tests fail without it + self.set_adapter(self.active_adapters) + return output + + def create_or_update_model_card(self, output_dir: str): + raise NotImplementedError(f"Model card creation is not supported for {self.__class__.__name__} (yet).") + + def save_pretrained( + self, + save_directory: str, + safe_serialization: bool = False, + selected_adapters: Optional[list[str]] = None, + **kwargs: Any, + ): + raise NotImplementedError(f"Saving is not supported for {self.__class__.__name__} (yet).") + + @classmethod + def from_pretrained( + cls, + model: nn.Module, + model_id: str | os.PathLike, + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + **kwargs: Any, + ): + r""" + Instantiate a PEFT mixed model from a pretrained model and loaded PEFT weights. + + Note that the passed `model` may be modified inplace. + + Args: + model (`nn.Module`): + The model to be adapted. + model_id (`str` or `os.PathLike`): + The name of the PEFT configuration to use. Can be either: + - A string, the `model id` of a PEFT configuration hosted inside a model repo on the Hugging Face + Hub. + - A path to a directory containing a PEFT configuration file saved using the `save_pretrained` + method (`./my_peft_config_directory/`). + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to be loaded. This is useful for loading multiple adapters. + is_trainable (`bool`, *optional*, defaults to `False`): + Whether the adapter should be trainable or not. If `False`, the adapter will be frozen and use for + inference + config ([`~peft.PeftConfig`], *optional*): + The configuration object to use instead of an automatically loaded configuation. This configuration + object is mutually exclusive with `model_id` and `kwargs`. This is useful when configuration is already + loaded before calling `from_pretrained`. + kwargs: (`optional`): + Additional keyword arguments passed along to the specific PEFT configuration class. + """ + # note: adapted from PeftModel.from_pretrained + from .mapping import PEFT_TYPE_TO_CONFIG_MAPPING + + # load the config + if config is None: + config = PEFT_TYPE_TO_CONFIG_MAPPING[ + PeftConfig._get_peft_type( + model_id, + subfolder=kwargs.get("subfolder", None), + revision=kwargs.get("revision", None), + cache_dir=kwargs.get("cache_dir", None), + use_auth_token=kwargs.get("use_auth_token", None), + ) + ].from_pretrained(model_id, **kwargs) + elif isinstance(config, PeftConfig): + config.inference_mode = not is_trainable + else: + raise ValueError(f"The input config must be a PeftConfig, got {config.__class__}") + + # note: this is different from PeftModel.from_pretrained + if config.peft_type not in PEFT_TYPE_TO_MODEL_MAPPING: + raise ValueError(f"Adapter of type {config.peft_type} is not supported for mixed models.") + + if (getattr(model, "hf_device_map", None) is not None) and len( + set(model.hf_device_map.values()).intersection({"cpu", "disk"}) + ) > 0: + remove_hook_from_submodules(model) + + if config.is_prompt_learning and is_trainable: + # note: should not be possible to reach, but just in case + raise ValueError("Cannot set a prompt learning adapter to trainable when loading pretrained adapter.") + else: + config.inference_mode = not is_trainable + + # note: this is different from PeftModel.from_pretrained, we always return a PeftMixedModel + model = cls(model, config, adapter_name) + model.load_adapter(model_id, adapter_name, is_trainable=is_trainable, **kwargs) + return model diff --git a/src/peft/tuners/__init__.py b/src/peft/tuners/__init__.py index f5f665dd99..9211cfb4f8 100644 --- a/src/peft/tuners/__init__.py +++ b/src/peft/tuners/__init__.py @@ -28,3 +28,4 @@ from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit from .multitask_prompt_tuning import MultitaskPromptEmbedding, MultitaskPromptTuningConfig, MultitaskPromptTuningInit from .oft import OFTConfig, OFTModel +from .mixed import MixedModel diff --git a/src/peft/tuners/mixed/__init__.py b/src/peft/tuners/mixed/__init__.py new file mode 100644 index 0000000000..f21cff3b29 --- /dev/null +++ b/src/peft/tuners/mixed/__init__.py @@ -0,0 +1,19 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .model import COMPATIBLE_TUNER_TYPES, MixedModel + + +__all__ = ["COMPATIBLE_TUNER_TYPES", "MixedModel"] diff --git a/src/peft/tuners/mixed/model.py b/src/peft/tuners/mixed/model.py new file mode 100644 index 0000000000..5e7acf1cfe --- /dev/null +++ b/src/peft/tuners/mixed/model.py @@ -0,0 +1,323 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import warnings +from typing import Any, Optional, Union + +from torch import nn +from tqdm import tqdm + +from peft.tuners import adalora, loha, lokr, lora +from peft.tuners.tuners_utils import BaseTuner, BaseTunerLayer, check_target_module_exists +from peft.utils import ( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, + ModulesToSaveWrapper, + PeftType, + _get_submodules, + get_auto_gptq_quant_linear, +) + + +# Collection of constants used for all tuners +COMPATIBLE_TUNER_TYPES = (PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.ADALORA) +PREFIXES = [lora.LoraModel.prefix, lokr.LoKrModel.prefix, loha.LoHaModel.prefix] +Configs = Union[lora.LoraConfig, loha.LoHaConfig, lokr.LoKrConfig, adalora.AdaLoraConfig] +Layers = (lora.layer.LoraLayer, loha.layer.LoHaLayer, lokr.layer.LoKrLayer, adalora.layer.AdaLoraLayer) + + +class MixedModel(BaseTuner): + """ + A class that allows to mix different types of adapters in a single model. + + Note: This class should usually not be initialized directly. Instead, use `get_peft_model` with the argument + `mixed=True`. + + Args: + model (:obj:`nn.Module`): + The model to be tuned. + config (:obj:`PeftConfig`): + The config of the model to be tuned. The adapter type must be compatible. + adapter_name (:obj:`str`): + The name of the first adapter. + """ + + def __init__(self, model: nn.Module, config: Configs, adapter_name: str) -> None: + super().__init__(model, config, adapter_name) + + def _check_new_adapter_config(self, config: Configs) -> None: + """ + A helper method to check the config when a new adapter is being added. + + Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters. + + """ + if not isinstance(config, Configs.__args__): + raise ValueError( + f"{self.__class__.__name__} only supports {COMPATIBLE_TUNER_TYPES} configs, but got {type(config)}." + ) + + biases = (getattr(config, "bias", None) for config in self.peft_config) + biases = [bias for bias in biases if bias not in (None, "none")] + if len(biases) > 1: + raise ValueError( + f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, " + "set bias to 'none' for all adapters." + ) + + @staticmethod + def _check_target_module_exists(config: Configs, key: str): + return check_target_module_exists(config, key) + + def _create_and_replace( + self, + config: Configs, + *args: Any, + **kwargs: Any, + ) -> None: + if isinstance(config, adalora.AdaLoraConfig): + adalora.AdaLoraModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, lora.LoraConfig): + lora.LoraModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, loha.LoHaConfig): + loha.LoHaModel._create_and_replace(self, config, *args, **kwargs) + elif isinstance(config, lokr.LoKrConfig): + lokr.LoKrModel._create_and_replace(self, config, *args, **kwargs) + else: + raise ValueError(f"Unsupported config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") + + def _replace_module(self, parent, child_name, new_module, child) -> None: + setattr(parent, child_name, new_module) + # It's not necessary to set requires_grad here, as that is handled by + # _mark_only_adapters_as_trainable + + # child layer wraps the original module, unpack it + if hasattr(child, "base_layer"): + child = child.get_base_layer() + elif hasattr(child, "quant_linear_module"): + # TODO maybe not necessary to have special treatment? + child = child.quant_linear_module + + if not hasattr(new_module, "base_layer"): + new_module.weight = child.weight + if hasattr(child, "bias"): + new_module.bias = child.bias + + if getattr(child, "state", None) is not None: + if hasattr(new_module, "base_layer"): + new_module.base_layer.state = child.state + else: + new_module.state = child.state + new_module.to(child.weight.device) + + # dispatch to correct device + for name, module in new_module.named_modules(): + if any(prefix in name for prefix in PREFIXES): + module.to(child.weight.device) + if "ranknum" in name: + module.to(child.weight.device) + + def _mark_only_adapters_as_trainable(self) -> None: + for n, p in self.model.named_parameters(): + if not any(prefix in n for prefix in PREFIXES): + p.requires_grad = False + + for active_adapter in self.active_adapters: + bias = getattr(self.peft_config[active_adapter], "bias", "none") + if bias == "none": + continue + + if bias == "all": + for n, p in self.model.named_parameters(): + if "bias" in n: + p.requires_grad = True + elif bias == "lora_only": + # TODO: check if this is needed for other supported types + for m in self.model.modules(): + if isinstance(m, Layers) and hasattr(m, "bias") and m.bias is not None: + m.bias.requires_grad = True + else: + raise ValueError(f"Requested bias: {bias}, is not implemented.") + + @staticmethod + def _create_new_module(config, adapter_name, target, **kwargs): + gptq_quantization_config = kwargs.get("gptq_quantization_config", None) + AutoGPTQQuantLinear = get_auto_gptq_quant_linear(gptq_quantization_config) + if (gptq_quantization_config is not None) or (AutoGPTQQuantLinear is not None): + raise ValueError(f"GPTQ quantization not supported for {config.peft_type.value} (yet).") + + loaded_in_8bit = kwargs.pop("loaded_in_8bit", False) + loaded_in_4bit = kwargs.pop("loaded_in_4bit", False) + if loaded_in_8bit or loaded_in_4bit: + raise ValueError(f"8bit and 4bit quantization not supported for {config.peft_type.value} (yet).") + + if isinstance(config, adalora.AdaLoraConfig): + new_module = adalora.AdaLoraModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, lora.LoraConfig): + new_module = lora.LoraModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, loha.LoHaConfig): + new_module = loha.LoHaModel._create_new_module(config, adapter_name, target, **kwargs) + elif isinstance(config, lokr.LoKrConfig): + new_module = lokr.LoKrModel._create_new_module(config, adapter_name, target, **kwargs) + else: + raise ValueError(f"Unknown config type {type(config)}, should be one of {COMPATIBLE_TUNER_TYPES}.") + return new_module + + def _set_adapter_layers(self, enabled=True): + for module in self.model.modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + module.enable_adapters(enabled) + + def enable_adapter_layers(self): + self._set_adapter_layers(enabled=True) + + def disable_adapter_layers(self): + for active_adapter in self.active_adapters: + val = getattr(self.peft_config[active_adapter], "bias", "none") + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) + self._set_adapter_layers(enabled=False) + + def set_adapter(self, adapter_name: Union[str, list[str]]) -> None: + for module in self.model.modules(): + if isinstance(module, Layers): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + self.active_adapter = adapter_name + + @staticmethod + def _prepare_adapter_config(peft_config, model_config): + if peft_config.target_modules is None: + if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: + raise ValueError("Please specify `target_modules` in `peft_config`") + + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) + return peft_config + + def _unload_and_optionally_merge( + self, + merge=True, + progressbar: bool = False, + safe_merge: bool = False, + adapter_names: Optional[list[str]] = None, + ): + if merge: + if getattr(self.model, "quantization_method", None) == "gptq": + raise ValueError("Cannot merge layers when the model is gptq quantized") + + def merge_recursively(module): + # helper function to recursively merge the base_layer of the target + path = [] + layer = module + while hasattr(layer, "base_layer"): + path.append(layer) + layer = layer.base_layer + for layer_before, layer_after in zip(path[:-1], path[1:]): + layer_after.merge(safe_merge=safe_merge, adapter_names=adapter_names) + layer_before.base_layer = layer_after.base_layer + module.merge(safe_merge=safe_merge, adapter_names=adapter_names) + + key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] + desc = "Unloading " + ("and merging " if merge else "") + "model" + + for key in tqdm(key_list, disable=not progressbar, desc=desc): + try: + parent, target, target_name = _get_submodules(self.model, key) + except AttributeError: + continue + + if hasattr(target, "base_layer"): + if merge: + merge_recursively(target) + self._replace_module(parent, target_name, target.get_base_layer(), target) + elif isinstance(target, ModulesToSaveWrapper): + # save any additional trainable modules part of `modules_to_save` + setattr(parent, target_name, target.modules_to_save[target.active_adapter]) + + return self.model + + def add_weighted_adapter(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError(f"Weighted adapters are not supported for {self.__class__.__name__} (yet).") + + def delete_adapter(self, adapter_name: Union[str, list[str]]) -> None: + """ + Deletes an existing adapter. + + Args: + adapter_name (Union[str, list[str]]): Name of the adapter(s) to delete. + """ + if isinstance(adapter_name, str): + adapter_names = [adapter_name] + else: + adapter_names = adapter_name + + mismatched = set(adapter_names) - set(self.peft_config.keys()) + if mismatched: + raise ValueError( + f"Adapter(s) {sorted(mismatched)} not found, available adapters: {sorted(self.peft_config.keys())}" + ) + + for adapter_name in adapter_names: + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if not any(prefix in key for prefix in PREFIXES)] + new_adapter = None + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, BaseTunerLayer): + target.delete_adapter(adapter_name) + if new_adapter is None: + new_adapter = target.active_adapters[:] + + self.active_adapter = new_adapter or [] + + def merge_and_unload( + self, progressbar: bool = False, safe_merge: bool = False, adapter_names: Optional[list[str]] = None + ) -> nn.Module: + r""" + This method merges the layers into the base model. This is needed if someone wants to use the base model as a + standalone model. + + Args: + progressbar (`bool`): + whether to show a progressbar indicating the unload and merge process + safe_merge (`bool`): + whether to activate the safe merging check to check if there is any potential Nan in the adapter + weights + adapter_names (`List[str]`, *optional*): + The list of adapter names that should be merged. If None, all active adapters will be merged. Defaults + to `None`. + """ + return self._unload_and_optionally_merge( + progressbar=progressbar, safe_merge=safe_merge, adapter_names=adapter_names + ) + + def unload(self) -> nn.Module: + """ + Gets back the base model by removing all the lora modules without merging. This gives back the original base + model. + """ + return self._unload_and_optionally_merge(merge=False) + + def generate(self, *args: Any, **kwargs: Any): + return self.model.generate(*args, **kwargs) diff --git a/tests/test_mixed.py b/tests/test_mixed.py new file mode 100644 index 0000000000..bd8f455e99 --- /dev/null +++ b/tests/test_mixed.py @@ -0,0 +1,794 @@ +# coding=utf-8 +# Copyright 2023-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import itertools +import os +import re +import tempfile +import unittest + +import torch +from parameterized import parameterized +from torch import nn +from transformers import AutoModelForCausalLM + +from peft import AdaLoraConfig, LoHaConfig, LoKrConfig, LoraConfig, PeftMixedModel, PrefixTuningConfig, get_peft_model +from peft.tuners.tuners_utils import BaseTunerLayer +from peft.utils import infer_device + + +class SimpleNet(nn.Module): + def __init__(self, bias=True): + super().__init__() + self.lin0 = nn.Linear(10, 20, bias=bias) + self.relu = nn.ReLU() + self.lin1 = nn.Linear(20, 2, bias=bias) + + def forward(self, X): + X = X.float() + X = self.lin0(X) + X = self.relu(X) + X = self.lin1(X) + return X + + +def _param_name_func(testcase_func, param_num, params): + # for parameterized tests in TextMixedAdapterTypes + config0, config1 = params[0] + name0 = config0.__class__.__name__ + name1 = config1.__class__.__name__ + if name0 != name1: + return f"{testcase_func.__name__}_{param_num}_{name0}_{name1}" + return f"{testcase_func.__name__}_{param_num}_{name0}_x2" + + +class TestMixedAdapterTypes(unittest.TestCase): + torch_device = infer_device() + + def _get_model(self, model_cls, peft_config=None, adapter_name=None, seed=0, mixed=True): + torch.manual_seed(0) # always use seed 0 for base model, seed for adapters may differ + base_model = model_cls().eval().to(self.torch_device) + if peft_config is None: + return base_model + + torch.manual_seed(seed) + assert adapter_name is not None + peft_model = get_peft_model(base_model, peft_config, adapter_name=adapter_name, mixed=mixed) + return peft_model.eval().to(self.torch_device) + + def _check_mixed_outputs(self, model_cls, config0, config1, input, *, is_commutative): + # This test checks different combinations of adapter0, adapter1, or combinations of the two, and whether + # outputs are the same/different, depending on context. If we pass is_commutative=True, it means that the order + # of adapters does not matter, and we expect the same output regardless of the order in which adapters are + # applied. + # We have to very careful with resetting the random seed each time it is used, otherwise the adapters may be + # initialized with different values, and the test will fail. + + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + # base model + base_model = self._get_model(model_cls) + output_base = base_model(input) + self.assertTrue(torch.isfinite(output_base).all()) + + # adapter 0 + peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + output_config0 = peft_model_0(input) + + self.assertTrue(torch.isfinite(output_config0).all()) + self.assertFalse(torch.allclose(output_base, output_config0, atol=atol, rtol=rtol)) + + # adapter 1 + peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + output_config1 = peft_model_1(input) + + self.assertTrue(torch.isfinite(output_config1).all()) + self.assertFalse(torch.allclose(output_base, output_config1, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_config0, output_config1, atol=atol, rtol=rtol)) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + + # check the number of tuner layer types + tuner_layers = [mod for mod in peft_model_01.modules() if isinstance(mod, BaseTunerLayer)] + tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} + if type(config0) == type(config1): + self.assertEqual(len(tuner_types), 1) + else: + self.assertEqual(len(tuner_types), 2) + + self.assertEqual(peft_model_01.active_adapters, ["adapter0", "adapter1"]) + self.assertTrue(torch.isfinite(output_mixed_01).all()) + self.assertFalse(torch.allclose(output_config0, output_mixed_01, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_config1, output_mixed_01, atol=atol, rtol=rtol)) + if is_commutative: + delta0 = output_config0 - output_base + delta1 = output_config1 - output_base + delta_mixed_01 = output_mixed_01 - output_base + self.assertTrue(torch.allclose(delta0 + delta1, delta_mixed_01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + + # check the number of tuner layer types + tuner_layers = [mod for mod in peft_model_10.modules() if isinstance(mod, BaseTunerLayer)] + tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} + if type(config0) == type(config1): + self.assertEqual(len(tuner_types), 1) + else: + self.assertEqual(len(tuner_types), 2) + + self.assertEqual(peft_model_10.active_adapters, ["adapter1", "adapter0"]) + self.assertTrue(torch.isfinite(output_mixed_10).all()) + self.assertFalse(torch.allclose(output_config0, output_mixed_10, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_config1, output_mixed_10, atol=atol, rtol=rtol)) + if is_commutative: + self.assertTrue(torch.allclose(output_mixed_01, output_mixed_10, atol=atol, rtol=rtol)) + + # turn around the order of the adapters of the 0 + 1 mixed model, should behave like the 0 + 1 mixed model + peft_model_10.set_adapter(["adapter0", "adapter1"]) + output_mixed_reversed = peft_model_10(input) + + # check the number of tuner layer types + tuner_layers = [mod for mod in peft_model_10.modules() if isinstance(mod, BaseTunerLayer)] + tuner_types = {type(tuner_layer) for tuner_layer in tuner_layers} + if type(config0) == type(config1): + self.assertEqual(len(tuner_types), 1) + else: + self.assertEqual(len(tuner_types), 2) + + self.assertEqual(peft_model_10.active_adapters, ["adapter0", "adapter1"]) + self.assertTrue(torch.isfinite(output_mixed_reversed).all()) + self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_01, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_mixed_reversed, output_config0, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_mixed_reversed, output_config1, atol=atol, rtol=rtol)) + if is_commutative: + self.assertTrue(torch.allclose(output_mixed_reversed, output_mixed_10, atol=atol, rtol=rtol)) + + def _check_merging(self, model_cls, config0, config1, input): + # Ensure that when merging mixed adapters, the result is the same as when applying the adapters separately. + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + + model_merged_01 = peft_model_01.merge_and_unload() + output_merged_01 = model_merged_01(input) + self.assertTrue(torch.allclose(output_mixed_01, output_merged_01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + + model_merged_10 = peft_model_10.merge_and_unload() + output_merged_10 = model_merged_10(input) + self.assertTrue(torch.allclose(output_mixed_10, output_merged_10, atol=atol, rtol=rtol)) + + def _check_unload(self, model_cls, config0, config1, input): + # Ensure that we can unload the base model without merging + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + base_model = self._get_model(model_cls) + output_base = base_model(input) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed = peft_model_01(input) + + # unload + model_unloaded = peft_model_01.unload() + output_unloaded = model_unloaded(input) + + self.assertFalse(torch.allclose(output_mixed, output_unloaded, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_unloaded, atol=atol, rtol=rtol)) + + def _check_disable(self, model_cls, config0, config1, input): + # Ensure that we can disable adapters + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + # base model + base_model = self._get_model(model_cls) + output_base = base_model(input) + + # adapter 0 + peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + output_config0 = peft_model_0(input) + with peft_model_0.disable_adapter(): + output_disabled0 = peft_model_0(input) + + self.assertFalse(torch.allclose(output_base, output_config0, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled0, atol=atol, rtol=rtol)) + + # adapter 1 + peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + output_config1 = peft_model_1(input) + with peft_model_1.disable_adapter(): + output_disabled1 = peft_model_1(input) + + self.assertFalse(torch.allclose(output_base, output_config1, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled1, atol=atol, rtol=rtol)) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + with peft_model_01.disable_adapter(): + output_disabled01 = peft_model_01(input) + + self.assertFalse(torch.allclose(output_base, output_mixed_01, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + peft_model_10 = self._get_model(model_cls, config1, "adapter1", seed=seed1) + torch.manual_seed(seed0) + peft_model_10.add_adapter("adapter0", config0) + peft_model_10.set_adapter(["adapter1", "adapter0"]) + output_mixed_10 = peft_model_10(input) + with peft_model_10.disable_adapter(): + output_disabled10 = peft_model_10(input) + + self.assertFalse(torch.allclose(output_base, output_mixed_10, atol=atol, rtol=rtol)) + self.assertTrue(torch.allclose(output_base, output_disabled10, atol=atol, rtol=rtol)) + + def _check_loading(self, model_cls, config0, config1, input): + # Check that we can load two adapters into the same model + # Note that we save the adapters using a normal PeftModel because PeftMixModel doesn't support saving yet + atol = 1e-5 + rtol = 1e-5 + seed0 = 0 + seed1 = 1 + + with tempfile.TemporaryDirectory() as tmp_dirname: + # SAVING + # adapter 0: note that we set mixed=False because mixed models don't support saving (yet) + peft_model_0 = self._get_model(model_cls, config0, "adapter0", seed=seed0, mixed=False) + output_config0 = peft_model_0(input) + peft_model_0.save_pretrained(os.path.join(tmp_dirname, "adapter0")) + + # adapter 1: note that we set mixed=False because mixed models don't support saving (yet) + peft_model_1 = self._get_model(model_cls, config1, "adapter1", seed=seed1, mixed=False) + output_config1 = peft_model_1(input) + peft_model_1.save_pretrained(os.path.join(tmp_dirname, "adapter1")) + + # adapter 0 + 1 + peft_model_01 = self._get_model(model_cls, config0, "adapter0", seed=seed0) + torch.manual_seed(seed1) + peft_model_01.add_adapter("adapter1", config1) + peft_model_01.set_adapter(["adapter0", "adapter1"]) + output_mixed_01 = peft_model_01(input) + + # LOADING + # adapter 0 + base_model = self._get_model(model_cls) + # Notes: + # Path is tmp_dirname/adapter0/adapter0 because non-default adapters are saved in a subfolder. + # As a sanity check, we should set a completely different seed here. That way, we ensure that the the + # weights are not just randomly initialized exactly to the same values as before. + torch.manual_seed(123456) + peft_model_loaded0 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" + ) + output_loaded0 = peft_model_loaded0(input) + self.assertTrue(torch.allclose(output_config0, output_loaded0, atol=atol, rtol=rtol)) + + # adapter 1 + base_model = self._get_model(model_cls) + torch.manual_seed(654321) # setting a completely different seed here should not affect the result + peft_model_loaded1 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" + ) + output_loaded1 = peft_model_loaded1(input) + self.assertTrue(torch.allclose(output_config1, output_loaded1, atol=atol, rtol=rtol)) + + # adapter 0 + 1 + base_model = self._get_model(model_cls) + torch.manual_seed(97531) # setting a completely different seed here should not affect the result + peft_model_loaded_01 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0" + ) + peft_model_loaded_01.load_adapter(os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1") + # at this point, "config0" should still be active + self.assertEqual(peft_model_loaded_01.active_adapters, ["adapter0"]) + output_loaded01_0 = peft_model_loaded_01(input) + self.assertTrue(torch.allclose(output_config0, output_loaded01_0, atol=atol, rtol=rtol)) + # activate adapter1 + peft_model_loaded_01.set_adapter(["adapter1"]) + self.assertEqual(peft_model_loaded_01.active_adapters, ["adapter1"]) + output_loaded01_1 = peft_model_loaded_01(input) + self.assertTrue(torch.allclose(output_config1, output_loaded01_1, atol=atol, rtol=rtol)) + # activate both adapters + peft_model_loaded_01.set_adapter(["adapter0", "adapter1"]) + output_loaded01 = peft_model_loaded_01(input) + self.assertTrue(torch.allclose(output_mixed_01, output_loaded01, atol=atol, rtol=rtol)) + + # adapter 1 + 0 + base_model = self._get_model(model_cls) + torch.manual_seed(445566) # setting a completely different seed here should not affect the result + peft_model_loaded_10 = PeftMixedModel.from_pretrained( + base_model, os.path.join(tmp_dirname, "adapter1", "adapter1"), "adapter1" + ) + peft_model_loaded_10.load_adapter(os.path.join(tmp_dirname, "adapter0", "adapter0"), "adapter0") + # at this point, "config0" should still be active + self.assertEqual(peft_model_loaded_10.active_adapters, ["adapter1"]) + output_loaded10_1 = peft_model_loaded_10(input) + self.assertTrue(torch.allclose(output_config1, output_loaded10_1, atol=atol, rtol=rtol)) + # activate adapter1 + peft_model_loaded_10.set_adapter(["adapter0"]) + self.assertEqual(peft_model_loaded_10.active_adapters, ["adapter0"]) + output_loaded10_0 = peft_model_loaded_10(input) + self.assertTrue(torch.allclose(output_config0, output_loaded10_0, atol=atol, rtol=rtol)) + # activate both adapters + peft_model_loaded_10.set_adapter(["adapter1", "adapter0"]) + output_loaded10 = peft_model_loaded_10(input) + self.assertTrue(torch.allclose(output_mixed_01, output_loaded10, atol=atol, rtol=rtol)) + + @parameterized.expand( + itertools.combinations( + [ + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoKrConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + ], + r=2, + ), + name_func=_param_name_func, + ) + def test_target_first_layer(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + @parameterized.expand( + itertools.combinations( + [ + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ], + r=2, + ), + name_func=_param_name_func, + ) + def test_target_last_layer(self, config0, config1): + # We are targeting the last layer of the SimpleNet. Therefore, since the adapters only add their activations + # to the output, the results should be commutative. This would *not* work if the adapters do something more + # complex or if we target an earlier layer, because of the non-linearity would destroy the commutativity. + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=True) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + @parameterized.expand( + [ + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ], + name_func=_param_name_func, + ) + def test_target_different_layers(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + @parameterized.expand( + [ + ( + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + LoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ( + LoHaConfig(target_modules=["lin1"], init_weights=False), + LoHaConfig(target_modules=["lin1"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin1"], init_weights=False), + LoKrConfig(target_modules=["lin1"], init_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin1"], init_lora_weights=False), + ), + ], + name_func=_param_name_func, + ) + def test_target_last_layer_same_type(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=True) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + + @parameterized.expand( + [ + ( + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + LoraConfig(target_modules=["lin0"], init_lora_weights=False), + ), + ( + LoHaConfig(target_modules=["lin0"], init_weights=False), + LoHaConfig(target_modules=["lin0"], init_weights=False), + ), + ( + LoKrConfig(target_modules=["lin0"], init_weights=False), + LoKrConfig(target_modules=["lin0"], init_weights=False), + ), + ( + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + AdaLoraConfig(target_modules=["lin0"], init_lora_weights=False), + ), + ], + name_func=_param_name_func, + ) + def test_target_first_layer_same_type(self, config0, config1): + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + self._check_mixed_outputs(SimpleNet, config0, config1, input, is_commutative=False) + self._check_merging(SimpleNet, config0, config1, input) + self._check_unload(SimpleNet, config0, config1, input) + self._check_disable(SimpleNet, config1, config0, input) + self._check_loading(SimpleNet, config0, config1, input) + + def test_deeply_nested(self): + # a somewhat absurdly nested model using different adapter types + atol = 1e-5 + rtol = 1e-5 + torch.manual_seed(0) + + model = SimpleNet().eval().to(self.torch_device) + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + output_base = model(input) + + config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + + config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) + peft_model.add_adapter("adapter1", config1) + + config2 = AdaLoraConfig(r=4, lora_alpha=4, target_modules=["lin1"], init_lora_weights=False) + peft_model.add_adapter("adapter2", config2) + + config3 = LoKrConfig(r=4, alpha=4, target_modules=["lin0", "lin1"], init_weights=False) + peft_model.add_adapter("adapter3", config3) + + config4 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + peft_model.add_adapter("adapter4", config4) + + peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) + output_mixed = peft_model(input) + self.assertTrue(torch.isfinite(output_base).all()) + self.assertFalse(torch.allclose(output_base, output_mixed, atol=atol, rtol=rtol)) + + # test disabling all adapters + with peft_model.disable_adapter(): + output_disabled = peft_model(input) + self.assertTrue(torch.isfinite(output_disabled).all()) + self.assertTrue(torch.allclose(output_base, output_disabled, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_mixed, output_disabled, atol=atol, rtol=rtol)) + + # merge and unload all adapters + model_copy = copy.deepcopy(peft_model) + model = model_copy.merge_and_unload() + output_merged = model(input) + self.assertTrue(torch.isfinite(output_merged).all()) + self.assertTrue(torch.allclose(output_mixed, output_merged, atol=atol, rtol=rtol)) + + # merge and unload only adapter1 and adapter3 + model_copy = copy.deepcopy(peft_model) + model_copy.set_adapter(["adapter1", "adapter3"]) + output_13 = model_copy(input) + self.assertTrue(torch.isfinite(output_13).all()) + self.assertFalse(torch.allclose(output_mixed, output_13, atol=atol, rtol=rtol)) + + model_copy.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3", "adapter4"]) + model_merged_unloaded = model_copy.merge_and_unload(adapter_names=["adapter1", "adapter3"]) + output_merged_13 = model_merged_unloaded(input) + self.assertTrue(torch.isfinite(output_merged_13).all()) + self.assertTrue(torch.allclose(output_13, output_merged_13, atol=atol, rtol=rtol)) + + # test unloading + model_copy = copy.deepcopy(peft_model) + model_unloaded = model_copy.unload() + output_unloaded = model_unloaded(input) + self.assertTrue(torch.isfinite(output_unloaded).all()) + self.assertTrue(torch.allclose(output_base, output_unloaded, atol=atol, rtol=rtol)) + + def test_delete_adapter(self): + atol = 1e-5 + rtol = 1e-5 + torch.manual_seed(0) + + model = SimpleNet().eval().to(self.torch_device) + input = torch.arange(90).reshape(9, 10).to(self.torch_device) + output_base = model(input) + + # create adapter0 + torch.manual_seed(0) + config0 = LoraConfig(r=4, lora_alpha=4, target_modules=["lin0", "lin1"], init_lora_weights=False) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + output_0 = peft_model(input) + self.assertFalse(torch.allclose(output_base, output_0, atol=atol, rtol=rtol)) + + # add adapter1 + torch.manual_seed(1) + config1 = LoHaConfig(r=4, alpha=4, target_modules=["lin0"], init_weights=False) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + output_01 = peft_model(input) + self.assertFalse(torch.allclose(output_base, output_01, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_0, output_01, atol=atol, rtol=rtol)) + + # delete adapter1 + peft_model.delete_adapter("adapter1") + self.assertEqual(peft_model.active_adapters, ["adapter0"]) + output_deleted_1 = peft_model(input) + self.assertTrue(torch.allclose(output_0, output_deleted_1, atol=atol, rtol=rtol)) + + msg = re.escape("Adapter(s) ['adapter1'] not found, available adapters: ['adapter0']") + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.set_adapter(["adapter0", "adapter1"]) + + # re-add adapter1 + torch.manual_seed(1) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + output_01_readded = peft_model(input) + self.assertFalse(torch.allclose(output_base, output_01_readded, atol=atol, rtol=rtol)) + + # same as above, but this time delete adapter0 first + torch.manual_seed(0) + model = SimpleNet().eval().to(self.torch_device) + torch.manual_seed(0) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + torch.manual_seed(1) + peft_model.add_adapter("adapter1", config1) + peft_model.delete_adapter("adapter0") + self.assertEqual(peft_model.active_adapters, ["adapter1"]) + output_deleted_0 = peft_model(input) + self.assertFalse(torch.allclose(output_deleted_0, output_base, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output_deleted_0, output_01, atol=atol, rtol=rtol)) + + msg = re.escape("Adapter(s) ['adapter0'] not found, available adapters: ['adapter1']") + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.set_adapter(["adapter0", "adapter1"]) + + peft_model.delete_adapter("adapter1") + self.assertEqual(peft_model.active_adapters, []) + output_deleted_01 = peft_model(input) + self.assertTrue(torch.allclose(output_deleted_01, output_base, atol=atol, rtol=rtol)) + + def test_modules_to_save(self): + model = SimpleNet().eval().to(self.torch_device) + config0 = LoraConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + + # adding a second adapter with same modules_to_save is not allowed + # TODO: theoretically, we could allow this if it's the same target layer + config1 = LoHaConfig(target_modules=["lin0"], modules_to_save=["lin1"]) + peft_model.add_adapter("adapter1", config1) + msg = "Only one adapter can be set at a time for modules_to_save" + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.set_adapter(["adapter0", "adapter1"]) + + def test_get_nb_trainable_parameters(self): + model = SimpleNet().eval().to(self.torch_device) + config0 = LoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + trainable_params0, all_param0 = peft_model.get_nb_trainable_parameters() + + params_base = 262 + params_lora = sum(p.numel() for n, p in model.named_parameters() if "adapter0" in n) + self.assertEqual(trainable_params0, params_lora) + self.assertEqual(all_param0, params_base + params_lora) + + config1 = LoHaConfig(target_modules=["lin1"]) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + params_loha = sum(p.numel() for n, p in model.named_parameters() if "adapter1" in n) + trainable_params1, all_param1 = peft_model.get_nb_trainable_parameters() + self.assertEqual(trainable_params1, params_lora + params_loha) + self.assertEqual(all_param1, params_base + params_lora + params_loha) + + config2 = AdaLoraConfig(target_modules=["lin0", "lin1"]) + peft_model.add_adapter("adapter2", config2) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) + params_adalora = sum(p.numel() for n, p in model.named_parameters() if "adapter2" in n) + trainable_params2, all_param2 = peft_model.get_nb_trainable_parameters() + # remove 2 params because we need to exclude "ranknum" for AdaLora trainable params + self.assertEqual(trainable_params2, params_lora + params_loha + params_adalora - 2) + self.assertEqual(all_param2, params_base + params_lora + params_loha + params_adalora) + + def test_incompatible_config_raises(self): + model = SimpleNet().eval().to(self.torch_device) + config0 = LoraConfig(target_modules=["lin0"]) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + + config1 = PrefixTuningConfig() + msg = "The provided `peft_type` 'PREFIX_TUNING' is not compatible with the `PeftMixedModel`." + with self.assertRaisesRegex(ValueError, expected_regex=msg): + peft_model.add_adapter("adapter1", config1) + + def test_decoder_model(self): + # test a somewhat realistic model instead of a toy model + torch.manual_seed(0) + + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + input_ids = torch.tensor([[1, 1, 1], [1, 2, 1]]).to(self.torch_device) + attention_mask = torch.tensor([[1, 1, 1], [1, 0, 1]]).to(self.torch_device) + input_dict = { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + output_base = model.generate(**input_dict) + + torch.manual_seed(0) + config0 = LoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) + peft_model = get_peft_model(model, config0, "adapter0", mixed=True) + output0 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output0).all()) + self.assertFalse(torch.allclose(output_base, output0)) + + torch.manual_seed(1) + config1 = LoHaConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) + peft_model.add_adapter("adapter1", config1) + peft_model.set_adapter(["adapter0", "adapter1"]) + output1 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output1).all()) + self.assertFalse(torch.allclose(output0, output1)) + + torch.manual_seed(2) + config2 = AdaLoraConfig(task_type="CAUSAL_LM", init_lora_weights=False) + peft_model.add_adapter("adapter2", config2) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2"]) + output2 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output2).all()) + self.assertFalse(torch.allclose(output1, output2)) + + torch.manual_seed(3) + config3 = LoKrConfig(task_type="CAUSAL_LM", target_modules=["q_proj", "v_proj"], init_weights=False) + peft_model.add_adapter("adapter3", config3) + peft_model.set_adapter(["adapter0", "adapter1", "adapter2", "adapter3"]) + output3 = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output3).all()) + self.assertFalse(torch.allclose(output2, output3)) + + with peft_model.disable_adapter(): + output_disabled = peft_model.generate(**input_dict) + self.assertTrue(torch.isfinite(output_disabled).all()) + self.assertTrue(torch.allclose(output_base, output_disabled)) + + model_unloaded = peft_model.merge_and_unload() + output_unloaded = model_unloaded.generate(**input_dict) + self.assertTrue(torch.isfinite(output_unloaded).all()) + self.assertTrue(torch.allclose(output3, output_unloaded)) + + with tempfile.TemporaryDirectory() as tmp_dir: + # save adapter0 (use normal PeftModel, because PeftMixedModel does not support saving) + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + torch.manual_seed(0) + peft_model = get_peft_model(model, config0, "adapter0") + output0_save = peft_model(**input_dict).logits + self.assertTrue(torch.isfinite(output0_save).all()) + peft_model.save_pretrained(tmp_dir) + + # save adapter1 + torch.manual_seed(0) + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + torch.manual_seed(1) + peft_model = get_peft_model(model, config1, "adapter1") + output1_save = peft_model(**input_dict).logits + self.assertTrue(torch.isfinite(output1_save).all()) + peft_model.save_pretrained(tmp_dir) + + # load adapter0 and adapter1 + model = AutoModelForCausalLM.from_pretrained(model_id).eval().to(self.torch_device) + peft_model = PeftMixedModel.from_pretrained(model, os.path.join(tmp_dir, "adapter0"), "adapter0") + peft_model.load_adapter(os.path.join(tmp_dir, "adapter1"), "adapter1") + peft_model.set_adapter(["adapter0", "adapter1"]) + output01_loaded = peft_model(**input_dict).logits + + atol, rtol = 1e-3, 1e-3 + self.assertTrue(torch.isfinite(output01_loaded).all()) + self.assertFalse(torch.allclose(output0_save, output01_loaded, atol=atol, rtol=rtol)) + self.assertFalse(torch.allclose(output1_save, output01_loaded, atol=atol, rtol=rtol))