Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add delete_adapter(), delete_adapter_fusion() and delete_head() methods #189

Merged
merged 4 commits into from
Jun 29, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions adapter_docs/prediction_heads.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ model.save_head("/path/to/dir", "mrpc")
model.load_head("/path/to/dir")
```

Lastly, it's also possible to delete an added head again:

```python
model.delete_head("mrpc")
```

## HuggingFace heads

The `transformers` library provides strongly typed model classes with heads for various different tasks (e.g. `RobertaForSequenceClassification`, `AutoModelForMultipleChoice` ...).
Expand Down
13 changes: 11 additions & 2 deletions adapter_docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
## Introduction

Currently, *adapter-transformers* adds adapter components to the PyTorch implementations of all transformer models listed in the *Supported Models* section.
For working with adapters, a couple of methods for creation (e.g. `add_adapter()`), loading (e.g. `load_adapter()`) and
storing (e.g. `save_adapter()`) are added to the model classes. In the following, we will briefly go through some examples.
For working with adapters, a couple of methods for creation (`add_adapter()`), loading (`load_adapter()`),
storing (`save_adapter()`) and deletion (`delete_adapter()`) are added to the model classes. In the following, we will briefly go through some examples.

```eval_rst
.. note::
Expand Down Expand Up @@ -80,6 +80,15 @@ model.load_adapter('./path/to/adapter/directory/')

Similar to how the weights of the full model are saved, the `save_adapter()` will create a file for saving the adapter weights and a file for saving the adapter configuration in the specified directory.

Finally, if we have finished working with adapters, we can restore the base Transformer in its original form by deactivating and deleting the adapter:

```python
# deactivate all adapters
model.set_active_adapters(None)
# delete the added adapter
model.delete_adapter('sst-2')
```

## Quick Tour: Adapter training

_We also have a Quickstart Colab notebook for adapter training:_ [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Adapter-Hub/adapter-transformers/blob/master/notebooks/01_Adapter_Training.ipynb)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections.abc import Collection, Mapping
from dataclasses import FrozenInstanceError, asdict, dataclass, field, is_dataclass, replace
from typing import List, Optional, Sequence, Union
from typing import List, Optional, Union

from .composition import AdapterCompositionBlock
from .utils import get_adapter_config_hash, resolve_adapter_config
Expand Down Expand Up @@ -195,7 +195,7 @@ def __init__(self, **kwargs):
adapters_list = dict(
map(lambda t: (t[0], t[1][1] or t[1][0] if isinstance(t[1], tuple) else t[1]), adapters_list.items())
)
self.adapters: Sequence[str] = adapters_list
self.adapters: Mapping[str] = adapters_list
self.config_map = kwargs.pop("config_map", {})
# TODO-V2 Save this with config?
self.active_setup: Optional[AdapterCompositionBlock] = None
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/adapters/heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,21 @@ def add_prediction_head(
f"Model already contains a head with name '{head.name}'. Use overwrite_ok=True to force overwrite."
)

def delete_head(self, head_name: str):
"""
Deletes the prediction head with the specified name from the model.

Args:
head_name (str): The name of the prediction to delete.
"""
if head_name not in self.config.prediction_heads:
logger.info("No prediction head '%s' found for deletion. Skipping.", head_name)
return
del self.config.prediction_heads[head_name]
del self.heads[head_name]
if self.active_head == head_name:
self.active_head = None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When deleting a head that is the active head, the active head is set to None but when deleting an adapter in the active adapters setup the setup remains as it is. Maybe it would be better to handle these two cases similarily

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks!

def forward_head(
self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs
):
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
adapter.train(self.training) # make sure training mode is consistent
self.adapters[adapter_name] = adapter

def delete_adapter(self, adapter_name: str):
if adapter_name in self.adapters:
del self.adapters[adapter_name]

def add_fusion_layer(self, adapter_names: Union[List, str]):
"""See BertModel.add_fusion_layer"""
adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",")
Expand All @@ -75,6 +79,11 @@ def add_fusion_layer(self, adapter_names: Union[List, str]):
fusion.train(self.training) # make sure training mode is consistent
self.adapter_fusion_layer[",".join(adapter_names)] = fusion

def delete_fusion_layer(self, adapter_names: Union[List, str]):
adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names)
if adapter_names in self.adapter_fusion_layer:
del self.adapter_fusion_layer[adapter_names]

def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_fusion: bool):
"""
Unfreezes a given list of adapters, the adapter fusion layer, or both
Expand Down
51 changes: 48 additions & 3 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def add_invertible_adapter(self, adapter_name: str):
self.invertible_adapters[adapter_name] = inv_adap
self.invertible_adapters[adapter_name].apply(Adapter.init_bert_weights)

def delete_invertible_adapter(self, adapter_name: str):
if adapter_name in self.invertible_adapters:
del self.invertible_adapters[adapter_name]

def get_invertible_adapter(self):
# TODO: Currently no fusion over invertible adapters, takes only very first language adapter position
if self.config.adapters.active_setup is not None and len(self.config.adapters.active_setup) > 0:
Expand Down Expand Up @@ -206,7 +210,7 @@ def set_adapter_fusion_config(self, adapter_fusion_config, override_kwargs=None)
else:
raise ValueError("Invalid adapter type {}".format(adapter_fusion_config))

def add_adapter(self, adapter_name: str, config=None):
def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False):
"""
Adds a new adapter module of the specified type to the model.

Expand All @@ -217,9 +221,13 @@ def add_adapter(self, adapter_name: str, config=None):
- the string identifier of a pre-defined configuration dictionary
- a configuration dictionary specifying the full config
- if not given, the default configuration for this adapter type will be used
overwrite_ok (bool, optional): Overwrite an adapter with the same name if it exists. By default (False), an exception is thrown.
"""
if isinstance(config, dict):
config = AdapterConfig.from_dict(config) # ensure config is ok and up-to-date
# In case adapter already exists and we allow overwriting, explicitly delete the existing one first
if overwrite_ok and adapter_name in self.config.adapters:
self.delete_adapter(adapter_name)
self.config.adapters.add(adapter_name, config=config)
self.base_model._add_adapter(adapter_name)

Expand Down Expand Up @@ -260,6 +268,42 @@ def add_fusion(self, adapter_names: Union[Fuse, list], adapter_fusion_config=Non
self.config.adapter_fusion_models.append(adapter_fusion_name)
self.base_model._add_fusion_layer(adapter_names)

def delete_adapter(self, adapter_name: str):
"""
Deletes the adapter with the specified name from the model.

Args:
adapter_name (str): The name of the adapter.
"""
if adapter_name not in self.config.adapters:
logger.info("No adapter '%s' found for deletion. Skipping.", adapter_name)
return
del self.config.adapters.adapters[adapter_name]
self.base_model._delete_adapter(adapter_name)

def delete_adapter_fusion(self, adapter_names: Union[Fuse, list]):
"""
Deletes the AdapterFusion layer of the specified adapters.

Args:
adapter_names (Union[Fuse, list]): List of adapters for which to delete the AdapterFusion layer.
"""
if isinstance(adapter_names, Fuse):
adapter_fusion_name = ",".join(adapter_names.children)
elif isinstance(adapter_names, list):
adapter_fusion_name = ",".join(adapter_names)
else:
adapter_fusion_name = adapter_names

if (
not hasattr(self.config, "adapter_fusion_models")
or adapter_fusion_name not in self.config.adapter_fusion_models
):
logger.info("No AdapterFusion '%s' found for deletion. Skipping.", adapter_fusion_name)
return
self.config.adapter_fusion_models.remove(adapter_fusion_name)
self.base_model._delete_fusion_layer(adapter_fusion_name)

def save_adapter(
self,
save_directory: str,
Expand Down Expand Up @@ -477,7 +521,7 @@ def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self._convert_to_flex_head = False

def add_adapter(self, adapter_name: str, config=None):
def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False):
"""
Adds a new adapter module of the specified type to the model.

Expand All @@ -488,8 +532,9 @@ def add_adapter(self, adapter_name: str, config=None):
- the string identifier of a pre-defined configuration dictionary
- a configuration dictionary specifying the full config
- if not given, the default configuration for this adapter type will be used
overwrite_ok (bool, optional): Overwrite an adapter with the same name if it exists. By default (False), an exception is thrown.
"""
self.base_model.add_adapter(adapter_name, config)
self.base_model.add_adapter(adapter_name, config, overwrite_ok=overwrite_ok)

def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock]):
"""Sets the model into mode for training the given adapters."""
Expand Down
35 changes: 35 additions & 0 deletions src/transformers/adapters/models/bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
self.attention_adapters.add_adapter(adapter_name, layer_idx)
self.output_adapters.add_adapter(adapter_name, layer_idx)

def delete_adapter(self, adapter_name):
self.attention_adapters.delete_adapter(adapter_name)
self.output_adapters.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
self.attention_adapters.delete_fusion_layer(adapter_names)
self.output_adapters.delete_fusion_layer(adapter_names)

def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool):
self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
Expand All @@ -112,6 +120,14 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
super().add_adapter(adapter_name, layer_idx)
self.cross_attention_adapters.add_adapter(adapter_name, layer_idx)

def delete_adapter(self, adapter_name):
super().delete_adapter(adapter_name)
self.cross_attention_adapters.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
super().delete_fusion_layer(adapter_names)
self.cross_attention_adapters.delete_fusion_layer(adapter_names)

def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool):
super().enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
self.cross_attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
Expand All @@ -131,6 +147,14 @@ def add_adapter(self, adapter_name: str, layer_idx_offset: int = 0):
if i not in leave_out:
layer.add_adapter(adapter_name, i)

def delete_adapter(self, adapter_name: str):
for layer in self.layers:
layer.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
for layer in self.layers:
layer.delete_fusion_layer(adapter_names)

def enable_adapters(
self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool
):
Expand Down Expand Up @@ -197,6 +221,17 @@ def _add_fusion_layer(self, adapter_names):
self.encoder.add_fusion_layer(adapter_names)
self.decoder.add_fusion_layer(adapter_names)

def _delete_adapter(self, adapter_name: str):
if hasattr(self, "encoder"):
self.encoder.delete_adapter(adapter_name)
self.encoder.delete_invertible_adapter(adapter_name)
self.decoder.delete_adapter(adapter_name)

def _delete_fusion_layer(self, adapter_names):
if hasattr(self, "encoder"):
self.encoder.delete_fusion_layer(adapter_names)
self.decoder.delete_fusion_layer(adapter_names)

def get_fusion_regularization_loss(self):
reg_loss = 0.0
target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device)
Expand Down
23 changes: 23 additions & 0 deletions src/transformers/adapters/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
self.attention.output.add_adapter(adapter_name, layer_idx)
self.output.add_adapter(adapter_name, layer_idx)

def delete_adapter(self, adapter_name):
self.attention.output.delete_adapter(adapter_name)
self.output.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
self.attention.output.delete_fusion_layer(adapter_names)
self.output.delete_fusion_layer(adapter_names)

def enable_adapters(
self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool
):
Expand All @@ -67,6 +75,14 @@ def add_adapter(self, adapter_name: str):
if i not in leave_out:
layer.add_adapter(adapter_name, i)

def delete_adapter(self, adapter_name: str):
for layer in self.layer:
layer.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
for layer in self.layer:
layer.delete_fusion_layer(adapter_names)

def enable_adapters(
self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool
):
Expand Down Expand Up @@ -114,6 +130,13 @@ def _add_adapter(self, adapter_name):
def _add_fusion_layer(self, adapter_names):
self.encoder.add_fusion_layer(adapter_names)

def _delete_adapter(self, adapter_name: str):
self.encoder.delete_adapter(adapter_name)
self.delete_invertible_adapter(adapter_name)

def _delete_fusion_layer(self, adapter_names):
self.encoder.delete_fusion_layer(adapter_names)

def get_fusion_regularization_loss(self):
reg_loss = 0.0

Expand Down
15 changes: 15 additions & 0 deletions src/transformers/adapters/models/distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
self.attention_adapters.add_adapter(adapter_name, layer_idx)
self.output_adapters.add_adapter(adapter_name, layer_idx)

def delete_adapter(self, adapter_name):
self.attention_adapters.delete_adapter(adapter_name)
self.output_adapters.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
self.attention_adapters.delete_fusion_layer(adapter_names)
self.output_adapters.delete_fusion_layer(adapter_names)

def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool):
self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
Expand Down Expand Up @@ -96,6 +104,13 @@ def _add_adapter(self, adapter_name):
def _add_fusion_layer(self, adapter_names):
self.transformer.add_fusion_layer(adapter_names)

def _delete_adapter(self, adapter_name: str):
self.transformer.delete_adapter(adapter_name)
self.delete_invertible_adapter(adapter_name)

def _delete_fusion_layer(self, adapter_names):
self.transformer.delete_fusion_layer(adapter_names)

def get_fusion_regularization_loss(self):
reg_loss = 0.0
target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device)
Expand Down
32 changes: 17 additions & 15 deletions src/transformers/adapters/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,14 @@ def add_adapter(self, adapter_name: str, layer_idx: int):
self.attention_adapters.add_adapter(adapter_name, layer_idx)
self.output_adapters.add_adapter(adapter_name, layer_idx)

def delete_adapter(self, adapter_name):
self.attention_adapters.delete_adapter(adapter_name)
self.output_adapters.delete_adapter(adapter_name)

def delete_fusion_layer(self, adapter_names):
self.attention_adapters.delete_fusion_layer(adapter_names)
self.output_adapters.delete_fusion_layer(adapter_names)

def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool):
self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention)
Expand All @@ -79,21 +87,6 @@ def _init_adapter_modules(self):
for fusion_adapter_names in self.config.fusion_models:
self.add_fusion_layer(fusion_adapter_names)

def add_adapter(self, adapter_name: str, config=None):
"""
Adds a new adapter module of the specified type to the model.

Args:
adapter_name (str): The name of the adapter module to be added.
config (str or dict or AdapterConfig, optional): The adapter configuration, can be either:

- the string identifier of a pre-defined configuration dictionary
- a configuration dictionary specifying the full config
- if not given, the default configuration for this adapter type will be used
"""
self.config.adapters.add(adapter_name, config=config)
self._add_adapter(adapter_name)

def _add_adapter(self, adapter_name: str):
adapter_config = self.config.adapters.get(adapter_name)
leave_out = adapter_config.get("leave_out", [])
Expand Down Expand Up @@ -137,6 +130,15 @@ def _add_fusion_layer(self, adapter_names):
for layer in self.base_model.h:
layer.add_fusion_layer(adapter_names)

def _delete_adapter(self, adapter_name: str):
for layer in self.base_model.h:
layer.delete_adapter(adapter_name)
self.delete_invertible_adapter(adapter_name)

def _delete_fusion_layer(self, adapter_names):
for layer in self.base_model.h:
layer.delete_fusion_layer(adapter_names)

def get_fusion_regularization_loss(self):
reg_loss = 0.0
target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device)
Expand Down
Loading