Skip to content

Commit

Permalink
Feedback from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrs committed Oct 10, 2023
1 parent 7f567a0 commit 43483b7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 52 deletions.
72 changes: 37 additions & 35 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import operator
import re
import warnings
from dataclasses import asdict, replace
from enum import Enum
from functools import reduce

import torch
from transformers.pytorch_utils import Conv1D
Expand Down Expand Up @@ -279,13 +280,15 @@ def _prepare_adapter_config(self, peft_config, model_config):
if peft_config.target_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING:
raise ValueError("Please specify `target_modules` in `peft_config`")
peft_config.target_modules = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
peft_config.target_modules = set(
TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_config["model_type"]]
)
if peft_config.feedforward_modules is None:
if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING:
raise ValueError("Please specify `feedforward_modules` in `peft_config`")
peft_config.feedforward_modules = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[
model_config["model_type"]
]
peft_config.feedforward_modules = set(
TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_config["model_type"]]
)
return peft_config

def merge_and_unload(self, safe_merge: bool = False):
Expand Down Expand Up @@ -345,7 +348,7 @@ def delete_adapter(self, adapter_name: str):
Args:
adapter_name (str): Name of the adapter to be deleted.
"""
if adapter_name not in list(self.peft_config.keys()):
if adapter_name not in self.peft_config:
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]

Expand All @@ -364,6 +367,32 @@ def delete_adapter(self, adapter_name: str):
)
target.set_adapter(resetting_active_adapter)

def _new_modules(self, adapters, module_type):
"""
Args:
adapters (`list`):
List of adapter names to be merged.
module_type (`str`):
Type of the module to be merged.
"""
module_types = [type(getattr(self.peft_config[adapter], module_type)) for adapter in adapters]
if not module_types:
raise ValueError(f"Found no adapter matching the names in {adapters}")
if len(set(module_types)) > 1:
raise ValueError(
"all adapter configs should follow the same target modules type. "
f"Combining adapters with `{module_type}` type being a mix of list/set and string is not supported."
)
if module_types[0] == str:
new_modules = "|".join(f"({getattr(self.peft_config[adapter], module_type)})" for adapter in adapters)
elif module_types[0] == set:
new_modules = reduce(
operator.or_, (getattr(self.peft_config[adapter], module_type) for adapter in adapters)
)
else:
raise TypeError(f"Invalid type {module_types[0]} found in {module_type}")
return new_modules

def add_weighted_adapter(self, adapters, weights, adapter_name):
"""
This method adds a new adapter by merging the given adapters with the given weights.
Expand All @@ -382,35 +411,8 @@ def add_weighted_adapter(self, adapters, weights, adapter_name):
if adapter not in list(self.peft_config.keys()):
raise ValueError(f"Adapter {adapter} does not exist")

target_modules_type = type(self.peft_config[adapters[0]].target_modules)
new_target_modules = set() if target_modules_type == list else ""
feedforward_modules_type = type(self.peft_config[adapters[0]].feedforward_modules)
new_feedforward_modules = set() if feedforward_modules_type == list else ""
for adapter in adapters:
if type(self.peft_config[adapter].target_modules) != target_modules_type:
raise ValueError(
"all adapter configs should follow the same target modules type. "
"Combining adapters with `target_modules` type being a mix of list and string is not supported."
)
if target_modules_type == list:
new_target_modules |= set(self.peft_config[adapter].target_modules)
else:
new_target_modules += f"({self.peft_config[adapter].target_modules})|"

if type(self.peft_config[adapter].feedforward_modules) != feedforward_modules_type:
raise ValueError(
"all adapter configs should follow the same feedforward modules type. "
"Combining adapters with `feedforward_modules` type being a mix of list and string is not supported."
)
if feedforward_modules_type == list:
new_feedforward_modules |= set(self.peft_config[adapter].feedforward_modules)
else:
new_feedforward_modules += f"({self.peft_config[adapter].feedforward_modules})|"

new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1]
new_feedforward_modules = (
list(new_feedforward_modules) if target_modules_type == list else new_feedforward_modules[:-1]
)
new_target_modules = self._new_modules(adapters, "target_modules")
new_feedforward_modules = self._new_modules(adapters, "feedforward_modules")

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
Expand Down
28 changes: 11 additions & 17 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,8 +832,8 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
"lora_dropout",
]:
self.assertFalse(adapter_to_delete in getattr(target, attr))
if isinstance(target, IA3Layer):
self.assertFalse(adapter_to_delete in getattr(target, "ia3_l"))
elif isinstance(target, IA3Layer):
self.assertFalse(adapter_to_delete in target.ia3_l)

def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
Expand Down Expand Up @@ -873,14 +873,14 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
base_model_name_or_path=model_id,
**config_kwargs,
)
if not isinstance(config, (LoraConfig)) or not isinstance(config, (IA3Config)):
if not isinstance(config, (LoraConfig, IA3Config)):
return
model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)

if isinstance(config, (LoraConfig)):
model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)
# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")

Expand Down Expand Up @@ -940,6 +940,10 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
_ = model(**dummy_input)[0]

elif isinstance(config, (IA3Config)):
model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], config)
model = model.to(self.torch_device)
# single adapter re-weighting and multi adapter linear re-weighting
# Note: IA3 only supports linear re-weighting
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")
Expand All @@ -952,16 +956,6 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "ia3" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, IA3Layer):
for adapter_name in new_adapters:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
self.assertTrue(
torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
)
for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
Expand Down

0 comments on commit 43483b7

Please sign in to comment.