From a9ea6161dfd5e21fb8bca72f7a06591608857f6c Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 28 Jan 2022 20:06:00 +0100 Subject: [PATCH] Unify & simplify model adapter integration (#263) * `AdapterLayer` module * Simplify adapter model mixin implementations * Remove unused imports * Fix AdapterSetup context test --- adapter_docs/classes/adapter_layer.rst | 4 +- src/transformers/__init__.py | 4 +- src/transformers/adapters/composition.py | 16 + src/transformers/adapters/layer.py | 126 +++++--- src/transformers/adapters/model_mixin.py | 119 +++++--- src/transformers/adapters/models/bart.py | 286 ++---------------- src/transformers/adapters/models/bert.py | 156 +--------- .../adapters/models/distilbert.py | 145 +-------- .../adapters/models/encoder_decoder.py | 51 +--- src/transformers/adapters/models/gpt2.py | 174 +---------- src/transformers/adapters/models/t5.py | 233 ++------------ src/transformers/models/bart/modeling_bart.py | 28 +- src/transformers/models/bert/modeling_bert.py | 13 +- .../models/distilbert/modeling_distilbert.py | 11 +- src/transformers/models/gpt2/modeling_gpt2.py | 9 +- .../models/mbart/modeling_mbart.py | 28 +- .../models/roberta/modeling_roberta.py | 13 +- src/transformers/models/t5/modeling_t5.py | 30 +- tests/test_adapter.py | 1 + tests/test_adapter_common.py | 8 +- tests/test_adapter_fusion_common.py | 2 + tests/test_adapter_heads.py | 2 +- tests/test_adapter_setup_context.py | 6 +- 23 files changed, 339 insertions(+), 1126 deletions(-) diff --git a/adapter_docs/classes/adapter_layer.rst b/adapter_docs/classes/adapter_layer.rst index bb954484d1..92e7d2252a 100644 --- a/adapter_docs/classes/adapter_layer.rst +++ b/adapter_docs/classes/adapter_layer.rst @@ -1,5 +1,5 @@ -AdapterLayerBaseMixin +AdapterLayer ======================= -.. autoclass:: transformers.AdapterLayerBaseMixin +.. autoclass:: transformers.AdapterLayer :members: diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 3bf0a33d7f..ee3a1dcf15 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1385,7 +1385,7 @@ ] _import_structure["adapters.context"] = ["AdapterSetup"] _import_structure["adapters.heads"] = ["ModelWithFlexibleHeadsAdaptersMixin"] - _import_structure["adapters.layer"] = ["AdapterLayerBaseMixin"] + _import_structure["adapters.layer"] = ["AdapterLayer"] _import_structure["adapters.loading"] = [ "AdapterFusionLoader", "AdapterLoader", @@ -3172,7 +3172,7 @@ ) from .adapters.context import AdapterSetup from .adapters.heads import ModelWithFlexibleHeadsAdaptersMixin - from .adapters.layer import AdapterLayerBaseMixin + from .adapters.layer import AdapterLayer from .adapters.loading import ( AdapterFusionLoader, AdapterLoader, diff --git a/src/transformers/adapters/composition.py b/src/transformers/adapters/composition.py index fdc3aa7c09..de37a51f83 100644 --- a/src/transformers/adapters/composition.py +++ b/src/transformers/adapters/composition.py @@ -179,3 +179,19 @@ def parse_heads_from_composition(adapter_composition, reference_heads: list = No ) else: return None + + +def adjust_tensors_for_parallel(hidden_states, *tensors): + """ + Replicates a given list of tensors based on the shape of the reference tensor (first argument). + """ + outputs = [] + for tensor in tensors: + if tensor is not None and hidden_states.shape[0] != tensor.shape[0]: + repeats = [1] * len(tensor.shape) + repeats[0] = hidden_states.shape[0] // tensor.shape[0] + new_tensor = tensor.repeat(*repeats) + outputs.append(new_tensor) + else: + outputs.append(tensor) + return tuple(outputs) diff --git a/src/transformers/adapters/layer.py b/src/transformers/adapters/layer.py index 3552748bf1..e245c4a4f6 100644 --- a/src/transformers/adapters/layer.py +++ b/src/transformers/adapters/layer.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from typing import List, Mapping, Union import torch @@ -9,22 +8,11 @@ from .modeling import Adapter, BertFusion -class AdapterLayerBaseMixin(ABC): - """ - An abstract base implementation of adapter integration into a Transformer block. In BERT, subclasses of this module - are placed in the BertSelfOutput module and in the BertOutput module. - """ - - # override this property if layer norm has a different name - @property - def transformer_layer_norm(self): - return self.LayerNorm - - @property - @abstractmethod - def adapter_config_key(self): - """Gets the name of the key by which this adapter location is identified in the adapter configuration.""" - pass +class AdapterLayer(nn.Module): + def __init__(self, location_key: str, config): + super().__init__() + self.location_key = location_key + self.config = config @property def layer_idx(self): @@ -43,7 +31,12 @@ def _init_adapter_modules(self): def add_adapter(self, adapter_name: str, layer_idx: int): self.layer_idx = layer_idx adapter_config = self.config.adapters.get(adapter_name) - if adapter_config and adapter_config.get(self.adapter_config_key, None): + if adapter_config and adapter_config.get(self.location_key, None): + # Check whether to skip this layer. + leave_out = adapter_config.get("leave_out", []) + if self.layer_idx in leave_out: + return + reduction_factor = adapter_config["reduction_factor"] if isinstance(reduction_factor, Mapping): if str(self.layer_idx) in reduction_factor: @@ -75,7 +68,7 @@ def delete_adapter(self, adapter_name: str): def add_fusion_layer(self, adapter_names: Union[List, str]): """See BertModel.add_fusion_layer""" adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") - if self.config.adapters.common_config_value(adapter_names, self.adapter_config_key): + if self.config.adapters.common_config_value(adapter_names, self.location_key): fusion_config = self.config.adapters.get_fusion(adapter_names) fusion = BertFusion( fusion_config, @@ -115,11 +108,20 @@ def enable_adapters(self, adapter_setup: AdapterCompositionBlock, unfreeze_adapt for param in self.adapter_fusion_layer[sub_setup.name].parameters(): param.requires_grad = True + def adapter_state_dict(self, adapter_name: str, destination=None, prefix=""): + if adapter_name in self.adapters: + return self.adapters[adapter_name].state_dict( + destination=destination, prefix=prefix + f"{self.location_key}.adapters.{adapter_name}." + ) + else: + return destination + def get_adapter_preparams( self, adapter_config, hidden_states, input_tensor, + layer_norm, fusion_config=None, ): """ @@ -142,8 +144,8 @@ def get_adapter_preparams( query = hidden_states if adapter_config["original_ln_before"]: - if self.transformer_layer_norm: - hidden_states = self.transformer_layer_norm(hidden_states + input_tensor) + if layer_norm: + hidden_states = layer_norm(hidden_states + input_tensor) else: hidden_states = hidden_states + input_tensor @@ -155,7 +157,7 @@ def get_adapter_preparams( return hidden_states, query, residual - def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, lvl=0): + def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, layer_norm, lvl=0): """ Forwards the given input through the given stack of adapters. """ @@ -169,23 +171,31 @@ def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, lvl=0 ) # Case 1: We have a nested fusion layer -> call fusion method if isinstance(adapter_stack_layer, Fuse): - hidden_states = self.adapter_fusion(adapter_stack_layer, hidden_states, input_tensor, lvl=lvl + 1) + hidden_states = self.adapter_fusion( + adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 + ) # Case 2: We have a nested split layer -> call split method elif isinstance(adapter_stack_layer, Split): - hidden_states = self.adapter_split(adapter_stack_layer, hidden_states, input_tensor, lvl=lvl + 1) + hidden_states = self.adapter_split( + adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 + ) # Case 3: We have a nested parallel layer -> call parallel method elif isinstance(adapter_stack_layer, Parallel): hidden_states, input_tensor = self.adapter_parallel( - adapter_stack_layer, hidden_states, input_tensor, lvl=lvl + 1 + adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 ) # Case 4: We have a nested batch split block -> call batchsplit method elif isinstance(adapter_stack_layer, BatchSplit): - hidden_states = self.adapter_batchsplit(adapter_stack_layer, hidden_states, input_tensor, lvl=lvl + 1) + hidden_states = self.adapter_batchsplit( + adapter_stack_layer, hidden_states, input_tensor, layer_norm, lvl=lvl + 1 + ) # Case 5: We have a single adapter which is part of this module -> forward pass elif adapter_stack_layer in self.adapters: adapter_layer = self.adapters[adapter_stack_layer] adapter_config = self.config.adapters.get(adapter_stack_layer) - hidden_states, _, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) + hidden_states, _, residual = self.get_adapter_preparams( + adapter_config, hidden_states, input_tensor, layer_norm + ) hidden_states, _, up = adapter_layer(hidden_states, residual_input=residual) # as this stack might be part of a fusion block, return the adapter up-projection output here # together with the final output (with potential residuals & norms) if we reached the last block of the stack @@ -197,7 +207,7 @@ def adapter_stack(self, adapter_setup: Stack, hidden_states, input_tensor, lvl=0 # or no adapter was found. In both cases, we don't need to set the second return value for fusion return hidden_states, None, input_tensor - def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, lvl=0): + def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, layer_norm, lvl=0): """ Performs adapter fusion with the given adapters for the given input. """ @@ -205,7 +215,7 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, lvl=0 adapter_config = self.config.adapters.get(adapter_setup.last()) fusion_config = self.config.adapters.get_fusion(adapter_setup.name) hidden_states, query, residual = self.get_adapter_preparams( - adapter_config, hidden_states, input_tensor, fusion_config=fusion_config + adapter_config, hidden_states, input_tensor, layer_norm, fusion_config=fusion_config ) up_list = [] @@ -213,7 +223,7 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, lvl=0 for adapter_block in adapter_setup: # Case 1: We have a nested stack -> call stack method if isinstance(adapter_block, Stack): - _, up, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, lvl=lvl + 1) + _, up, _ = self.adapter_stack(adapter_block, hidden_states, input_tensor, layer_norm, lvl=lvl + 1) if up is not None: # could be none if stack is empty up_list.append(up) # Case 2: We have a single adapter which is part of this module -> forward pass @@ -243,13 +253,15 @@ def adapter_fusion(self, adapter_setup: Fuse, hidden_states, input_tensor, lvl=0 return hidden_states - def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, lvl=0): + def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, layer_norm, lvl=0): """ Splits the given input between the given adapters. """ # config of _first_ of splitted adapters is significant adapter_config = self.config.adapters.get(adapter_setup.first()) - hidden_states, _, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) + hidden_states, query, residual = self.get_adapter_preparams( + adapter_config, hidden_states, input_tensor, layer_norm + ) # split hidden representations and residuals at split index split_hidden_states = [ @@ -269,17 +281,17 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, lvl=0 # Case 1: We have a nested stack -> call stack method if isinstance(adapter_block, Stack): split_hidden_states[i], _, _ = self.adapter_stack( - adapter_block, split_hidden_states[i], split_input_tensor[i], lvl=lvl + 1 + adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 ) # Case 2: We have a nested split -> recursively call split elif isinstance(adapter_block, Split): split_hidden_states[i] = self.adapter_split( - adapter_block, split_hidden_states[i], split_input_tensor[i], lvl=lvl + 1 + adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 ) # Case 3: We have a nested batch split -> call batch split method elif isinstance(adapter_block, BatchSplit): split_hidden_states[i] = self.adapter_batchsplit( - adapter_block, split_hidden_states[i], split_input_tensor[i], lvl=lvl + 1 + adapter_block, split_hidden_states[i], split_input_tensor[i], layer_norm, lvl=lvl + 1 ) # Case 4: We have a single adapter which is part of this module -> forward pass elif adapter_block in self.adapters: @@ -297,7 +309,7 @@ def adapter_split(self, adapter_setup: Split, hidden_states, input_tensor, lvl=0 hidden_states = torch.cat(split_hidden_states, dim=1) return hidden_states - def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, lvl=0): + def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, layer_norm, lvl=0): """ For parallel execution of the adapters on the same input. This means that the input is repeated N times before feeding it to the adapters (where N is the number of adapters). @@ -319,7 +331,9 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, ) orig_batch_size = hidden_states.shape[0] // adapter_setup.parallel_channels - hidden_states, _, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) + hidden_states, _, residual = self.get_adapter_preparams( + adapter_config, hidden_states, input_tensor, layer_norm + ) # sequentially feed different parts of the blown-up batch into different adapters children_hidden = [] @@ -330,6 +344,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, child, hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], + layer_norm, lvl=lvl + 1, ) children_hidden.append(child_hidden_states) @@ -339,6 +354,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, child, hidden_states[i * orig_batch_size : (i + 1) * orig_batch_size], input_tensor[i * orig_batch_size : (i + 1) * orig_batch_size], + layer_norm, lvl=lvl + 1, ) children_hidden.append(child_hidden_states) @@ -365,7 +381,7 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, hidden_states = torch.cat(children_hidden, 0) return hidden_states, input_tensor - def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_tensor, lvl=0): + def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_tensor, layer_norm, lvl=0): if not sum(adapter_setup.batch_sizes) == hidden_states.shape[0]: raise IndexError( "The given batch has a size of {} which is not compatible with batch_sizes {}".format( @@ -374,7 +390,9 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten ) adapter_config = self.config.adapters.get(adapter_setup.first()) - hidden_states, _, residual = self.get_adapter_preparams(adapter_config, hidden_states, input_tensor) + hidden_states, _, residual = self.get_adapter_preparams( + adapter_config, hidden_states, input_tensor, layer_norm + ) children_hidden = [] for i, adapter_block in enumerate(adapter_setup): # compute ids of sequences thet should be passed to the ith adapter @@ -388,6 +406,7 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten adapter_block, hidden_states[batch_idx[0] : batch_idx[1]], input_tensor[batch_idx[0] : batch_idx[1]], + layer_norm, lvl=lvl + 1, ) children_hidden.append(child) @@ -397,6 +416,7 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten adapter_block, hidden_states[batch_idx[0] : batch_idx[1]], input_tensor[batch_idx[0] : batch_idx[1]], + layer_norm, lvl=lvl + 1, ) children_hidden.append(child) @@ -406,6 +426,7 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten adapter_block, hidden_states[batch_idx[0] : batch_idx[1]], input_tensor[batch_idx[0] : batch_idx[1]], + layer_norm, lvl=lvl + 1, ) children_hidden.append(child) @@ -431,7 +452,7 @@ def adapter_batchsplit(self, adapter_setup: BatchSplit, hidden_states, input_ten hidden_states = torch.cat(children_hidden, 0) return hidden_states - def adapters_forward(self, hidden_states, input_tensor): + def adapter_layer_forward(self, hidden_states, input_tensor, layer_norm): """ Called for each forward pass through adapters. """ @@ -449,30 +470,37 @@ def adapters_forward(self, hidden_states, input_tensor): ) if not skip_adapters and (len(set(self.adapters.keys()) & adapter_setup.flatten()) > 0): if isinstance(adapter_setup, Stack): - hidden_states, _, input_tensor = self.adapter_stack(adapter_setup, hidden_states, input_tensor) + hidden_states, _, input_tensor = self.adapter_stack( + adapter_setup, hidden_states, input_tensor, layer_norm + ) elif isinstance(adapter_setup, Fuse): - hidden_states = self.adapter_fusion(adapter_setup, hidden_states, input_tensor) + hidden_states = self.adapter_fusion(adapter_setup, hidden_states, input_tensor, layer_norm) elif isinstance(adapter_setup, Split): - hidden_states = self.adapter_split(adapter_setup, hidden_states, input_tensor) + hidden_states = self.adapter_split(adapter_setup, hidden_states, input_tensor, layer_norm) elif isinstance(adapter_setup, Parallel): # notice that we are overriding input tensor here to keep the same dim as hidden_states for the residual # in case we were blowing up the batch for parallel processing of multiple adapters for the same input - hidden_states, input_tensor = self.adapter_parallel(adapter_setup, hidden_states, input_tensor) + hidden_states, input_tensor = self.adapter_parallel( + adapter_setup, hidden_states, input_tensor, layer_norm + ) elif isinstance(adapter_setup, BatchSplit): - hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, input_tensor) + hidden_states = self.adapter_batchsplit(adapter_setup, hidden_states, input_tensor, layer_norm) else: raise ValueError(f"Invalid adapter setup {adapter_setup}") last_config = self.config.adapters.get(adapter_setup.last()) if last_config["original_ln_after"]: - if self.transformer_layer_norm: - hidden_states = self.transformer_layer_norm(hidden_states + input_tensor) + if layer_norm: + hidden_states = layer_norm(hidden_states + input_tensor) else: hidden_states = hidden_states + input_tensor - elif self.transformer_layer_norm: - hidden_states = self.transformer_layer_norm(hidden_states + input_tensor) + elif layer_norm: + hidden_states = layer_norm(hidden_states + input_tensor) else: hidden_states = hidden_states + input_tensor return hidden_states + + def forward(self, hidden_states, input_tensor, layer_norm): + return self.adapter_layer_forward(hidden_states, input_tensor, layer_norm) diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 9ceea4446d..131f241505 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -2,8 +2,9 @@ import os import warnings from abc import ABC, abstractmethod +from collections import defaultdict from os.path import join -from typing import List, Optional, Union +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -12,6 +13,7 @@ from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash from .context import AdapterSetup from .hub_mixin import PushAdapterToHubMixin +from .layer import AdapterLayer from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock from .utils import EMBEDDING_FILE, TOKENIZER_PATH, inherit_doc @@ -130,19 +132,45 @@ def _init_adapter_modules(self): """ # Initialize adapters from config for adapter_name in self.config.adapters: - self._add_adapter(adapter_name) + self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) # Initialize fusion from config for fusion_name in self.config.adapters.fusions: - self._add_fusion_layer(fusion_name) + self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(fusion_name)) self.loaded_embeddings["default"] = self.get_input_embeddings() # These methods have to be implemented by every deriving class: @abstractmethod + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + """ + Iterates over all layers of the model. + + This abstract method has to ne implemented by every implementing model. + """ + pass + + def apply_to_adapter_layers(self, fn): + """ + Applies a function to all adapter layers of the model. + """ + for i, layer in self.iter_layers(): + for module in layer.modules(): + if isinstance(module, AdapterLayer): + fn(i, module) + def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): """Sets the model into mode for training the given adapters.""" - pass + self.train() + self.freeze_model(True) + adapter_setup = parse_composition(adapter_setup) + self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, True, False)) + if isinstance(self, InvertibleAdaptersMixin): + self.enable_invertible_adapters(adapter_setup.flatten()) + # use the adapters to be trained by default in every forward pass + self.set_active_adapters(adapter_setup) + if train_embeddings: + self.get_input_embeddings().train() def train_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" @@ -152,18 +180,15 @@ def train_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfr ) self.train_adapter_fusion(adapter_setup, unfreeze_adapters=unfreeze_adapters) - @abstractmethod def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" - pass - - @abstractmethod - def _add_adapter(self, adapter_name): - pass - - @abstractmethod - def _add_fusion_layer(self, adapter_names): - pass + self.train() + self.freeze_model(True) + adapter_setup = parse_composition(adapter_setup) + self.apply_to_adapter_layers(lambda i, layer: layer.enable_adapters(adapter_setup, unfreeze_adapters, True)) + # use the adapters to be trained by default in every forward pass + self.set_active_adapters(adapter_setup) + # TODO implement fusion for invertible adapters def has_adapters(self): return len(self.config.adapters.adapters) > 0 @@ -225,7 +250,9 @@ def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False if overwrite_ok and adapter_name in self.config.adapters: self.delete_adapter(adapter_name) self.config.adapters.add(adapter_name, config=config) - self.base_model._add_adapter(adapter_name) + self.apply_to_adapter_layers(lambda i, layer: layer.add_adapter(adapter_name, i)) + if isinstance(self, InvertibleAdaptersMixin): + self.add_invertible_adapter(adapter_name) if set_active: self.set_active_adapters(adapter_name) @@ -272,7 +299,7 @@ def add_adapter_fusion( if overwrite_ok and self.config.adapters.get_fusion(adapter_names) is not None: self.delete_adapter_fusion(adapter_names) self.config.adapters.add_fusion(adapter_names, config=config) - self.base_model._add_fusion_layer(adapter_names) + self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names)) if set_active: if not isinstance(adapter_names, list): adapter_names = adapter_names.split(",") @@ -289,7 +316,9 @@ def delete_adapter(self, adapter_name: str): logger.info("No adapter '%s' found for deletion. Skipping.", adapter_name) return del self.config.adapters.adapters[adapter_name] - self.base_model._delete_adapter(adapter_name) + self.apply_to_adapter_layers(lambda i, layer: layer.delete_adapter(adapter_name)) + if isinstance(self, InvertibleAdaptersMixin): + self.delete_invertible_adapter(adapter_name) # Reset active adapters if this was the only active adapter if self.active_adapters == Stack(adapter_name): self.active_adapters = None @@ -314,7 +343,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]): logger.info("No AdapterFusion '%s' found for deletion. Skipping.", adapter_fusion_name) return del self.config.adapters.fusions[adapter_fusion_name] - self.base_model._delete_fusion_layer(adapter_fusion_name) + self.apply_to_adapter_layers(lambda i, layer: layer.delete_fusion_layer(adapter_fusion_name)) # Reset active adapters if this was the active setup if self.active_adapters == adapter_names: self.active_adapters = None @@ -661,6 +690,31 @@ def set_active_embeddings(self, name): def active_embeddings(self): return self._active_embedding + def get_fusion_regularization_loss(self): + reg_loss = 0.0 + + target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) + for i, layer in self.iter_layers(): + for module in layer.modules(): + if isinstance(module, AdapterLayer): + for _, layer_fusion in module.adapter_fusion_layer.items(): + if hasattr(layer_fusion, "value"): + reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() + + return reg_loss + + def get_adapter(self, name): + destination = defaultdict(dict) + + # use a custom index to ensure numbering is from 0 to N layers + for i, (_, layer) in enumerate(self.iter_layers()): + for module in layer.modules(): + if isinstance(module, AdapterLayer): + if name in module.adapters: + destination[i][module.location_key] = module.adapters[name] + + return dict(destination) + @inherit_doc class ModelWithHeadsAdaptersMixin(ModelAdaptersMixin): @@ -672,6 +726,15 @@ def __init__(self, config, *args, **kwargs): super().__init__(config, *args, **kwargs) self._convert_to_flex_head = False + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + """ + Iterates over all layers of the model. + """ + if self.base_model is self: + return super().iter_layers() + else: + return self.base_model.iter_layers() + def add_adapter(self, adapter_name: str, config=None, overwrite_ok: bool = False, set_active: bool = False): """ Adds a new adapter module of the specified type to the model. @@ -714,26 +777,6 @@ def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBloc else: self.base_model.train_adapter_fusion(adapter_setup, unfreeze_adapters=unfreeze_adapters) - def _add_adapter(self, adapter_name): - """ - If self.base_model is self, must inherit from a class that implements this method, to preclude infinite - recursion - """ - if self.base_model is self: - super()._add_adapter(adapter_name) - else: - self.base_model._add_adapter(adapter_name) - - def _add_fusion_layer(self, adapter_names): - """ - If self.base_model is self, must inherit from a class that implements this method, to preclude infinite - recursion - """ - if self.base_model is self: - super()._add_fusion_layer(adapter_names) - else: - self.base_model._add_fusion_layer(adapter_names) - def save_head(self, save_directory: str, head_name: str = None): loader = PredictionHeadLoader(self) loader.save(save_directory, name=head_name) diff --git a/src/transformers/adapters/models/bart.py b/src/transformers/adapters/models/bart.py index b6323456aa..590b573566 100644 --- a/src/transformers/adapters/models/bart.py +++ b/src/transformers/adapters/models/bart.py @@ -1,9 +1,7 @@ -from typing import Union +from typing import Iterable, Tuple -import torch -from torch import nn +import torch.nn as nn -from ..composition import AdapterCompositionBlock, parse_composition from ..heads import ( ClassificationHead, ModelWithFlexibleHeadsAdaptersMixin, @@ -11,106 +9,18 @@ QuestionAnsweringHead, Seq2SeqLMHead, ) -from ..layer import AdapterLayerBaseMixin -from ..model_mixin import ModelAdaptersMixin - - -class BartSelfAttentionAdaptersModule(AdapterLayerBaseMixin, nn.Module): - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def adapter_config_key(self): - return "mh_adapter" - - @property - def transformer_layer_norm(self): - # MBart has layer norms before each component - if self.config.model_type == "mbart": - return None - else: - return self.parent.self_attn_layer_norm - - -class BartCrossAttentionAdaptersModule(AdapterLayerBaseMixin, nn.Module): - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def adapter_config_key(self): - return "cross_adapter" - - @property - def transformer_layer_norm(self): - # MBart has layer norms before each component - if self.config.model_type == "mbart": - return None - else: - return self.parent.encoder_attn_layer_norm - - -class BartOutputAdaptersModule(AdapterLayerBaseMixin, nn.Module): - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def adapter_config_key(self): - return "output_adapter" - - @property - def transformer_layer_norm(self): - # MBart has layer norms before each component - if self.config.model_type == "mbart": - return None - else: - return self.parent.final_layer_norm +from ..layer import AdapterLayer +from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin class BartEncoderLayerAdaptersMixin: """Adds adapters to the BartEncoderLayer module of BART.""" def _init_adapter_modules(self): - self.attention_adapters = BartSelfAttentionAdaptersModule(self) - self.output_adapters = BartOutputAdaptersModule(self) + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) self.attention_adapters._init_adapter_modules() self.output_adapters._init_adapter_modules() - self.register_forward_pre_hook(self._adapter_block_pre_hook) - - def add_fusion_layer(self, adapter_names): - self.attention_adapters.add_fusion_layer(adapter_names) - self.output_adapters.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx: int): - self.attention_adapters.add_adapter(adapter_name, layer_idx) - self.output_adapters.add_adapter(adapter_name, layer_idx) - - def delete_adapter(self, adapter_name): - self.attention_adapters.delete_adapter(adapter_name) - self.output_adapters.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - self.attention_adapters.delete_fusion_layer(adapter_names) - self.output_adapters.delete_fusion_layer(adapter_names) - - def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool): - self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - - # Makes sure the "parent" reference always points to the correct module. - # This is especially relevant when using torch data parallelism. - @staticmethod - def _adapter_block_pre_hook(module, input_tensors): - object.__setattr__(module.attention_adapters, "parent", module) - object.__setattr__(module.output_adapters, "parent", module) class BartDecoderLayerAdaptersMixin(BartEncoderLayerAdaptersMixin): @@ -118,191 +28,33 @@ class BartDecoderLayerAdaptersMixin(BartEncoderLayerAdaptersMixin): def _init_adapter_modules(self): super()._init_adapter_modules() - self.cross_attention_adapters = BartCrossAttentionAdaptersModule(self) + self.cross_attention_adapters = AdapterLayer("cross_adapter", self.config) self.cross_attention_adapters._init_adapter_modules() - def add_fusion_layer(self, adapter_names): - super().add_fusion_layer(adapter_names) - self.cross_attention_adapters.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx: int): - super().add_adapter(adapter_name, layer_idx) - self.cross_attention_adapters.add_adapter(adapter_name, layer_idx) - - def delete_adapter(self, adapter_name): - super().delete_adapter(adapter_name) - self.cross_attention_adapters.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - super().delete_fusion_layer(adapter_names) - self.cross_attention_adapters.delete_fusion_layer(adapter_names) - - def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool): - super().enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - self.cross_attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - - # Makes sure the "parent" reference always points to the correct module. - # This is especially relevant when using torch data parallelism. - @staticmethod - def _adapter_block_pre_hook(module, input_tensors): - object.__setattr__(module.attention_adapters, "parent", module) - object.__setattr__(module.output_adapters, "parent", module) - object.__setattr__(module.cross_attention_adapters, "parent", module) - - -class BartEncoderDecoderAdaptersMixin: - """Adds adapters to the BartEncoder or BartDecoder module.""" - - def add_fusion_layer(self, adapter_names): - for layer in self.layers: - layer.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx_offset: int = 0): - adapter_config = self.config.adapters.get(adapter_name) - leave_out = adapter_config.get("leave_out", []) - for i, layer in enumerate(self.layers, start=layer_idx_offset): - if i not in leave_out: - layer.add_adapter(adapter_name, i) - - def delete_adapter(self, adapter_name: str): - for layer in self.layers: - layer.delete_adapter(adapter_name) - def delete_fusion_layer(self, adapter_names): - for layer in self.layers: - layer.delete_fusion_layer(adapter_names) - - def enable_adapters( - self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool - ): - for layer in self.layers: - layer.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - - def adjust_attention_mask_for_parallel(self, hidden_states, attention_mask): - if attention_mask is not None and hidden_states.shape[0] != attention_mask.shape[0]: - repeats = [1] * len(attention_mask.shape) - repeats[0] = hidden_states.shape[0] // attention_mask.shape[0] - attention_mask = attention_mask.repeat(*repeats) - return attention_mask - - -class BartModelAdaptersMixin(ModelAdaptersMixin): +class BartModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): """Adds adapters to the BartModel class.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "encoder"): + for i, layer in enumerate(self.encoder.layers): + yield i, layer + for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): + yield i, layer + else: + for i, layer in enumerate(self.decoder.layers): + yield i, layer def _init_adapter_modules(self): - super()._init_adapter_modules() if hasattr(self, "encoder"): # In BART, the invertible adapters are implemented by the encoder module. # Therefore, relay mixin calls to the encoder here. self.invertible_adapters = self.encoder.invertible_adapters self.add_invertible_adapter = self.encoder.add_invertible_adapter self.get_invertible_adapter = self.encoder.get_invertible_adapter + self.enable_invertible_adapters = self.encoder.enable_invertible_adapters self.invertible_adapters_forward = self.encoder.invertible_adapters_forward - - def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): - """Sets the model into mode for training the given adapters.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - if hasattr(self, "encoder"): - self.encoder.enable_adapters(adapter_setup, True, False) - self.encoder.enable_invertible_adapters(adapter_setup.flatten()) - self.decoder.enable_adapters(adapter_setup, True, False) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - if train_embeddings: - self.get_input_embeddings().train() - - def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): - """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - if hasattr(self, "encoder"): - self.encoder.enable_adapters(adapter_setup, unfreeze_adapters, True) - self.decoder.enable_adapters(adapter_setup, unfreeze_adapters, True) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - - def _add_adapter(self, adapter_name): - if hasattr(self, "encoder"): - self.encoder.add_adapter(adapter_name) - # make sure the layers in encoder & decoder are numbered from 0 to len(encoder+decoder) - self.decoder.add_adapter(adapter_name, layer_idx_offset=len(self.encoder.layers)) - self.encoder.add_invertible_adapter(adapter_name) - else: - self.decoder.add_adapter(adapter_name) - - def _add_fusion_layer(self, adapter_names): - if hasattr(self, "encoder"): - self.encoder.add_fusion_layer(adapter_names) - self.decoder.add_fusion_layer(adapter_names) - - def _delete_adapter(self, adapter_name: str): - if hasattr(self, "encoder"): - self.encoder.delete_adapter(adapter_name) - self.encoder.delete_invertible_adapter(adapter_name) - self.decoder.delete_adapter(adapter_name) - - def _delete_fusion_layer(self, adapter_names): - if hasattr(self, "encoder"): - self.encoder.delete_fusion_layer(adapter_names) - self.decoder.delete_fusion_layer(adapter_names) - - def get_fusion_regularization_loss(self): - reg_loss = 0.0 - target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) - # encoder - if hasattr(self, "encoder"): - for _, v in self.encoder.layers._modules.items(): - for _, layer_fusion in v.output_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.attention_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - # decoder - for _, v in self.decoder.layers._modules.items(): - for _, layer_fusion in v.output_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.attention_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - return reg_loss - - def adjust_tensors_for_parallel(self, hidden_states, *tensors): - outputs = [] - for tensor in tensors: - if tensor is not None and hidden_states.shape[0] != tensor.shape[0]: - repeats = [1] * len(tensor.shape) - repeats[0] = hidden_states.shape[0] // tensor.shape[0] - new_tensor = tensor.repeat(*repeats) - outputs.append(new_tensor) - else: - outputs.append(tensor) - return tuple(outputs) - - def get_adapter(self, name): - return_adapters = {} - for idx, layer in enumerate(self.encoder.layers): - adapters = { - "attention": layer.attention_adapters.adapters, - "output": layer.output_adapters.adapters, - } - for key, adapt in adapters.items(): - if hasattr(adapt, name): - if idx not in return_adapters: - return_adapters[idx] = {} - return_adapters[idx][key] = getattr(adapt, name) - - return return_adapters + super()._init_adapter_modules() class BartModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): diff --git a/src/transformers/adapters/models/bert.py b/src/transformers/adapters/models/bert.py index 85b539dfef..6fd51986f4 100644 --- a/src/transformers/adapters/models/bert.py +++ b/src/transformers/adapters/models/bert.py @@ -1,9 +1,8 @@ import logging -from typing import Union +from typing import Iterable, Tuple -import torch +import torch.nn as nn -from ..composition import AdapterCompositionBlock, parse_composition from ..heads import ( BertStyleMaskedLMHead, BiaffineParsingHead, @@ -15,162 +14,35 @@ QuestionAnsweringHead, TaggingHead, ) -from ..layer import AdapterLayerBaseMixin +from ..layer import AdapterLayer from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin logger = logging.getLogger(__name__) -class BertSelfOutputAdaptersMixin(AdapterLayerBaseMixin): +# For backwards compatibility, BertSelfOutput inherits directly from AdapterLayer +class BertSelfOutputAdaptersMixin(AdapterLayer): """Adds adapters to the BertSelfOutput module.""" - @property - def adapter_config_key(self): - return "mh_adapter" + def __init__(self): + super().__init__("mh_adapter", None) -class BertOutputAdaptersMixin(AdapterLayerBaseMixin): +# For backwards compatibility, BertOutput inherits directly from AdapterLayer +class BertOutputAdaptersMixin(AdapterLayer): """Adds adapters to the BertOutput module.""" - @property - def adapter_config_key(self): - return "output_adapter" - - -class BertLayerAdaptersMixin: - """Adds adapters to the BertLayer module.""" - - def add_fusion_layer(self, adapter_names): - self.attention.output.add_fusion_layer(adapter_names) - self.output.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx: int): - self.attention.output.add_adapter(adapter_name, layer_idx) - self.output.add_adapter(adapter_name, layer_idx) - - def delete_adapter(self, adapter_name): - self.attention.output.delete_adapter(adapter_name) - self.output.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - self.attention.output.delete_fusion_layer(adapter_names) - self.output.delete_fusion_layer(adapter_names) - - def enable_adapters( - self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool - ): - self.attention.output.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - self.output.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - - -class BertEncoderAdaptersMixin: - """Adds adapters to the BertEncoder module.""" - - def add_fusion_layer(self, adapter_names): - for layer in self.layer: - layer.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str): - adapter_config = self.config.adapters.get(adapter_name) - leave_out = adapter_config.get("leave_out", []) - for i, layer in enumerate(self.layer): - if i not in leave_out: - layer.add_adapter(adapter_name, i) - - def delete_adapter(self, adapter_name: str): - for layer in self.layer: - layer.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - for layer in self.layer: - layer.delete_fusion_layer(adapter_names) - - def enable_adapters( - self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool - ): - for layer in self.layer: - layer.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - - def adjust_attention_mask_for_parallel(self, hidden_states, attention_mask): - if attention_mask is not None and hidden_states.shape[0] != attention_mask.shape[0]: - repeats = [1] * len(attention_mask.shape) - repeats[0] = hidden_states.shape[0] // attention_mask.shape[0] - attention_mask = attention_mask.repeat(*repeats) - return attention_mask + def __init__(self): + super().__init__("output_adapter", None) class BertModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): """Adds adapters to the BertModel module.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): - """Sets the model into mode for training the given adapters.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - self.encoder.enable_adapters(adapter_setup, True, False) - self.enable_invertible_adapters(adapter_setup.flatten()) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - if train_embeddings: - self.get_input_embeddings().train() - - def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): - """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - self.encoder.enable_adapters(adapter_setup, unfreeze_adapters, True) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - # TODO implement fusion for invertible adapters - - def _add_adapter(self, adapter_name): - self.encoder.add_adapter(adapter_name) - self.add_invertible_adapter(adapter_name) - - def _add_fusion_layer(self, adapter_names): - self.encoder.add_fusion_layer(adapter_names) - - def _delete_adapter(self, adapter_name: str): - self.encoder.delete_adapter(adapter_name) - self.delete_invertible_adapter(adapter_name) - - def _delete_fusion_layer(self, adapter_names): - self.encoder.delete_fusion_layer(adapter_names) - - def get_fusion_regularization_loss(self): - reg_loss = 0.0 - - target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) - for _, v in self.encoder.layer._modules.items(): - - for _, layer_fusion in v.output.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.attention.output.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - return reg_loss - - def get_adapter(self, name): - return_adapters = {} - for idx, layer in enumerate(self.encoder.layer): - adapters = { - "attention": layer.attention.output.adapters, - "output": layer.output.adapters, - } - for key, adapt in adapters.items(): - if hasattr(adapt, name): - if idx not in return_adapters: - return_adapters[idx] = {} - return_adapters[idx][key] = getattr(adapt, name) - - return return_adapters + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.encoder.layer): + yield i, layer class BertModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): diff --git a/src/transformers/adapters/models/distilbert.py b/src/transformers/adapters/models/distilbert.py index d7a30b97d9..7b81664c9a 100644 --- a/src/transformers/adapters/models/distilbert.py +++ b/src/transformers/adapters/models/distilbert.py @@ -1,155 +1,28 @@ -from typing import Union +from typing import Iterable, Tuple -import torch -from torch import nn +import torch.nn as nn -from ..composition import AdapterCompositionBlock, parse_composition +from ..layer import AdapterLayer from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin -from .bert import BertEncoderAdaptersMixin, BertModelHeadsMixin, BertOutputAdaptersMixin, BertSelfOutputAdaptersMixin - - -class DistilBertSelfAttentionAdaptersModule(BertSelfOutputAdaptersMixin, nn.Module): - """Adds attention adapters to the Transformer module of DistilBert.""" - - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def transformer_layer_norm(self): - return self.parent.sa_layer_norm - - -class DistilBertOutputAdaptersModule(BertOutputAdaptersMixin, nn.Module): - """Adds output adapters to the Transformer module of DistilBert.""" - - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def transformer_layer_norm(self): - return self.parent.output_layer_norm +from .bert import BertModelHeadsMixin class DistilBertTransfomerBlockAdaptersMixin: """Adds adapters to the TransformerBlock module of DistilBert.""" def _init_adapter_modules(self): - self.attention_adapters = DistilBertSelfAttentionAdaptersModule(self) - self.output_adapters = DistilBertOutputAdaptersModule(self) + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) self.attention_adapters._init_adapter_modules() self.output_adapters._init_adapter_modules() - self.register_forward_pre_hook(self._adapter_block_pre_hook) - - def add_fusion_layer(self, adapter_names): - self.attention_adapters.add_fusion_layer(adapter_names) - self.output_adapters.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx: int): - self.attention_adapters.add_adapter(adapter_name, layer_idx) - self.output_adapters.add_adapter(adapter_name, layer_idx) - - def delete_adapter(self, adapter_name): - self.attention_adapters.delete_adapter(adapter_name) - self.output_adapters.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - self.attention_adapters.delete_fusion_layer(adapter_names) - self.output_adapters.delete_fusion_layer(adapter_names) - - def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool): - self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - - # Makes sure the "parent" reference always points to the correct module. - # This is especially relevant when using torch data parallelism. - @staticmethod - def _adapter_block_pre_hook(module, input_tensors): - object.__setattr__(module.attention_adapters, "parent", module) - object.__setattr__(module.output_adapters, "parent", module) - - -class DistilBertTransformerAdaptersMixin(BertEncoderAdaptersMixin): - """Adds adapters to the Transformer module of DistilBert.""" - - pass class DistilBertModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): """Adds adapters to the DistilBert module.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): - """Sets the model into mode for training the given adapters.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - self.transformer.enable_adapters(adapter_setup, True, False) - self.enable_invertible_adapters(adapter_setup.flatten()) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - if train_embeddings: - self.get_input_embeddings().train() - - def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): - """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - self.transformer.enable_adapters(adapter_setup, unfreeze_adapters, True) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - - def _add_adapter(self, adapter_name): - self.transformer.add_adapter(adapter_name) - self.add_invertible_adapter(adapter_name) - - def _add_fusion_layer(self, adapter_names): - self.transformer.add_fusion_layer(adapter_names) - - def _delete_adapter(self, adapter_name: str): - self.transformer.delete_adapter(adapter_name) - self.delete_invertible_adapter(adapter_name) - - def _delete_fusion_layer(self, adapter_names): - self.transformer.delete_fusion_layer(adapter_names) - - def get_fusion_regularization_loss(self): - reg_loss = 0.0 - target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) - for _, v in self.transformer.layer._modules.items(): - - for _, layer_fusion in v.output_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.attention_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - return reg_loss - - def get_adapter(self, name): - return_adapters = {} - for idx, layer in enumerate(self.transformer.layer): - adapters = { - "attention": layer.attention_adapters.adapters, - "output": layer.output_adapters.adapters, - } - for key, adapt in adapters.items(): - if hasattr(adapt, name): - if idx not in return_adapters: - return_adapters[idx] = {} - return_adapters[idx][key] = getattr(adapt, name) - - return return_adapters + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in enumerate(self.transformer.layer): + yield i, layer class DistilBertModelHeadsMixin(BertModelHeadsMixin): diff --git a/src/transformers/adapters/models/encoder_decoder.py b/src/transformers/adapters/models/encoder_decoder.py index 8a6911b929..3d11018c6c 100644 --- a/src/transformers/adapters/models/encoder_decoder.py +++ b/src/transformers/adapters/models/encoder_decoder.py @@ -1,10 +1,11 @@ -from typing import Union +from typing import Iterable, Tuple -from ..composition import AdapterCompositionBlock -from ..model_mixin import ModelAdaptersMixin +import torch.nn as nn +from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin -class EncoderDecoderModelAdaptersMixin(ModelAdaptersMixin): + +class EncoderDecoderModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): """Adds adapters to the EncoderDecoderModel class.""" def __init__(self, *args, **kwargs): @@ -14,11 +15,11 @@ def _init_adapter_modules(self): if self.config.adapters is None: return - super()._init_adapter_modules() # Relay all invertible adapter calls to encoder self.invertible_adapters = self.encoder.base_model.invertible_adapters self.add_invertible_adapter = self.encoder.base_model.add_invertible_adapter self.get_invertible_adapter = self.encoder.base_model.get_invertible_adapter + self.enable_invertible_adapters = self.encoder.base_model.enable_invertible_adapters self.invertible_adapters_forward = self.encoder.base_model.invertible_adapters_forward # Decoder should use invertible adapters of encoder self.decoder.base_model.invertible_adapters = self.encoder.base_model.invertible_adapters @@ -36,38 +37,10 @@ def decoder_invertible_adapters_forward(hidden_states, rev=False): self.decoder.base_model.invertible_adapters_forward = decoder_invertible_adapters_forward - def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock]): - """Sets the model into mode for training the given adapters.""" - self.encoder.train_adapter(adapter_setup) - self.decoder.train_adapter(adapter_setup) - - def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): - """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" - self.encoder.train_adapter_fusion(adapter_setup, unfreeze_adapters) - self.decoder.train_adapter_fusion(adapter_setup, unfreeze_adapters) - - def _add_adapter(self, adapter_name): - self.encoder.base_model._add_adapter(adapter_name) - self.decoder.base_model._add_adapter(adapter_name) - - def _add_fusion_layer(self, adapter_names): - self.encoder.base_model._add_fusion_layer(adapter_names) - self.decoder.base_model._add_fusion_layer(adapter_names) - - def _delete_adapter(self, adapter_name: str): - self.encoder.base_model._delete_adapter(adapter_name) - self.decoder.base_model._delete_adapter(adapter_name) - - def _delete_fusion_layer(self, adapter_names): - self.encoder.base_model._delete_fusion_layer(adapter_names) - self.decoder.base_model._delete_fusion_layer(adapter_names) - - def get_fusion_regularization_loss(self): - return self.encoder.get_fusion_regularization_loss() + self.decoder.get_fusion_regularization_loss() - - def get_adapter(self, name): - return_adapters = self.encoder.get_adapter(name) - for idx, items in self.decoder.get_adapter(name).items(): - return_adapters[len(return_adapters) + idx] = items + super()._init_adapter_modules() - return return_adapters + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + for i, layer in self.encoder.iter_layers(): + yield i, layer + for i, layer in self.decoder.iter_layers(): + yield i, layer diff --git a/src/transformers/adapters/models/gpt2.py b/src/transformers/adapters/models/gpt2.py index d79b26ae0b..b43ebd2577 100644 --- a/src/transformers/adapters/models/gpt2.py +++ b/src/transformers/adapters/models/gpt2.py @@ -1,183 +1,27 @@ -from typing import Union +from typing import Iterable, Tuple -import torch -from torch import nn +import torch.nn as nn -from ..composition import AdapterCompositionBlock, parse_composition from ..heads import CausalLMHead, ClassificationHead, MultiLabelClassificationHead, TaggingHead +from ..layer import AdapterLayer from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin -from .bert import ( - BertEncoderAdaptersMixin, - BertOutputAdaptersMixin, - BertSelfOutputAdaptersMixin, - ModelWithFlexibleHeadsAdaptersMixin, -) +from .bert import ModelWithFlexibleHeadsAdaptersMixin -class GPT2AttentionAdaptersModule(BertSelfOutputAdaptersMixin, nn.Module): - """Adds attention adapters to the Transformer module of DistilBert.""" - - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def transformer_layer_norm(self): - return None - - -class GPT2OutputAdaptersModule(BertOutputAdaptersMixin, nn.Module): - """Adds output adapters to the Transformer module of DistilBert.""" - - def __init__(self, parent): - super().__init__() - # keep a reference to the parent module without registering as a submodule - object.__setattr__(self, "parent", parent) - self.config = parent.config - - @property - def transformer_layer_norm(self): - return None - - -class GPT2DecoderBlockAdaptersMixin(BertEncoderAdaptersMixin): +class GPT2DecoderBlockAdaptersMixin: """Adds adapters to the TransformerBlock module of DistilBert.""" def _init_adapter_modules(self): - self.attention_adapters = GPT2AttentionAdaptersModule(self) - self.output_adapters = GPT2OutputAdaptersModule(self) + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) self.attention_adapters._init_adapter_modules() self.output_adapters._init_adapter_modules() - self.register_forward_pre_hook(self._adapter_block_pre_hook) - - def add_fusion_layer(self, adapter_names): - self.attention_adapters.add_fusion_layer(adapter_names) - self.output_adapters.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx: int): - self.attention_adapters.add_adapter(adapter_name, layer_idx) - self.output_adapters.add_adapter(adapter_name, layer_idx) - - def delete_adapter(self, adapter_name): - self.attention_adapters.delete_adapter(adapter_name) - self.output_adapters.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - self.attention_adapters.delete_fusion_layer(adapter_names) - self.output_adapters.delete_fusion_layer(adapter_names) - - def enable_adapters(self, adapter_names: list, unfreeze_adapters: bool, unfreeze_attention: bool): - self.attention_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - self.output_adapters.enable_adapters(adapter_names, unfreeze_adapters, unfreeze_attention) - - # Makes sure the "parent" reference always points to the correct module. - # This is especially relevant when using torch data parallelism. - @staticmethod - def _adapter_block_pre_hook(module, input_tensors): - object.__setattr__(module.attention_adapters, "parent", module) - object.__setattr__(module.output_adapters, "parent", module) class GPT2ModelAdapterMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _init_adapter_modules(self): - super()._init_adapter_modules() - - # add adapters specified in config; invertible adapter will only be added if required - for adapter_name in self.config.adapters.adapters: - self._add_adapter(adapter_name) - # fusion - if hasattr(self.config, "fusion_models"): - for fusion_adapter_names in self.config.fusion_models: - self.add_fusion_layer(fusion_adapter_names) - - def _add_adapter(self, adapter_name: str): - adapter_config = self.config.adapters.get(adapter_name) - leave_out = adapter_config.get("leave_out", []) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: for i, layer in enumerate(self.base_model.h): - if i not in leave_out: - layer.add_adapter(adapter_name, i) - - self.add_invertible_adapter(adapter_name) - - def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - self.enable_adapters(adapter_setup, True, False) - self.enable_invertible_adapters(adapter_setup.flatten()) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - if train_embeddings: - self.get_input_embeddings().train() - - def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - self.enable_adapters(adapter_setup, unfreeze_adapters, True) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - - def enable_adapters( - self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool - ): - for layer in self.base_model.h: - layer.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - - def adjust_attention_mask_for_parallel(self, hidden_states, attention_mask): - if attention_mask is not None and hidden_states.shape[0] != attention_mask.shape[0]: - repeats = [1] * len(attention_mask.shape) - repeats[0] = hidden_states.shape[0] // attention_mask.shape[0] - attention_mask = attention_mask.repeat(*repeats) - return attention_mask - - def _add_fusion_layer(self, adapter_names): - for layer in self.base_model.h: - layer.add_fusion_layer(adapter_names) - - def _delete_adapter(self, adapter_name: str): - for layer in self.base_model.h: - layer.delete_adapter(adapter_name) - self.delete_invertible_adapter(adapter_name) - - def _delete_fusion_layer(self, adapter_names): - for layer in self.base_model.h: - layer.delete_fusion_layer(adapter_names) - - def get_fusion_regularization_loss(self): - reg_loss = 0.0 - target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) - for _, v in self.base_model.h._modules.items(): - - for _, layer_fusion in v.output_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.attention_adapters.adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - return reg_loss - - def get_adapter(self, name): - return_adapters = {} - for idx, layer in enumerate(self.h): - adapters = { - "attention": layer.attention_adapters.adapters, - "output": layer.output_adapters.adapters, - } - for key, adapt in adapters.items(): - if hasattr(adapt, name): - if idx not in return_adapters: - return_adapters[idx] = {} - return_adapters[idx][key] = getattr(adapt, name) - - return return_adapters + yield i, layer class GPT2ModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): diff --git a/src/transformers/adapters/models/t5.py b/src/transformers/adapters/models/t5.py index a14604ca6b..fa344bc5b4 100644 --- a/src/transformers/adapters/models/t5.py +++ b/src/transformers/adapters/models/t5.py @@ -1,235 +1,52 @@ -from typing import Union +from typing import Iterable, Tuple -import torch +import torch.nn as nn -from ..composition import AdapterCompositionBlock, parse_composition from ..heads import Seq2SeqLMHead -from ..layer import AdapterLayerBaseMixin -from ..model_mixin import ModelAdaptersMixin +from ..layer import AdapterLayer +from ..model_mixin import InvertibleAdaptersMixin, ModelAdaptersMixin from .bert import ModelWithFlexibleHeadsAdaptersMixin -class T5SelfAttentionLayerAdaptersMixin(AdapterLayerBaseMixin): - @property - def adapter_config_key(self): - return "mh_adapter" +class T5SelfAttentionLayerAdaptersMixin(AdapterLayer): + def __init__(self): + super().__init__("mh_adapter", None) - @property - def transformer_layer_norm(self): - # T5 has layer norms after each component - return None - - -class T5CrossAttentionLayerAdaptersMixin(AdapterLayerBaseMixin): - @property - def adapter_config_key(self): - return "cross_adapter" - @property - def transformer_layer_norm(self): - # T5 has layer norms after each component - return None +class T5CrossAttentionLayerAdaptersMixin(AdapterLayer): + def __init__(self): + super().__init__("cross_adapter", None) -class T5FFLayerAdaptersMixin(AdapterLayerBaseMixin): - @property - def adapter_config_key(self): - return "output_adapter" +class T5FFLayerAdaptersMixin(AdapterLayer): + def __init__(self): + super().__init__("output_adapter", None) - @property - def transformer_layer_norm(self): - # T5 has layer norms after each component - return None - -class T5BlockAdaptersMixin: - """Adds adapters to the T5Block module of T5.""" - - def __init__(self, config, *args, **kwargs): - super().__init__(*args, **kwargs) - self.config = config - - def add_fusion_layer(self, adapter_names): - self.layer[0].add_fusion_layer(adapter_names) # attention adapters - self.layer[-1].add_fusion_layer(adapter_names) # output adapters - - def add_adapter(self, adapter_name: str, layer_idx: int): - for layer in self.layer: - layer.add_adapter(adapter_name, layer_idx) - - def enable_adapters( - self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool - ): - for layer in self.layer: - layer.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - - def delete_adapter(self, adapter_name): - for layer in self.layer: - layer.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - for layer in self.layer: - layer.delete_fusion_layer(adapter_names) - - -class T5StackAdaptersMixin: - """Adds adapters to the T5Stack module of T5.""" - - def point_adapter_configs(self, parent_config): - self.config = parent_config - - def add_fusion_layer(self, adapter_names): - for block in self.block: - block.add_fusion_layer(adapter_names) - - def add_adapter(self, adapter_name: str, layer_idx_offset: int = 0): - adapter_config = self.config.adapters.get(adapter_name) - leave_out = adapter_config.get("leave_out", []) - for i, block in enumerate(self.block, start=layer_idx_offset): - if i not in leave_out: - block.add_adapter(adapter_name, i) - - def delete_adapter(self, adapter_name: str): - for layer in self.block: - layer.delete_adapter(adapter_name) - - def delete_fusion_layer(self, adapter_names): - for layer in self.block: - layer.delete_fusion_layer(adapter_names) - - def enable_adapters( - self, adapter_setup: AdapterCompositionBlock, unfreeze_adapters: bool, unfreeze_attention: bool - ): - for block in self.block: - block.enable_adapters(adapter_setup, unfreeze_adapters, unfreeze_attention) - - def adjust_attention_mask_for_parallel(self, hidden_states, attention_mask): - if attention_mask is not None and hidden_states.shape[0] != attention_mask.shape[0]: - repeats = [1] * len(attention_mask.shape) - repeats[0] = hidden_states.shape[0] // attention_mask.shape[0] - attention_mask = attention_mask.repeat(*repeats) - return attention_mask - - def adjust_tensors_for_parallel(self, hidden_states, *tensors): - outputs = [] - for tensor in tensors: - if tensor is not None and hidden_states.shape[0] != tensor.shape[0]: - repeats = [1] * len(tensor.shape) - repeats[0] = hidden_states.shape[0] // tensor.shape[0] - new_tensor = tensor.repeat(*repeats) - outputs.append(new_tensor) - else: - outputs.append(tensor) - return tuple(outputs) - - -class T5ModelAdaptersMixin(ModelAdaptersMixin): +class T5ModelAdaptersMixin(InvertibleAdaptersMixin, ModelAdaptersMixin): """Adds adapters to the T5Model class.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "encoder"): + for i, layer in enumerate(self.encoder.block): + yield i, layer + for i, layer in enumerate(self.decoder.block, start=len(self.encoder.block)): + yield i, layer + else: + for i, layer in enumerate(self.decoder.block): + yield i, layer def _init_adapter_modules(self): - super()._init_adapter_modules() if hasattr(self, "encoder"): # In T5, the invertible adapters are implemented by the encoder module. # Therefore, relay mixin calls to the encoder here. self.invertible_adapters = self.encoder.invertible_adapters self.add_invertible_adapter = self.encoder.add_invertible_adapter self.get_invertible_adapter = self.encoder.get_invertible_adapter + self.enable_invertible_adapters = self.encoder.enable_invertible_adapters self.invertible_adapters_forward = self.encoder.invertible_adapters_forward self.delete_invertible_adapter = self.encoder.delete_invertible_adapter - - def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], train_embeddings=False): - """Sets the model into mode for training the given adapters.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - if hasattr(self, "encoder"): - self.encoder.enable_adapters(adapter_setup, True, False) - self.encoder.enable_invertible_adapters(adapter_setup.flatten()) - self.decoder.enable_adapters(adapter_setup, True, False) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - if train_embeddings: - self.get_input_embeddings().train() - - def train_adapter_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): - """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" - self.train() - self.freeze_model(True) - adapter_setup = parse_composition(adapter_setup) - if hasattr(self, "encoder"): - self.encoder.enable_adapters(adapter_setup, unfreeze_adapters, True) - self.decoder.enable_adapters(adapter_setup, unfreeze_adapters, True) - # use the adapters to be trained by default in every forward pass - self.set_active_adapters(adapter_setup) - - def _add_adapter(self, adapter_name): - if hasattr(self, "encoder"): - self.encoder.add_adapter(adapter_name) - # make sure the layers in encoder & decoder are numbered from 0 to len(encoder+decoder) - self.decoder.add_adapter(adapter_name, layer_idx_offset=len(self.encoder.block)) - else: - self.decoder.add_adapter(adapter_name) - self.encoder.add_invertible_adapter(adapter_name) - - def _add_fusion_layer(self, adapter_names): - if hasattr(self, "encoder"): - self.encoder.add_fusion_layer(adapter_names) - self.decoder.add_fusion_layer(adapter_names) - - def _delete_adapter(self, adapter_name: str): - if hasattr(self, "encoder"): - self.encoder.delete_adapter(adapter_name) - self.encoder.delete_invertible_adapter(adapter_name) - self.decoder.delete_adapter(adapter_name) - - def _delete_fusion_layer(self, adapter_names): - if hasattr(self, "encoder"): - self.encoder.delete_fusion_layer(adapter_names) - self.decoder.delete_fusion_layer(adapter_names) - - def get_fusion_regularization_loss(self): - reg_loss = 0.0 - target = torch.zeros((self.config.hidden_size, self.config.hidden_size)).fill_diagonal_(1.0).to(self.device) - # encoder - if hasattr(self, "encoder"): - for _, v in self.encoder.block._modules.items(): - for _, layer_fusion in v.layer[-1].adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.layer[0].adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - # decoder - for _, v in self.decoder.block._modules.items(): - for _, layer_fusion in v.layer[-1].adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - for _, layer_fusion in v.layer[0].adapter_fusion_layer.items(): - if hasattr(layer_fusion, "value"): - reg_loss += 0.01 * (target - layer_fusion.value.weight).pow(2).sum() - - return reg_loss - - def get_adapter(self, name): - return_adapters = {} - for idx, block in enumerate(self.encoder.block): - # In each block of T5Stack that is an encoder, the first layer is T5LayerSelfAttention, the second is T5LayerFF - adapters = { - "attention": block.layer[0].adapters, - "output": block.layer[1].adapters, - } - for key, adapt in adapters.items(): - if hasattr(adapt, name): - if idx not in return_adapters: - return_adapters[idx] = {} - return_adapters[idx][key] = getattr(adapt, name) - - return return_adapters + super()._init_adapter_modules() class T5ModelHeadsMixin(ModelWithFlexibleHeadsAdaptersMixin): diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 1ea097993f..c60ec7d697 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,10 +25,10 @@ from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin from ...adapters.models.bart import ( BartDecoderLayerAdaptersMixin, - BartEncoderDecoderAdaptersMixin, BartEncoderLayerAdaptersMixin, BartModelAdaptersMixin, BartModelHeadsMixin, @@ -319,14 +319,14 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.attention_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.attention_adapters(hidden_states, residual, self.self_attn_layer_norm) residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.output_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.output_adapters(hidden_states, residual, self.final_layer_norm) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() @@ -416,7 +416,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.attention_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.attention_adapters(hidden_states, residual, self.self_attn_layer_norm) # Cross-Attention Block cross_attn_present_key_value = None @@ -435,7 +435,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.cross_attention_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.cross_attention_adapters(hidden_states, residual, self.encoder_attn_layer_norm) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value @@ -446,7 +446,7 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.output_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.output_adapters(hidden_states, residual, self.final_layer_norm) outputs = (hidden_states,) @@ -673,7 +673,7 @@ def __init_subclass__(self): """ -class BartEncoder(InvertibleAdaptersMixin, BartEncoderDecoderAdaptersMixin, BartPretrainedModel): +class BartEncoder(InvertibleAdaptersMixin, BartPretrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a :class:`BartEncoderLayer`. @@ -835,7 +835,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) @@ -850,7 +850,7 @@ def custom_forward(*inputs): ) -class BartDecoder(BartEncoderDecoderAdaptersMixin, BartPretrainedModel): +class BartDecoder(BartPretrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`BartDecoderLayer` @@ -861,8 +861,6 @@ class BartDecoder(BartEncoderDecoderAdaptersMixin, BartPretrainedModel): def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) - self.config = config - self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id @@ -1095,7 +1093,7 @@ def custom_forward(*inputs): use_cache=use_cache, ) hidden_states = layer_outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) @@ -1218,7 +1216,7 @@ def forward( ) # inflate all decoder inputs according to encoder output - decoder_input_ids, decoder_attention_mask, attention_mask = self.adjust_tensors_for_parallel( + decoder_input_ids, decoder_attention_mask, attention_mask = adjust_tensors_for_parallel( encoder_outputs[0], decoder_input_ids, decoder_attention_mask, attention_mask ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) @@ -1317,7 +1315,7 @@ def forward( # sequence classification based on last token in sequence x = outputs[0] # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) - eos_mask = self.model.encoder.adjust_attention_mask_for_parallel(x, eos_mask) + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] @@ -1574,7 +1572,7 @@ def forward( hidden_states = outputs[0] # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) - eos_mask = self.model.encoder.adjust_attention_mask_for_parallel(hidden_states, eos_mask) + (eos_mask,) = adjust_tensors_for_parallel(hidden_states, eos_mask) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 8e06bb5acc..69fabfbb91 100644 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -29,11 +29,10 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.context import AdapterSetup from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.bert import ( - BertEncoderAdaptersMixin, - BertLayerAdaptersMixin, BertModelAdaptersMixin, BertModelHeadsMixin, BertOutputAdaptersMixin, @@ -371,7 +370,7 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapters_forward(hidden_states, input_tensor) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -452,11 +451,11 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapters_forward(hidden_states, input_tensor) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states -class BertLayer(BertLayerAdaptersMixin, nn.Module): +class BertLayer(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -541,7 +540,7 @@ def feed_forward_chunk(self, attention_output): return layer_output -class BertEncoder(BertEncoderAdaptersMixin, nn.Module): +class BertEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config @@ -607,7 +606,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if use_cache: next_decoder_cache += (layer_outputs[-1],) diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index cfc1d25712..e1b4579466 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -27,12 +27,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import gelu +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.distilbert import ( DistilBertModelAdaptersMixin, DistilBertModelHeadsMixin, DistilBertTransfomerBlockAdaptersMixin, - DistilBertTransformerAdaptersMixin, ) from ...deepspeed import is_deepspeed_zero3_enabled from ...file_utils import ( @@ -292,11 +292,11 @@ def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False): else: # To handle these `output_attentions` or `output_hidden_states` cases returning tuples assert type(sa_output) == tuple sa_output = sa_output[0] - sa_output = self.attention_adapters.adapters_forward(sa_output, x) # (bs, seq_length, dim) + sa_output = self.attention_adapters(sa_output, x, self.sa_layer_norm) # (bs, seq_length, dim) # Feed Forward Network ffn_output = self.ffn(sa_output) # (bs, seq_length, dim) - ffn_output = self.output_adapters.adapters_forward(ffn_output, sa_output) # (bs, seq_length, dim) + ffn_output = self.output_adapters(ffn_output, sa_output, self.output_layer_norm) # (bs, seq_length, dim) output = (ffn_output,) if output_attentions: @@ -304,10 +304,9 @@ def forward(self, x, attn_mask=None, head_mask=None, output_attentions=False): return output -class Transformer(DistilBertTransformerAdaptersMixin, nn.Module): +class Transformer(nn.Module): def __init__(self, config): super().__init__() - self.config = config self.n_layers = config.n_layers self.layer = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)]) @@ -340,7 +339,7 @@ def forward( x=hidden_state, attn_mask=attn_mask, head_mask=head_mask[i], output_attentions=output_attentions ) hidden_state = layer_outputs[-1] - attn_mask = self.adjust_attention_mask_for_parallel(hidden_state, attn_mask) + (attn_mask,) = adjust_tensors_for_parallel(hidden_state, attn_mask) if output_attentions: assert len(layer_outputs) == 2 diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index d8cab65c76..2929abf0ea 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -34,6 +34,7 @@ is_amp_available = False from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.gpt2 import GPT2DecoderBlockAdaptersMixin, GPT2ModelAdapterMixin, GPT2ModelHeadsMixin from ...file_utils import ( @@ -406,7 +407,7 @@ def forward( ) attn_output = attn_outputs[0] # output_attn: a, present, (attentions) outputs = attn_outputs[1:] - hidden_states = self.attention_adapters.adapters_forward(attn_output, residual) + hidden_states = self.attention_adapters(attn_output, residual, None) if encoder_hidden_states is not None: # add one self-attention block for cross-attention @@ -434,7 +435,7 @@ def forward( hidden_states = self.ln_2(hidden_states) feed_forward_hidden_states = self.mlp(hidden_states) # residual connection - hidden_states = self.output_adapters.adapters_forward(feed_forward_hidden_states, residual) + hidden_states = self.output_adapters(feed_forward_hidden_states, residual, None) if use_cache: outputs = (hidden_states,) + outputs @@ -902,7 +903,7 @@ def custom_forward(*inputs): ) hidden_states = outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) # HACK: if output_shape is identical to hidden states shape except for batch size, update output_shape if output_shape[1:] == hidden_states.size()[1:]: output_shape = hidden_states.size() @@ -1603,7 +1604,7 @@ def forward( else: if input_ids is not None: sequence_lengths = torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1 - sequence_lengths = self.transformer.adjust_attention_mask_for_parallel(outputs[0], sequence_lengths) + (sequence_lengths,) = adjust_tensors_for_parallel(outputs[0], sequence_lengths) else: sequence_lengths = -1 logger.warning( diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 953ea7c022..3175428b64 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,10 +24,10 @@ from torch.nn import CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin from ...adapters.models.bart import ( BartDecoderLayerAdaptersMixin, - BartEncoderDecoderAdaptersMixin, BartEncoderLayerAdaptersMixin, BartModelAdaptersMixin, BartModelHeadsMixin, @@ -327,7 +327,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.attention_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.attention_adapters(hidden_states, residual, self.self_attn_layer_norm) residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -335,7 +335,7 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.output_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.output_adapters(hidden_states, residual, self.final_layer_norm) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() @@ -426,7 +426,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.attention_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.attention_adapters(hidden_states, residual, self.self_attn_layer_norm) # Cross-Attention Block cross_attn_present_key_value = None @@ -446,7 +446,7 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.cross_attention_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.cross_attention_adapters(hidden_states, residual, self.encoder_attn_layer_norm) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value @@ -458,7 +458,7 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = self.output_adapters.adapters_forward(hidden_states, residual) + hidden_states = self.output_adapters(hidden_states, residual, self.final_layer_norm) outputs = (hidden_states,) @@ -675,7 +675,7 @@ def dummy_inputs(self): """ -class MBartEncoder(InvertibleAdaptersMixin, BartEncoderDecoderAdaptersMixin, MBartPreTrainedModel): +class MBartEncoder(InvertibleAdaptersMixin, MBartPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a :class:`MBartEncoderLayer`. @@ -832,7 +832,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if output_attentions: all_attentions = all_attentions + (layer_outputs[1],) @@ -849,7 +849,7 @@ def custom_forward(*inputs): ) -class MBartDecoder(BartEncoderDecoderAdaptersMixin, MBartPreTrainedModel): +class MBartDecoder(MBartPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a :class:`MBartDecoderLayer` @@ -860,8 +860,6 @@ class MBartDecoder(BartEncoderDecoderAdaptersMixin, MBartPreTrainedModel): def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) - self.config = config - self.dropout = config.dropout self.layerdrop = config.decoder_layerdrop self.padding_idx = config.pad_token_id @@ -1096,7 +1094,7 @@ def custom_forward(*inputs): use_cache=use_cache, ) hidden_states = layer_outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if use_cache: next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) @@ -1218,7 +1216,7 @@ def forward( ) # inflate all decoder inputs according to encoder output - decoder_input_ids, decoder_attention_mask, attention_mask = self.adjust_tensors_for_parallel( + decoder_input_ids, decoder_attention_mask, attention_mask = adjust_tensors_for_parallel( encoder_outputs[0], decoder_input_ids, decoder_attention_mask, attention_mask ) # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) @@ -1317,7 +1315,7 @@ def forward( # sequence classification based on last token in sequence x = outputs[0] # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) - eos_mask = self.model.encoder.adjust_attention_mask_for_parallel(x, eos_mask) + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) if len(torch.unique(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] @@ -1579,7 +1577,7 @@ def forward( hidden_states = outputs[0] # last hidden state eos_mask = input_ids.eq(self.config.eos_token_id) - eos_mask = self.model.encoder.adjust_attention_mask_for_parallel(hidden_states, eos_mask) + (eos_mask,) = adjust_tensors_for_parallel(hidden_states, eos_mask) if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: raise ValueError("All examples must have the same number of tokens.") diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index e5f00e9280..dbbe151a05 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -24,11 +24,10 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, gelu +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.context import AdapterSetup from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.bert import ( - BertEncoderAdaptersMixin, - BertLayerAdaptersMixin, BertModelAdaptersMixin, BertModelHeadsMixin, BertOutputAdaptersMixin, @@ -308,7 +307,7 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapters_forward(hidden_states, input_tensor) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states @@ -392,12 +391,12 @@ def __init__(self, config): def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) - hidden_states = self.adapters_forward(hidden_states, input_tensor) + hidden_states = self.adapter_layer_forward(hidden_states, input_tensor, self.LayerNorm) return hidden_states # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta -class RobertaLayer(BertLayerAdaptersMixin, nn.Module): +class RobertaLayer(nn.Module): def __init__(self, config): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -483,7 +482,7 @@ def feed_forward_chunk(self, attention_output): # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta -class RobertaEncoder(BertEncoderAdaptersMixin, nn.Module): +class RobertaEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config @@ -549,7 +548,7 @@ def custom_forward(*inputs): ) hidden_states = layer_outputs[0] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) + (attention_mask,) = adjust_tensors_for_parallel(hidden_states, attention_mask) if use_cache: next_decoder_cache += (layer_outputs[-1],) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 1a6346aaa4..859aaf5ec4 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -26,15 +26,14 @@ from torch.utils.checkpoint import checkpoint from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin from ...adapters.models.t5 import ( - T5BlockAdaptersMixin, T5CrossAttentionLayerAdaptersMixin, T5FFLayerAdaptersMixin, T5ModelAdaptersMixin, T5ModelHeadsMixin, T5SelfAttentionLayerAdaptersMixin, - T5StackAdaptersMixin, ) from ...file_utils import ( DUMMY_INPUTS, @@ -311,7 +310,7 @@ def __init__(self, config): def forward(self, hidden_states): forwarded_states = self.layer_norm(hidden_states) forwarded_states = self.DenseReluDense(forwarded_states) - hidden_states = self.adapters_forward(hidden_states, self.dropout(forwarded_states)) + hidden_states = self.adapter_layer_forward(hidden_states, self.dropout(forwarded_states), None) return hidden_states @@ -566,7 +565,7 @@ def forward( use_cache=use_cache, output_attentions=output_attentions, ) - hidden_states = self.adapters_forward(hidden_states, self.dropout(attention_output[0])) + hidden_states = self.adapter_layer_forward(hidden_states, self.dropout(attention_output[0]), None) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs @@ -604,14 +603,14 @@ def forward( query_length=query_length, output_attentions=output_attentions, ) - layer_output = self.adapters_forward(hidden_states, self.dropout(attention_output[0])) + layer_output = self.adapter_layer_forward(hidden_states, self.dropout(attention_output[0]), None) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them return outputs -class T5Block(T5BlockAdaptersMixin, nn.Module): +class T5Block(nn.Module): def __init__(self, config, has_relative_attention_bias=False): - super().__init__(config) + super().__init__() self.is_decoder = config.is_decoder self.layer = nn.ModuleList() self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) @@ -817,13 +816,12 @@ def _shift_right(self, input_ids): return shifted_input_ids -class T5Stack(InvertibleAdaptersMixin, T5StackAdaptersMixin, T5PreTrainedModel): +class T5Stack(InvertibleAdaptersMixin, T5PreTrainedModel): def __init__(self, config, embed_tokens=None): super().__init__(config) self.embed_tokens = embed_tokens self.is_decoder = config.is_decoder - self.use_cache = config.use_cache self.block = nn.ModuleList( [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] @@ -902,9 +900,8 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.is_decoder and encoder_hidden_states is not None: - (input_ids,) = self.adjust_tensors_for_parallel(encoder_hidden_states, input_ids) - encoder_attention_mask = self.adjust_attention_mask_for_parallel( - encoder_hidden_states, encoder_attention_mask + input_ids, encoder_attention_mask = adjust_tensors_for_parallel( + encoder_hidden_states, input_ids, encoder_attention_mask ) if input_ids is not None and inputs_embeds is not None: @@ -1045,8 +1042,9 @@ def custom_forward(*inputs): hidden_states, present_key_value_state = layer_outputs[:2] - attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, attention_mask) - extended_attention_mask = self.adjust_attention_mask_for_parallel(hidden_states, extended_attention_mask) + attention_mask, extended_attention_mask = adjust_tensors_for_parallel( + hidden_states, attention_mask, extended_attention_mask + ) # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1059,9 +1057,9 @@ def custom_forward(*inputs): present_key_value_states = present_key_value_states + (present_key_value_state,) if position_bias is not None: - position_bias = self.adjust_tensors_for_parallel(hidden_states, position_bias)[0] + position_bias = adjust_tensors_for_parallel(hidden_states, position_bias)[0] if encoder_decoder_position_bias is not None: - encoder_decoder_position_bias = self.adjust_tensors_for_parallel( + encoder_decoder_position_bias = adjust_tensors_for_parallel( hidden_states, encoder_decoder_position_bias )[0] diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 2a032b51f5..e09fefc41b 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -357,6 +357,7 @@ def forward_pre_hook(module, input): inv_adapter.register_forward_pre_hook(forward_pre_hook) in_data = self.get_input_samples((1, 128), config=model.config) + model.to(torch_device) out = model(**in_data) self.assertEqual((1, 128, model.config.decoder.vocab_size), out[0].shape) diff --git a/tests/test_adapter_common.py b/tests/test_adapter_common.py index 2eb8d2fe5f..c02d3a78bb 100644 --- a/tests/test_adapter_common.py +++ b/tests/test_adapter_common.py @@ -148,13 +148,13 @@ def test_add_adapter_multiple_reduction_factors(self): adapter = model.get_adapter(name) self.assertEqual( - adapter[0]["output"].adapter_down[0].in_features - / adapter[0]["output"].adapter_down[0].out_features, + adapter[0]["output_adapter"].adapter_down[0].in_features + / adapter[0]["output_adapter"].adapter_down[0].out_features, reduction_factor["default"], ) self.assertEqual( - adapter[1]["output"].adapter_down[0].in_features - / adapter[1]["output"].adapter_down[0].out_features, + adapter[1]["output_adapter"].adapter_down[0].in_features + / adapter[1]["output_adapter"].adapter_down[0].out_features, reduction_factor["1"], ) diff --git a/tests/test_adapter_fusion_common.py b/tests/test_adapter_fusion_common.py index d05883bc5b..2d6aa3d74e 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_adapter_fusion_common.py @@ -189,6 +189,8 @@ def test_adapter_fusion_save_with_head(self): # assert equal forward pass in_data = self.get_input_samples((1, 128), config=model1.config) + model1.to(torch_device) + model2.to(torch_device) output1 = model1(**in_data) output2 = model2(**in_data) self.assertEqual(len(output1), len(output2)) diff --git a/tests/test_adapter_heads.py b/tests/test_adapter_heads.py index b02422c17e..3dee3d9bb1 100644 --- a/tests/test_adapter_heads.py +++ b/tests/test_adapter_heads.py @@ -363,7 +363,7 @@ def forward_pre_hook(module, input): nonlocal calls calls += 1 - adapter = model.get_adapter("a")[0]["output"] + adapter = model.get_adapter("a")[0]["output_adapter"] adapter.register_forward_pre_hook(forward_pre_hook) with AdapterSetup("a"): diff --git a/tests/test_adapter_setup_context.py b/tests/test_adapter_setup_context.py index 035d691caf..f9a00ff7ab 100644 --- a/tests/test_adapter_setup_context.py +++ b/tests/test_adapter_setup_context.py @@ -40,9 +40,9 @@ def forward_pre_hook_b(module, input): nonlocal calls_b calls_b += 1 - adapter_a = model.get_adapter("a")[0]["output"] + adapter_a = model.get_adapter("a")[0]["output_adapter"] adapter_a.register_forward_pre_hook(forward_pre_hook_a) - adapter_b = model.get_adapter("b")[0]["output"] + adapter_b = model.get_adapter("b")[0]["output_adapter"] adapter_b.register_forward_pre_hook(forward_pre_hook_b) with AdapterSetup("a"): @@ -75,7 +75,7 @@ def forward_pre_hook(module, input): nonlocal calls calls += 1 - adapter = model.get_adapter(adapter_setup)[0]["output"] + adapter = model.get_adapter(adapter_setup)[0]["output_adapter"] adapter.register_forward_pre_hook(forward_pre_hook) with AdapterSetup(adapter_setup):