Skip to content

Commit cf88fbb

Browse files
BenjaminBossanArthurZucker
authored andcommitted
FIX: Bug in PEFT integration delete_adapter method (#41252)
The main content of this PR is to fix a bug in the delete_adapter method of the PeftAdapterMixin. Previously, it did not take into account auxiliary modules from PEFT, e.g. those added by modules_to_save. This PR fixes this oversight. Note that the PR uses a new functionality from PEFT that exposes integration functions like delete_adapter. Those will be contained in the next PEFT release, 0.18.0 (yet unreleased). Therefore, the bug is only fixed when users have a PEFT version fullfilling this requirement. I ensured that with old PEFT versions, the integration still works the same as previously. The newly added test for this is skipped if the PEFT version is too low. (Note: I tested locally with that the test will pass with PEFT 0.18.0) While working on this, I also cleaned up the following: - The active_adapter property has been deprecated for more than 2 years (#26407). It is safe to remove it now. - There were numerous small errors or outdated pieces of information in the docstrings, which have been addressed. When PEFT < 0.18.0 is used, although we cannot delete modules_to_save, we can still detect them and warn about it.
1 parent 531bb75 commit cf88fbb

File tree

2 files changed

+110
-58
lines changed

2 files changed

+110
-58
lines changed

src/transformers/integrations/peft.py

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import importlib
1616
import inspect
1717
import re
18-
import warnings
1918
from typing import Any, Optional, Union
2019

2120
from packaging import version
@@ -70,14 +69,9 @@ class PeftAdapterMixin:
7069
more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT
7170
library: https://huggingface.co/docs/peft/index
7271
73-
Currently supported PEFT methods are all non-prefix tuning methods. Below is the list of supported PEFT methods
74-
that anyone can load, train and run with this mixin class:
75-
- Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora
76-
- IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3
77-
- AdaLora: https://huggingface.co/papers/2303.10512
78-
79-
Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable"
80-
into a torch module. For using these methods, please refer to the usage guide of PEFT library.
72+
Currently supported PEFT methods are all non-prompt learning methods (LoRA, IA³, etc.). Other PEFT models such as
73+
prompt tuning, prompt learning are out of scope as these adapters are not "injectable" into a torch module. For
74+
using these methods, please refer to the usage guide of PEFT library.
8175
8276
With this mixin, if the correct PEFT version is installed, it is possible to:
8377
@@ -110,24 +104,21 @@ def load_adapter(
110104
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
111105
invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft
112106
113-
Requires peft as a backend to load the adapter weights.
107+
Requires PEFT to be installed as a backend to load the adapter weights.
114108
115109
Args:
116110
peft_model_id (`str`, *optional*):
117111
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
118112
and adapter weights.
119113
adapter_name (`str`, *optional*):
120-
The adapter name to use. If not set, will use the default adapter.
114+
The adapter name to use. If not set, will use the name "default".
121115
revision (`str`, *optional*, defaults to `"main"`):
122116
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
123117
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
124118
identifier allowed by git.
125119
126-
<Tip>
127-
128-
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
129-
130-
</Tip>
120+
> [!TIP]
121+
> To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
131122
132123
token (`str`, `optional`):
133124
Whether to use authentication token to load the remote folder. Useful to load private repositories
@@ -151,11 +142,11 @@ def load_adapter(
151142
offload_index (`int`, `optional`):
152143
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
153144
peft_config (`dict[str, Any]`, *optional*):
154-
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
155-
methods. This argument is used in case users directly pass PEFT state dicts
145+
The configuration of the adapter to add, supported adapters are all non-prompt learning configs (LoRA,
146+
IA³, etc). This argument is used in case users directly pass PEFT state dicts.
156147
adapter_state_dict (`dict[str, torch.Tensor]`, *optional*):
157148
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
158-
dicts
149+
dicts.
159150
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
160151
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
161152
Requires PEFT version 0.13.0 or higher.
@@ -320,10 +311,12 @@ def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> Non
320311
name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
321312
default adapter name).
322313
314+
Note that the newly added adapter is not automatically activated. To activate it, use `model.set_adapter`.
315+
323316
Args:
324317
adapter_config (`~peft.PeftConfig`):
325-
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
326-
methods
318+
The configuration of the adapter to add, supported adapters are non-prompt learning methods (LoRA,
319+
IA³, etc.).
327320
adapter_name (`str`, *optional*, defaults to `"default"`):
328321
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
329322
"""
@@ -470,13 +463,6 @@ def active_adapters(self) -> list[str]:
470463

471464
return active_adapters
472465

473-
def active_adapter(self) -> str:
474-
warnings.warn(
475-
"The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning
476-
)
477-
478-
return self.active_adapters()[0]
479-
480466
def get_adapter_state_dict(self, adapter_name: Optional[str] = None, state_dict: Optional[dict] = None) -> dict:
481467
"""
482468
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
@@ -564,34 +550,47 @@ def _dispatch_accelerate_model(
564550

565551
def delete_adapter(self, adapter_names: Union[list[str], str]) -> None:
566552
"""
567-
Delete an adapter's LoRA layers from the underlying model.
553+
Delete a PEFT adapter from the underlying model.
568554
569555
Args:
570556
adapter_names (`Union[list[str], str]`):
571557
The name(s) of the adapter(s) to delete.
572-
573-
Example:
574-
575-
```py
576-
from diffusers import AutoPipelineForText2Image
577-
import torch
578-
579-
pipeline = AutoPipelineForText2Image.from_pretrained(
580-
"stabilityai/stable-diffusion-xl-base-1.0", dtype=torch.float16
581-
).to("cuda")
582-
pipeline.load_lora_weights(
583-
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
584-
)
585-
pipeline.delete_adapters("cinematic")
586-
```
587558
"""
588559

589560
check_peft_version(min_version=MIN_PEFT_VERSION)
561+
min_version_delete_adapter = "0.18.0"
590562

591563
if not self._hf_peft_config_loaded:
592564
raise ValueError("No adapter loaded. Please load an adapter first.")
593565

594-
from peft.tuners.tuners_utils import BaseTunerLayer
566+
# TODO: delete old version once support for PEFT < 0.18.0 is dropped
567+
def old_delete_adapter(model, adapter_name, prefix=None):
568+
from peft.tuners.tuners_utils import BaseTunerLayer
569+
from peft.utils import ModulesToSaveWrapper
570+
571+
has_modules_to_save = False
572+
for module in model.modules():
573+
if isinstance(module, ModulesToSaveWrapper):
574+
has_modules_to_save |= True
575+
continue
576+
if isinstance(module, BaseTunerLayer):
577+
if hasattr(module, "delete_adapter"):
578+
module.delete_adapter(adapter_name)
579+
else:
580+
raise ValueError(
581+
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
582+
)
583+
584+
if has_modules_to_save:
585+
logger.warning(
586+
"The deleted adapter contains modules_to_save, which could not be deleted. For this to work, PEFT version "
587+
f">= {min_version_delete_adapter} is required."
588+
)
589+
590+
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_delete_adapter):
591+
from peft.functional import delete_adapter
592+
else:
593+
delete_adapter = old_delete_adapter
595594

596595
if isinstance(adapter_names, str):
597596
adapter_names = [adapter_names]
@@ -603,16 +602,9 @@ def delete_adapter(self, adapter_names: Union[list[str], str]) -> None:
603602
f"The following adapter(s) are not present and cannot be deleted: {', '.join(missing_adapters)}"
604603
)
605604

606-
for adapter_name in adapter_names:
607-
for module in self.modules():
608-
if isinstance(module, BaseTunerLayer):
609-
if hasattr(module, "delete_adapter"):
610-
module.delete_adapter(adapter_name)
611-
else:
612-
raise ValueError(
613-
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
614-
)
615-
605+
prefixes = [f"{self.peft_config[adapter_name].peft_type.value.lower()}_" for adapter_name in adapter_names]
606+
for adapter_name, prefix in zip(adapter_names, prefixes):
607+
delete_adapter(self, adapter_name=adapter_name, prefix=prefix)
616608
# For transformers integration - we need to pop the adapter from the config
617609
if getattr(self, "_hf_peft_config_loaded", False) and hasattr(self, "peft_config"):
618610
self.peft_config.pop(adapter_name, None)

tests/peft_integration/test_peft_integration.py

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from datasets import Dataset, DatasetDict
2121
from huggingface_hub import hf_hub_download
2222
from packaging import version
23+
from torch import nn
2324

2425
from transformers import (
2526
AutoModelForCausalLM,
@@ -337,11 +338,9 @@ def test_peft_add_multi_adapter(self):
337338

338339
model.set_adapter("default")
339340
self.assertTrue(model.active_adapters() == ["default"])
340-
self.assertTrue(model.active_adapter() == "default")
341341

342342
model.set_adapter("adapter-2")
343343
self.assertTrue(model.active_adapters() == ["adapter-2"])
344-
self.assertTrue(model.active_adapter() == "adapter-2")
345344

346345
# Logits comparison
347346
self.assertFalse(
@@ -351,7 +350,6 @@ def test_peft_add_multi_adapter(self):
351350

352351
model.set_adapter(["adapter-2", "default"])
353352
self.assertTrue(model.active_adapters() == ["adapter-2", "default"])
354-
self.assertTrue(model.active_adapter() == "adapter-2")
355353

356354
logits_adapter_mixed = model(dummy_input)
357355
self.assertFalse(
@@ -429,6 +427,68 @@ def test_delete_adapter(self):
429427
self.assertNotIn("adapter_1", model.peft_config)
430428
self.assertIn("adapter_2", model.peft_config)
431429

430+
def test_delete_adapter_with_modules_to_save(self):
431+
"""
432+
Ensure that modules_to_save is accounted for when deleting an adapter.
433+
"""
434+
min_version_delete_adapter = "0.18.0"
435+
if version.parse(importlib.metadata.version("peft")) < version.parse(min_version_delete_adapter):
436+
self.skipTest("Correctly deleting modules_to_save only works with PEFT >= 0.18.0")
437+
438+
from peft import LoraConfig
439+
440+
# the test assumes a specific model architecture, so only test this one:
441+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
442+
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
443+
peft_config = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
444+
model.add_adapter(peft_config, adapter_name="adapter_1")
445+
446+
# sanity checks
447+
self.assertIn("adapter_1", model.peft_config)
448+
self.assertNotIsInstance(model.lm_head, nn.Linear) # a ModulesToSaveWrapper
449+
self.assertTrue(hasattr(model.lm_head, "modules_to_save"))
450+
self.assertTrue("adapter_1" in model.lm_head.modules_to_save)
451+
452+
# now delete the adapter
453+
model.delete_adapter("adapter_1")
454+
self.assertFalse(hasattr(model, "peft_config"))
455+
self.assertFalse("adapter_1" in model.lm_head.modules_to_save)
456+
self.assertFalse(model.lm_head.modules_to_save) # i.e. empty ModuleDict
457+
458+
def test_delete_adapter_with_modules_to_save_old_peft_warns(self):
459+
"""
460+
When PEFT < 0.18.0 is being used, modules_to_save are not deleted but the user should get a warning.
461+
"""
462+
from peft import LoraConfig
463+
464+
peft_ge_018 = version.parse(importlib.metadata.version("peft")) >= version.parse("0.18.0")
465+
logger = logging.get_logger("transformers.integrations.peft")
466+
warn_msg = "The deleted adapter contains modules_to_save"
467+
# the test assumes a specific model architecture, so only test this one:
468+
model_id = "hf-internal-testing/tiny-random-OPTForCausalLM"
469+
470+
# first a sanity check: when there is no modules_to_save, there is also no warning
471+
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
472+
peft_config_0 = LoraConfig(init_lora_weights=False)
473+
model.add_adapter(peft_config_0, adapter_name="adapter_1")
474+
with CaptureLogger(logger) as cl:
475+
model.delete_adapter("adapter_1")
476+
assert warn_msg not in cl.out
477+
478+
# now test a model with modules_to_save
479+
model = AutoModelForCausalLM.from_pretrained(model_id).to(torch_device)
480+
peft_config_1 = LoraConfig(init_lora_weights=False, modules_to_save=["lm_head"])
481+
model.add_adapter(peft_config_1, adapter_name="adapter_1")
482+
with CaptureLogger(logger) as cl:
483+
model.delete_adapter("adapter_1")
484+
485+
if peft_ge_018:
486+
self.assertTrue("adapter_1" not in model.lm_head.modules_to_save)
487+
assert warn_msg not in cl.out
488+
else:
489+
self.assertTrue("adapter_1" in model.lm_head.modules_to_save)
490+
assert warn_msg in cl.out
491+
432492
@require_torch_accelerator
433493
@require_bitsandbytes
434494
def test_peft_from_pretrained_kwargs(self):

0 commit comments

Comments
 (0)