Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Delete IA3 adapter #1153

Merged
merged 1 commit into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@
class IA3Layer(BaseTunerLayer):
# All names of layers that may contain adapter weights
adapter_layer_names = ("ia3_l",)
# All names of other parameters that may contain adapter-related parameters
other_layer_names = ("scaling",)

def __init__(self, base_layer: nn.Module, is_feedforward: bool, **kwargs) -> None:
self.base_layer = base_layer
self.scaling = {}
Comment on lines -30 to -35
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, we do not use self.scaling and $(IA)^3$ does not have this parameter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, thanks for cleaning that up.

self.ia3_l = nn.ParameterDict({})
# Mark the weight as unmerged
self._disable_adapters = False
Expand Down Expand Up @@ -294,7 +291,7 @@ def unmerge(self) -> None:
base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data)

def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
previous_dtype = x.dtype
dtype = previous_dtype = x.dtype
BenjaminBossan marked this conversation as resolved.
Show resolved Hide resolved

if self.disable_adapters:
if self.merged:
Expand Down
33 changes: 28 additions & 5 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ class IA3Model(BaseTuner):
- **peft_config** ([`ia3Config`]): The configuration of the (IA)^3 model.
"""

prefix: str = "ia3_"
alexrs marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name)

Expand Down Expand Up @@ -146,7 +148,7 @@ def _check_target_module_exists(ia3_config, key):

def _mark_only_adapters_as_trainable(self) -> None:
for n, p in self.model.named_parameters():
if "ia3_" not in n:
if self.prefix not in n:
p.requires_grad = False

def _create_and_replace(
Expand Down Expand Up @@ -202,8 +204,7 @@ def _check_target_module_feedforward(ia3_config, key) -> bool:
is_feedforward = any(key.endswith(target_key) for target_key in ia3_config.feedforward_modules)
return is_feedforward

@staticmethod
def _replace_module(parent, child_name, new_module, child):
def _replace_module(self, parent, child_name, new_module, child):
setattr(parent, child_name, new_module)

# child layer wraps the original module, unpack it
Expand All @@ -225,7 +226,7 @@ def _replace_module(parent, child_name, new_module, child):

# dispatch to correct device
for name, module in new_module.named_modules():
if "ia3_" in name:
if self.prefix in name:
module.to(child.weight.device)

def __getattr__(self, name: str):
Expand Down Expand Up @@ -298,7 +299,7 @@ def _unload_and_optionally_merge(
if getattr(self.model, "is_loaded_in_4bit", False):
raise ValueError("Cannot merge ia3 layers when the model is loaded in 4-bit mode")

key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
for key in key_list:
try:
parent, target, target_name = _get_submodules(self.model, key)
Expand Down Expand Up @@ -348,3 +349,25 @@ def unload(self):
model.
"""
return self._unload_and_optionally_merge(merge=False)

def delete_adapter(self, adapter_name: str):
"""
Deletes an existing adapter.

Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in self.peft_config:
raise ValueError(f"Adapter {adapter_name} does not exist")
alexrs marked this conversation as resolved.
Show resolved Hide resolved
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if self.prefix not in key]
new_adapter = None
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
target.delete_adapter(adapter_name)
if new_adapter is None:
new_adapter = target.active_adapters[:]

self.active_adapter = new_adapter or []
10 changes: 5 additions & 5 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,7 +885,7 @@ def _test_training_prompt_learning_tasks(self, model_id, config_cls, config_kwar
self.assertIsNotNone(param.grad)

def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR]
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
config = config_cls(
Expand All @@ -905,7 +905,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
Expand All @@ -923,7 +923,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):

def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
# same as test_delete_adapter, but this time an inactive adapter is deleted
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR]
supported_peft_types = [PeftType.LORA, PeftType.LOHA, PeftType.LOKR, PeftType.IA3]
# IA3 does not support deleting adapters yet, but it just needs to be added
# AdaLora does not support multiple adapters
config = config_cls(
Expand All @@ -943,7 +943,7 @@ def _test_delete_inactive_adapter(self, model_id, config_cls, config_kwargs):
self.assertFalse(adapter_to_delete in model.peft_config)
self.assertEqual(model.active_adapters, ["default"])

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
attributes_to_check = getattr(target, "adapter_layer_names", []) + getattr(target, "other_param_names", [])
Expand Down Expand Up @@ -1038,7 +1038,7 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
key_list = [key for key, _ in model.named_modules()]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
Expand Down
Loading