diff --git a/src/transformers/integrations/peft.py b/src/transformers/integrations/peft.py index 4fba01df425a..22261eecad0b 100644 --- a/src/transformers/integrations/peft.py +++ b/src/transformers/integrations/peft.py @@ -15,7 +15,6 @@ import importlib import inspect import re -import warnings from typing import Any, Optional, Union from packaging import version @@ -70,14 +69,9 @@ class PeftAdapterMixin: more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT library: https://huggingface.co/docs/peft/index - Currently supported PEFT methods are all non-prefix tuning methods. Below is the list of supported PEFT methods - that anyone can load, train and run with this mixin class: - - Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora - - IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3 - - AdaLora: https://huggingface.co/papers/2303.10512 - - Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable" - into a torch module. For using these methods, please refer to the usage guide of PEFT library. + Currently supported PEFT methods are all non-prompt learning methods (LoRA, IA³, etc.). Other PEFT models such as + prompt tuning, prompt learning are out of scope as these adapters are not "injectable" into a torch module. For + using these methods, please refer to the usage guide of PEFT library. With this mixin, if the correct PEFT version is installed, it is possible to: @@ -110,24 +104,21 @@ def load_adapter( Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft - Requires peft as a backend to load the adapter weights. + Requires PEFT to be installed as a backend to load the adapter weights. Args: peft_model_id (`str`, *optional*): The identifier of the model to look for on the Hub, or a local path to the saved adapter config file and adapter weights. adapter_name (`str`, *optional*): - The adapter name to use. If not set, will use the default adapter. + The adapter name to use. If not set, will use the name "default". revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git. - - - To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. - - + > [!TIP] + > To test a pull request you made on the Hub, you can pass `revision="refs/pr/"`. token (`str`, `optional`): Whether to use authentication token to load the remote folder. Useful to load private repositories @@ -151,11 +142,11 @@ def load_adapter( offload_index (`int`, `optional`): `offload_index` argument to be passed to `accelerate.dispatch_model` method. peft_config (`dict[str, Any]`, *optional*): - The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts - methods. This argument is used in case users directly pass PEFT state dicts + The configuration of the adapter to add, supported adapters are all non-prompt learning configs (LoRA, + IA³, etc). This argument is used in case users directly pass PEFT state dicts. adapter_state_dict (`dict[str, torch.Tensor]`, *optional*): The state dict of the adapter to load. This argument is used in case users directly pass PEFT state - dicts + dicts. low_cpu_mem_usage (`bool`, *optional*, defaults to `False`): Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process. Requires PEFT version 0.13.0 or higher. @@ -320,10 +311,12 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the default adapter name). + Note that the newly added adapter is not automatically activated. To activate it, use `model.set_adapter`. + Args: adapter_config (`~peft.PeftConfig`): - The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts - methods + The configuration of the adapter to add, supported adapters are non-prompt learning methods (LoRA, + IA³, etc.). adapter_name (`str`, *optional*, defaults to `"default"`): The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. """ @@ -470,13 +463,6 @@ def active_adapters(self) -> list[str]: return active_adapters - def active_adapter(self) -> str: - warnings.warn( - "The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning - ) - - return self.active_adapters()[0] - def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict: """ If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT @@ -564,34 +550,47 @@ def _dispatch_accelerate_model( def delete_adapter(self, adapter_names: Union[list[str], str]) -> None: """ - Delete an adapter's LoRA layers from the underlying model. + Delete a PEFT adapter from the underlying model. Args: adapter_names (`Union[list[str], str]`): The name(s) of the adapter(s) to delete. - - Example: - - ```py - from diffusers import AutoPipelineForText2Image - import torch - - pipeline = AutoPipelineForText2Image.from_pretrained( - "stabilityai/stable-diffusion-xl-base-1.0", dtype=torch.float16 - ).to("cuda") - pipeline.load_lora_weights( - "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic" - ) - pipeline.delete_adapters("cinematic") - ``` """ check_peft_version(min_version=MIN_PEFT_VERSION) + min_version_delete_adapter = "0.18.0" if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") - from peft.tuners.tuners_utils import BaseTunerLayer + # TODO: delete old version once support for PEFT < 0.18.0 is dropped + def old_delete_adapter(model, adapter_name, prefix=None): + from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import ModulesToSaveWrapper + + has_modules_to_save = False + for module in model.modules(): + if isinstance(module, ModulesToSaveWrapper): + has_modules_to_save |= True + continue + if isinstance(module, BaseTunerLayer): + if hasattr(module, "delete_adapter"): + module.delete_adapter(adapter_name) + else: + raise ValueError( + "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" + ) + + if has_modules_to_save: + logger.warning( + "The deleted adapter contains modules_to_save, which could not be deleted. For this to work, PEFT version " + f">= {min_version_delete_adapter} is required." + ) + + if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_delete_adapter): + from peft.functional import delete_adapter + else: + delete_adapter = old_delete_adapter if isinstance(adapter_names, str): adapter_names = [adapter_names] @@ -603,16 +602,9 @@ def delete_adapter(self, adapter_names: Union[list[str], str]) -> None: f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}" ) - for adapter_name in adapter_names: - for module in self.modules(): - if isinstance(module, BaseTunerLayer): - if hasattr(module, "delete_adapter"): - module.delete_adapter(adapter_name) - else: - raise ValueError( - "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1" - ) - + prefixes = [f"{self.peft_config[adapter_name].peft_type.value.lower()}_" for adapter_name in adapter_names] + for adapter_name, prefix in zip(adapter_names, prefixes): + delete_adapter(self, adapter_name=adapter_name, prefix=prefix) # For transformers integration - we need to pop the adapter from the config if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"): self.peft_config.pop(adapter_name, None) diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 616763aa92e7..ad0978164043 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -20,6 +20,7 @@ from datasets import Dataset, DatasetDict from huggingface_hub import hf_hub_download from packaging import version +from torch import nn from transformers import ( AutoModelForCausalLM, @@ -337,11 +338,9 @@ def test_peft_add_multi_adapter(self): model.set_adapter("default") self.assertTrue(model.active_adapters() == ["default"]) - self.assertTrue(model.active_adapter() == "default") model.set_adapter("adapter-2") self.assertTrue(model.active_adapters() == ["adapter-2"]) - self.assertTrue(model.active_adapter() == "adapter-2") # Logits comparison self.assertFalse( @@ -351,7 +350,6 @@ def test_peft_add_multi_adapter(self): model.set_adapter(["adapter-2", "default"]) self.assertTrue(model.active_adapters() == ["adapter-2", "default"]) - self.assertTrue(model.active_adapter() == "adapter-2") logits_adapter_mixed = model(dummy_input) self.assertFalse( @@ -429,6 +427,68 @@ def test_delete_adapter(self): self.assertNotIn("adapter_1", model.peft_config) self.assertIn("adapter_2", model.peft_config) + def test_delete_adapter_with_modules_to_save(self): + """ + Ensure that modules_to_save is accounted for when deleting an adapter. + """ + min_version_delete_adapter = "0.18.0" + if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_delete_adapter): + self.skipTest("Correctly deleting modules_to_save only works with PEFT >= 0.18.0") + + from peft import LoraConfig + + # the test assumes a specific model architecture, so only test this one: + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"]) + model.add_adapter(peft_config, adapter_name="adapter_1") + + # sanity checks + self.assertIn("adapter_1", model.peft_config) + self.assertNotIsInstance(model.lm_head, nn.Linear) # a ModulesToSaveWrapper + self.assertTrue(hasattr(model.lm_head, "modules_to_save")) + self.assertTrue("adapter_1" in model.lm_head.modules_to_save) + + # now delete the adapter + model.delete_adapter("adapter_1") + self.assertFalse(hasattr(model, "peft_config")) + self.assertFalse("adapter_1" in model.lm_head.modules_to_save) + self.assertFalse(model.lm_head.modules_to_save) # i.e. empty ModuleDict + + def test_delete_adapter_with_modules_to_save_old_peft_warns(self): + """ + When PEFT < 0.18.0 is being used, modules_to_save are not deleted but the user should get a warning. + """ + from peft import LoraConfig + + peft_ge_018 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.18.0") + logger = logging.get_logger("transformers.integrations.peft") + warn_msg = "The deleted adapter contains modules_to_save" + # the test assumes a specific model architecture, so only test this one: + model_id = "hf-internal-testing/tiny-random-OPTForCausalLM" + + # first a sanity check: when there is no modules_to_save, there is also no warning + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + peft_config_0 = LoraConfig(init_lora_weights=False) + model.add_adapter(peft_config_0, adapter_name="adapter_1") + with CaptureLogger(logger) as cl: + model.delete_adapter("adapter_1") + assert warn_msg not in cl.out + + # now test a model with modules_to_save + model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device) + peft_config_1 = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"]) + model.add_adapter(peft_config_1, adapter_name="adapter_1") + with CaptureLogger(logger) as cl: + model.delete_adapter("adapter_1") + + if peft_ge_018: + self.assertTrue("adapter_1" not in model.lm_head.modules_to_save) + assert warn_msg not in cl.out + else: + self.assertTrue("adapter_1" in model.lm_head.modules_to_save) + assert warn_msg in cl.out + @require_torch_accelerator @require_bitsandbytes def test_peft_from_pretrained_kwargs(self):