Skip to content

Commit

Permalink
Add ForwardContext to wrap model forward pass (#267)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Feb 2, 2022
1 parent 7574056 commit 2472ccc
Show file tree
Hide file tree
Showing 12 changed files with 80 additions and 28 deletions.
2 changes: 0 additions & 2 deletions src/transformers/adapters/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ def __init__(self, **kwargs):
# TODO-V2 Save this with config?
self.active_setup: Optional[AdapterCompositionBlock] = None
self.skip_layers = None
# TODO This flag will be set & reset in every forward pass. Check if there is a better solution without state mutation.
self.is_parallelized = False

def __contains__(self, item):
return item in self.adapters.keys()
Expand Down
56 changes: 53 additions & 3 deletions src/transformers/adapters/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import threading

from .composition import parse_composition, parse_heads_from_composition
Expand Down Expand Up @@ -41,9 +42,9 @@ def __exit__(self, type, value, traceback):

@classmethod
def get_contexts(cls):
if not hasattr(cls.storage, "setups"):
cls.storage.setups = []
return cls.storage.setups
if not hasattr(cls.storage, "contexts"):
cls.storage.contexts = []
return cls.storage.contexts

@classmethod
def get_context(cls):
Expand All @@ -58,3 +59,52 @@ def get_context_head_setup(cls):
if context:
return context.head_setup
return None


class ForwardContext:
"""
Holds context information during a forward pass through a model. This class should be used via the
``ForwardContext.wrap()`` method.
Note that the context is thread-local.
"""

# thread-local storage that holds a stack of active contexts
storage = threading.local()

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
if hasattr(model, "forward_context"):
model.forward_context(self, *args, **kwargs)

@classmethod
def wrap(cls, f):
"""
Decorator method that wraps a ``forward()`` function of a model class.
"""

@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.config.adapters is not None:
context = cls(self, *args, **kwargs)
cls.get_contexts().append(context)
results = f(self, *args, **kwargs)
cls.get_contexts().pop()
return results
else:
return f(self, *args, **kwargs)

return wrapper_func

@classmethod
def get_contexts(cls):
if not hasattr(cls.storage, "contexts"):
cls.storage.contexts = []
return cls.storage.contexts

@classmethod
def get_context(cls):
try:
return cls.get_contexts()[-1]
except IndexError:
return None
7 changes: 4 additions & 3 deletions src/transformers/adapters/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn

from .composition import AdapterCompositionBlock, BatchSplit, Fuse, Parallel, Split, Stack
from .context import AdapterSetup
from .context import AdapterSetup, ForwardContext
from .modeling import Adapter, BertFusion


Expand Down Expand Up @@ -317,11 +317,12 @@ def adapter_parallel(self, adapter_setup: Parallel, hidden_states, input_tensor,
# We assume that all adapters have the same config
adapter_config = self.config.adapters.get(adapter_setup.first())

if not self.config.adapters.is_parallelized:
context = ForwardContext.get_context()
if not context.adapters_parallelized:
orig_batch_size = input_tensor.shape[0]
input_tensor = input_tensor.repeat(self.config.adapters.active_setup.parallel_channels, 1, 1)
hidden_states = hidden_states.repeat(self.config.adapters.active_setup.parallel_channels, 1, 1)
self.config.adapters.is_parallelized = True
context.adapters_parallelized = True
else:
# The base model should handle replication of input.
# Therefore, we assume the (replicated) input batch to be divisible by the number of parallel channels.
Expand Down
8 changes: 4 additions & 4 deletions src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from .composition import AdapterCompositionBlock, Fuse, Stack, parse_composition
from .configuration import AdapterConfig, AdapterFusionConfig, ModelAdaptersConfig, get_adapter_config_hash
from .context import AdapterSetup
from .context import AdapterSetup, ForwardContext
from .hub_mixin import PushAdapterToHubMixin
from .layer import AdapterLayer
from .loading import AdapterFusionLoader, AdapterLoader, PredictionHeadLoader, WeightsLoader
Expand Down Expand Up @@ -562,16 +562,16 @@ def freeze_model(self, freeze=True):
param.requires_grad = not freeze
self.model_freezed = freeze

def pre_transformer_forward(self):
def forward_context(self, context: ForwardContext, *args, **kwargs):
"""
This method should be called by every adapter-implementing model at the very beginning of the forward() method.
This method is called by the ``ForwardContext`` at the beginning of the forward pass.
"""
# some warnings if we don't use available adapters
active_adapters = self.active_adapters or AdapterSetup.get_context()
if not active_adapters and self.has_adapters():
logger.warning("There are adapters available but none are activated for the forward pass.")

self.config.adapters.is_parallelized = False
context.adapters_parallelized = False

def load_embeddings(self, path: str, name: str):
"""
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from ...activations import ACT2FN
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import ForwardContext
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.models.bart import (
BartDecoderLayerAdaptersMixin,
Expand Down Expand Up @@ -1163,6 +1164,7 @@ def get_decoder(self):
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1195,7 +1197,6 @@ def forward(
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

if encoder_outputs is None:
encoder_outputs = self.encoder(
Expand Down Expand Up @@ -1741,8 +1742,8 @@ def __init__(self, config):

self._init_adapter_modules()

@ForwardContext.wrap
def forward(self, *args, **kwargs):
self.pre_transformer_forward()

return self.decoder(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from ...activations import ACT2FN
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import AdapterSetup
from ...adapters.context import AdapterSetup, ForwardContext
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.bert import (
BertModelAdaptersMixin,
Expand Down Expand Up @@ -912,6 +912,7 @@ class PreTrainedModel
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -953,7 +954,6 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from ...activations import gelu
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import ForwardContext
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.distilbert import (
DistilBertModelAdaptersMixin,
Expand Down Expand Up @@ -527,6 +528,7 @@ class PreTrainedModel
output_type=BaseModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -542,7 +544,6 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch
from torch.nn import CrossEntropyLoss

from ...adapters.context import ForwardContext
from ...adapters.models.encoder_decoder import EncoderDecoderModelAdaptersMixin
from ...configuration_utils import PretrainedConfig
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
Expand Down Expand Up @@ -408,6 +409,7 @@ def from_encoder_decoder_pretrained(

@add_start_docstrings_to_model_forward(ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -462,9 +464,6 @@ def forward(
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}

if self.config.adapters:
self.pre_transformer_forward()

if encoder_outputs is None:
encoder_outputs = self.encoder(
input_ids=input_ids,
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/gpt2/modeling_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from ...activations import ACT2FN
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import ForwardContext
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.gpt2 import GPT2DecoderBlockAdaptersMixin, GPT2ModelAdapterMixin, GPT2ModelHeadsMixin
from ...file_utils import (
Expand Down Expand Up @@ -742,6 +743,7 @@ def _prune_heads(self, heads_to_prune):
output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -758,8 +760,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
):
self.pre_transformer_forward()

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from ...activations import ACT2FN
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import ForwardContext
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.models.bart import (
BartDecoderLayerAdaptersMixin,
Expand Down Expand Up @@ -1166,6 +1167,7 @@ def get_decoder(self):
output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC,
)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -1190,7 +1192,6 @@ def forward(
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

# different to other models, MBart automatically creates decoder_input_ids from
# input_ids if no decoder_input_ids are provided
Expand Down Expand Up @@ -1748,8 +1749,8 @@ def __init__(self, config):

self._init_adapter_modules()

@ForwardContext.wrap
def forward(self, *args, **kwargs):
self.pre_transformer_forward()

return self.decoder(*args, **kwargs)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/roberta/modeling_roberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ...activations import ACT2FN, gelu
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import AdapterSetup
from ...adapters.context import AdapterSetup, ForwardContext
from ...adapters.model_mixin import ModelWithHeadsAdaptersMixin
from ...adapters.models.bert import (
BertModelAdaptersMixin,
Expand Down Expand Up @@ -764,6 +764,7 @@ class PreTrainedModel
output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC,
)
@ForwardContext.wrap
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
Expand Down Expand Up @@ -806,7 +807,6 @@ def forward(
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
Expand Down
5 changes: 3 additions & 2 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from ...activations import ACT2FN
from ...adapters.composition import adjust_tensors_for_parallel
from ...adapters.context import ForwardContext
from ...adapters.model_mixin import InvertibleAdaptersMixin, ModelWithHeadsAdaptersMixin
from ...adapters.models.t5 import (
T5CrossAttentionLayerAdaptersMixin,
Expand Down Expand Up @@ -1354,6 +1355,7 @@ class PreTrainedModel

@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1391,7 +1393,6 @@ def forward(
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
Expand Down Expand Up @@ -1554,6 +1555,7 @@ def get_decoder(self):

@add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand Down Expand Up @@ -1603,7 +1605,6 @@ def forward(
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
self.pre_transformer_forward()

# FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
if head_mask is not None and decoder_head_mask is None:
Expand Down

0 comments on commit 2472ccc

Please sign in to comment.