Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Swappable Transformer Components #3567

Merged
merged 23 commits into from
Apr 29, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6c8bd4e
minimum viable component-swappable transformer
spencerp Apr 1, 2021
f3456ea
add example
spencerp Apr 1, 2021
725bbc9
woops forgot to pass in manifest in example
spencerp Apr 1, 2021
7052932
make ComponentSpec immutable
spencerp Apr 1, 2021
b3a1cf6
keep build_encoder and build_decoder backwards compatible, add some c…
spencerp Apr 1, 2021
34f33ec
autoformat.sh
spencerp Apr 1, 2021
a52c96c
update signature of overridden build_decoder
spencerp Apr 1, 2021
5e539cc
address comments: rename manifest/tcomponent and provide example of l…
spencerp Apr 8, 2021
265699e
Merge branch 'master' into transformer-manifest-mvp
spencerp Apr 9, 2021
f3db720
update comment with new naming
spencerp Apr 9, 2021
8a3aae3
explicitly labeling some model components as static
spencerp Apr 9, 2021
12f88f0
tweak ModuleComponentSpec
spencerp Apr 14, 2021
2f9b08c
Merge branch 'master' into transformer-manifest-mvp
spencerp Apr 20, 2021
93ea7e7
add fully-specified example
spencerp Apr 20, 2021
c4b6924
add another example
spencerp Apr 20, 2021
d871fd8
remove weird auto-import
spencerp Apr 20, 2021
b95135e
simplify API
spencerp Apr 26, 2021
8f8d867
stick everything in a decorator
spencerp Apr 27, 2021
6f5817b
remove StaticComponent, clean up implementation a little, add comment…
spencerp Apr 28, 2021
9a16730
add website docs
spencerp Apr 28, 2021
e175f58
obey header level rule in docs
spencerp Apr 28, 2021
e7cdb06
in docs, Model->Transformer
spencerp Apr 29, 2021
9752d2c
Merge branch 'master' into transformer-manifest-mvp
spencerp Apr 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions parlai/agents/examples/transformer_variant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import replace
import torch
from typing import Tuple, Optional, Union

from parlai.agents.transformer.modules import (
TransformerGeneratorModel,
TransformerEncoder,
)
from parlai.agents.transformer.transformer import TransformerGeneratorAgent


class TransformerVariantAgent(TransformerGeneratorAgent):
spencerp marked this conversation as resolved.
Show resolved Hide resolved
def build_model(self, states=None):
manifest = TransformerGeneratorModel.Manifest()
manifest.encoder = replace(manifest.encoder, klass=MyCustomEncoder)
return TransformerGeneratorModel(self.opt, self.dict, manifest)


class MyCustomEncoder(TransformerEncoder):
def forward( # type: ignore
self,
input: torch.LongTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.BoolTensor]]:
# Comment out the following line and write your custom `forward` instead.
return super().forward(input, positions, segments)
spencerp marked this conversation as resolved.
Show resolved Hide resolved
spencerp marked this conversation as resolved.
Show resolved Hide resolved
304 changes: 162 additions & 142 deletions parlai/agents/transformer/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
Transformer decoder implementations.
"""

from typing import Dict, Tuple, Optional
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Type

import numpy as np
import torch
Expand All @@ -20,12 +22,160 @@
MultiHeadAttention,
TransformerFFN,
)
from parlai.agents.transformer.modules.interfaces import ComponentSpec, TComponent
from parlai.core.opt import Opt
from parlai.utils.misc import warn_once
from parlai.utils.torch import PipelineHelper


class TransformerDecoder(nn.Module):
class TransformerDecoderLayer(nn.Module, TComponent):
spencerp marked this conversation as resolved.
Show resolved Hide resolved
"""
Implements a single Transformer decoder layer.

Decoder layers are similar to encoder layers but:

1. Self-attention is limited in a casaul (auto-regressive) manner.
spencerp marked this conversation as resolved.
Show resolved Hide resolved
2. Attend over all of the encoder states.
"""

@dataclass
class Manifest(TComponent.Manifest):
self_attention: Type[MultiHeadAttention] = MultiHeadAttention
encoder_attention: Type[MultiHeadAttention] = MultiHeadAttention
feedforward: Type[TransformerFFN] = TransformerFFN

def __init__(
self,
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: str = 'aiayn',
manifest: Optional[Manifest] = None,
):
super().__init__()
manifest = manifest or self.Manifest()
self.dim = embedding_size
self.ffn_dim = ffn_size
self.variant = variant
self.activation = activation
self.dropout = nn.Dropout(p=dropout)

self.self_attention = manifest.self_attention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.encoder_attention = manifest.encoder_attention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.ffn = manifest.feedforward(
embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation
)
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.

The incremental state is a dict with values for self- and encoder-attention
states.
"""

if incr_state is None:
incr_state = {}

decoder_mask = self._create_selfattn_mask(x)
# first self attn
residual = x
if self.variant == 'prelayernorm':
x = self.norm1(x)

# don't peak into the future!
x, final_self_attn_incr_state = self.self_attention(
query=x,
mask=decoder_mask,
incr_state=incr_state.get('self_attn'),
static_kv=False,
)[:2]
x = self.dropout(x) # --dropout
x = x + residual
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
x = self.norm1(x)

residual = x
# encoder_attn_layer_norm norm 2
if self.variant == 'prelayernorm':
x = self.norm2(x)
x, final_encoder_attn_incr_state = self.encoder_attention(
query=x,
key=encoder_output,
value=encoder_output,
mask=encoder_mask,
incr_state=incr_state.get('encoder_attn'),
static_kv=True,
)[:2]
x = self.dropout(x) # --dropout
x = residual + x
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
x = self.norm2(x)

# finally the ffn
residual = x
if self.variant == 'prelayernorm':
x = self.norm3(x)
x = self.ffn(x)
x = self.dropout(x) # --dropout
x = residual + x
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
x = self.norm3(x)

new_incr_state = {
'self_attn': final_self_attn_incr_state,
'encoder_attn': final_encoder_attn_incr_state,
}
return x, new_incr_state

def _create_selfattn_mask(self, x):
# figure out how many timestamps we need
bsz = x.size(0)
time = x.size(1)
# make sure that we don't look into the future
mask = torch.tril(x.new(time, time).fill_(1))
# broadcast across batch
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
return mask

def reorder_incremental_state(
self, incremental_state: Dict[str, dict], inds: torch.Tensor
) -> Dict[str, dict]:
"""
Reorder all incremental-state tensors for this layer.
"""
attn_types = {
'self_attn': self.self_attention,
'encoder_attn': self.encoder_attention,
}
return {
attn_type: attn.reorder_incremental_state(
incremental_state[attn_type], inds
)
for attn_type, attn in attn_types.items()
}


class TransformerDecoder(nn.Module, TComponent):
"""
Transformer Decoder module.

Expand All @@ -38,13 +188,21 @@ class TransformerDecoder(nn.Module):
:param int n_positions: Size of the position embeddings matrix.
"""

@dataclass
class Manifest(TComponent.Manifest):
layer: ComponentSpec[TransformerDecoderLayer] = ComponentSpec(
TransformerDecoderLayer, TransformerDecoderLayer.Manifest()
spencerp marked this conversation as resolved.
Show resolved Hide resolved
)

def __init__(
self,
opt: Opt,
embedding: Optional[nn.Embedding] = None,
n_positions: Optional[int] = None,
manifest: Optional[Manifest] = None,
):
super().__init__()
manifest = manifest or self.Manifest()

def _default(val, default):
return val if val is not None else default
Expand Down Expand Up @@ -106,7 +264,7 @@ def _default(val, default):
self.layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(
TransformerDecoderLayer(
manifest.layer.klass(
self.n_heads,
self.embedding_size,
self.ffn_size,
Expand All @@ -115,6 +273,7 @@ def _default(val, default):
dropout=dropout_frac,
activation=self.activation,
variant=self.variant,
manifest=manifest.layer.manifest,
)
)

Expand Down Expand Up @@ -273,142 +432,3 @@ def _apply_model_parallel(self, tensor, encoder_output, encoder_mask, incr_state
}

return tensor_out, new_incr_state


class TransformerDecoderLayer(nn.Module):
"""
Implements a single Transformer decoder layer.

Decoder layers are similar to encoder layers but:

1. Self-attention is limited in a casaul (auto-regressive) manner.
2. Attend over all of the encoder states.
"""

def __init__(
self,
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: str = 'aiayn',
):
super().__init__()
self.dim = embedding_size
self.ffn_dim = ffn_size
self.variant = variant
self.activation = activation
self.dropout = nn.Dropout(p=dropout)

self.self_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.encoder_attention = MultiHeadAttention(
n_heads, embedding_size, dropout=attention_dropout
)
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.ffn = TransformerFFN(
embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation
)
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.

The incremental state is a dict with values for self- and encoder-attention
states.
"""

if incr_state is None:
incr_state = {}

decoder_mask = self._create_selfattn_mask(x)
# first self attn
residual = x
if self.variant == 'prelayernorm':
x = self.norm1(x)

# don't peak into the future!
x, final_self_attn_incr_state = self.self_attention(
query=x,
mask=decoder_mask,
incr_state=incr_state.get('self_attn'),
static_kv=False,
)[:2]
x = self.dropout(x) # --dropout
x = x + residual
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
x = self.norm1(x)

residual = x
# encoder_attn_layer_norm norm 2
if self.variant == 'prelayernorm':
x = self.norm2(x)
x, final_encoder_attn_incr_state = self.encoder_attention(
query=x,
key=encoder_output,
value=encoder_output,
mask=encoder_mask,
incr_state=incr_state.get('encoder_attn'),
static_kv=True,
)[:2]
x = self.dropout(x) # --dropout
x = residual + x
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
x = self.norm2(x)

# finally the ffn
residual = x
if self.variant == 'prelayernorm':
x = self.norm3(x)
x = self.ffn(x)
x = self.dropout(x) # --dropout
x = residual + x
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
x = self.norm3(x)

new_incr_state = {
'self_attn': final_self_attn_incr_state,
'encoder_attn': final_encoder_attn_incr_state,
}
return x, new_incr_state

def _create_selfattn_mask(self, x):
# figure out how many timestamps we need
bsz = x.size(0)
time = x.size(1)
# make sure that we don't look into the future
mask = torch.tril(x.new(time, time).fill_(1))
# broadcast across batch
mask = mask.unsqueeze(0).expand(bsz, -1, -1)
return mask

def reorder_incremental_state(
self, incremental_state: Dict[str, dict], inds: torch.Tensor
) -> Dict[str, dict]:
"""
Reorder all incremental-state tensors for this layer.
"""
attn_types = {
'self_attn': self.self_attention,
'encoder_attn': self.encoder_attention,
}
return {
attn_type: attn.reorder_incremental_state(
incremental_state[attn_type], inds
)
for attn_type, attn in attn_types.items()
}
Loading