Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexrs committed Oct 1, 2023
1 parent 7f73e6d commit 583dcb7
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 51 deletions.
1 change: 0 additions & 1 deletion src/peft/tuners/ia3/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def _linear(self, input: torch.Tensor) -> torch.Tensor:

def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = previous_dtype = x.dtype

if self.disable_adapters:
if self.merged:
self.unmerge()
Expand Down
21 changes: 19 additions & 2 deletions src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def delete_adapter(self, adapter_name: str):
raise ValueError(f"Adapter {adapter_name} does not exist")
del self.peft_config[adapter_name]

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
Expand Down Expand Up @@ -354,6 +354,8 @@ def add_weighted_adapter(self, adapters, weights, adapter_name):

target_modules_type = type(self.peft_config[adapters[0]].target_modules)
new_target_modules = set() if target_modules_type == list else ""
feedforward_modules_type = type(self.peft_config[adapters[0]].feedforward_modules)
new_feedforward_modules = set() if feedforward_modules_type == list else ""
for adapter in adapters:
if type(self.peft_config[adapter].target_modules) != target_modules_type:
raise ValueError(
Expand All @@ -365,17 +367,32 @@ def add_weighted_adapter(self, adapters, weights, adapter_name):
else:
new_target_modules += f"({self.peft_config[adapter].target_modules})|"

if type(self.peft_config[adapter].feedforward_modules) != feedforward_modules_type:
raise ValueError(
"all adapter configs should follow the same feedforward modules type. "
"Combining adapters with `feedforward_modules` type being a mix of list and string is not supported."
)
if feedforward_modules_type == list:
new_feedforward_modules |= set(self.peft_config[adapter].feedforward_modules)
else:
new_feedforward_modules += f"({self.peft_config[adapter].feedforward_modules})|"

new_target_modules = list(new_target_modules) if target_modules_type == list else new_target_modules[:-1]
new_feedforward_modules = (
list(new_feedforward_modules) if target_modules_type == list else new_feedforward_modules[:-1]
)

self.peft_config[adapter_name] = replace(
self.peft_config[adapters[0]],
target_modules=new_target_modules,
feedforward_modules=new_feedforward_modules,
)
self.inject_adapter(self.model, adapter_name)

# Do we really need that?
_freeze_adapter(self.model, adapter_name)

key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
key_list = [key for key, _ in self.model.named_modules() if "ia3" not in key]
for key in key_list:
_, target, _ = _get_submodules(self.model, key)
if isinstance(target, IA3Layer):
Expand Down
130 changes: 82 additions & 48 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
get_peft_model_state_dict,
prepare_model_for_int8_training,
)
from peft.tuners.ia3 import IA3Layer
from peft.tuners.lora import LoraLayer
from peft.utils import _get_submodules, infer_device

Expand Down Expand Up @@ -704,7 +705,7 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
model.set_adapter(adapter_to_delete)
model = model.to(self.torch_device)

if config.peft_type not in ("LORA"):
if config.peft_type not in ("LORA", "IA3"):
with self.assertRaises(AttributeError):
model.delete_adapter(adapter_to_delete)
else:
Expand All @@ -725,6 +726,8 @@ def _test_delete_adapter(self, model_id, config_cls, config_kwargs):
"lora_dropout",
]:
self.assertFalse(adapter_to_delete in getattr(target, attr))
if isinstance(target, IA3Layer):
self.assertFalse(adapter_to_delete in getattr(target, "ia3_l"))

def _test_unload_adapter(self, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id)
Expand Down Expand Up @@ -764,70 +767,101 @@ def _test_weighted_combination_of_adapters(self, model_id, config_cls, config_kw
base_model_name_or_path=model_id,
**config_kwargs,
)
if not isinstance(config, (LoraConfig)):
if not isinstance(config, (LoraConfig)) or not isinstance(config, (IA3Config)):
return
model = get_peft_model(model, config, adapter_list[0])
model.add_adapter(adapter_list[1], config)
model.add_adapter(adapter_list[2], replace(config, r=20))
model = model.to(self.torch_device)

# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")
if isinstance(config, (LoraConfig)):
# test re-weighting single adapter
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")

# test svd re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting")
# test svd re-weighting with multiple adapters
model.add_weighted_adapter(adapter_list[1:], weight_list[1:], "multi_adapter_svd_reweighting")

# test cat re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat"
)

# test linear re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear"
)
# test cat re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:], weight_list[1:], "multi_adapter_cat_reweighting", combination_type="cat"
)

with self.assertRaises(ValueError):
# test linear re-weighting with multiple adapters
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_linear_reweighting_uneven_r",
combination_type="linear",
adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting", combination_type="linear"
)

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for adapter_name in new_adapters:
if "single" in adapter_name:
with self.assertRaises(ValueError):
model.add_weighted_adapter(
adapter_list[1:],
weight_list[1:],
"multi_adapter_linear_reweighting_uneven_r",
combination_type="linear",
)

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_svd_reweighting",
"multi_adapter_cat_reweighting",
"multi_adapter_linear_reweighting",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, LoraLayer):
for adapter_name in new_adapters:
if "single" in adapter_name:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
self.assertTrue(
torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
)
elif "svd" in adapter_name:
self.assertTrue(target.r[adapter_name] == 20)
elif "linear" in adapter_name:
self.assertTrue(target.r[adapter_name] == 8)
elif "cat" in adapter_name:
self.assertTrue(target.r[adapter_name] == 28)

for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
_ = model(**dummy_input)[0]

elif isinstance(config, (IA3Config)):
# single adapter re-weighting and multi adapter linear re-weighting
# Note: IA3 only supports linear re-weighting
model.add_weighted_adapter([adapter_list[0]], [weight_list[0]], "single_adapter_reweighting")
model.add_weighted_adapter(adapter_list[:2], weight_list[:2], "multi_adapter_linear_reweighting")

new_adapters = [
"single_adapter_reweighting",
"multi_adapter_linear_reweighting",
]
for new_adapter in new_adapters:
self.assertTrue(new_adapter in model.peft_config)

key_list = [key for key, _ in model.named_modules() if "ia3" not in key]
for key in key_list:
_, target, _ = _get_submodules(model, key)
if isinstance(target, IA3Layer):
for adapter_name in new_adapters:
new_delta_weight = target.get_delta_weight(adapter_name)
weighted_original_delta_weights = target.get_delta_weight(adapter_list[0]) * weight_list[0]
self.assertTrue(
torch.allclose(new_delta_weight, weighted_original_delta_weights, atol=1e-4, rtol=1e-4)
)
elif "svd" in adapter_name:
self.assertTrue(target.r[adapter_name] == 20)
elif "linear" in adapter_name:
self.assertTrue(target.r[adapter_name] == 8)
elif "cat" in adapter_name:
self.assertTrue(target.r[adapter_name] == 28)

for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
_ = model(**dummy_input)[0]
for adapter_name in new_adapters:
# ensuring new adapters pass the forward loop
model.set_adapter(adapter_name)
dummy_input = self.prepare_inputs_for_testing()
model.eval()
_ = model(**dummy_input)[0]

def _test_disable_adapter(self, model_id, config_cls, config_kwargs):
task_type = config_kwargs.get("task_type")
Expand Down

0 comments on commit 583dcb7

Please sign in to comment.