Skip to content

Commit

Permalink
ENH Delete IA3 adapters (#1153)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrs authored Nov 20, 2023
1 parent f1ecfa6 commit 8351331
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
5 changes: 1 addition & 4 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@
class IA3Layer(BaseTunerLayer):
# All names of layers that may contain adapter weights
adapter_layer_names = ("ia3_l",)
# All names of other parameters that may contain adapter-related parameters
other_layer_names = ("scaling",)

def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
self.base_layer = base_layer
self.scaling = {}
self.ia3_l = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
Expand Down Expand Up @@ -294,7 +291,7 @@ def unmerge(self) -> None:
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
dtype = previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
Expand Down
33 changes: 28 additions & 5 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class IA3Model(BaseTuner):
- **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model.
"""

prefix: str = "ia3_"

def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name)

Expand Down Expand Up @@ -146,7 +148,7 @@ def _check_target_module_exists(ia3_config, key):

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
if "ia3_" not in n:
if self.prefix not in n:
p.requires_grad = False

def _create_and_replace(
Expand Down Expand Up @@ -202,8 +204,7 @@ def _check_target_module_feedforward(ia3_config, key) -> bool:
is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules)
return is_feedforward

@staticmethod
def _replace_module(parent, child_name, new_module, child):
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)

# child layer wraps the original module, unpack it
Expand All @@ -225,7 +226,7 @@ def _replace_module(parent, child_name, new_module, child):

# dispatch to correct device
for name, module in new_module.named_modules():
if "ia3_" in name:
if self.prefix in name:
module.to(child.weight.device)

def __getattr__(self, name: str):
Expand Down Expand Up @@ -298,7 +299,7 @@ def _unload_and_optionally_merge(
if getattr(self.model, "is_loaded_in_4bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode")

key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
try:
parent, target, target_name = _get_submodules(self.model, key)
Expand Down Expand Up @@ -348,3 +349,25 @@ def unload(self):
model.
"""
return self._unload_and_optionally_merge(merge=False)

def delete_adapter(self, adapter_name: str):
"""
Deletes an existing adapter.
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in self.peft_config:
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]

self.active_adapter = new_adapter or []
10 changes: 5 additions & 5 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar
self.assertIsNotNone(param.grad)

def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR]
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
config = config_cls(
Expand All @@ -905,7 +905,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
Expand All @@ -923,7 +923,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):

def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
# same as test_delete_adapter, but this time an inactive adapter is deleted
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR]
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
config = config_cls(
Expand All @@ -943,7 +943,7 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
Expand Down

0 comments on commit 8351331

Please sign in to comment.