Skip to content

Commit

Permalink
FIX Bug with handling of active adapters (#1659)
Browse files Browse the repository at this point in the history
There was a bug for some models like IA3, LoHa, etc., where calling
set_adapter would not correctly update the active_adapter. This is now
fixed.

Note that this is not about the active_adapter attribute on PeftModel or
layers, which are handled separately.

This PR also ensures that LoraModel, IA3Model, etc. consistently use
self.active_adapters, not self.active_adapter. The latter should be
treated more like a private attribute (but this isn't changed for
backwards compatibility).
  • Loading branch information
BenjaminBossan authored Apr 17, 2024
1 parent 56773b9 commit ed865e2
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _create_and_replace(
# If it is not an AdaLoraLayer, create a new module, else update it with new adapters
if not isinstance(target, AdaLoraLayer):
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/boft/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _create_and_replace(
# If it is not a BOFTLayer, create a new module, else update it with new adapters
if not isinstance(target, BOFTLayer):
new_module = self._create_new_module(boft_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
Expand Down Expand Up @@ -244,6 +244,7 @@ def set_adapter(self, adapter_name):
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

@staticmethod
def _prepare_adapter_config(peft_config, model_config):
Expand Down
3 changes: 2 additions & 1 deletion src/peft/tuners/ia3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _create_and_replace(
)
else:
new_module = self._create_new_module(ia3_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
Expand Down Expand Up @@ -277,6 +277,7 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def _prepare_adapter_config(self, peft_config, model_config):
if peft_config.target_modules is None:
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def _create_and_replace(
)
else:
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
if adapter_name != self.active_adapter:
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/lycoris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ def set_adapter(self, adapter_name: str | list[str]) -> None:
warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
module.unmerge()
module.set_adapter(adapter_name)
self.active_adapter = adapter_name

def delete_adapter(self, adapter_name: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/poly/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _create_and_replace(
adapter_name,
target,
)
if adapter_name != self.active_adapter:
if adapter_name not in self.active_adapters:
# adding an additional adapter: it is not automatically trainable
new_module.requires_grad_(False)
self._replace_module(parent, target_name, new_module, target)
Expand Down
4 changes: 2 additions & 2 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]],
# user is adding a dict of PeftConfigs
self.peft_config.update(peft_config)

self.active_adapter = adapter_name
self.active_adapter: str | list[str] = adapter_name
self.inject_adapter(self.model, adapter_name)

# Copy the peft_config in the injected model.
Expand Down Expand Up @@ -477,7 +477,7 @@ def disable_adapters(self) -> bool:
return self._disable_adapters

@property
def active_adapter(self) -> str:
def active_adapter(self) -> str | list[str]:
# use a property to ensure that active_adapter is not set directly, instead use the set_adapter method
return self._active_adapter

Expand Down
30 changes: 30 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,36 @@ def run_with_disable(config_kwargs, bias):
# This is bad, there was a warning about the bias when there should not have been any.
self.fail("There should be no warning when bias is set to 'none'")

@parameterized.expand(TEST_CASES)
def test_active_adapter(self, test_name, model_id, config_cls, config_kwargs):
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
model = get_peft_model(model, config)
assert model.active_adapters == ["default"]
assert model.active_adapter == "default"

# at this stage, "default" is still the activate adapter, "other" is disabled
model.add_adapter("other", config)
assert model.active_adapters == ["default"]
assert model.active_adapter == "default"

# set "other" as the active adapter
model.set_adapter("other")
assert model.active_adapters == ["other"]
assert model.active_adapter == "other"

# set both adapters as active
# Note: On the PeftModel, there cannot be multiple active adapters, so we have to go through model.base_model
# instead.
model.base_model.set_adapter(["default", "other"])
# model.active_adapters works, as it delegates to the base_model
assert model.active_adapters == ["default", "other"]
# model.active_adapter would not work, thus we have to check the base_model directly
assert model.base_model.active_adapter == ["default", "other"]

@parameterized.expand(TEST_CASES)
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)
Expand Down

0 comments on commit ed865e2

Please sign in to comment.