Skip to content

Commit

Permalink
Refactor deletion of adapters
Browse files Browse the repository at this point in the history
Description

The job of deleting an adapter is now transferred to the adapter layer,
instead of the adapter model. This makes it easier for users or other
libraries who don't use the adapter model to delete adapters.

Implementation

The code should now be more generic, relying less on hard-coded
attributes.

As a precaution, I also changed the type of adapter_layer_names from
list to tuple, as it should not be mutated.

When deleting the active adapter, the logic for choosing the new active
adapter has been changed slightly to ensure consistency across layers.
In practice, this should rarely make a difference. An error is now
raised if the last remaining adapter is deleted.

Test coverage has been increased:

- Deleting adapters is now also tested for custom models.
- It is also tested for LoHa, LoKr, not only LoRA.
- I added a test for deleting the non-active adapter.

Not implemented

I did not add adapter deletion to IA³, since it is included in huggingface#980. LMK
if it should be added here instead.
  • Loading branch information
BenjaminBossan committed Nov 9, 2023
1 parent face67d commit 7313053
Show file tree
Hide file tree
Showing 13 changed files with 158 additions and 67 deletions.
3 changes: 2 additions & 1 deletion src/peft/tuners/adalora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
class AdaLoraLayer(LoraLayer):
# List all names of layers that may contain adapter weights
# Note: ranknum doesn't need to be included as it is not an nn.Module
adapter_layer_names = ["lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B"]
adapter_layer_names = ("lora_A", "lora_B", "lora_E", "lora_embedding_A", "lora_embedding_B")
# other_param_names is defined in LoraLayer

def __init__(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@


class IA3Layer(BaseTunerLayer):
# List all names of layers that may contain adapter weights
adapter_layer_names = ["ia3_l"]
# All names of layers that may contain adapter weights
adapter_layer_names = ("ia3_l",)
# All names of other parameters that may contain adapter-related parameters
other_layer_names = ("scaling",)

def __init__(
self,
Expand Down
5 changes: 3 additions & 2 deletions src/peft/tuners/loha/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@


class LoHaLayer(LycorisLayer, nn.Module):
# List all names of layers that may contain adapter weights
adapter_layer_names = ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2"]
# All names of layers that may contain adapter weights
adapter_layer_names = ("hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b", "hada_t1", "hada_t2")
# other_param_names is defined on parent class

def __init__(self):
LycorisLayer.__init__(self)
Expand Down
7 changes: 4 additions & 3 deletions src/peft/tuners/lokr/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@


class LoKrLayer(LycorisLayer, nn.Module):
# List all names of layers that may contain adapter weights
adapter_layer_names = [
# All names of layers that may contain adapter weights
adapter_layer_names = (
"lokr_w1",
"lokr_w1_a",
"lokr_w1_b",
"lokr_w2",
"lokr_w2_a",
"lokr_w2_b",
"lokr_t2",
]
)
# other_param_names is defined on parent class

def __init__(self):
LycorisLayer.__init__(self)
Expand Down
6 changes: 4 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@


class LoraLayer(BaseTunerLayer):
# List all names of layers that may contain adapter weights
adapter_layer_names = ["lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B"]
# All names of layers that may contain (trainable) adapter weights
adapter_layer_names = ("lora_A", "lora_B", "lora_embedding_A", "lora_embedding_B")
# All names of other parameters that may contain adapter-related parameters
other_param_names = ("r", "lora_alpha", "scaling", "lora_dropout")

def __init__(self, in_features: int, out_features: int, **kwargs):
self.r = {}
Expand Down
27 changes: 7 additions & 20 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,29 +661,16 @@ def delete_adapter(self, adapter_name: str):
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LoraLayer):
for attr in [
"r",
"lora_alpha",
"scaling",
"lora_A",
"lora_B",
"lora_embedding_A",
"lora_embedding_B",
"lora_dropout",
]:
if adapter_name in getattr(target, attr):
getattr(target, attr).pop(adapter_name)
if adapter_name in target.active_adapters:
resetting_active_adapter = (
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
)
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
)
target.set_adapter(resetting_active_adapter)
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]

if new_adapter:
self.active_adapter = new_adapter

def merge_and_unload(self, progressbar: bool = False, safe_merge: bool = False):
r"""
Expand Down
20 changes: 9 additions & 11 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ class LycorisLayer(BaseTunerLayer, nn.Module):
r"""
A base layer for LyCORIS like adapters
"""
# adapter_layer_names needs to be defined on the child class
other_param_names = ("r", "alpha", "scaling", "rank_dropout", "module_dropout")

def __init__(self):
self.r = {}
Expand Down Expand Up @@ -391,17 +393,13 @@ def delete_adapter(self, adapter_name: str):
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, LycorisLayer):
for attr in target.adapter_layer_names:
if adapter_name in getattr(target, attr):
getattr(target, attr).pop(adapter_name)
if adapter_name in target.active_adapters:
resetting_active_adapter = (
list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default"
)
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. "
)
target.set_adapter(resetting_active_adapter)
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]

if new_adapter:
self.active_adapter = new_adapter
55 changes: 53 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
import re
import warnings
from abc import ABC, abstractmethod
from typing import Any, Union

Expand Down Expand Up @@ -272,8 +273,10 @@ class BaseTunerLayer(ABC):
"""
active_adapter = None

# List all names of layers that may contain adapter weights
adapter_layer_names: list[str] = []
# All names of layers that may contain adapter (trainable) weights
adapter_layer_names: tuple[str] = ()
# All names of other parameters that may contain adapter-related parameters
other_param_names: tuple[str] = ()

# indicates whether all adapters should be disabled
_disable_adapters: bool = False
Expand Down Expand Up @@ -351,6 +354,54 @@ def set_adapter(self, adapter_names: str | list[str]):

self._active_adapter = adapter_names

def _all_available_adapter_names(self) -> list[str]:
"""Return a sorted list of all available adapter names"""
adapter_names = set()
for name in self.adapter_layer_names + self.other_param_names:
# we check each possible attribute and if it's a dict or ModuleDict, we assume that the keys are the adapter
# names
attr = getattr(self, name)
if hasattr(attr, "keys"):
adapter_names.update(attr.keys())
return sorted(adapter_names)

def delete_adapter(self, adapter_name: str) -> None:
"""
Delete an adapter from the layer
This should be called on all adapter layers, or else we will get an inconsistent state.
This method will also set a new active adapter if the deleted adapter was an active adapter. It is important that
the new adapter is chosen in a deterministic way, so that the same adapter is chosen on all layers.
Args:
adapter_name (`str`): The name of the adapter to delete
"""
for attr in self.adapter_layer_names + self.other_param_names:
if adapter_name in getattr(self, attr):
del getattr(self, attr)[adapter_name]

if adapter_name in self.active_adapters:
# choose a new active adapter
active_adapters = self.active_adapters[:]
active_adapters.remove(adapter_name)
if active_adapters:
self.set_adapter(active_adapters)
else:
# no active adapters left, set a new default adapter
# here we get the list of all adapters existing adapter names and choose the first one
remaining_adapters = self._all_available_adapter_names()
if not remaining_adapters:
raise ValueError("You tried to delete the only adapter in the model, this is not possible.")

new_active_adapter = remaining_adapters[0]
warnings.warn(
f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to "
f"{new_active_adapter}."
)
self.set_adapter(remaining_adapters[0])


def check_target_module_exists(config, key: str) -> bool | re.Match[str] | None:
"""A helper method to check if the passed module's key name matches any of the target modules in the adapter_config.
Expand Down
8 changes: 8 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,14 @@ def run_with_disable(config_kwargs, bias):
# This is bad, there was a warning about the bias when there should not have been any.
self.fail("There should be no warning when bias is set to 'none'")

@parameterized.expand(TEST_CASES)
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(TEST_CASES)
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_feature_extraction_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_delete_inactive_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_inactive_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
Expand Down
76 changes: 52 additions & 24 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
IA3Config,
LoraConfig,
PeftModel,
PeftType,
PrefixTuningConfig,
PromptEncoderConfig,
PromptLearningConfig,
Expand Down Expand Up @@ -815,42 +816,69 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar
self.assertIsNotNone(param.grad)

def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
if issubclass(config_cls, AdaLoraConfig):
# AdaLora does not support adding more than 1 adapter
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if config.peft_type not in supported_peft_types:
return

model = self.transformers_class.from_pretrained(model_id)
if isinstance(config.target_modules, str):
# TODO this should be doable
self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.")

adapter_to_delete = "delete_me"
model = get_peft_model(model, config)
model.add_adapter(adapter_to_delete, config)
model.set_adapter(adapter_to_delete)
model = model.to(self.torch_device)
model.delete_adapter(adapter_to_delete)
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
for attr in attributes_to_check:
self.assertFalse(adapter_to_delete in getattr(target, attr))

def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
# same as test_delete_adapter, but this time an inactive adapter is deleted
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
if config.peft_type not in supported_peft_types:
return

model = self.transformers_class.from_pretrained(model_id)
if isinstance(config.target_modules, str):
# TODO this should be doable
self.skipTest("Multiple adapters cannot currently be added when target_modules is a string.")

adapter_to_delete = "delete_me"
model = get_peft_model(model, config)
model.add_adapter(adapter_to_delete, config)
model.set_adapter(adapter_to_delete)
# "delete_me" is added but not activated
model = model.to(self.torch_device)
model.delete_adapter(adapter_to_delete)
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

if config.peft_type not in ("LORA"):
with self.assertRaises(AttributeError):
model.delete_adapter(adapter_to_delete)
else:
model.delete_adapter(adapter_to_delete)
self.assertFalse(adapter_to_delete in model.peft_config)
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for attr in [
"r",
"lora_alpha",
"scaling",
"lora_A",
"lora_B",
"lora_embedding_A",
"lora_embedding_B",
"lora_dropout",
]:
self.assertFalse(adapter_to_delete in getattr(target, attr))
key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
for attr in attributes_to_check:
self.assertFalse(adapter_to_delete in getattr(target, attr))

def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
Expand Down

0 comments on commit 7313053

Please sign in to comment.