Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 47 additions & 55 deletions src/transformers/integrations/peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import importlib
import inspect
import re
import warnings
from typing import Any, Optional, Union

from packaging import version
Expand Down Expand Up @@ -70,14 +69,9 @@ class PeftAdapterMixin:
more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT
library: https://huggingface.co/docs/peft/index

Currently supported PEFT methods are all non-prefix tuning methods. Below is the list of supported PEFT methods
that anyone can load, train and run with this mixin class:
- Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora
- IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3
- AdaLora: https://huggingface.co/papers/2303.10512

Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable"
into a torch module. For using these methods, please refer to the usage guide of PEFT library.
Currently supported PEFT methods are all non-prompt learning methods (LoRA, IA³, etc.). Other PEFT models such as
prompt tuning, prompt learning are out of scope as these adapters are not "injectable" into a torch module. For
using these methods, please refer to the usage guide of PEFT library.

With this mixin, if the correct PEFT version is installed, it is possible to:

Expand Down Expand Up @@ -110,24 +104,21 @@ def load_adapter(
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft

Requires peft as a backend to load the adapter weights.
Requires PEFT to be installed as a backend to load the adapter weights.

Args:
peft_model_id (`str`, *optional*):
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
and adapter weights.
adapter_name (`str`, *optional*):
The adapter name to use. If not set, will use the default adapter.
The adapter name to use. If not set, will use the name "default".
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.

<Tip>

To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.

</Tip>
> [!TIP]
> To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.

token (`str`, `optional`):
Whether to use authentication token to load the remote folder. Useful to load private repositories
Expand All @@ -151,11 +142,11 @@ def load_adapter(
offload_index (`int`, `optional`):
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
peft_config (`dict[str, Any]`, *optional*):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods. This argument is used in case users directly pass PEFT state dicts
The configuration of the adapter to add, supported adapters are all non-prompt learning configs (LoRA,
IA³, etc). This argument is used in case users directly pass PEFT state dicts.
adapter_state_dict (`dict[str, torch.Tensor]`, *optional*):
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
dicts
dicts.
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
Requires PEFT version 0.13.0 or higher.
Expand Down Expand Up @@ -320,10 +311,12 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non
name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
default adapter name).

Note that the newly added adapter is not automatically activated. To activate it, use `model.set_adapter`.

Args:
adapter_config (`~peft.PeftConfig`):
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
methods
The configuration of the adapter to add, supported adapters are non-prompt learning methods (LoRA,
IA³, etc.).
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
Expand Down Expand Up @@ -470,13 +463,6 @@ def active_adapters(self) -> list[str]:

return active_adapters

def active_adapter(self) -> str:
warnings.warn(
"The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
)

return self.active_adapters()[0]

def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict:
"""
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
Expand Down Expand Up @@ -564,34 +550,47 @@ def _dispatch_accelerate_model(

def delete_adapter(self, adapter_names: Union[list[str], str]) -> None:
"""
Delete an adapter's LoRA layers from the underlying model.
Delete a PEFT adapter from the underlying model.

Args:
adapter_names (`Union[list[str], str]`):
The name(s) of the adapter(s) to delete.

Example:

```py
from diffusers import AutoPipelineForText2Image
import torch

pipeline = AutoPipelineForText2Image.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", dtype=torch.float16
).to("cuda")
pipeline.load_lora_weights(
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
)
pipeline.delete_adapters("cinematic")
```
Comment on lines -573 to -586
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't want to keep a doc anymore here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The doc was unfitting, as it used diffusers, not transformers. Moreover, the API is so simple that it really doesn't need an example.

"""

check_peft_version(min_version=MIN_PEFT_VERSION)
min_version_delete_adapter = "0.18.0"

if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")

from peft.tuners.tuners_utils import BaseTunerLayer
# TODO: delete old version once support for PEFT < 0.18.0 is dropped
def old_delete_adapter(model, adapter_name, prefix=None):
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper

has_modules_to_save = False
for module in model.modules():
if isinstance(module, ModulesToSaveWrapper):
has_modules_to_save |= True
continue
if isinstance(module, BaseTunerLayer):
if hasattr(module, "delete_adapter"):
module.delete_adapter(adapter_name)
else:
raise ValueError(
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
)

if has_modules_to_save:
logger.warning(
"The deleted adapter contains modules_to_save, which could not be deleted. For this to work, PEFT version "
f">= {min_version_delete_adapter} is required."
)

if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_delete_adapter):
from peft.functional import delete_adapter
else:
delete_adapter = old_delete_adapter

if isinstance(adapter_names, str):
adapter_names = [adapter_names]
Expand All @@ -603,16 +602,9 @@ def delete_adapter(self, adapter_names: Union[list[str], str]) -> None:
f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}"
)

for adapter_name in adapter_names:
for module in self.modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "delete_adapter"):
module.delete_adapter(adapter_name)
else:
raise ValueError(
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
)

prefixes = [f"{self.peft_config[adapter_name].peft_type.value.lower()}_" for adapter_name in adapter_names]
for adapter_name, prefix in zip(adapter_names, prefixes):
delete_adapter(self, adapter_name=adapter_name, prefix=prefix)
# For transformers integration - we need to pop the adapter from the config
if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
Expand Down
66 changes: 63 additions & 3 deletions tests/peft_integration/test_peft_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from datasets import Dataset, DatasetDict
from huggingface_hub import hf_hub_download
from packaging import version
from torch import nn

from transformers import (
AutoModelForCausalLM,
Expand Down Expand Up @@ -337,11 +338,9 @@ def test_peft_add_multi_adapter(self):

model.set_adapter("default")
self.assertTrue(model.active_adapters() == ["default"])
self.assertTrue(model.active_adapter() == "default")

model.set_adapter("adapter-2")
self.assertTrue(model.active_adapters() == ["adapter-2"])
self.assertTrue(model.active_adapter() == "adapter-2")

# Logits comparison
self.assertFalse(
Expand All @@ -351,7 +350,6 @@ def test_peft_add_multi_adapter(self):

model.set_adapter(["adapter-2", "default"])
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
self.assertTrue(model.active_adapter() == "adapter-2")

logits_adapter_mixed = model(dummy_input)
self.assertFalse(
Expand Down Expand Up @@ -429,6 +427,68 @@ def test_delete_adapter(self):
self.assertNotIn("adapter_1", model.peft_config)
self.assertIn("adapter_2", model.peft_config)

def test_delete_adapter_with_modules_to_save(self):
"""
Ensure that modules_to_save is accounted for when deleting an adapter.
"""
min_version_delete_adapter = "0.18.0"
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_delete_adapter):
self.skipTest("Correctly deleting modules_to_save only works with PEFT >= 0.18.0")

from peft import LoraConfig

# the test assumes a specific model architecture, so only test this one:
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
model.add_adapter(peft_config, adapter_name="adapter_1")

# sanity checks
self.assertIn("adapter_1", model.peft_config)
self.assertNotIsInstance(model.lm_head, nn.Linear) # a ModulesToSaveWrapper
self.assertTrue(hasattr(model.lm_head, "modules_to_save"))
self.assertTrue("adapter_1" in model.lm_head.modules_to_save)

# now delete the adapter
model.delete_adapter("adapter_1")
self.assertFalse(hasattr(model, "peft_config"))
self.assertFalse("adapter_1" in model.lm_head.modules_to_save)
self.assertFalse(model.lm_head.modules_to_save) # i.e. empty ModuleDict

def test_delete_adapter_with_modules_to_save_old_peft_warns(self):
"""
When PEFT < 0.18.0 is being used, modules_to_save are not deleted but the user should get a warning.
"""
from peft import LoraConfig

peft_ge_018 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.18.0")
logger = logging.get_logger("transformers.integrations.peft")
warn_msg = "The deleted adapter contains modules_to_save"
# the test assumes a specific model architecture, so only test this one:
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"

# first a sanity check: when there is no modules_to_save, there is also no warning
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
peft_config_0 = LoraConfig(init_lora_weights=False)
model.add_adapter(peft_config_0, adapter_name="adapter_1")
with CaptureLogger(logger) as cl:
model.delete_adapter("adapter_1")
assert warn_msg not in cl.out

# now test a model with modules_to_save
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
peft_config_1 = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
model.add_adapter(peft_config_1, adapter_name="adapter_1")
with CaptureLogger(logger) as cl:
model.delete_adapter("adapter_1")

if peft_ge_018:
self.assertTrue("adapter_1" not in model.lm_head.modules_to_save)
assert warn_msg not in cl.out
else:
self.assertTrue("adapter_1" in model.lm_head.modules_to_save)
assert warn_msg in cl.out

@require_torch_accelerator
@require_bitsandbytes
def test_peft_from_pretrained_kwargs(self):
Expand Down