Skip to content

Commit

Permalink
Unify & simplify model adapter integration (#263)
Browse files Browse the repository at this point in the history
* `AdapterLayer` module

* Simplify adapter model mixin implementations

* Remove unused imports

* Fix AdapterSetup context test
  • Loading branch information
calpt committed Feb 2, 2022
1 parent 972869e commit 7574056
Show file tree
Hide file tree
Showing 23 changed files with 339 additions and 1,126 deletions.
4 changes: 2 additions & 2 deletions adapter_docs/classes/adapter_layer.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
AdapterLayerBaseMixin
AdapterLayer
=======================

.. autoclass:: transformers.AdapterLayerBaseMixin
.. autoclass:: transformers.AdapterLayer
:members:
4 changes: 2 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@
]
_import_structure["adapters.context"] = ["AdapterSetup"]
_import_structure["adapters.heads"] = ["ModelWithFlexibleHeadsAdaptersMixin"]
_import_structure["adapters.layer"] = ["AdapterLayerBaseMixin"]
_import_structure["adapters.layer"] = ["AdapterLayer"]
_import_structure["adapters.loading"] = [
"AdapterFusionLoader",
"AdapterLoader",
Expand Down Expand Up @@ -3172,7 +3172,7 @@
)
from .adapters.context import AdapterSetup
from .adapters.heads import ModelWithFlexibleHeadsAdaptersMixin
from .adapters.layer import AdapterLayerBaseMixin
from .adapters.layer import AdapterLayer
from .adapters.loading import (
AdapterFusionLoader,
AdapterLoader,
Expand Down
16 changes: 16 additions & 0 deletions src/transformers/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,19 @@ def parse_heads_from_composition(adapter_composition, reference_heads: list = No
)
else:
return None


def adjust_tensors_for_parallel(hidden_states, *tensors):
"""
Replicates a given list of tensors based on the shape of the reference tensor (first argument).
"""
outputs = []
for tensor in tensors:
if tensor is not None and hidden_states.shape[0] != tensor.shape[0]:
repeats = [1] * len(tensor.shape)
repeats[0] = hidden_states.shape[0] // tensor.shape[0]
new_tensor = tensor.repeat(*repeats)
outputs.append(new_tensor)
else:
outputs.append(tensor)
return tuple(outputs)
126 changes: 77 additions & 49 deletions src/transformers/adapters/layer.py

Large diffs are not rendered by default.

119 changes: 81 additions & 38 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
import os
import warnings
from abc import ABC, abstractmethod
from collections import defaultdict
from os.path import join
from typing import List, Optional, Union
from typing import Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -12,6 +13,7 @@
from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash
from .context import AdapterSetup
from .hub_mixin import PushAdapterToHubMixin
from .layer import AdapterLayer
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock
from .utils import EMBEDDING_FILE, TOKENIZER_PATH, inherit_doc
Expand Down Expand Up @@ -130,19 +132,45 @@ def _init_adapter_modules(self):
"""
# Initialize adapters from config
for adapter_name in self.config.adapters:
self._add_adapter(adapter_name)
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
# Initialize fusion from config
for fusion_name in self.config.adapters.fusions:
self._add_fusion_layer(fusion_name)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(fusion_name))

self.loaded_embeddings["default"] = self.get_input_embeddings()

# These methods have to be implemented by every deriving class:

@abstractmethod
def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
"""
Iterates over all layers of the model.
This abstract method has to ne implemented by every implementing model.
"""
pass

def apply_to_adapter_layers(self, fn):
"""
Applies a function to all adapter layers of the model.
"""
for i, layer in self.iter_layers():
for module in layer.modules():
if isinstance(module, AdapterLayer):
fn(i, module)

def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False):
"""Sets the model into mode for training the given adapters."""
pass
self.train()
self.freeze_model(True)
adapter_setup = parse_composition(adapter_setup)
self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, True, False))
if isinstance(self, InvertibleAdaptersMixin):
self.enable_invertible_adapters(adapter_setup.flatten())
# use the adapters to be trained by default in every forward pass
self.set_active_adapters(adapter_setup)
if train_embeddings:
self.get_input_embeddings().train()

def train_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False):
"""Sets the model into mode for training of adapter fusion determined by a list of adapter names."""
Expand All @@ -152,18 +180,15 @@ def train_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfr
)
self.train_adapter_fusion(adapter_setup, unfreeze_adapters=unfreeze_adapters)

@abstractmethod
def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False):
"""Sets the model into mode for training of adapter fusion determined by a list of adapter names."""
pass

@abstractmethod
def _add_adapter(self, adapter_name):
pass

@abstractmethod
def _add_fusion_layer(self, adapter_names):
pass
self.train()
self.freeze_model(True)
adapter_setup = parse_composition(adapter_setup)
self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, unfreeze_adapters, True))
# use the adapters to be trained by default in every forward pass
self.set_active_adapters(adapter_setup)
# TODO implement fusion for invertible adapters

def has_adapters(self):
return len(self.config.adapters.adapters) > 0
Expand Down Expand Up @@ -225,7 +250,9 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False
if overwrite_ok and adapter_name in self.config.adapters:
self.delete_adapter(adapter_name)
self.config.adapters.add(adapter_name, config=config)
self.base_model._add_adapter(adapter_name)
self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i))
if isinstance(self, InvertibleAdaptersMixin):
self.add_invertible_adapter(adapter_name)
if set_active:
self.set_active_adapters(adapter_name)

Expand Down Expand Up @@ -272,7 +299,7 @@ def add_adapter_fusion(
if overwrite_ok and self.config.adapters.get_fusion(adapter_names) is not None:
self.delete_adapter_fusion(adapter_names)
self.config.adapters.add_fusion(adapter_names, config=config)
self.base_model._add_fusion_layer(adapter_names)
self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names))
if set_active:
if not isinstance(adapter_names, list):
adapter_names = adapter_names.split(",")
Expand All @@ -289,7 +316,9 @@ def delete_adapter(self, adapter_name: str):
logger.info("No adapter '%s' found for deletion. Skipping.", adapter_name)
return
del self.config.adapters.adapters[adapter_name]
self.base_model._delete_adapter(adapter_name)
self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name))
if isinstance(self, InvertibleAdaptersMixin):
self.delete_invertible_adapter(adapter_name)
# Reset active adapters if this was the only active adapter
if self.active_adapters == Stack(adapter_name):
self.active_adapters = None
Expand All @@ -314,7 +343,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]):
logger.info("No AdapterFusion '%s' found for deletion. Skipping.", adapter_fusion_name)
return
del self.config.adapters.fusions[adapter_fusion_name]
self.base_model._delete_fusion_layer(adapter_fusion_name)
self.apply_to_adapter_layers(lambda i, layer: layer.delete_fusion_layer(adapter_fusion_name))
# Reset active adapters if this was the active setup
if self.active_adapters == adapter_names:
self.active_adapters = None
Expand Down Expand Up @@ -661,6 +690,31 @@ def set_active_embeddings(self, name):
def active_embeddings(self):
return self._active_embedding

def get_fusion_regularization_loss(self):
reg_loss = 0.0

target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device)
for i, layer in self.iter_layers():
for module in layer.modules():
if isinstance(module, AdapterLayer):
for _, layer_fusion in module.adapter_fusion_layer.items():
if hasattr(layer_fusion, "value"):
reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum()

return reg_loss

def get_adapter(self, name):
destination = defaultdict(dict)

# use a custom index to ensure numbering is from 0 to N layers
for i, (_, layer) in enumerate(self.iter_layers()):
for module in layer.modules():
if isinstance(module, AdapterLayer):
if name in module.adapters:
destination[i][module.location_key] = module.adapters[name]

return dict(destination)


@inherit_doc
class ModelWithHeadsAdaptersMixin(ModelAdaptersMixin):
Expand All @@ -672,6 +726,15 @@ def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)
self._convert_to_flex_head = False

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
"""
Iterates over all layers of the model.
"""
if self.base_model is self:
return super().iter_layers()
else:
return self.base_model.iter_layers()

def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False, set_active: bool = False):
"""
Adds a new adapter module of the specified type to the model.
Expand Down Expand Up @@ -714,26 +777,6 @@ def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBloc
else:
self.base_model.train_adapter_fusion(adapter_setup, unfreeze_adapters=unfreeze_adapters)

def _add_adapter(self, adapter_name):
"""
If self.base_model is self, must inherit from a class that implements this method, to preclude infinite
recursion
"""
if self.base_model is self:
super()._add_adapter(adapter_name)
else:
self.base_model._add_adapter(adapter_name)

def _add_fusion_layer(self, adapter_names):
"""
If self.base_model is self, must inherit from a class that implements this method, to preclude infinite
recursion
"""
if self.base_model is self:
super()._add_fusion_layer(adapter_names)
else:
self.base_model._add_fusion_layer(adapter_names)

def save_head(self, save_directory: str, head_name: str = None):
loader = PredictionHeadLoader(self)
loader.save(save_directory, name=head_name)
Expand Down
Loading

0 comments on commit 7574056

Please sign in to comment.