This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Swappable Transformer Components (#3567)
* minimum viable component-swappable transformer * add example * woops forgot to pass in manifest in example * make ComponentSpec immutable * keep build_encoder and build_decoder backwards compatible, add some comments * autoformat.sh * update signature of overridden build_decoder * address comments: rename manifest/tcomponent and provide example of leaf component substitution * update comment with new naming * explicitly labeling some model components as static * tweak ModuleComponentSpec * add fully-specified example * add another example * remove weird auto-import * simplify API * stick everything in a decorator * remove StaticComponent, clean up implementation a little, add comments, tests * add website docs * obey header level rule in docs * in docs, Model->Transformer
- Loading branch information
Showing
10 changed files
with
785 additions
and
213 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Swapping Out Transformer Subcomponents | ||
|
||
__Author__: Spencer Poff | ||
|
||
Sometimes you find yourself wanting to experiment with an architecture that looks a lot like another, but with one component modified. If that component is buried deep within the model, this is not easily accomplished with subclassing without copying and pasting much of the original implementation. | ||
|
||
To make this easier and avoid copypasta, we provide the `@swappable` decorator. | ||
|
||
## Making a Module Swappable | ||
|
||
Let's say you have an existing class, `TransformerLayer`, that uses a module that you'd like to modify, `TransformerFFN`. You can make that FFN swappable in two steps: | ||
|
||
1. Decorate `TransformerLayer` with `@swappable`, passing in a name for the component you'd like to swap and its default class/constructor: | ||
```python | ||
@swappable(ffn=TransformerFFN) | ||
class TransformerLayer(nn.Module): | ||
... | ||
``` | ||
|
||
2. At runtime, the class for ffn will be added to a property `swappables` of `TransformerLayer`. Replace your instantiation of `TransformerFFN` with a call to that constructor: | ||
|
||
```python | ||
self.feedforward = self.swappables.ffn(opt, ...) | ||
``` | ||
|
||
That's it! | ||
|
||
## Making the Swap | ||
|
||
You can now replace `TransformerFFN` with whatever class or constructor you want before instantiating `TransformerLayer`: | ||
```python | ||
layer = TransformerLayer.with_components(ffn=NewCustomFFN)(opt, ...) | ||
``` | ||
|
||
As long as `NewCustomFFN` has the same `__init__` and `forward` method signatures as `TransformerFFN`, everything should just work. | ||
|
||
## Composability | ||
|
||
Since the swapping happens before instantiation, decorated components can be transparently composed. For example: | ||
```python | ||
model = TransformerGeneratorModel.with_components( | ||
encoder=TransformerEncoder.with_components( | ||
layer=TransformerEncoderLayer.with_components( | ||
self_attention=MultiHeadAttention, | ||
feedforward=TransformerFFN, | ||
) | ||
), | ||
decoder=TransformerDecoder.with_components( | ||
layer=TransformerDecoderLayer.with_components( | ||
encoder_attention=MultiHeadAttention, | ||
self_attention=MultiHeadAttention, | ||
feedforward=TransformerFFN, | ||
) | ||
), | ||
)(opt=self.opt, dictionary=self.dict) | ||
``` | ||
|
||
## Implementation | ||
|
||
See `parlai/agents/transformer/modules/modular.py` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
""" | ||
Example code for specifying custom transformer variants. | ||
TransformerVariantAgent: | ||
- Minimal changes needed to: | ||
- Swap out a high-level component (encoder) | ||
- Swap out a low-level component (decoder->layer->self_attention) | ||
VerboseTransformerAgent: | ||
- Doesn't swap out anything | ||
- Fully specifies all components, for illustration | ||
ConfigurableTransformerAgent: | ||
- Swaps out components based on command line args | ||
""" | ||
from __future__ import annotations | ||
import torch | ||
from enum import Enum | ||
from typing import Dict, Optional, Tuple, Union | ||
|
||
from parlai.agents.transformer.modules import ( | ||
TransformerFFN, | ||
MultiHeadAttention, | ||
TransformerDecoder, | ||
TransformerDecoderLayer, | ||
TransformerEncoder, | ||
TransformerEncoderLayer, | ||
TransformerGeneratorModel, | ||
) | ||
from parlai.agents.transformer.transformer import TransformerGeneratorAgent | ||
from parlai.core.opt import Opt | ||
from parlai.core.params import ParlaiParser | ||
import parlai.utils.logging as logging | ||
|
||
|
||
########################################### | ||
# Transformer With Two Components Swapped # | ||
########################################### | ||
|
||
|
||
class TransformerVariantAgent(TransformerGeneratorAgent): | ||
""" | ||
Swapping out two things: | ||
1. Encoder (high-level component) | ||
2. Decoder self attention (low-level component) | ||
""" | ||
|
||
def build_model(self, states=None): | ||
wrapped_class = TransformerGeneratorModel.with_components( | ||
encoder=MyCustomEncoder, | ||
decoder=TransformerDecoder.with_components( | ||
layer=TransformerDecoderLayer.with_components( | ||
self_attention=MyCustomAttention | ||
) | ||
), | ||
) | ||
return wrapped_class(self.opt, self.dict) | ||
|
||
|
||
class MyCustomEncoder(TransformerEncoder): | ||
""" | ||
For brevity this subclasses TransformerEncoder, but you could write your own | ||
nn.Module from scratch as long as the __init__ and forward signatures match | ||
TransformerEncoder. | ||
""" | ||
|
||
def forward( | ||
self, | ||
input: torch.LongTensor, | ||
positions: Optional[torch.LongTensor] = None, | ||
segments: Optional[torch.LongTensor] = None, | ||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.BoolTensor]]: | ||
logging.info("Custom encoder called!") | ||
# Comment out the following line and write your custom `forward` instead. | ||
return super().forward(input, positions, segments) # type: ignore | ||
|
||
|
||
class MyCustomAttention(MultiHeadAttention): | ||
""" | ||
For brevity this just renames MultiHeadAttention, but ideally you'd define a new | ||
nn.Module with the same __init__ and forward signature as MultiHeadAttention. | ||
""" | ||
|
||
def forward( | ||
self, | ||
query: torch.Tensor, | ||
key: Optional[torch.Tensor] = None, | ||
value: Optional[torch.Tensor] = None, | ||
mask: torch.Tensor = None, | ||
incr_state: Optional[Dict[str, torch.Tensor]] = None, | ||
static_kv: bool = False, | ||
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: | ||
logging.info("Custom attention called!") | ||
# Comment out the following line and write your custom `forward` instead. | ||
return super().forward( | ||
query, | ||
key=key, | ||
value=value, | ||
mask=mask, | ||
incr_state=incr_state, | ||
static_kv=static_kv, | ||
) | ||
|
||
|
||
####################################### | ||
# Fully-specified Default Transformer # | ||
####################################### | ||
|
||
|
||
class VerboseTransformerAgent(TransformerGeneratorAgent): | ||
""" | ||
Doesn't make any changes to TransformerGeneratorModel, just specifies all | ||
subcomponents explicitly. | ||
This is meant to be a reference for how to swap any component within | ||
TransformerGeneratorModel. | ||
""" | ||
|
||
def build_model(self, states=None): | ||
wrapped_class = TransformerGeneratorModel.with_components( | ||
encoder=TransformerEncoder.with_components( | ||
layer=TransformerEncoderLayer.with_components( | ||
self_attention=MultiHeadAttention, feedforward=TransformerFFN | ||
) | ||
), | ||
decoder=TransformerDecoder.with_components( | ||
layer=TransformerDecoderLayer.with_components( | ||
encoder_attention=MultiHeadAttention, | ||
self_attention=MultiHeadAttention, | ||
feedforward=TransformerFFN, | ||
) | ||
), | ||
) | ||
return wrapped_class(opt=self.opt, dictionary=self.dict) | ||
|
||
|
||
################################################ | ||
# Command-line Configurable Custom Transformer # | ||
################################################ | ||
|
||
|
||
class DecoderFeedForwardVariant(Enum): | ||
ONE = 'one' | ||
TWO = 'two' | ||
|
||
|
||
class DecoderFFNOne(TransformerFFN): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
logging.info("Using Decoder FFN Variant One") | ||
return super().forward(x) | ||
|
||
|
||
class DecoderFFNTwo(TransformerFFN): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
logging.info("Using Decoder FFN Variant Two") | ||
return super().forward(x) | ||
|
||
|
||
class ConfigurableTransformerAgent(TransformerGeneratorAgent): | ||
""" | ||
Illustrates swapping out components based on command line args. | ||
Specifically, swaps out the decoder ffn between two options. | ||
""" | ||
|
||
@classmethod | ||
def add_cmdline_args( | ||
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None | ||
) -> ParlaiParser: | ||
super().add_cmdline_args(parser, partial_opt=partial_opt) | ||
agent = parser.add_argument_group('MyCustom Transformer Arguments') | ||
parser.add_argument( | ||
'--decoder-ffn-variants', | ||
type=DecoderFeedForwardVariant, | ||
default=DecoderFeedForwardVariant.ONE, | ||
help='Some variants in the decoder FFN implementation', | ||
) | ||
return agent # type: ignore | ||
|
||
def build_model(self, states=None): | ||
decoder_variant: DecoderFeedForwardVariant = self.opt['decoder_ffn_variants'] | ||
if decoder_variant == DecoderFeedForwardVariant.ONE: | ||
decoder_ffn_class = DecoderFFNOne | ||
elif decoder_variant == DecoderFeedForwardVariant.TWO: | ||
decoder_ffn_class = DecoderFFNTwo | ||
else: | ||
logging.error( | ||
'Invalid --decoder-ffn-variants option, defaulting to original ffn implementation.' | ||
) | ||
decoder_ffn_class = TransformerFFN | ||
|
||
wrapped_class = TransformerGeneratorModel.with_components( | ||
decoder=TransformerDecoder.with_components( | ||
layer=TransformerDecoderLayer.with_components( | ||
feedforward=decoder_ffn_class | ||
) | ||
) | ||
) | ||
return wrapped_class(opt=self.opt, dictionary=self.dict) |
Oops, something went wrong.