From 7412ef8363f0cb526c6bb10a71868372be7fec1b Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 2 Feb 2022 20:26:35 +0100 Subject: [PATCH] Add `ForwardContext` to wrap model forward pass (#267) --- src/transformers/adapters/configuration.py | 2 - src/transformers/adapters/context.py | 56 ++++++++++++++++++- src/transformers/adapters/layer.py | 7 ++- src/transformers/adapters/model_mixin.py | 8 +-- src/transformers/models/bart/modeling_bart.py | 5 +- src/transformers/models/bert/modeling_bert.py | 4 +- .../models/distilbert/modeling_distilbert.py | 3 +- .../modeling_encoder_decoder.py | 5 +- src/transformers/models/gpt2/modeling_gpt2.py | 4 +- .../models/mbart/modeling_mbart.py | 5 +- .../models/roberta/modeling_roberta.py | 4 +- src/transformers/models/t5/modeling_t5.py | 5 +- 12 files changed, 80 insertions(+), 28 deletions(-) diff --git a/src/transformers/adapters/configuration.py b/src/transformers/adapters/configuration.py index d459360abc..991b42459f 100644 --- a/src/transformers/adapters/configuration.py +++ b/src/transformers/adapters/configuration.py @@ -207,8 +207,6 @@ def __init__(self, **kwargs): # TODO-V2 Save this with config? self.active_setup: Optional[AdapterCompositionBlock] = None self.skip_layers = None - # TODO This flag will be set & reset in every forward pass. Check if there is a better solution without state mutation. - self.is_parallelized = False def __contains__(self, item): return item in self.adapters.keys() diff --git a/src/transformers/adapters/context.py b/src/transformers/adapters/context.py index 79d72fbb7d..a07ac929bb 100644 --- a/src/transformers/adapters/context.py +++ b/src/transformers/adapters/context.py @@ -1,3 +1,4 @@ +import functools import threading from .composition import parse_composition, parse_heads_from_composition @@ -41,9 +42,9 @@ def __exit__(self, type, value, traceback): @classmethod def get_contexts(cls): - if not hasattr(cls.storage, "setups"): - cls.storage.setups = [] - return cls.storage.setups + if not hasattr(cls.storage, "contexts"): + cls.storage.contexts = [] + return cls.storage.contexts @classmethod def get_context(cls): @@ -58,3 +59,52 @@ def get_context_head_setup(cls): if context: return context.head_setup return None + + +class ForwardContext: + """ + Holds context information during a forward pass through a model. This class should be used via the + ``ForwardContext.wrap()`` method. + + Note that the context is thread-local. + """ + + # thread-local storage that holds a stack of active contexts + storage = threading.local() + + def __init__(self, model, *args, **kwargs): + # If the model has a method ``forward_context()``, use it to create the context. + if hasattr(model, "forward_context"): + model.forward_context(self, *args, **kwargs) + + @classmethod + def wrap(cls, f): + """ + Decorator method that wraps a ``forward()`` function of a model class. + """ + + @functools.wraps(f) + def wrapper_func(self, *args, **kwargs): + if self.config.adapters is not None: + context = cls(self, *args, **kwargs) + cls.get_contexts().append(context) + results = f(self, *args, **kwargs) + cls.get_contexts().pop() + return results + else: + return f(self, *args, **kwargs) + + return wrapper_func + + @classmethod + def get_contexts(cls): + if not hasattr(cls.storage, "contexts"): + cls.storage.contexts = [] + return cls.storage.contexts + + @classmethod + def get_context(cls): + try: + return cls.get_contexts()[-1] + except IndexError: + return None diff --git a/src/transformers/adapters/layer.py b/src/transformers/adapters/layer.py index e245c4a4f6..7b0e118913 100644 --- a/src/transformers/adapters/layer.py +++ b/src/transformers/adapters/layer.py @@ -4,7 +4,7 @@ from torch import nn from .composition import AdapterCompositionBlock, BatchSplit, Fuse, Parallel, Split, Stack -from .context import AdapterSetup +from .context import AdapterSetup, ForwardContext from .modeling import Adapter, BertFusion @@ -317,11 +317,12 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor, # We assume that all adapters have the same config adapter_config = self.config.adapters.get(adapter_setup.first()) - if not self.config.adapters.is_parallelized: + context = ForwardContext.get_context() + if not context.adapters_parallelized: orig_batch_size = input_tensor.shape[0] input_tensor = input_tensor.repeat(self.config.adapters.active_setup.parallel_channels, 1, 1) hidden_states = hidden_states.repeat(self.config.adapters.active_setup.parallel_channels, 1, 1) - self.config.adapters.is_parallelized = True + context.adapters_parallelized = True else: # The base model should handle replication of input. # Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels. diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 131f241505..dd1a1e9081 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -11,7 +11,7 @@ from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash -from .context import AdapterSetup +from .context import AdapterSetup, ForwardContext from .hub_mixin import PushAdapterToHubMixin from .layer import AdapterLayer from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader @@ -562,16 +562,16 @@ def freeze_model(self, freeze=True): param.requires_grad = not freeze self.model_freezed = freeze - def pre_transformer_forward(self): + def forward_context(self, context: ForwardContext, *args, **kwargs): """ - This method should be called by every adapter-implementing model at the very beginning of the forward() method. + This method is called by the ``ForwardContext`` at the beginning of the forward pass. """ # some warnings if we don't use available adapters active_adapters = self.active_adapters or AdapterSetup.get_context() if not active_adapters and self.has_adapters(): logger.warning("There are adapters available but none are activated for the forward pass.") - self.config.adapters.is_parallelized = False + context.adapters_parallelized = False def load_embeddings(self, path: str, name: str): """ diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index c60ec7d697..e87b152984 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin from ...adapters.models.bart import ( BartDecoderLayerAdaptersMixin, @@ -1163,6 +1164,7 @@ def get_decoder(self): output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -1195,7 +1197,6 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() if encoder_outputs is None: encoder_outputs = self.encoder( @@ -1741,8 +1742,8 @@ def __init__(self, config): self._init_adapter_modules() + @ForwardContext.wrap def forward(self, *args, **kwargs): - self.pre_transformer_forward() return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 69fabfbb91..704b0141a5 100644 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -30,7 +30,7 @@ from ...activations import ACT2FN from ...adapters.composition import adjust_tensors_for_parallel -from ...adapters.context import AdapterSetup +from ...adapters.context import AdapterSetup, ForwardContext from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.bert import ( BertModelAdaptersMixin, @@ -912,6 +912,7 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -953,7 +954,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index e1b4579466..d686b61cc5 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -28,6 +28,7 @@ from ...activations import gelu from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.distilbert import ( DistilBertModelAdaptersMixin, @@ -527,6 +528,7 @@ class PreTrainedModel output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -542,7 +544,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") diff --git a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py index 81e36b9382..8eeeb9cbd7 100644 --- a/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_encoder_decoder.py @@ -20,6 +20,7 @@ import torch from torch.nn import CrossEntropyLoss +from ...adapters.context import ForwardContext from ...adapters.models.encoder_decoder import EncoderDecoderModelAdaptersMixin from ...configuration_utils import PretrainedConfig from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings @@ -408,6 +409,7 @@ def from_encoder_decoder_pretrained( @add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -462,9 +464,6 @@ def forward( argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") } - if self.config.adapters: - self.pre_transformer_forward() - if encoder_outputs is None: encoder_outputs = self.encoder( input_ids=input_ids, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 2929abf0ea..0f6fe2c02a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -35,6 +35,7 @@ from ...activations import ACT2FN from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.gpt2 import GPT2DecoderBlockAdaptersMixin, GPT2ModelAdapterMixin, GPT2ModelHeadsMixin from ...file_utils import ( @@ -742,6 +743,7 @@ def _prune_heads(self, heads_to_prune): output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -758,8 +760,6 @@ def forward( output_hidden_states=None, return_dict=None, ): - self.pre_transformer_forward() - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 3175428b64..b5d7381eb8 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -25,6 +25,7 @@ from ...activations import ACT2FN from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin from ...adapters.models.bart import ( BartDecoderLayerAdaptersMixin, @@ -1166,6 +1167,7 @@ def get_decoder(self): output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -1190,7 +1192,6 @@ def forward( ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() # different to other models, MBart automatically creates decoder_input_ids from # input_ids if no decoder_input_ids are provided @@ -1748,8 +1749,8 @@ def __init__(self, config): self._init_adapter_modules() + @ForwardContext.wrap def forward(self, *args, **kwargs): - self.pre_transformer_forward() return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index dbbe151a05..cc543e9482 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -25,7 +25,7 @@ from ...activations import ACT2FN, gelu from ...adapters.composition import adjust_tensors_for_parallel -from ...adapters.context import AdapterSetup +from ...adapters.context import AdapterSetup, ForwardContext from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin from ...adapters.models.bert import ( BertModelAdaptersMixin, @@ -764,6 +764,7 @@ class PreTrainedModel output_type=BaseModelOutputWithPoolingAndCrossAttentions, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap # Copied from transformers.models.bert.modeling_bert.BertModel.forward def forward( self, @@ -806,7 +807,6 @@ def forward( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 859aaf5ec4..180e5f8e9b 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin from ...adapters.models.t5 import ( T5CrossAttentionLayerAdaptersMixin, @@ -1354,6 +1355,7 @@ class PreTrainedModel @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -1391,7 +1393,6 @@ def forward( """ use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: @@ -1554,6 +1555,7 @@ def get_decoder(self): @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -1603,7 +1605,6 @@ def forward( """ use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - self.pre_transformer_forward() # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask if head_mask is not None and decoder_head_mask is None: