Skip to content

Commit

Permalink
Add adapter error handling (#800)
Browse files Browse the repository at this point in the history
When a user tries to add a 2nd adapter, Lora and AdaLora make some checks to
ensure the new adapter is compatible with existing adapters. Currently, that
check is performed halfway through the method. This means that if the check
fails, the new adapter is partially applied, leaving the model in a bad state.
The main purpose of this PR is to ensure that the model state is correct after
such a failure is encountered.

Tests were added to catch this potential bug.

While working on this, I also did some related, but not strictly necessary
changes to the add_adapter methods:

- Previously, the peft_config from the PeftModel was passed to the base
  model. This meant that sometimes, the base model would hold a reference
  to PeftModel.peft_config, but not always, as some base models would
  create new dicts. This is problematic, because some code would rely on
  the objects being the same. Now, they are never the same, leading to
  more consistency.
- I think that the check if multiple adapters have biases (which is not
  supported) was accidentally removed by #749. It is added back in.
- Add some type annotations
- Extend docstrings to contain adapter_name
  • Loading branch information
BenjaminBossan authored Aug 8, 2023
1 parent ed396a6 commit aac7722
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 20 deletions.
30 changes: 18 additions & 12 deletions src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class PeftModel(PushToHubMixin, torch.nn.Module):
Args:
model ([`~transformers.PreTrainedModel`]): The base transformer model used for Peft.
peft_config ([`PeftConfig`]): The configuration of the Peft model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
**Attributes**:
- **base_model** ([`~transformers.PreTrainedModel`]) -- The base transformer model used for Peft.
Expand Down Expand Up @@ -478,19 +478,25 @@ def add_adapter(self, adapter_name: str, peft_config: PeftConfig):
f"Cannot combine adapters with different peft types. "
f"Found {self.peft_type} and {peft_config.peft_type}."
)

self.peft_config[adapter_name] = peft_config
if peft_config.is_prompt_learning:
if hasattr(self.config, "to_dict"):
dict_config = self.config.to_dict()
else:
dict_config = self.config

peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
self._setup_prompt_encoder(adapter_name)
elif peft_config.is_adaption_prompt:
self.base_model.add_adapter(adapter_name, peft_config)
else:
self.inject_adapter(self, adapter_name)
try:
if peft_config.is_prompt_learning:
if hasattr(self.config, "to_dict"):
dict_config = self.config.to_dict()
else:
dict_config = self.config

peft_config = _prepare_prompt_learning_config(peft_config, dict_config)
self._setup_prompt_encoder(adapter_name)
elif peft_config.is_adaption_prompt:
self.base_model.add_adapter(adapter_name, peft_config)
else:
self.base_model.inject_adapter(self, adapter_name)
except Exception: # somthing went wrong, roll back
del self.peft_config[adapter_name]
raise

self.set_additional_trainable_modules(peft_config, adapter_name)

Expand Down
26 changes: 22 additions & 4 deletions src/peft/tuners/adalora.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class AdaLoraModel(LoraModel):
Args:
model ([`transformers.PreTrainedModel`]): The model to be adapted.
config ([`AdaLoraConfig`]): The configuration of the AdaLora model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The AdaLora model.
Expand All @@ -132,10 +133,6 @@ class AdaLoraModel(LoraModel):
def __init__(self, model, config, adapter_name):
super().__init__(model, config, adapter_name)

if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
raise ValueError(
"AdaLoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
)
traininable_mode_counter = 0
for config in self.peft_config.values():
if not config.inference_mode:
Expand All @@ -153,6 +150,27 @@ def __init__(self, model, config, adapter_name):
self.trainable_adapter_name = adapter_name
self.rankallocator = RankAllocator(self.model, self.peft_config[adapter_name], self.trainable_adapter_name)

def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
super()._check_new_adapter_config(config)

traininable_mode_counter = 0
for config_ in self.peft_config.values():
if not config_.inference_mode:
traininable_mode_counter += 1

if traininable_mode_counter > 1:
raise ValueError(
f"{self.__class__.__name__} supports only 1 trainable adapter. "
"When using multiple adapters, set inference_mode to True for all adapters except the one "
"you want to train."
)

def _create_and_replace(
self,
lora_config,
Expand Down
1 change: 1 addition & 0 deletions src/peft/tuners/ia3.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class IA3Model(BaseTuner):
Args:
model ([`~transformers.PreTrainedModel`]): The model to be adapted.
config ([`IA3Config`]): The configuration of the (IA)^3 model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The (IA)^3 model.
Expand Down
18 changes: 17 additions & 1 deletion src/peft/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class LoraModel(BaseTuner):
Args:
model ([`~transformers.PreTrainedModel`]): The model to be adapted.
config ([`LoraConfig`]): The configuration of the Lora model.
adapter_name (`str`): The name of the adapter, defaults to `"default"`.
Returns:
`torch.nn.Module`: The Lora model.
Expand Down Expand Up @@ -268,9 +269,24 @@ class LoraModel(BaseTuner):
- **peft_config** ([`LoraConfig`]): The configuration of the Lora model.
"""

def __init__(self, model, config, adapter_name):
def __init__(self, model, config, adapter_name) -> None:
super().__init__(model, config, adapter_name)

def _check_new_adapter_config(self, config: LoraConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
# TODO: there should be a check if any of the existing adapters actually has bias != "none", or else the check
# does not fully correspond to the error message.
if (len(self.peft_config) > 1) and (config.bias != "none"):
raise ValueError(
f"{self.__class__.__name__} supports only 1 adapter with bias. When using multiple adapters, "
"set bias to 'none' for all adapters."
)

@staticmethod
def _check_target_module_exists(lora_config, key):
if isinstance(lora_config.target_modules, str):
Expand Down
25 changes: 22 additions & 3 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Union

from torch import nn

Expand Down Expand Up @@ -59,7 +61,7 @@ class BaseTuner(nn.Module, ABC):
The model configuration object, it should be a dictionary of `str` to `Any` objects.
"""

def __init__(self, model, peft_config, adapter_name):
def __init__(self, model, peft_config: Union[PeftConfig, dict[str, PeftConfig]], adapter_name: str) -> None:
super().__init__()

self.model = model
Expand All @@ -74,7 +76,11 @@ def __init__(self, model, peft_config, adapter_name):
"Already found a `peft_config` attribute in the model. This will lead to having multiple adapters"
" in the model. Make sure to know what you are doing!"
)
self.peft_config[adapter_name] = peft_config
if isinstance(peft_config, PeftConfig):
self.peft_config[adapter_name] = peft_config
else:
# user is adding a dict of PeftConfigs
self.peft_config.update(peft_config)

# transformers models have a .config attribute, whose presence is assumed later on
if not hasattr(self, "config"):
Expand Down Expand Up @@ -159,6 +165,15 @@ def _mark_only_adapters_as_trainable(self):
"""
...

def _check_new_adapter_config(self, config: PeftConfig) -> None:
"""
A helper method to check the config when a new adapter is being added.
Raise a ValueError if there is something wrong with the config or if it conflicts with existing adapters.
"""
pass

def inject_adapter(self, model: nn.Module, adapter_name: str):
r"""
Creates adapter layers and replaces the target modules with the adapter layers. This method is called under the
Expand All @@ -173,6 +188,10 @@ def inject_adapter(self, model: nn.Module, adapter_name: str):
The adapter name.
"""
peft_config = self.peft_config[adapter_name]
# Note: If possible, all checks should be performed *at the start of this method*.
# This way, we can raise early if something goes wrong, without leaving the model
# in a bad (half-initialized) state.
self._check_new_adapter_config(peft_config)

is_target_modules_in_base_model = False
key_list = [key for key, _ in model.named_modules()]
Expand Down
4 changes: 4 additions & 0 deletions tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,3 +324,7 @@ def run_with_disable(config_kwargs, bias):
if bias_warning_was_given:
# This is bad, there was a warning about the bias when there should not have been any.
self.fail("There should be no warning when bias is set to 'none'")

@parameterized.expand(TEST_CASES)
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)
4 changes: 4 additions & 0 deletions tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
Expand Down
4 changes: 4 additions & 0 deletions tests/test_encoder_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ def test_peft_model_device_map(self, test_name, model_id, config_cls, config_kwa
def test_delete_adapter(self, test_name, model_id, config_cls, config_kwargs):
self._test_delete_adapter(model_id, config_cls, config_kwargs)

@parameterized.expand(PeftTestConfigManager.get_grid_parameters(FULL_GRID))
def test_adding_multiple_adapters_with_bias_raises(self, test_name, model_id, config_cls, config_kwargs):
self._test_adding_multiple_adapters_with_bias_raises(model_id, config_cls, config_kwargs)

@parameterized.expand(
PeftTestConfigManager.get_grid_parameters(
{
Expand Down
22 changes: 22 additions & 0 deletions tests/testing_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,28 @@ def get_output(model):

# TODO: add tests to check if disabling adapters works after calling merge_adapter

def _test_adding_multiple_adapters_with_bias_raises(self, model_id, config_cls, config_kwargs):
# When trying to add multiple adapters with bias in Lora or AdaLora, an error should be
# raised. Also, the peft model should not be left in a half-initialized state.
if not issubclass(config_cls, (LoraConfig, AdaLoraConfig)):
return

config_kwargs = config_kwargs.copy()
config_kwargs["bias"] = "all"
config = config_cls(
base_model_name_or_path=model_id,
**config_kwargs,
)

model = self.transformers_class.from_pretrained(model_id)
model = get_peft_model(model, config, "adapter0")
with self.assertRaises(ValueError):
model.add_adapter("adapter1", replace(config, r=20))

# (superficial) test that the model is not left in a half-initialized state when adding an adapter fails
self.assertFalse("adapter1" in model.peft_config)
self.assertFalse("adapter1" in model.base_model.peft_config)

def _test_passing_input_embeds_works(self, test_name, model_id, config_cls, config_kwargs):
# https://github.com/huggingface/peft/issues/727
model = self.transformers_class.from_pretrained(model_id)
Expand Down

0 comments on commit aac7722

Please sign in to comment.