Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Swappable Transformer Components #3567

Merged
merged 23 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6c8bd4e
minimum viable component-swappable transformer
spencerp Apr 1, 2021
f3456ea
add example
spencerp Apr 1, 2021
725bbc9
woops forgot to pass in manifest in example
spencerp Apr 1, 2021
7052932
make ComponentSpec immutable
spencerp Apr 1, 2021
b3a1cf6
keep build_encoder and build_decoder backwards compatible, add some c…
spencerp Apr 1, 2021
34f33ec
autoformat.sh
spencerp Apr 1, 2021
a52c96c
update signature of overridden build_decoder
spencerp Apr 1, 2021
5e539cc
address comments: rename manifest/tcomponent and provide example of l…
spencerp Apr 8, 2021
265699e
Merge branch 'master' into transformer-manifest-mvp
spencerp Apr 9, 2021
f3db720
update comment with new naming
spencerp Apr 9, 2021
8a3aae3
explicitly labeling some model components as static
spencerp Apr 9, 2021
12f88f0
tweak ModuleComponentSpec
spencerp Apr 14, 2021
2f9b08c
Merge branch 'master' into transformer-manifest-mvp
spencerp Apr 20, 2021
93ea7e7
add fully-specified example
spencerp Apr 20, 2021
c4b6924
add another example
spencerp Apr 20, 2021
d871fd8
remove weird auto-import
spencerp Apr 20, 2021
b95135e
simplify API
spencerp Apr 26, 2021
8f8d867
stick everything in a decorator
spencerp Apr 27, 2021
6f5817b
remove StaticComponent, clean up implementation a little, add comment…
spencerp Apr 28, 2021
9a16730
add website docs
spencerp Apr 28, 2021
e175f58
obey header level rule in docs
spencerp Apr 28, 2021
e7cdb06
in docs, Model->Transformer
spencerp Apr 29, 2021
9752d2c
Merge branch 'master' into transformer-manifest-mvp
spencerp Apr 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tutorial_fast
tutorial_mutators
tutorial_crowdsourcing
tutorial_chat_service
tutorial_swap_components
tutorial_tests
```

Expand Down
60 changes: 60 additions & 0 deletions docs/source/tutorial_swap_components.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Swapping Out Transformer Subcomponents

__Author__: Spencer Poff

Sometimes you find yourself wanting to experiment with an architecture that looks a lot like another, but with one component modified. If that component is buried deep within the model, this is not easily accomplished with subclassing without copying and pasting much of the original implementation.

To make this easier and avoid copypasta, we provide the `@swappable` decorator.

## Making a Module Swappable

Let's say you have an existing class, `TransformerLayer`, that uses a module that you'd like to modify, `TransformerFFN`. You can make that FFN swappable in two steps:

1. Decorate `TransformerLayer` with `@swappable`, passing in a name for the component you'd like to swap and its default class/constructor:
```python
@swappable(ffn=TransformerFFN)
class TransformerLayer(nn.Module):
...
```

2. At runtime, the class for ffn will be added to a property `swappables` of `TransformerLayer`. Replace your instantiation of `TransformerFFN` with a call to that constructor:

```python
self.feedforward = self.swappables.ffn(opt, ...)
```

That's it!

## Making the Swap

You can now replace `TransformerFFN` with whatever class or constructor you want before instantiating `TransformerLayer`:
```python
layer = TransformerLayer.with_components(ffn=NewCustomFFN)(opt, ...)
```

As long as `NewCustomFFN` has the same `__init__` and `forward` method signatures as `TransformerFFN`, everything should just work.

## Composability

Since the swapping happens before instantiation, decorated components can be transparently composed. For example:
```python
model = TransformerGeneratorModel.with_components(
encoder=TransformerEncoder.with_components(
layer=TransformerEncoderLayer.with_components(
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
encoder_attention=MultiHeadAttention,
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
)(opt=self.opt, dictionary=self.dict)
```

## Implementation

See `parlai/agents/transformer/modules/modular.py`
204 changes: 204 additions & 0 deletions parlai/agents/examples/transformer_variant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Example code for specifying custom transformer variants.

TransformerVariantAgent:
- Minimal changes needed to:
- Swap out a high-level component (encoder)
- Swap out a low-level component (decoder->layer->self_attention)

VerboseTransformerAgent:
- Doesn't swap out anything
- Fully specifies all components, for illustration

ConfigurableTransformerAgent:
- Swaps out components based on command line args
"""
from __future__ import annotations
import torch
from enum import Enum
from typing import Dict, Optional, Tuple, Union

from parlai.agents.transformer.modules import (
TransformerFFN,
MultiHeadAttention,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
TransformerGeneratorModel,
)
from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
import parlai.utils.logging as logging


###########################################
# Transformer With Two Components Swapped #
###########################################


class TransformerVariantAgent(TransformerGeneratorAgent):
spencerp marked this conversation as resolved.
Show resolved Hide resolved
"""
Swapping out two things:

1. Encoder (high-level component)
2. Decoder self attention (low-level component)
"""

def build_model(self, states=None):
wrapped_class = TransformerGeneratorModel.with_components(
encoder=MyCustomEncoder,
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
self_attention=MyCustomAttention
)
),
)
return wrapped_class(self.opt, self.dict)


class MyCustomEncoder(TransformerEncoder):
"""
For brevity this subclasses TransformerEncoder, but you could write your own
nn.Module from scratch as long as the __init__ and forward signatures match
TransformerEncoder.
"""

def forward(
self,
input: torch.LongTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.BoolTensor]]:
logging.info("Custom encoder called!")
# Comment out the following line and write your custom `forward` instead.
return super().forward(input, positions, segments) # type: ignore


class MyCustomAttention(MultiHeadAttention):
"""
For brevity this just renames MultiHeadAttention, but ideally you'd define a new
nn.Module with the same __init__ and forward signature as MultiHeadAttention.
"""

def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
mask: torch.Tensor = None,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
static_kv: bool = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
logging.info("Custom attention called!")
# Comment out the following line and write your custom `forward` instead.
return super().forward(
query,
key=key,
value=value,
mask=mask,
incr_state=incr_state,
static_kv=static_kv,
)


#######################################
# Fully-specified Default Transformer #
#######################################


class VerboseTransformerAgent(TransformerGeneratorAgent):
"""
Doesn't make any changes to TransformerGeneratorModel, just specifies all
subcomponents explicitly.

This is meant to be a reference for how to swap any component within
TransformerGeneratorModel.
"""

def build_model(self, states=None):
wrapped_class = TransformerGeneratorModel.with_components(
encoder=TransformerEncoder.with_components(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is the one

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, this looks nice

layer=TransformerEncoderLayer.with_components(
self_attention=MultiHeadAttention, feedforward=TransformerFFN
)
),
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
encoder_attention=MultiHeadAttention,
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
)
return wrapped_class(opt=self.opt, dictionary=self.dict)


################################################
# Command-line Configurable Custom Transformer #
################################################


class DecoderFeedForwardVariant(Enum):
ONE = 'one'
TWO = 'two'


class DecoderFFNOne(TransformerFFN):
def forward(self, x: torch.Tensor) -> torch.Tensor:
logging.info("Using Decoder FFN Variant One")
return super().forward(x)


class DecoderFFNTwo(TransformerFFN):
def forward(self, x: torch.Tensor) -> torch.Tensor:
logging.info("Using Decoder FFN Variant Two")
return super().forward(x)


class ConfigurableTransformerAgent(TransformerGeneratorAgent):
"""
Illustrates swapping out components based on command line args.

Specifically, swaps out the decoder ffn between two options.
"""

@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt=partial_opt)
agent = parser.add_argument_group('MyCustom Transformer Arguments')
parser.add_argument(
'--decoder-ffn-variants',
type=DecoderFeedForwardVariant,
default=DecoderFeedForwardVariant.ONE,
help='Some variants in the decoder FFN implementation',
)
return agent # type: ignore

def build_model(self, states=None):
decoder_variant: DecoderFeedForwardVariant = self.opt['decoder_ffn_variants']
if decoder_variant == DecoderFeedForwardVariant.ONE:
decoder_ffn_class = DecoderFFNOne
elif decoder_variant == DecoderFeedForwardVariant.TWO:
decoder_ffn_class = DecoderFFNTwo
else:
logging.error(
'Invalid --decoder-ffn-variants option, defaulting to original ffn implementation.'
)
decoder_ffn_class = TransformerFFN

wrapped_class = TransformerGeneratorModel.with_components(
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
feedforward=decoder_ffn_class
)
)
)
return wrapped_class(opt=self.opt, dictionary=self.dict)
Loading