diff --git a/docs/source/index.md b/docs/source/index.md index 3a13565d6a1..6ecf93d2507 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -17,6 +17,7 @@ tutorial_fast tutorial_mutators tutorial_crowdsourcing tutorial_chat_service +tutorial_swap_components tutorial_tests ``` diff --git a/docs/source/tutorial_swap_components.md b/docs/source/tutorial_swap_components.md new file mode 100644 index 00000000000..8e6a32b4d5d --- /dev/null +++ b/docs/source/tutorial_swap_components.md @@ -0,0 +1,60 @@ +# Swapping Out Transformer Subcomponents + +__Author__: Spencer Poff + +Sometimes you find yourself wanting to experiment with an architecture that looks a lot like another, but with one component modified. If that component is buried deep within the model, this is not easily accomplished with subclassing without copying and pasting much of the original implementation. + +To make this easier and avoid copypasta, we provide the `@swappable` decorator. + +## Making a Module Swappable + +Let's say you have an existing class, `TransformerLayer`, that uses a module that you'd like to modify, `TransformerFFN`. You can make that FFN swappable in two steps: + +1. Decorate `TransformerLayer` with `@swappable`, passing in a name for the component you'd like to swap and its default class/constructor: +```python +@swappable(ffn=TransformerFFN) +class TransformerLayer(nn.Module): + ... +``` + +2. At runtime, the class for ffn will be added to a property `swappables` of `TransformerLayer`. Replace your instantiation of `TransformerFFN` with a call to that constructor: + +```python +self.feedforward = self.swappables.ffn(opt, ...) +``` + +That's it! + +## Making the Swap + +You can now replace `TransformerFFN` with whatever class or constructor you want before instantiating `TransformerLayer`: +```python +layer = TransformerLayer.with_components(ffn=NewCustomFFN)(opt, ...) +``` + +As long as `NewCustomFFN` has the same `__init__` and `forward` method signatures as `TransformerFFN`, everything should just work. + +## Composability + +Since the swapping happens before instantiation, decorated components can be transparently composed. For example: +```python +model = TransformerGeneratorModel.with_components( + encoder=TransformerEncoder.with_components( + layer=TransformerEncoderLayer.with_components( + self_attention=MultiHeadAttention, + feedforward=TransformerFFN, + ) + ), + decoder=TransformerDecoder.with_components( + layer=TransformerDecoderLayer.with_components( + encoder_attention=MultiHeadAttention, + self_attention=MultiHeadAttention, + feedforward=TransformerFFN, + ) + ), +)(opt=self.opt, dictionary=self.dict) +``` + +## Implementation + +See `parlai/agents/transformer/modules/modular.py` \ No newline at end of file diff --git a/parlai/agents/examples/transformer_variant.py b/parlai/agents/examples/transformer_variant.py new file mode 100644 index 00000000000..2f4f9367646 --- /dev/null +++ b/parlai/agents/examples/transformer_variant.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Example code for specifying custom transformer variants. + +TransformerVariantAgent: +- Minimal changes needed to: + - Swap out a high-level component (encoder) + - Swap out a low-level component (decoder->layer->self_attention) + +VerboseTransformerAgent: +- Doesn't swap out anything +- Fully specifies all components, for illustration + +ConfigurableTransformerAgent: +- Swaps out components based on command line args +""" +from __future__ import annotations +import torch +from enum import Enum +from typing import Dict, Optional, Tuple, Union + +from parlai.agents.transformer.modules import ( + TransformerFFN, + MultiHeadAttention, + TransformerDecoder, + TransformerDecoderLayer, + TransformerEncoder, + TransformerEncoderLayer, + TransformerGeneratorModel, +) +from parlai.agents.transformer.transformer import TransformerGeneratorAgent +from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser +import parlai.utils.logging as logging + + +########################################### +# Transformer With Two Components Swapped # +########################################### + + +class TransformerVariantAgent(TransformerGeneratorAgent): + """ + Swapping out two things: + + 1. Encoder (high-level component) + 2. Decoder self attention (low-level component) + """ + + def build_model(self, states=None): + wrapped_class = TransformerGeneratorModel.with_components( + encoder=MyCustomEncoder, + decoder=TransformerDecoder.with_components( + layer=TransformerDecoderLayer.with_components( + self_attention=MyCustomAttention + ) + ), + ) + return wrapped_class(self.opt, self.dict) + + +class MyCustomEncoder(TransformerEncoder): + """ + For brevity this subclasses TransformerEncoder, but you could write your own + nn.Module from scratch as long as the __init__ and forward signatures match + TransformerEncoder. + """ + + def forward( + self, + input: torch.LongTensor, + positions: Optional[torch.LongTensor] = None, + segments: Optional[torch.LongTensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.BoolTensor]]: + logging.info("Custom encoder called!") + # Comment out the following line and write your custom `forward` instead. + return super().forward(input, positions, segments) # type: ignore + + +class MyCustomAttention(MultiHeadAttention): + """ + For brevity this just renames MultiHeadAttention, but ideally you'd define a new + nn.Module with the same __init__ and forward signature as MultiHeadAttention. + """ + + def forward( + self, + query: torch.Tensor, + key: Optional[torch.Tensor] = None, + value: Optional[torch.Tensor] = None, + mask: torch.Tensor = None, + incr_state: Optional[Dict[str, torch.Tensor]] = None, + static_kv: bool = False, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: + logging.info("Custom attention called!") + # Comment out the following line and write your custom `forward` instead. + return super().forward( + query, + key=key, + value=value, + mask=mask, + incr_state=incr_state, + static_kv=static_kv, + ) + + +####################################### +# Fully-specified Default Transformer # +####################################### + + +class VerboseTransformerAgent(TransformerGeneratorAgent): + """ + Doesn't make any changes to TransformerGeneratorModel, just specifies all + subcomponents explicitly. + + This is meant to be a reference for how to swap any component within + TransformerGeneratorModel. + """ + + def build_model(self, states=None): + wrapped_class = TransformerGeneratorModel.with_components( + encoder=TransformerEncoder.with_components( + layer=TransformerEncoderLayer.with_components( + self_attention=MultiHeadAttention, feedforward=TransformerFFN + ) + ), + decoder=TransformerDecoder.with_components( + layer=TransformerDecoderLayer.with_components( + encoder_attention=MultiHeadAttention, + self_attention=MultiHeadAttention, + feedforward=TransformerFFN, + ) + ), + ) + return wrapped_class(opt=self.opt, dictionary=self.dict) + + +################################################ +# Command-line Configurable Custom Transformer # +################################################ + + +class DecoderFeedForwardVariant(Enum): + ONE = 'one' + TWO = 'two' + + +class DecoderFFNOne(TransformerFFN): + def forward(self, x: torch.Tensor) -> torch.Tensor: + logging.info("Using Decoder FFN Variant One") + return super().forward(x) + + +class DecoderFFNTwo(TransformerFFN): + def forward(self, x: torch.Tensor) -> torch.Tensor: + logging.info("Using Decoder FFN Variant Two") + return super().forward(x) + + +class ConfigurableTransformerAgent(TransformerGeneratorAgent): + """ + Illustrates swapping out components based on command line args. + + Specifically, swaps out the decoder ffn between two options. + """ + + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + super().add_cmdline_args(parser, partial_opt=partial_opt) + agent = parser.add_argument_group('MyCustom Transformer Arguments') + parser.add_argument( + '--decoder-ffn-variants', + type=DecoderFeedForwardVariant, + default=DecoderFeedForwardVariant.ONE, + help='Some variants in the decoder FFN implementation', + ) + return agent # type: ignore + + def build_model(self, states=None): + decoder_variant: DecoderFeedForwardVariant = self.opt['decoder_ffn_variants'] + if decoder_variant == DecoderFeedForwardVariant.ONE: + decoder_ffn_class = DecoderFFNOne + elif decoder_variant == DecoderFeedForwardVariant.TWO: + decoder_ffn_class = DecoderFFNTwo + else: + logging.error( + 'Invalid --decoder-ffn-variants option, defaulting to original ffn implementation.' + ) + decoder_ffn_class = TransformerFFN + + wrapped_class = TransformerGeneratorModel.with_components( + decoder=TransformerDecoder.with_components( + layer=TransformerDecoderLayer.with_components( + feedforward=decoder_ffn_class + ) + ) + ) + return wrapped_class(opt=self.opt, dictionary=self.dict) diff --git a/parlai/agents/transformer/modules/decoder.py b/parlai/agents/transformer/modules/decoder.py index abd066fbb5e..4bbcb11e345 100644 --- a/parlai/agents/transformer/modules/decoder.py +++ b/parlai/agents/transformer/modules/decoder.py @@ -7,7 +7,8 @@ Transformer decoder implementations. """ -from typing import Dict, Tuple, Optional +from __future__ import annotations +from typing import Dict, Optional, Tuple import numpy as np import torch @@ -20,11 +21,159 @@ MultiHeadAttention, TransformerFFN, ) +from parlai.agents.transformer.modules.modular import swappable from parlai.core.opt import Opt from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper +@swappable( + self_attention=MultiHeadAttention, + encoder_attention=MultiHeadAttention, + feedforward=TransformerFFN, +) +class TransformerDecoderLayer(nn.Module): + """ + Implements a single Transformer decoder layer. + + Decoder layers are similar to encoder layers but: + + 1. Self-attention is limited in a causal (auto-regressive) manner. + 2. Attend over all of the encoder states. + """ + + def __init__( + self, + n_heads: int, + embedding_size: int, + ffn_size: int, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: str = 'aiayn', + **kwargs, + ): + super().__init__(**kwargs) + self.dim = embedding_size + self.ffn_dim = ffn_size + self.variant = variant + self.activation = activation + self.dropout = nn.Dropout(p=dropout) + + self.self_attention = self.swappables.self_attention( + n_heads, embedding_size, dropout=attention_dropout + ) # type: ignore + self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + + self.encoder_attention = self.swappables.encoder_attention( + n_heads, embedding_size, dropout=attention_dropout + ) # type: ignore + self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + + self.ffn = self.swappables.feedforward( + embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation + ) # type: ignore + self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + + def forward( + self, + x: torch.Tensor, + encoder_output: torch.Tensor, + encoder_mask: torch.Tensor, + incr_state: Optional[Dict[str, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Forward pass. + + The incremental state is a dict with values for self- and encoder-attention + states. + """ + + if incr_state is None: + incr_state = {} + + decoder_mask = self._create_selfattn_mask(x) + # first self attn + residual = x + if self.variant == 'prelayernorm': + x = self.norm1(x) + + # don't peak into the future! + x, final_self_attn_incr_state = self.self_attention( + query=x, + mask=decoder_mask, + incr_state=incr_state.get('self_attn'), + static_kv=False, + )[:2] + x = self.dropout(x) # --dropout + x = x + residual + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + x = self.norm1(x) + + residual = x + # encoder_attn_layer_norm norm 2 + if self.variant == 'prelayernorm': + x = self.norm2(x) + x, final_encoder_attn_incr_state = self.encoder_attention( + query=x, + key=encoder_output, + value=encoder_output, + mask=encoder_mask, + incr_state=incr_state.get('encoder_attn'), + static_kv=True, + )[:2] + x = self.dropout(x) # --dropout + x = residual + x + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + x = self.norm2(x) + + # finally the ffn + residual = x + if self.variant == 'prelayernorm': + x = self.norm3(x) + x = self.ffn(x) + x = self.dropout(x) # --dropout + x = residual + x + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + x = self.norm3(x) + + new_incr_state = { + 'self_attn': final_self_attn_incr_state, + 'encoder_attn': final_encoder_attn_incr_state, + } + return x, new_incr_state + + def _create_selfattn_mask(self, x): + # figure out how many timestamps we need + bsz = x.size(0) + time = x.size(1) + # make sure that we don't look into the future + mask = torch.tril(x.new(time, time).fill_(1)) + # broadcast across batch + mask = mask.unsqueeze(0).expand(bsz, -1, -1) + return mask + + def reorder_incremental_state( + self, incremental_state: Dict[str, dict], inds: torch.Tensor + ) -> Dict[str, dict]: + """ + Reorder all incremental-state tensors for this layer. + """ + attn_types = { + 'self_attn': self.self_attention, + 'encoder_attn': self.encoder_attention, + } + return { + attn_type: attn.reorder_incremental_state( + incremental_state[attn_type], inds + ) + for attn_type, attn in attn_types.items() + } + + +@swappable(layer=TransformerDecoderLayer) class TransformerDecoder(nn.Module): """ Transformer Decoder module. @@ -43,8 +192,9 @@ def __init__( opt: Opt, embedding: Optional[nn.Embedding] = None, n_positions: Optional[int] = None, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) def _default(val, default): return val if val is not None else default @@ -106,7 +256,7 @@ def _default(val, default): self.layers = nn.ModuleList() for _ in range(self.n_layers): self.layers.append( - TransformerDecoderLayer( + self.swappables.layer( self.n_heads, self.embedding_size, self.ffn_size, @@ -115,7 +265,7 @@ def _default(val, default): dropout=dropout_frac, activation=self.activation, variant=self.variant, - ) + ) # type: ignore ) def forward_embedding( @@ -199,6 +349,7 @@ def forward( input: torch.Tensor, encoder_state, incr_state: Optional[Dict[str, torch.Tensor]] = None, + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Forward pass. @@ -273,142 +424,3 @@ def _apply_model_parallel(self, tensor, encoder_output, encoder_mask, incr_state } return tensor_out, new_incr_state - - -class TransformerDecoderLayer(nn.Module): - """ - Implements a single Transformer decoder layer. - - Decoder layers are similar to encoder layers but: - - 1. Self-attention is limited in a casaul (auto-regressive) manner. - 2. Attend over all of the encoder states. - """ - - def __init__( - self, - n_heads: int, - embedding_size: int, - ffn_size: int, - attention_dropout: float = 0.0, - relu_dropout: float = 0.0, - dropout: float = 0.0, - activation: str = 'relu', - variant: str = 'aiayn', - ): - super().__init__() - self.dim = embedding_size - self.ffn_dim = ffn_size - self.variant = variant - self.activation = activation - self.dropout = nn.Dropout(p=dropout) - - self.self_attention = MultiHeadAttention( - n_heads, embedding_size, dropout=attention_dropout - ) - self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - - self.encoder_attention = MultiHeadAttention( - n_heads, embedding_size, dropout=attention_dropout - ) - self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - - self.ffn = TransformerFFN( - embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation - ) - self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - - def forward( - self, - x: torch.Tensor, - encoder_output: torch.Tensor, - encoder_mask: torch.Tensor, - incr_state: Optional[Dict[str, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """ - Forward pass. - - The incremental state is a dict with values for self- and encoder-attention - states. - """ - - if incr_state is None: - incr_state = {} - - decoder_mask = self._create_selfattn_mask(x) - # first self attn - residual = x - if self.variant == 'prelayernorm': - x = self.norm1(x) - - # don't peak into the future! - x, final_self_attn_incr_state = self.self_attention( - query=x, - mask=decoder_mask, - incr_state=incr_state.get('self_attn'), - static_kv=False, - )[:2] - x = self.dropout(x) # --dropout - x = x + residual - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - x = self.norm1(x) - - residual = x - # encoder_attn_layer_norm norm 2 - if self.variant == 'prelayernorm': - x = self.norm2(x) - x, final_encoder_attn_incr_state = self.encoder_attention( - query=x, - key=encoder_output, - value=encoder_output, - mask=encoder_mask, - incr_state=incr_state.get('encoder_attn'), - static_kv=True, - )[:2] - x = self.dropout(x) # --dropout - x = residual + x - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - x = self.norm2(x) - - # finally the ffn - residual = x - if self.variant == 'prelayernorm': - x = self.norm3(x) - x = self.ffn(x) - x = self.dropout(x) # --dropout - x = residual + x - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - x = self.norm3(x) - - new_incr_state = { - 'self_attn': final_self_attn_incr_state, - 'encoder_attn': final_encoder_attn_incr_state, - } - return x, new_incr_state - - def _create_selfattn_mask(self, x): - # figure out how many timestamps we need - bsz = x.size(0) - time = x.size(1) - # make sure that we don't look into the future - mask = torch.tril(x.new(time, time).fill_(1)) - # broadcast across batch - mask = mask.unsqueeze(0).expand(bsz, -1, -1) - return mask - - def reorder_incremental_state( - self, incremental_state: Dict[str, dict], inds: torch.Tensor - ) -> Dict[str, dict]: - """ - Reorder all incremental-state tensors for this layer. - """ - attn_types = { - 'self_attn': self.self_attention, - 'encoder_attn': self.encoder_attention, - } - return { - attn_type: attn.reorder_incremental_state( - incremental_state[attn_type], inds - ) - for attn_type, attn in attn_types.items() - } diff --git a/parlai/agents/transformer/modules/encoder.py b/parlai/agents/transformer/modules/encoder.py index ec338e4e004..ed9efdd4a20 100644 --- a/parlai/agents/transformer/modules/encoder.py +++ b/parlai/agents/transformer/modules/encoder.py @@ -7,6 +7,7 @@ Transformer encoder implementations. """ +from __future__ import annotations from typing import Tuple, Optional, Union import numpy as np @@ -20,11 +21,72 @@ MultiHeadAttention, TransformerFFN, ) +from parlai.agents.transformer.modules.modular import swappable from parlai.core.opt import Opt from parlai.utils.misc import warn_once from parlai.utils.torch import PipelineHelper +@swappable(self_attention=MultiHeadAttention, feedforward=TransformerFFN) +class TransformerEncoderLayer(nn.Module): + """ + Implements a single Transformer encoder layer. + """ + + def __init__( + self, + n_heads: int, + embedding_size: int, + ffn_size: int, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: Optional[str] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = embedding_size + self.ffn_dim = ffn_size + self.activation = activation + self.variant = variant + self.attention = self.swappables.self_attention( # type: ignore + n_heads, embedding_size, dropout=attention_dropout # --attention-dropout + ) + self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + self.ffn = self.swappables.feedforward( # type: ignore + embedding_size, + ffn_size, + relu_dropout=relu_dropout, + activation=self.activation, + ) + self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) + self.dropout = nn.Dropout(p=dropout) + + def forward( + self, tensor: torch.Tensor, mask: torch.Tensor, **kwargs + ) -> torch.Tensor: + """ + Forward pass. + """ + residual = tensor + if self.variant == 'prelayernorm': + tensor = self.norm1(tensor) + attended_tensor = self.attention(tensor, mask=mask)[0] + tensor = residual + self.dropout(attended_tensor) + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + tensor = self.norm1(tensor) + residual = tensor + if self.variant == 'prelayernorm': + tensor = self.norm2(tensor) + tensor = residual + self.dropout(self.ffn(tensor)) + if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': + tensor = self.norm2(tensor) + tensor *= mask.unsqueeze(-1).type_as(tensor) + return tensor + + +@swappable(layer=TransformerEncoderLayer) class TransformerEncoder(nn.Module): """ Transformer encoder module. @@ -58,12 +120,14 @@ def __init__( activation: Optional[str] = None, variant: Optional[str] = None, output_scaling: Optional[float] = None, + **kwargs, ): - super(TransformerEncoder, self).__init__() + super().__init__(**kwargs) def _default(val, default): return val if val is not None else default + self.opt = opt self.embedding_size = opt['embedding_size'] self.ffn_size = opt['ffn_size'] self.n_layers = ( @@ -141,7 +205,7 @@ def _default(val, default): self.layers = nn.ModuleList() for _ in range(self.n_layers): self.layers.append( - TransformerEncoderLayer( + self.swappables.layer( # type: ignore self.n_heads, self.embedding_size, self.ffn_size, @@ -256,6 +320,7 @@ def forward( # type: ignore input: torch.LongTensor, positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, + **kwargs, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.BoolTensor]]: """ Forward pass. @@ -306,58 +371,3 @@ def _apply_model_parallel(self, tensor, mask): tensor_out, mask_out = PipelineHelper.join(chunks) return tensor_out - - -class TransformerEncoderLayer(nn.Module): - """ - Implements a single Transformer encoder layer. - """ - - def __init__( - self, - n_heads: int, - embedding_size: int, - ffn_size: int, - attention_dropout: float = 0.0, - relu_dropout: float = 0.0, - dropout: float = 0.0, - activation: str = 'relu', - variant: Optional[str] = None, - ): - super().__init__() - self.dim = embedding_size - self.ffn_dim = ffn_size - self.activation = activation - self.variant = variant - self.attention = MultiHeadAttention( - n_heads, embedding_size, dropout=attention_dropout # --attention-dropout - ) - self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - self.ffn = TransformerFFN( - embedding_size, - ffn_size, - relu_dropout=relu_dropout, - activation=self.activation, - ) - self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - self.dropout = nn.Dropout(p=dropout) - - def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - """ - Forward pass. - """ - residual = tensor - if self.variant == 'prelayernorm': - tensor = self.norm1(tensor) - attended_tensor = self.attention(tensor, mask=mask)[0] - tensor = residual + self.dropout(attended_tensor) - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - tensor = self.norm1(tensor) - residual = tensor - if self.variant == 'prelayernorm': - tensor = self.norm2(tensor) - tensor = residual + self.dropout(self.ffn(tensor)) - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - tensor = self.norm2(tensor) - tensor *= mask.unsqueeze(-1).type_as(tensor) - return tensor diff --git a/parlai/agents/transformer/modules/generator.py b/parlai/agents/transformer/modules/generator.py index e6db5dc0289..c96272c5ab4 100644 --- a/parlai/agents/transformer/modules/generator.py +++ b/parlai/agents/transformer/modules/generator.py @@ -16,7 +16,8 @@ literature (BERT and XLM; https://arxiv.org/abs/1901.07291). """ -from typing import Dict +from __future__ import annotations +from typing import Dict, Type import torch import torch.cuda @@ -27,12 +28,14 @@ TransformerDecoder, TransformerEncoder, ) +from parlai.agents.transformer.modules.modular import swappable from parlai.core.opt import Opt from parlai.core.torch_agent import DictionaryAgent from parlai.core.torch_generator_agent import TorchGeneratorModel from parlai.utils.torch import neginf +@swappable(encoder=TransformerEncoder, decoder=TransformerDecoder) class TransformerGeneratorModel(TorchGeneratorModel): """ Implements a full generator model, with one encoder and one decoder. @@ -40,33 +43,57 @@ class TransformerGeneratorModel(TorchGeneratorModel): @classmethod def build_encoder( - cls, opt, dictionary, embedding=None, padding_idx=None, reduction_type='mean' - ): - return TransformerEncoder( + cls, + opt, + dictionary, + embedding=None, + padding_idx=None, + reduction_type='mean', + encoder_class: Type[TransformerEncoder] = TransformerEncoder, + **kwargs, + ) -> TransformerEncoder: + return encoder_class( opt=opt, embedding=embedding, vocabulary_size=len(dictionary), padding_idx=padding_idx, reduction_type=reduction_type, + **kwargs, ) @classmethod - def build_decoder(cls, opt, embedding=None): - return TransformerDecoder(opt=opt, embedding=embedding) - - def __init__(self, opt: Opt, dictionary: DictionaryAgent): + def build_decoder( + cls, + opt, + embedding=None, + decoder_class: Type[TransformerDecoder] = TransformerDecoder, + **kwargs, + ) -> TransformerDecoder: + return decoder_class(opt=opt, embedding=embedding, **kwargs) + + def __init__(self, opt: Opt, dictionary: DictionaryAgent, **kwargs): self.pad_idx = dictionary[dictionary.null_token] self.start_idx = dictionary[dictionary.start_token] self.end_idx = dictionary[dictionary.end_token] - super().__init__(self.pad_idx, self.start_idx, self.end_idx) + super().__init__(self.pad_idx, self.start_idx, self.end_idx, **kwargs) + self.opt = opt self.embeddings = create_embeddings( dictionary, opt['embedding_size'], self.pad_idx ) self.encoder = self.build_encoder( - opt, dictionary, self.embeddings, self.pad_idx, reduction_type=None + opt, + dictionary, + self.embeddings, + self.pad_idx, + reduction_type=None, + encoder_class=self.swappables.encoder, # type: ignore + ) + self.decoder = self.build_decoder( + opt, + embedding=self.embeddings, + decoder_class=self.swappables.decoder, # type: ignore ) - self.decoder = self.build_decoder(opt, self.embeddings) def reorder_encoder_states(self, encoder_states, indices): """ diff --git a/parlai/agents/transformer/modules/modular.py b/parlai/agents/transformer/modules/modular.py new file mode 100644 index 00000000000..1415f198c2d --- /dev/null +++ b/parlai/agents/transformer/modules/modular.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Tools for annotating modules with lightweight dependency injection. + +Primarily designed to swap out individual modules deep within the transformer class hierarchy. + +Usage: + +```python +@swappable(component=DefaultClass, ...) +class SomeModel(nn.Module): + ... +``` + +Within the model, access the swappable classes like so: + +```python +self.swappables.component() +``` + +When instantiating the model, swap out the component like so: + +```python +model = SomeModel.with_components(component=NewCustomClass)() +``` +""" +from __future__ import annotations +from dataclasses import dataclass +from typing import Any, Callable, Dict, Generic, Optional, Type, TypeVar + + +C = TypeVar('C', bound=object) + + +class ModularComponent(Generic[C]): + @dataclass + class SwappableSubcomponents: + """ + Define any swappable subcomponents by adding the class (or a constructor) of the + components as attributes of this object. + + When using the @swappable decorator, this class is created programmatically. + """ + + @classmethod + def with_components(cls, **kwargs) -> ModularComponentBuilder[ModularComponent[C]]: + return ModularComponentBuilder( + klass=cls, subcomponents=cls.SwappableSubcomponents(**kwargs) + ) + + def __init__( + self, *args, subcomponents: Optional[SwappableSubcomponents] = None, **kwargs + ): + """ + Unpacks the swappable_components, then forwards the call up the MRO chain. + """ + self.swappables = subcomponents or type(self).SwappableSubcomponents() + assert ( + type(self.swappables) is not ModularComponent.SwappableSubcomponents + ), "Modular components must declare their own SwappableSubcomponents" + super().__init__(*args, **kwargs) + + +MC = TypeVar('MC', bound=ModularComponent) + + +class ModularComponentBuilder(Generic[MC]): + """ + When a component has swappable subcomponents, use this object to specify both the + component type and it's subcomponent types at the same time. + """ + + def __init__( + self, + klass: Type[MC], + subcomponents: Optional[ModularComponent.SwappableSubcomponents] = None, + ) -> None: + self._klass = klass + self._subcomponents = subcomponents or klass.SwappableSubcomponents() + + def __call__(self, *args, **kwargs) -> MC: + """ + Forward calls to this instance to __init__ of wrapped class. + """ + return self._klass(*args, subcomponents=self._subcomponents, **kwargs) + + @property + def swappables(self) -> Any: + return self._subcomponents + + def swap_components(self, **kwargs) -> None: + for name, value in kwargs.items(): + setattr(self._subcomponents, name, value) + + +def swappable(**kwargs) -> Callable[[Type[C]], Type[ModularComponent[C]]]: + """ + Decorator that adds swappable subcomponents to a class. + + Usage: + + ```python + @swappable(component_name=DefaultComponentClass, ...) + ``` + """ + + # Decorators need to return callables that accept only the decorated object. + # To comply, bundle kwargs into a function that accepts only one argument. + def wrap(cls: Type[C]) -> Type[ModularComponent[C]]: + return _make_class_swappable(cls, **kwargs) + + return wrap + + +def _make_class_swappable(cls: Type[C], **kwargs) -> Type[ModularComponent[C]]: + """ + Creates a new class that subclasses ModularComponent. + + Modifies that class to to accept constructors for the components passed to the + decorator. + """ + + def _class_dict(new_fields) -> Dict[str, Any]: + """ + Sets up the class attributes, along with type annotations. + """ + return { + **new_fields, + '__annotations__': {**{k: type(v) for k, v in new_fields.items()}}, + } + + # Create SwappableSubcomponents dataclass with components passed to @swappable + subcomponent_class_name = ModularComponent.SwappableSubcomponents.__name__ + subcomponent_dataclass = dataclass( + type(subcomponent_class_name, (), _class_dict(kwargs)) + ) + + # Create a new class that subclasses the decorated class. This new class holds + # the SwappableSubcomponents dataclass created above (list of components that are + # swappable) and the swappables attribute (the actual swappable component constructors). + return type( + # We append "_Swappable" to the new class name for transparency. + f"{cls.__name__}_Swappable", + # ModularComponent comes before the class so we can intercept __init__ calls. + (ModularComponent, cls), # type: ignore + # Items in this dictionary are converted to class attributes by type() + _class_dict({subcomponent_class_name: subcomponent_dataclass}), + ) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index ae92d48964d..25f7338d94b 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -119,8 +119,9 @@ def __init__( unknown_idx=3, input_dropout=0, longest_label=1, + **kwargs, ): - super().__init__() + super().__init__(**kwargs) self.NULL_IDX = padding_idx self.END_IDX = end_idx self.START_IDX = start_idx diff --git a/projects/style_gen/modules.py b/projects/style_gen/modules.py index 29d8b64c9d2..eee33773b00 100644 --- a/projects/style_gen/modules.py +++ b/projects/style_gen/modules.py @@ -81,7 +81,7 @@ class ClassifierOnGeneratorModel(TransformerGeneratorModel): """ @classmethod - def build_decoder(cls, opt, embedding=None): + def build_decoder(cls, opt, embedding=None, **kwargs): """ Return TransformerDecoderWithEmbeds instead of TransformerDecoder. """ diff --git a/tests/test_transformers.py b/tests/test_transformers.py index d9d71c60bfe..91b56f02667 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -9,11 +9,20 @@ """ import os +import torch import unittest +from unittest.mock import MagicMock import pytest import parlai.utils.testing as testing_utils +from parlai.agents.transformer.modules import ( + TransformerFFN, + TransformerGeneratorModel, + TransformerEncoder, + TransformerEncoderLayer, +) from parlai.core.agents import create_agent from parlai.core.agents import create_agent_from_model_file +from parlai.core.dict import DictionaryAgent from parlai.core.opt import Opt from .test_dict import DEFAULT_BYTELEVEL_BPE_VOCAB, DEFAULT_BYTELEVEL_BPE_MERGE from parlai.core.params import ParlaiParser @@ -872,5 +881,102 @@ def test_multitask(self): ), f'ImagePolyencoderAgent val-set accuracy on a simple task was {valid["accuracy"].value():0.2f}.' +class TestSwappableComponents(unittest.TestCase): + def _opt(self, **kwargs): + return Opt( + batchsize=4, + optimizer='adam', + n_layers=1, + n_heads=4, + ffn_size=16, + embedding_size=16, + skip_generation=True, + **kwargs, + ) + + def test_swap_encoder_attention(self): + CustomFFN = type('CustomFFN', (TransformerFFN,), {}) + CustomFFN.forward = MagicMock() + wrapped_class = TransformerGeneratorModel.with_components( + encoder=TransformerEncoder.with_components( + layer=TransformerEncoderLayer.with_components(feedforward=CustomFFN) + ) + ) + opt = self._opt() + CustomFFN.forward.assert_not_called + model = wrapped_class(opt=opt, dictionary=DictionaryAgent(opt)) + assert isinstance(model, TransformerGeneratorModel) # type: ignore + try: + model(torch.zeros(1, 1).long(), ys=torch.zeros(1, 1).long()) # type: ignore + except TypeError: + pass + finally: + CustomFFN.forward.assert_called + + def test_swap_is_not_persisted_in_class(self): + opt = self._opt() + dictionary = DictionaryAgent(opt) + + CustomFFN = type('CustomFFN', (TransformerFFN,), {}) + wrapped_class = TransformerGeneratorModel.with_components( + encoder=TransformerEncoder.with_components( + layer=TransformerEncoderLayer.with_components(feedforward=CustomFFN) + ) + ) + model = wrapped_class(opt=opt, dictionary=dictionary) + assert ( + model.swappables.encoder.swappables.layer.swappables.feedforward + == CustomFFN + ) # type: ignore + + another_model = TransformerGeneratorModel(opt, dictionary) + assert another_model.swappables != model.swappables + assert issubclass( + another_model.swappables.encoder, TransformerEncoder + ) # type: ignore + + wrapped_class.swap_components( + encoder=TransformerEncoder.with_components( + layer=TransformerEncoderLayer.with_components( + feedforward=TransformerFFN + ) + ) + ) + one_more_model = wrapped_class(opt=opt, dictionary=dictionary) + assert ( + one_more_model.swappables.encoder.swappables.layer.swappables.feedforward + == TransformerFFN + ) # type: ignore + + def test_examples_variant(self): + opt = ParlaiParser(True, True).parse_kwargs( + model='parlai.agents.examples.transformer_variant:TransformerVariantAgent' + ) + model = create_agent(opt) + # send the model a single training example to ensure it can forward/backward + model.observe({'text': '1 2 3 4', 'labels': ['1 2 3 4'], 'episode_done': True}) + model.act() + # send the model a single validation example + model.observe( + {'text': '1 2 3 4', 'eval_labels': ['1 2 3 4'], 'episode_done': True} + ) + model.act() + + def test_examples_configurable(self): + opt = ParlaiParser(True, True).parse_kwargs( + model='parlai.agents.examples.transformer_variant:ConfigurableTransformerAgent', + decoder_ffn_variants='two', + ) + model = create_agent(opt) + # send the model a single training example to ensure it can forward/backward + model.observe({'text': '1 2 3 4', 'labels': ['1 2 3 4'], 'episode_done': True}) + model.act() + # send the model a single validation example + model.observe( + {'text': '1 2 3 4', 'eval_labels': ['1 2 3 4'], 'episode_done': True} + ) + model.act() + + if __name__ == '__main__': unittest.main()