diff --git a/src/peft/tuners/lora.py b/src/peft/tuners/lora.py index 2440deadd5..6ddbc4716d 100644 --- a/src/peft/tuners/lora.py +++ b/src/peft/tuners/lora.py @@ -54,7 +54,9 @@ class LoraConfig(PeftConfig): lora_dropout (`float`): The dropout probability for Lora layers. fan_in_fan_out (`bool`): Set this to True if the layer to replace stores weight like (fan_in, fan_out). For example, gpt-2 uses `Conv1D` which stores weights like (fan_in, fan_out) and hence this should be set to `True`.: - bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only' + bias (`str`): Bias type for Lora. Can be 'none', 'all' or 'lora_only'. If 'all' or 'lora_only', the + corresponding biases will be updated during training. Be aware that this means that, even when disabling + the adapters, the model will not produce the same output as the base model would have without adaptation. modules_to_save (`List[str]`):List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. layers_to_transform (`Union[List[int],int]`): @@ -400,7 +402,27 @@ def _set_adapter_layers(self, enabled=True): def enable_adapter_layers(self): self._set_adapter_layers(enabled=True) + def _get_active_adapter(self) -> str: + active_adapter = None + for module in self.model.modules(): + if isinstance(module, LoraLayer): + active_adapter = module.active_adapter + + if active_adapter is None: + raise ValueError( + "Something went wrong, no active adapter could be found, please report the issue on GitHub" + ) + return active_adapter + def disable_adapter_layers(self): + active_adapter = self._get_active_adapter() + val = self.peft_config[active_adapter].bias + if val != "none": + msg = ( + f"Careful, disabling adapter layers with bias configured to be '{val}' does not produce the same " + "output as the the base model would without adaption." + ) + warnings.warn(msg) self._set_adapter_layers(enabled=False) def set_adapter(self, adapter_name): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index 8de1d6c80c..ac78286e89 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -264,3 +264,47 @@ def test_disable_adapters(self, test_name, model_id, config_cls, config_kwargs): self.assertFalse(torch.allclose(outputs_before, outputs_after)) self.assertTrue(torch.allclose(outputs_before, outputs_disabled)) + + @parameterized.expand(TEST_CASES) + def test_disable_adapter_with_bias_warns(self, test_name, model_id, config_cls, config_kwargs): + # When training biases in lora, disabling adapters does not reset the biases, so the output is not what users + # might expect. Therefore, a warning should be given. + + # Note: We test only with custom models since they run really fast. There is really no point in testing the same + # thing with decoder, encoder_decoder, etc. + + def run_with_disable(config_kwargs, bias): + config_kwargs = config_kwargs.copy() + config_kwargs["bias"] = bias + model = self.transformers_class.from_pretrained(model_id).to(self.torch_device) + config = config_cls( + base_model_name_or_path=model_id, + **config_kwargs, + ) + peft_model = get_peft_model(model, config) + with peft_model.disable_adapter(): + pass # there is nothing to be done + + # check that bias=all and bias=lora_only give a warning with the correct message + msg_start = "Careful, disabling adapter layers with bias configured to be" + with self.assertWarns(UserWarning, msg=msg_start): + run_with_disable(config_kwargs, bias="lora_only") + with self.assertWarns(UserWarning, msg=msg_start): + run_with_disable(config_kwargs, bias="all") + + # For bias=none, there is no warning. Unfortunately, AFAIK unittest has no option to assert that no warning is + # given, therefore, we check that the unittest gives us an AssertionError if we check for a warning + bias_warning_was_given = False + try: + with self.assertWarns(UserWarning) as cm: + run_with_disable(config_kwargs, bias="none") + # if we get here, it means there was no AssertionError, i.e. there are warnings -- let's check that they + # are not related to the bias setting + if any(warning.message.args[0].startswith(msg_start) for warning in cm.warnings): + bias_warning_was_given = True + except AssertionError: + # This is good, there was an AssertionError, i.e. there was no warning + pass + if bias_warning_was_given: + # This is bad, there was a warning about the bias when there should not have been any. + self.fail("There should be no warning when bias is set to 'none'")