Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Disabling adapter works with modules_to_save #736

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/peft/tuners/ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, IA3Layer):
module.disable_adapters = False if enabled else True
elif isinstance(module, ModulesToSaveWrapper):
module.disable_adapters = False if enabled else True

def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
Expand Down
2 changes: 2 additions & 0 deletions src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,8 @@ def _set_adapter_layers(self, enabled=True):
for module in self.model.modules():
if isinstance(module, LoraLayer):
module.disable_adapters = False if enabled else True
elif isinstance(module, ModulesToSaveWrapper):
module.disable_adapters = False if enabled else True

def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)
Expand Down
3 changes: 2 additions & 1 deletion src/peft/utils/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,13 @@ def __init__(self, module_to_save, adapter_name):
self.modules_to_save = torch.nn.ModuleDict({})
self.update(adapter_name)
self.active_adapter = adapter_name
self.disable_adapters = False

def update(self, adapter_name):
self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))

def forward(self, *args, **kwargs):
if self.active_adapter not in self.modules_to_save:
if self.disable_adapters or (self.active_adapter not in self.modules_to_save):
return self.original_module(*args, **kwargs)
return self.modules_to_save[self.active_adapter](*args, **kwargs)

Expand Down
18 changes: 17 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,17 @@
("Vanilla MLP 2", "MLP", LoraConfig, {"target_modules": ["lin0"]}),
("Vanilla MLP 3", "MLP", LoraConfig, {"target_modules": ["lin1"]}),
("Vanilla MLP 4", "MLP", LoraConfig, {"target_modules": ["lin0", "lin1"]}),
("Vanilla MLP 5", "MLP", LoraConfig, {"target_modules": ["lin0"], "modules_to_save": ["lin1"]}),
(
"Vanilla MLP 6",
"MLP",
LoraConfig,
{
"target_modules": ["lin0"],
"lora_alpha": 4,
"lora_dropout": 0.1,
},
),
("Embedding + transformers Conv1D 1", "EmbConv1D", LoraConfig, {"target_modules": ["conv1d"]}),
("Embedding + transformers Conv1D 2", "EmbConv1D", LoraConfig, {"target_modules": ["emb"]}),
("Embedding + transformers Conv1D 3", "EmbConv1D", LoraConfig, {"target_modules": ["emb", "conv1d"]}),
Expand Down Expand Up @@ -227,7 +238,8 @@ def test_only_params_are_updated(self, test_name, model_id, config_cls, config_k
self.assertEqual(params_before.keys(), params_after.keys())
for name, param_before in params_before.items():
param_after = params_after[name]
if "lora_" in name:
if ("lora_" in name) or ("modules_to_save" in name):
# target_modules and modules_to_save _are_ updated
self.assertFalse(torch.allclose(param_before, param_after, atol=tol, rtol=tol))
else:
self.assertTrue(torch.allclose(param_before, param_after, atol=tol, rtol=tol))
Expand Down Expand Up @@ -262,8 +274,12 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):
with model.disable_adapter():
outputs_disabled = model(**X)

# check that after leaving the disable_adapter context, everything is enabled again
outputs_enabled_after_disable = model(**X)

self.assertFalse(torch.allclose(outputs_before, outputs_after))
self.assertTrue(torch.allclose(outputs_before, outputs_disabled))
self.assertTrue(torch.allclose(outputs_after, outputs_enabled_after_disable))

@parameterized.expand(TEST_CASES)
def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, config_kwargs):
Expand Down
2 changes: 1 addition & 1 deletion tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _test_training(self, model_id, config_cls, config_kwargs):
loss.backward()
parameter_prefix = "ia3" if config_cls == IA3Config else "lora"
for n, param in model.named_parameters():
if parameter_prefix in n:
if (parameter_prefix in n) or ("modules_to_save" in n):
self.assertIsNotNone(param.grad)
else:
self.assertIsNone(param.grad)
Expand Down