Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Warn when disabling adapters and bias != 'none' #741

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ class LoraConfig(PeftConfig):
lora_dropout (`float`): The dropout probability for Lora layers.
fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out).
For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.:
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'
bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the
corresponding biases will be updated during training. Be aware that this means that, even when disabling the
adapters, the model will not produce the same output as the base model would have without adaptation.
modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable
and saved in the final checkpoint.
layers_to_transform (`Union[List[int],int]`):
Expand Down Expand Up @@ -400,7 +402,27 @@ def _set_adapter_layers(self, enabled=True):
def enable_adapter_layers(self):
self._set_adapter_layers(enabled=True)

def _get_active_adapter(self) -> str:
active_adapter = None
for module in self.model.modules():
if isinstance(module, LoraLayer):
active_adapter = module.active_adapter

if active_adapter is None:
raise ValueError(
"Something went wrong, no active adapter could be found, please report the issue on GitHub"
)
return active_adapter

def disable_adapter_layers(self):
active_adapter = self._get_active_adapter()
val = self.peft_config[active_adapter].bias
if val != "none":
msg = (
f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same "
"output as the the base model would without adaption."
)
warnings.warn(msg)
self._set_adapter_layers(enabled=False)

def set_adapter(self, adapter_name):
Expand Down
44 changes: 44 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,47 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs):

self.assertFalse(torch.allclose(outputs_before, outputs_after))
self.assertTrue(torch.allclose(outputs_before, outputs_disabled))

@parameterized.expand(TEST_CASES)
def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, config_kwargs):
# When training biases in lora, disabling adapters does not reset the biases, so the output is not what users
# might expect. Therefore, a warning should be given.

# Note: We test only with custom models since they run really fast. There is really no point in testing the same
# thing with decoder, encoder_decoder, etc.

def run_with_disable(config_kwargs, bias):
config_kwargs = config_kwargs.copy()
config_kwargs["bias"] = bias
model = self.transformers_class.from_pretrained(model_id).to(self.torch_device)
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)
peft_model = get_peft_model(model, config)
with peft_model.disable_adapter():
pass # there is nothing to be done

# check that bias=all and bias=lora_only give a warning with the correct message
msg_start = "Careful, disabling adapter layers with bias configured to be"
with self.assertWarns(UserWarning, msg=msg_start):
run_with_disable(config_kwargs, bias="lora_only")
with self.assertWarns(UserWarning, msg=msg_start):
run_with_disable(config_kwargs, bias="all")

# For bias=none, there is no warning. Unfortunately, AFAIK unittest has no option to assert that no warning is
# given, therefore, we check that the unittest gives us an AssertionError if we check for a warning
bias_warning_was_given = False
try:
with self.assertWarns(UserWarning) as cm:
run_with_disable(config_kwargs, bias="none")
# if we get here, it means there was no AssertionError, i.e. there are warnings -- let's check that they
# are not related to the bias setting
if any(warning.message.args[0].startswith(msg_start) for warning in cm.warnings):
bias_warning_was_given = True
except AssertionError:
# This is good, there was an AssertionError, i.e. there was no warning
pass
if bias_warning_was_given:
# This is bad, there was a warning about the bias when there should not have been any.
self.fail("There should be no warning when bias is set to 'none'")