Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Swappable Transformer Components (#3567)
Browse files Browse the repository at this point in the history
* minimum viable component-swappable transformer

* add example

* woops forgot to pass in manifest in example

* make ComponentSpec immutable

* keep build_encoder and build_decoder backwards compatible, add some comments

* autoformat.sh

* update signature of overridden build_decoder

* address comments: rename manifest/tcomponent and provide example of leaf component substitution

* update comment with new naming

* explicitly labeling some model components as static

* tweak ModuleComponentSpec

* add fully-specified example

* add another example

* remove weird auto-import

* simplify API

* stick everything in a decorator

* remove StaticComponent, clean up implementation a little, add comments, tests

* add website docs

* obey header level rule in docs

* in docs, Model->Transformer
  • Loading branch information
spencerp authored Apr 29, 2021
1 parent f9346e6 commit 6040079
Show file tree
Hide file tree
Showing 10 changed files with 785 additions and 213 deletions.
1 change: 1 addition & 0 deletions docs/source/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ tutorial_fast
tutorial_mutators
tutorial_crowdsourcing
tutorial_chat_service
tutorial_swap_components
tutorial_tests
```

Expand Down
60 changes: 60 additions & 0 deletions docs/source/tutorial_swap_components.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Swapping Out Transformer Subcomponents

__Author__: Spencer Poff

Sometimes you find yourself wanting to experiment with an architecture that looks a lot like another, but with one component modified. If that component is buried deep within the model, this is not easily accomplished with subclassing without copying and pasting much of the original implementation.

To make this easier and avoid copypasta, we provide the `@swappable` decorator.

## Making a Module Swappable

Let's say you have an existing class, `TransformerLayer`, that uses a module that you'd like to modify, `TransformerFFN`. You can make that FFN swappable in two steps:

1. Decorate `TransformerLayer` with `@swappable`, passing in a name for the component you'd like to swap and its default class/constructor:
```python
@swappable(ffn=TransformerFFN)
class TransformerLayer(nn.Module):
...
```

2. At runtime, the class for ffn will be added to a property `swappables` of `TransformerLayer`. Replace your instantiation of `TransformerFFN` with a call to that constructor:

```python
self.feedforward = self.swappables.ffn(opt, ...)
```

That's it!

## Making the Swap

You can now replace `TransformerFFN` with whatever class or constructor you want before instantiating `TransformerLayer`:
```python
layer = TransformerLayer.with_components(ffn=NewCustomFFN)(opt, ...)
```

As long as `NewCustomFFN` has the same `__init__` and `forward` method signatures as `TransformerFFN`, everything should just work.

## Composability

Since the swapping happens before instantiation, decorated components can be transparently composed. For example:
```python
model = TransformerGeneratorModel.with_components(
encoder=TransformerEncoder.with_components(
layer=TransformerEncoderLayer.with_components(
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
encoder_attention=MultiHeadAttention,
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
)(opt=self.opt, dictionary=self.dict)
```

## Implementation

See `parlai/agents/transformer/modules/modular.py`
204 changes: 204 additions & 0 deletions parlai/agents/examples/transformer_variant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Example code for specifying custom transformer variants.
TransformerVariantAgent:
- Minimal changes needed to:
- Swap out a high-level component (encoder)
- Swap out a low-level component (decoder->layer->self_attention)
VerboseTransformerAgent:
- Doesn't swap out anything
- Fully specifies all components, for illustration
ConfigurableTransformerAgent:
- Swaps out components based on command line args
"""
from __future__ import annotations
import torch
from enum import Enum
from typing import Dict, Optional, Tuple, Union

from parlai.agents.transformer.modules import (
TransformerFFN,
MultiHeadAttention,
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
TransformerGeneratorModel,
)
from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
import parlai.utils.logging as logging


###########################################
# Transformer With Two Components Swapped #
###########################################


class TransformerVariantAgent(TransformerGeneratorAgent):
"""
Swapping out two things:
1. Encoder (high-level component)
2. Decoder self attention (low-level component)
"""

def build_model(self, states=None):
wrapped_class = TransformerGeneratorModel.with_components(
encoder=MyCustomEncoder,
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
self_attention=MyCustomAttention
)
),
)
return wrapped_class(self.opt, self.dict)


class MyCustomEncoder(TransformerEncoder):
"""
For brevity this subclasses TransformerEncoder, but you could write your own
nn.Module from scratch as long as the __init__ and forward signatures match
TransformerEncoder.
"""

def forward(
self,
input: torch.LongTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.BoolTensor]]:
logging.info("Custom encoder called!")
# Comment out the following line and write your custom `forward` instead.
return super().forward(input, positions, segments) # type: ignore


class MyCustomAttention(MultiHeadAttention):
"""
For brevity this just renames MultiHeadAttention, but ideally you'd define a new
nn.Module with the same __init__ and forward signature as MultiHeadAttention.
"""

def forward(
self,
query: torch.Tensor,
key: Optional[torch.Tensor] = None,
value: Optional[torch.Tensor] = None,
mask: torch.Tensor = None,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
static_kv: bool = False,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
logging.info("Custom attention called!")
# Comment out the following line and write your custom `forward` instead.
return super().forward(
query,
key=key,
value=value,
mask=mask,
incr_state=incr_state,
static_kv=static_kv,
)


#######################################
# Fully-specified Default Transformer #
#######################################


class VerboseTransformerAgent(TransformerGeneratorAgent):
"""
Doesn't make any changes to TransformerGeneratorModel, just specifies all
subcomponents explicitly.
This is meant to be a reference for how to swap any component within
TransformerGeneratorModel.
"""

def build_model(self, states=None):
wrapped_class = TransformerGeneratorModel.with_components(
encoder=TransformerEncoder.with_components(
layer=TransformerEncoderLayer.with_components(
self_attention=MultiHeadAttention, feedforward=TransformerFFN
)
),
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
encoder_attention=MultiHeadAttention,
self_attention=MultiHeadAttention,
feedforward=TransformerFFN,
)
),
)
return wrapped_class(opt=self.opt, dictionary=self.dict)


################################################
# Command-line Configurable Custom Transformer #
################################################


class DecoderFeedForwardVariant(Enum):
ONE = 'one'
TWO = 'two'


class DecoderFFNOne(TransformerFFN):
def forward(self, x: torch.Tensor) -> torch.Tensor:
logging.info("Using Decoder FFN Variant One")
return super().forward(x)


class DecoderFFNTwo(TransformerFFN):
def forward(self, x: torch.Tensor) -> torch.Tensor:
logging.info("Using Decoder FFN Variant Two")
return super().forward(x)


class ConfigurableTransformerAgent(TransformerGeneratorAgent):
"""
Illustrates swapping out components based on command line args.
Specifically, swaps out the decoder ffn between two options.
"""

@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt=partial_opt)
agent = parser.add_argument_group('MyCustom Transformer Arguments')
parser.add_argument(
'--decoder-ffn-variants',
type=DecoderFeedForwardVariant,
default=DecoderFeedForwardVariant.ONE,
help='Some variants in the decoder FFN implementation',
)
return agent # type: ignore

def build_model(self, states=None):
decoder_variant: DecoderFeedForwardVariant = self.opt['decoder_ffn_variants']
if decoder_variant == DecoderFeedForwardVariant.ONE:
decoder_ffn_class = DecoderFFNOne
elif decoder_variant == DecoderFeedForwardVariant.TWO:
decoder_ffn_class = DecoderFFNTwo
else:
logging.error(
'Invalid --decoder-ffn-variants option, defaulting to original ffn implementation.'
)
decoder_ffn_class = TransformerFFN

wrapped_class = TransformerGeneratorModel.with_components(
decoder=TransformerDecoder.with_components(
layer=TransformerDecoderLayer.with_components(
feedforward=decoder_ffn_class
)
)
)
return wrapped_class(opt=self.opt, dictionary=self.dict)
Loading

0 comments on commit 6040079

Please sign in to comment.