Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Pass opt and kwargs all the way through transformer #3708

Merged
merged 3 commits into from
Jun 14, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion parlai/agents/transformer/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn as nn
import torch.nn.functional as F

from parlai.core.opt import Opt
from parlai.utils.torch import neginf


Expand Down Expand Up @@ -92,8 +93,18 @@ class MultiHeadAttention(nn.Module):
See Vaswani (2017) for an extensive description.
"""

def __init__(self, n_heads: int, dim: int, dropout: float = 0):
def __init__(
self, opt: Opt, n_heads: int = None, dim: int = None, dropout: float = 0
):
super(MultiHeadAttention, self).__init__()

def _default(val, default):
""" shorthand for explicit None check for optional arguments """
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
dim = _default(dim, opt['embedding_size'])

self.n_heads = n_heads
self.dim = dim

Expand Down
58 changes: 40 additions & 18 deletions parlai/agents/transformer/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ class TransformerDecoderLayer(nn.Module):

def __init__(
self,
n_heads: int,
embedding_size: int,
ffn_size: int,
opt: Opt,
n_heads: int = None,
embedding_size: int = None,
ffn_size: int = None,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
Expand All @@ -55,24 +56,38 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

def _default(val, default):
""" shorthand for explicit None check for optional arguments """
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
embedding_size = _default(embedding_size, opt['embedding_size'])
ffn_size = _default(ffn_size, opt['ffn_size'])

self.opt = opt
self.dim = embedding_size
self.ffn_dim = ffn_size
self.variant = variant
self.activation = activation
self.dropout = nn.Dropout(p=dropout)

self.self_attention = self.swappables.self_attention(
n_heads, embedding_size, dropout=attention_dropout
opt=self.opt, n_heads=n_heads, dim=embedding_size, dropout=attention_dropout
) # type: ignore
self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.encoder_attention = self.swappables.encoder_attention(
n_heads, embedding_size, dropout=attention_dropout
opt=self.opt, n_heads=n_heads, dim=embedding_size, dropout=attention_dropout
) # type: ignore
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.ffn = self.swappables.feedforward(
embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation
opt=self.opt,
dim=embedding_size,
dim_hidden=ffn_size,
relu_dropout=relu_dropout,
activation=activation,
) # type: ignore
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

Expand Down Expand Up @@ -106,6 +121,7 @@ def forward(
mask=decoder_mask,
incr_state=incr_state.get('self_attn'),
static_kv=False,
**kwargs,
)[:2]
x = self.dropout(x) # --dropout
x = x + residual
Expand All @@ -123,6 +139,7 @@ def forward(
mask=encoder_mask,
incr_state=incr_state.get('encoder_attn'),
static_kv=True,
**kwargs,
)[:2]
x = self.dropout(x) # --dropout
x = residual + x
Expand All @@ -133,7 +150,7 @@ def forward(
residual = x
if self.variant == 'prelayernorm':
x = self.norm3(x)
x = self.ffn(x)
x = self.ffn(x, **kwargs)
x = self.dropout(x) # --dropout
x = residual + x
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
Expand Down Expand Up @@ -195,6 +212,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.opt = opt

def _default(val, default):
return val if val is not None else default
Expand All @@ -212,8 +230,7 @@ def _default(val, default):
self.variant = opt.get('variant', 'aiayn')

self.embeddings_scale = opt.get('embeddings_scale', True)
dropout_frac = opt.get('dropout', 0.0)
self.dropout = nn.Dropout(p=dropout_frac) # --dropout
self.dropout = nn.Dropout(p=opt.get('dropout', 0.0)) # --dropout

self.n_positions = _default(n_positions, get_n_positions_from_options(opt))
self.out_dim = self.embedding_size
Expand Down Expand Up @@ -253,26 +270,29 @@ def _default(val, default):
)

# build the model
self.layers = nn.ModuleList()
self.layers = self.build_layers()

def build_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(
self.swappables.layer(
self.n_heads,
self.embedding_size,
self.ffn_size,
attention_dropout=opt.get('attention_dropout', 0.0),
relu_dropout=opt.get('relu_dropout', 0.0),
dropout=dropout_frac,
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.opt.get('dropout', 0.0),
activation=self.activation,
variant=self.variant,
) # type: ignore
)
return layers

def forward_embedding(
self,
input: torch.LongTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
**kwargs,
):
"""
Embed tokens prior to feeding into transformer.
Expand Down Expand Up @@ -311,6 +331,7 @@ def forward_layers(
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Dict[int, Dict[str, Dict[str, torch.Tensor]]],
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass of decoder layers.
Expand Down Expand Up @@ -340,6 +361,7 @@ def forward_layers(
encoder_output=encoder_output,
encoder_mask=encoder_mask,
incr_state=incr_state.get(idx),
**kwargs,
)

return tensor, new_incr_state
Expand Down Expand Up @@ -377,12 +399,12 @@ def forward(
else:
incr_state = {}

tensor = self.forward_embedding(input, positions)
tensor = self.forward_embedding(input, positions, **kwargs)

tensor = self.dropout(tensor) # --dropout

tensor, new_incr_state = self.forward_layers(
tensor, encoder_output, encoder_mask, incr_state
tensor, encoder_output, encoder_mask, incr_state, **kwargs
)

if self.variant == 'prelayernorm':
Expand Down
46 changes: 32 additions & 14 deletions parlai/agents/transformer/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class TransformerEncoderLayer(nn.Module):

def __init__(
self,
n_heads: int,
embedding_size: int,
ffn_size: int,
opt: Opt,
n_heads: int = None,
embedding_size: int = None,
ffn_size: int = None,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
Expand All @@ -46,17 +47,31 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

def _default(val, default):
""" shorthand for explicit None check for optional arguments """
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
embedding_size = _default(embedding_size, opt['embedding_size'])
ffn_size = _default(ffn_size, opt['ffn_size'])

self.opt = opt
self.dim = embedding_size
self.ffn_dim = ffn_size
self.activation = activation
self.variant = variant
self.attention = self.swappables.self_attention( # type: ignore
n_heads, embedding_size, dropout=attention_dropout # --attention-dropout
opt=self.opt,
n_heads=n_heads,
dim=embedding_size,
dropout=attention_dropout, # --attention-dropout
)
self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)
self.ffn = self.swappables.feedforward( # type: ignore
embedding_size,
ffn_size,
opt=self.opt,
dim=embedding_size,
dim_hidden=ffn_size,
relu_dropout=relu_dropout,
activation=self.activation,
)
Expand Down Expand Up @@ -145,6 +160,7 @@ def _default(val, default):
# this is --dropout, not --relu-dropout or --attention-dropout
self.dropout_frac = _default(dropout, opt.get('dropout', 0.0))
self.dropout = nn.Dropout(p=self.dropout_frac)
self.activation = _default(activation, opt.get('activation', 'relu'))
self.variant = _default(variant, opt.get('variant', 'aiayn'))
self.n_segments = _default(n_segments, opt.get('n_segments', 0))

Expand Down Expand Up @@ -203,21 +219,23 @@ def _default(val, default):
nn.init.normal_(self.segment_embeddings.weight, 0, self.dim ** -0.5)

# build the model
self.layers = nn.ModuleList()
self.layers = self.build_layers()
self.output_scaling = _default(output_scaling, opt.get('output_scaling', 1.0))

def build_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(
self.swappables.layer( # type: ignore
self.n_heads,
self.embedding_size,
self.ffn_size,
attention_dropout=opt.get('attention_dropout', 0.0),
relu_dropout=opt.get('relu_dropout', 0.0),
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.dropout_frac,
variant=self.variant,
activation=_default(activation, opt.get('activation', 'relu')),
activation=self.activation,
)
)
self.output_scaling = _default(output_scaling, opt.get('output_scaling', 1.0))
return layers

def forward_embedding(
self,
Expand Down
19 changes: 16 additions & 3 deletions parlai/agents/transformer/modules/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn as nn
import torch.nn.functional as F

from parlai.core.opt import Opt


class TransformerFFN(nn.Module):
"""
Expand All @@ -19,12 +21,23 @@ class TransformerFFN(nn.Module):

def __init__(
self,
dim: int,
dim_hidden: int,
opt: Opt,
dim: int = None,
dim_hidden: int = None,
relu_dropout: float = 0,
activation: str = 'relu',
**kwargs,
):
super(TransformerFFN, self).__init__()
super(TransformerFFN, self).__init__(**kwargs)

def _default(val, default):
""" shorthand for explicit None check for optional arguments """
return val if val is not None else default

dim = _default(dim, opt['embedding_size'])
dim_hidden = _default(dim_hidden, opt['ffn_size'])

self.opt = opt
self.relu_dropout = nn.Dropout(p=relu_dropout)
if activation == 'relu':
self.nonlinear = F.relu
Expand Down
Loading