Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with unloading double wrapped modules #1490

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,13 @@ def _unload_and_optionally_merge(
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)

return self.model

Expand Down
8 changes: 7 additions & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,13 @@ def _unload_and_optionally_merge(
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)

return self.model

Expand Down
8 changes: 7 additions & 1 deletion src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,13 @@ def _unload_and_optionally_merge(
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)

return self.model

Expand Down
8 changes: 7 additions & 1 deletion src/peft/tuners/mixed/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,13 @@ def merge_recursively(module):
self._replace_module(parent, target_name, target.get_base_layer(), target)
elif isinstance(target, ModulesToSaveWrapper):
# save any additional trainable modules part of `modules_to_save`
setattr(parent, target_name, target.modules_to_save[target.active_adapter])
new_module = target.modules_to_save[target.active_adapter]
if hasattr(new_module, "base_layer"):
# check if the module is itself a tuner layer
if merge:
new_module.merge(safe_merge=safe_merge, adapter_names=adapter_names)
new_module = new_module.get_base_layer()
setattr(parent, target_name, new_module)

return self.model

Expand Down
24 changes: 24 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from peft import AdaLoraConfig, IA3Config, LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, PeftModel, get_peft_model
from peft.tuners.tuners_utils import BaseTunerLayer
from peft.utils import ModulesToSaveWrapper

from .testing_common import PeftCommonTester
from .testing_utils import get_state_dict
Expand Down Expand Up @@ -923,6 +924,29 @@ def test_adapter_name_makes_no_difference(self, config0):
assert torch.allclose(output_custom1, output_custom2)
assert torch.allclose(output_default, output_custom1)

@parameterized.expand(["merge_and_unload", "unload"])
def test_double_wrapping_merge_and_unload(self, method):
# see issue #1485
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained("hf-internal-testing/tiny-random-RobertaModel")
config = LoraConfig(task_type="TOKEN_CLS", target_modules="all-linear")
model = get_peft_model(model, config)

# first check that double-wrapping happened
# Note: this may get fixed in a future PR, in which case this test can be removed
assert isinstance(model.base_model.model.classifier, ModulesToSaveWrapper)
assert hasattr(model.base_model.model.classifier.original_module, "lora_A")
assert hasattr(model.base_model.model.classifier.modules_to_save.default, "lora_A")

# after unloading, despite double wrapping, the classifier module should be a normal nn.Linear layer
if method == "merge_and_unload":
unloaded = model.merge_and_unload()
else:
unloaded = model.unload()

assert isinstance(unloaded.classifier, nn.Linear)


class TestMultiRankAdapter(unittest.TestCase):
"""Tests related to multirank LoRA adapters"""
Expand Down
Loading