From 149aeb3508c05f7f78f54dab582f3c0e3b6605e5 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Fri, 29 Sep 2023 13:25:38 +0200 Subject: [PATCH 1/8] Add add_weighted_adapter to ia3 --- src/peft/tuners/ia3/model.py | 93 +++++++++++++++++++++++++++++++++++- tests/test_custom_models.py | 2 +- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index f4a80e8cbc..56e91accc1 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -15,7 +15,7 @@ import re import warnings -from dataclasses import asdict +from dataclasses import asdict, replace from enum import Enum import torch @@ -28,6 +28,8 @@ TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, _get_submodules, + _is_valid_match, + _freeze_adapter, ) from .layer import Conv2d, IA3Layer, Linear @@ -336,3 +338,92 @@ def merge_and_unload(self, safe_merge: bool = False): self._replace_module(parent, target_name, new_module, target) return self.model + + def delete_adapter(self, adapter_name: str): + """ + Deletes an existing adapter. + + Args: + adapter_name (str): Name of the adapter to be deleted. + """ + if adapter_name not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter_name} does not exist") + del self.peft_config[adapter_name] + + key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, IA3Layer): + for attr in [ + "ia3_l", + "scaling", + ]: + if adapter_name in getattr(target, attr): + getattr(target, attr).pop(adapter_name) + if adapter_name in target.active_adapters: + resetting_active_adapter = ( + list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default" + ) + warnings.warn( + f"Adapter {adapter_name} was active which is now deleted. Setting active adapter to {resetting_active_adapter}. " + ) + target.set_adapter(resetting_active_adapter) + + + def add_weighted_adapter(self, adapters, weights, adapter_name): + """ + This method adds a new adapter by merging the given adapters with the given weights. + + Args: + adapters (`list`): + List of adapter names to be merged. + weights (`list`): + List of weights for each adapter. + adapter_name (`str`): + Name of the new adapter. + """ + if adapter_name in list(self.peft_config.keys()): + return + for adapter in adapters: + if adapter not in list(self.peft_config.keys()): + raise ValueError(f"Adapter {adapter} does not exist") + + target_modules_type = type(self.peft_config[adapters[0]].target_modules) + new_target_modules = set() if target_modules_type == list else "" + for adapter in adapters: + if type(self.peft_config[adapter].target_modules) != target_modules_type: + raise ValueError( + "all adapter configs should follow the same target modules type. " + "Combining adapters with `target_modules` type being a mix of list and string is not supported." + ) + if target_modules_type == list: + new_target_modules |= set(self.peft_config[adapter].target_modules) + else: + new_target_modules += f"({self.peft_config[adapter].target_modules})|" + + new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1] + self.peft_config[adapter_name] = replace( + self.peft_config[adapters[0]], + target_modules=new_target_modules, + ) + self.inject_adapter(self.model, adapter_name) + + # Do we really need that? + _freeze_adapter(self.model, adapter_name) + + key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + for key in key_list: + _, target, _ = _get_submodules(self.model, key) + if isinstance(target, IA3Layer): + if adapter_name in target.ia3_l: + target_ia3_l = target.ia3_l[adapter_name] + else: + continue + + target_ia3_l.data = target_ia3_l.data * 0.0 + for adapter, weight in zip(adapters, weights): + if adapter in target.ia3_l: + current_adapter_ia3_l = target.ia3_l[adapter] + else: + continue + target_ia3_l.data += current_adapter_ia3_l.data * weight * target.scaling[adapter] diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 51568919d0..5e0865d19e 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -816,7 +816,7 @@ def test_multiple_active_adapters_forward( self.assertFalse(torch.allclose(adapter_1_output, combined_output, atol=1e-5)) self.assertFalse(torch.allclose(adapter_2_output, combined_output, atol=1e-5)) - if tuner_method == "lora": + if tuner_method == "lora" or tuner_method == "ia3": # create a weighted adapter combining both adapters and check that # its output is same as setting multiple active adapters peft_model.add_weighted_adapter( From 31237bdadcf488158ab282e8c8f6d5721aef8c6b Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Fri, 29 Sep 2023 14:41:41 +0200 Subject: [PATCH 2/8] Remove unused scaling from IA3 --- src/peft/tuners/ia3/layer.py | 1 - src/peft/tuners/ia3/model.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index c35f3d875c..c40b2b3242 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -34,7 +34,6 @@ def __init__( out_features: int, is_feedforward: bool, ): - self.scaling = {} self.ia3_l = nn.ParameterDict({}) # Mark the weight as unmerged self._disable_adapters = False diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 56e91accc1..393553ee3c 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -426,4 +426,4 @@ def add_weighted_adapter(self, adapters, weights, adapter_name): current_adapter_ia3_l = target.ia3_l[adapter] else: continue - target_ia3_l.data += current_adapter_ia3_l.data * weight * target.scaling[adapter] + target_ia3_l.data += current_adapter_ia3_l.data * weight From eeb1b381841d370ec49ae91a99eef57b61a21668 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Fri, 29 Sep 2023 16:11:54 +0200 Subject: [PATCH 3/8] Improve delete_adapter --- src/peft/tuners/ia3/model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 393553ee3c..d13c6ee795 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -354,12 +354,8 @@ def delete_adapter(self, adapter_name: str): for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, IA3Layer): - for attr in [ - "ia3_l", - "scaling", - ]: - if adapter_name in getattr(target, attr): - getattr(target, attr).pop(adapter_name) + if adapter_name in target.ia3_l: + target.ia3_l.pop(adapter_name) if adapter_name in target.active_adapters: resetting_active_adapter = ( list(self.peft_config.keys())[0] if len(self.peft_config) > 0 else "default" From 0e55abc01022a7ef1e8291cd08f672e4a5d043f7 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Fri, 29 Sep 2023 18:09:18 +0200 Subject: [PATCH 4/8] fix style --- src/peft/tuners/ia3/model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index d13c6ee795..eb78787784 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -27,9 +27,9 @@ TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING, TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, ModulesToSaveWrapper, + _freeze_adapter, _get_submodules, _is_valid_match, - _freeze_adapter, ) from .layer import Conv2d, IA3Layer, Linear @@ -365,7 +365,6 @@ def delete_adapter(self, adapter_name: str): ) target.set_adapter(resetting_active_adapter) - def add_weighted_adapter(self, adapters, weights, adapter_name): """ This method adds a new adapter by merging the given adapters with the given weights. From e98e14803faff869ead3fc8157b33c745bc1a2d9 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Sat, 30 Sep 2023 16:04:32 +0200 Subject: [PATCH 5/8] revert test --- tests/test_custom_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 5e0865d19e..51568919d0 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -816,7 +816,7 @@ def test_multiple_active_adapters_forward( self.assertFalse(torch.allclose(adapter_1_output, combined_output, atol=1e-5)) self.assertFalse(torch.allclose(adapter_2_output, combined_output, atol=1e-5)) - if tuner_method == "lora" or tuner_method == "ia3": + if tuner_method == "lora": # create a weighted adapter combining both adapters and check that # its output is same as setting multiple active adapters peft_model.add_weighted_adapter( From 46b3ad896482f8c201b38de054b58b8e866cd129 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Sun, 1 Oct 2023 18:17:38 +0200 Subject: [PATCH 6/8] add tests --- src/peft/tuners/ia3/layer.py | 1 - src/peft/tuners/ia3/model.py | 21 +++++- tests/testing_common.py | 130 ++++++++++++++++++++++------------- 3 files changed, 101 insertions(+), 51 deletions(-) diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index c40b2b3242..af6264ea47 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -165,7 +165,6 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = previous_dtype = x.dtype - if self.disable_adapters: if self.merged: self.unmerge() diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index eb78787784..349a3af248 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -350,7 +350,7 @@ def delete_adapter(self, adapter_name: str): raise ValueError(f"Adapter {adapter_name} does not exist") del self.peft_config[adapter_name] - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key] for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, IA3Layer): @@ -385,6 +385,8 @@ def add_weighted_adapter(self, adapters, weights, adapter_name): target_modules_type = type(self.peft_config[adapters[0]].target_modules) new_target_modules = set() if target_modules_type == list else "" + feedforward_modules_type = type(self.peft_config[adapters[0]].feedforward_modules) + new_feedforward_modules = set() if feedforward_modules_type == list else "" for adapter in adapters: if type(self.peft_config[adapter].target_modules) != target_modules_type: raise ValueError( @@ -396,17 +398,32 @@ def add_weighted_adapter(self, adapters, weights, adapter_name): else: new_target_modules += f"({self.peft_config[adapter].target_modules})|" + if type(self.peft_config[adapter].feedforward_modules) != feedforward_modules_type: + raise ValueError( + "all adapter configs should follow the same feedforward modules type. " + "Combining adapters with `feedforward_modules` type being a mix of list and string is not supported." + ) + if feedforward_modules_type == list: + new_feedforward_modules |= set(self.peft_config[adapter].feedforward_modules) + else: + new_feedforward_modules += f"({self.peft_config[adapter].feedforward_modules})|" + new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1] + new_feedforward_modules = ( + list(new_feedforward_modules) if target_modules_type == list else new_feedforward_modules[:-1] + ) + self.peft_config[adapter_name] = replace( self.peft_config[adapters[0]], target_modules=new_target_modules, + feedforward_modules=new_feedforward_modules, ) self.inject_adapter(self.model, adapter_name) # Do we really need that? _freeze_adapter(self.model, adapter_name) - key_list = [key for key, _ in self.model.named_modules() if "lora" not in key] + key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key] for key in key_list: _, target, _ = _get_submodules(self.model, key) if isinstance(target, IA3Layer): diff --git a/tests/testing_common.py b/tests/testing_common.py index a69ac1f86d..0939c1837f 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -37,6 +37,7 @@ get_peft_model_state_dict, prepare_model_for_int8_training, ) +from peft.tuners.ia3 import IA3Layer from peft.tuners.lora import LoraLayer from peft.utils import _get_submodules, infer_device @@ -810,7 +811,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): model.set_adapter(adapter_to_delete) model = model.to(self.torch_device) - if config.peft_type not in ("LORA"): + if config.peft_type not in ("LORA", "IA3"): with self.assertRaises(AttributeError): model.delete_adapter(adapter_to_delete) else: @@ -831,6 +832,8 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): "lora_dropout", ]: self.assertFalse(adapter_to_delete in getattr(target, attr)) + if isinstance(target, IA3Layer): + self.assertFalse(adapter_to_delete in getattr(target, "ia3_l")) def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) @@ -870,70 +873,101 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw base_model_name_or_path=model_id, **config_kwargs, ) - if not isinstance(config, (LoraConfig)): + if not isinstance(config, (LoraConfig)) or not isinstance(config, (IA3Config)): return model = get_peft_model(model, config, adapter_list[0]) model.add_adapter(adapter_list[1], config) model.add_adapter(adapter_list[2], replace(config, r=20)) model = model.to(self.torch_device) - # test re-weighting single adapter - model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting") + if isinstance(config, (LoraConfig)): + # test re-weighting single adapter + model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting") - # test svd re-weighting with multiple adapters - model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting") + # test svd re-weighting with multiple adapters + model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting") - # test cat re-weighting with multiple adapters - model.add_weighted_adapter( - adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat" - ) - - # test linear re-weighting with multiple adapters - model.add_weighted_adapter( - adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear" - ) + # test cat re-weighting with multiple adapters + model.add_weighted_adapter( + adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat" + ) - with self.assertRaises(ValueError): + # test linear re-weighting with multiple adapters model.add_weighted_adapter( - adapter_list[1:], - weight_list[1:], - "multi_adapter_linear_reweighting_uneven_r", - combination_type="linear", + adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear" ) - new_adapters = [ - "single_adapter_reweighting", - "multi_adapter_svd_reweighting", - "multi_adapter_cat_reweighting", - "multi_adapter_linear_reweighting", - ] - for new_adapter in new_adapters: - self.assertTrue(new_adapter in model.peft_config) - - key_list = [key for key, _ in model.named_modules() if "lora" not in key] - for key in key_list: - _, target, _ = _get_submodules(model, key) - if isinstance(target, LoraLayer): - for adapter_name in new_adapters: - if "single" in adapter_name: + with self.assertRaises(ValueError): + model.add_weighted_adapter( + adapter_list[1:], + weight_list[1:], + "multi_adapter_linear_reweighting_uneven_r", + combination_type="linear", + ) + + new_adapters = [ + "single_adapter_reweighting", + "multi_adapter_svd_reweighting", + "multi_adapter_cat_reweighting", + "multi_adapter_linear_reweighting", + ] + for new_adapter in new_adapters: + self.assertTrue(new_adapter in model.peft_config) + + key_list = [key for key, _ in model.named_modules() if "lora" not in key] + for key in key_list: + _, target, _ = _get_submodules(model, key) + if isinstance(target, LoraLayer): + for adapter_name in new_adapters: + if "single" in adapter_name: + new_delta_weight = target.get_delta_weight(adapter_name) + weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0] + self.assertTrue( + torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4) + ) + elif "svd" in adapter_name: + self.assertTrue(target.r[adapter_name] == 20) + elif "linear" in adapter_name: + self.assertTrue(target.r[adapter_name] == 8) + elif "cat" in adapter_name: + self.assertTrue(target.r[adapter_name] == 28) + + for adapter_name in new_adapters: + # ensuring new adapters pass the forward loop + model.set_adapter(adapter_name) + dummy_input = self.prepare_inputs_for_testing() + model.eval() + _ = model(**dummy_input)[0] + + elif isinstance(config, (IA3Config)): + # single adapter re-weighting and multi adapter linear re-weighting + # Note: IA3 only supports linear re-weighting + model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting") + model.add_weighted_adapter(adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting") + + new_adapters = [ + "single_adapter_reweighting", + "multi_adapter_linear_reweighting", + ] + for new_adapter in new_adapters: + self.assertTrue(new_adapter in model.peft_config) + + key_list = [key for key, _ in model.named_modules() if "ia3" not in key] + for key in key_list: + _, target, _ = _get_submodules(model, key) + if isinstance(target, IA3Layer): + for adapter_name in new_adapters: new_delta_weight = target.get_delta_weight(adapter_name) weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0] self.assertTrue( torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4) ) - elif "svd" in adapter_name: - self.assertTrue(target.r[adapter_name] == 20) - elif "linear" in adapter_name: - self.assertTrue(target.r[adapter_name] == 8) - elif "cat" in adapter_name: - self.assertTrue(target.r[adapter_name] == 28) - - for adapter_name in new_adapters: - # ensuring new adapters pass the forward loop - model.set_adapter(adapter_name) - dummy_input = self.prepare_inputs_for_testing() - model.eval() - _ = model(**dummy_input)[0] + for adapter_name in new_adapters: + # ensuring new adapters pass the forward loop + model.set_adapter(adapter_name) + dummy_input = self.prepare_inputs_for_testing() + model.eval() + _ = model(**dummy_input)[0] def _test_disable_adapter(self, model_id, config_cls, config_kwargs): task_type = config_kwargs.get("task_type") From 7f567a0f4928ed59bf9b83f24a463f6120bd2a31 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Mon, 9 Oct 2023 11:58:16 +0200 Subject: [PATCH 7/8] rebase on main --- src/peft/tuners/ia3/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 349a3af248..659ee043bd 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -29,7 +29,6 @@ ModulesToSaveWrapper, _freeze_adapter, _get_submodules, - _is_valid_match, ) from .layer import Conv2d, IA3Layer, Linear From 43483b7acd41940898a8266a17531cd696b1c5f5 Mon Sep 17 00:00:00 2001 From: Alejandro Rodriguez Salamanca Date: Tue, 10 Oct 2023 14:54:46 +0200 Subject: [PATCH 8/8] Feedback from PR --- src/peft/tuners/ia3/model.py | 72 ++++++++++++++++++------------------ tests/testing_common.py | 28 ++++++-------- 2 files changed, 48 insertions(+), 52 deletions(-) diff --git a/src/peft/tuners/ia3/model.py b/src/peft/tuners/ia3/model.py index 659ee043bd..3738a19b61 100644 --- a/src/peft/tuners/ia3/model.py +++ b/src/peft/tuners/ia3/model.py @@ -12,11 +12,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import operator import re import warnings from dataclasses import asdict, replace from enum import Enum +from functools import reduce import torch from transformers.pytorch_utils import Conv1D @@ -279,13 +280,15 @@ def _prepare_adapter_config(self, peft_config, model_config): if peft_config.target_modules is None: if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING: raise ValueError("Please specify `target_modules` in `peft_config`") - peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]] + peft_config.target_modules = set( + TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]] + ) if peft_config.feedforward_modules is None: if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING: raise ValueError("Please specify `feedforward_modules` in `peft_config`") - peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[ - model_config["model_type"] - ] + peft_config.feedforward_modules = set( + TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_config["model_type"]] + ) return peft_config def merge_and_unload(self, safe_merge: bool = False): @@ -345,7 +348,7 @@ def delete_adapter(self, adapter_name: str): Args: adapter_name (str): Name of the adapter to be deleted. """ - if adapter_name not in list(self.peft_config.keys()): + if adapter_name not in self.peft_config: raise ValueError(f"Adapter {adapter_name} does not exist") del self.peft_config[adapter_name] @@ -364,6 +367,32 @@ def delete_adapter(self, adapter_name: str): ) target.set_adapter(resetting_active_adapter) + def _new_modules(self, adapters, module_type): + """ + Args: + adapters (`list`): + List of adapter names to be merged. + module_type (`str`): + Type of the module to be merged. + """ + module_types = [type(getattr(self.peft_config[adapter], module_type)) for adapter in adapters] + if not module_types: + raise ValueError(f"Found no adapter matching the names in {adapters}") + if len(set(module_types)) > 1: + raise ValueError( + "all adapter configs should follow the same target modules type. " + f"Combining adapters with `{module_type}` type being a mix of list/set and string is not supported." + ) + if module_types[0] == str: + new_modules = "|".join(f"({getattr(self.peft_config[adapter], module_type)})" for adapter in adapters) + elif module_types[0] == set: + new_modules = reduce( + operator.or_, (getattr(self.peft_config[adapter], module_type) for adapter in adapters) + ) + else: + raise TypeError(f"Invalid type {module_types[0]} found in {module_type}") + return new_modules + def add_weighted_adapter(self, adapters, weights, adapter_name): """ This method adds a new adapter by merging the given adapters with the given weights. @@ -382,35 +411,8 @@ def add_weighted_adapter(self, adapters, weights, adapter_name): if adapter not in list(self.peft_config.keys()): raise ValueError(f"Adapter {adapter} does not exist") - target_modules_type = type(self.peft_config[adapters[0]].target_modules) - new_target_modules = set() if target_modules_type == list else "" - feedforward_modules_type = type(self.peft_config[adapters[0]].feedforward_modules) - new_feedforward_modules = set() if feedforward_modules_type == list else "" - for adapter in adapters: - if type(self.peft_config[adapter].target_modules) != target_modules_type: - raise ValueError( - "all adapter configs should follow the same target modules type. " - "Combining adapters with `target_modules` type being a mix of list and string is not supported." - ) - if target_modules_type == list: - new_target_modules |= set(self.peft_config[adapter].target_modules) - else: - new_target_modules += f"({self.peft_config[adapter].target_modules})|" - - if type(self.peft_config[adapter].feedforward_modules) != feedforward_modules_type: - raise ValueError( - "all adapter configs should follow the same feedforward modules type. " - "Combining adapters with `feedforward_modules` type being a mix of list and string is not supported." - ) - if feedforward_modules_type == list: - new_feedforward_modules |= set(self.peft_config[adapter].feedforward_modules) - else: - new_feedforward_modules += f"({self.peft_config[adapter].feedforward_modules})|" - - new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1] - new_feedforward_modules = ( - list(new_feedforward_modules) if target_modules_type == list else new_feedforward_modules[:-1] - ) + new_target_modules = self._new_modules(adapters, "target_modules") + new_feedforward_modules = self._new_modules(adapters, "feedforward_modules") self.peft_config[adapter_name] = replace( self.peft_config[adapters[0]], diff --git a/tests/testing_common.py b/tests/testing_common.py index 0939c1837f..bfe8920704 100644 --- a/tests/testing_common.py +++ b/tests/testing_common.py @@ -832,8 +832,8 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs): "lora_dropout", ]: self.assertFalse(adapter_to_delete in getattr(target, attr)) - if isinstance(target, IA3Layer): - self.assertFalse(adapter_to_delete in getattr(target, "ia3_l")) + elif isinstance(target, IA3Layer): + self.assertFalse(adapter_to_delete in target.ia3_l) def _test_unload_adapter(self, model_id, config_cls, config_kwargs): model = self.transformers_class.from_pretrained(model_id) @@ -873,14 +873,14 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw base_model_name_or_path=model_id, **config_kwargs, ) - if not isinstance(config, (LoraConfig)) or not isinstance(config, (IA3Config)): + if not isinstance(config, (LoraConfig, IA3Config)): return - model = get_peft_model(model, config, adapter_list[0]) - model.add_adapter(adapter_list[1], config) - model.add_adapter(adapter_list[2], replace(config, r=20)) - model = model.to(self.torch_device) if isinstance(config, (LoraConfig)): + model = get_peft_model(model, config, adapter_list[0]) + model.add_adapter(adapter_list[1], config) + model.add_adapter(adapter_list[2], replace(config, r=20)) + model = model.to(self.torch_device) # test re-weighting single adapter model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting") @@ -940,6 +940,10 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw _ = model(**dummy_input)[0] elif isinstance(config, (IA3Config)): + model = get_peft_model(model, config, adapter_list[0]) + model.add_adapter(adapter_list[1], config) + model.add_adapter(adapter_list[2], config) + model = model.to(self.torch_device) # single adapter re-weighting and multi adapter linear re-weighting # Note: IA3 only supports linear re-weighting model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting") @@ -952,16 +956,6 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw for new_adapter in new_adapters: self.assertTrue(new_adapter in model.peft_config) - key_list = [key for key, _ in model.named_modules() if "ia3" not in key] - for key in key_list: - _, target, _ = _get_submodules(model, key) - if isinstance(target, IA3Layer): - for adapter_name in new_adapters: - new_delta_weight = target.get_delta_weight(adapter_name) - weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0] - self.assertTrue( - torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4) - ) for adapter_name in new_adapters: # ensuring new adapters pass the forward loop model.set_adapter(adapter_name)