Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Pass opt and kwargs all the way through transformer (#3708)
Browse files Browse the repository at this point in the history
* pass opt and kwargs all the way through transformer

* update hash ladder code

* massage
  • Loading branch information
spencerp authored Jun 14, 2021
1 parent c4c2669 commit 4014ae9
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 331 deletions.
16 changes: 15 additions & 1 deletion parlai/agents/transformer/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch.nn as nn
import torch.nn.functional as F

from parlai.core.opt import Opt
from parlai.utils.torch import neginf


Expand Down Expand Up @@ -92,8 +93,20 @@ class MultiHeadAttention(nn.Module):
See Vaswani (2017) for an extensive description.
"""

def __init__(self, n_heads: int, dim: int, dropout: float = 0):
def __init__(
self, opt: Opt, n_heads: int = None, dim: int = None, dropout: float = 0
):
super(MultiHeadAttention, self).__init__()

def _default(val, default):
"""
shorthand for explicit None check for optional arguments.
"""
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
dim = _default(dim, opt['embedding_size'])

self.n_heads = n_heads
self.dim = dim

Expand All @@ -120,6 +133,7 @@ def forward( # type: ignore
mask: torch.Tensor = None,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
static_kv: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""
Forward pass.
Expand Down
62 changes: 43 additions & 19 deletions parlai/agents/transformer/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ class TransformerDecoderLayer(nn.Module):

def __init__(
self,
n_heads: int,
embedding_size: int,
ffn_size: int,
opt: Opt,
n_heads: int = None,
embedding_size: int = None,
ffn_size: int = None,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
Expand All @@ -55,24 +56,40 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

def _default(val, default):
"""
shorthand for explicit None check for optional arguments.
"""
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
embedding_size = _default(embedding_size, opt['embedding_size'])
ffn_size = _default(ffn_size, opt['ffn_size'])

self.opt = opt
self.dim = embedding_size
self.ffn_dim = ffn_size
self.variant = variant
self.activation = activation
self.dropout = nn.Dropout(p=dropout)

self.self_attention = self.swappables.self_attention(
n_heads, embedding_size, dropout=attention_dropout
opt=self.opt, n_heads=n_heads, dim=embedding_size, dropout=attention_dropout
) # type: ignore
self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.encoder_attention = self.swappables.encoder_attention(
n_heads, embedding_size, dropout=attention_dropout
opt=self.opt, n_heads=n_heads, dim=embedding_size, dropout=attention_dropout
) # type: ignore
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

self.ffn = self.swappables.feedforward(
embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation
opt=self.opt,
dim=embedding_size,
dim_hidden=ffn_size,
relu_dropout=relu_dropout,
activation=activation,
) # type: ignore
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

Expand Down Expand Up @@ -106,6 +123,7 @@ def forward(
mask=decoder_mask,
incr_state=incr_state.get('self_attn'),
static_kv=False,
**kwargs,
)[:2]
x = self.dropout(x) # --dropout
x = x + residual
Expand All @@ -123,6 +141,7 @@ def forward(
mask=encoder_mask,
incr_state=incr_state.get('encoder_attn'),
static_kv=True,
**kwargs,
)[:2]
x = self.dropout(x) # --dropout
x = residual + x
Expand All @@ -133,7 +152,7 @@ def forward(
residual = x
if self.variant == 'prelayernorm':
x = self.norm3(x)
x = self.ffn(x)
x = self.ffn(x, **kwargs)
x = self.dropout(x) # --dropout
x = residual + x
if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart':
Expand Down Expand Up @@ -195,6 +214,7 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.opt = opt

def _default(val, default):
return val if val is not None else default
Expand All @@ -212,8 +232,7 @@ def _default(val, default):
self.variant = opt.get('variant', 'aiayn')

self.embeddings_scale = opt.get('embeddings_scale', True)
dropout_frac = opt.get('dropout', 0.0)
self.dropout = nn.Dropout(p=dropout_frac) # --dropout
self.dropout = nn.Dropout(p=opt.get('dropout', 0.0)) # --dropout

self.n_positions = _default(n_positions, get_n_positions_from_options(opt))
self.out_dim = self.embedding_size
Expand Down Expand Up @@ -253,26 +272,29 @@ def _default(val, default):
)

# build the model
self.layers = nn.ModuleList()
self.layers = self.build_layers()

def build_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(
layers.append(
self.swappables.layer(
self.n_heads,
self.embedding_size,
self.ffn_size,
attention_dropout=opt.get('attention_dropout', 0.0),
relu_dropout=opt.get('relu_dropout', 0.0),
dropout=dropout_frac,
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.opt.get('dropout', 0.0),
activation=self.activation,
variant=self.variant,
) # type: ignore
)
return layers

def forward_embedding(
self,
input: torch.LongTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
**kwargs,
):
"""
Embed tokens prior to feeding into transformer.
Expand Down Expand Up @@ -311,6 +333,7 @@ def forward_layers(
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Dict[int, Dict[str, Dict[str, torch.Tensor]]],
**kwargs,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass of decoder layers.
Expand Down Expand Up @@ -340,6 +363,7 @@ def forward_layers(
encoder_output=encoder_output,
encoder_mask=encoder_mask,
incr_state=incr_state.get(idx),
**kwargs,
)

return tensor, new_incr_state
Expand Down Expand Up @@ -377,12 +401,12 @@ def forward(
else:
incr_state = {}

tensor = self.forward_embedding(input, positions)
tensor = self.forward_embedding(input, positions, **kwargs)

tensor = self.dropout(tensor) # --dropout

tensor, new_incr_state = self.forward_layers(
tensor, encoder_output, encoder_mask, incr_state
tensor, encoder_output, encoder_mask, incr_state, **kwargs
)

if self.variant == 'prelayernorm':
Expand Down
50 changes: 35 additions & 15 deletions parlai/agents/transformer/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@ class TransformerEncoderLayer(nn.Module):

def __init__(
self,
n_heads: int,
embedding_size: int,
ffn_size: int,
opt: Opt,
n_heads: int = None,
embedding_size: int = None,
ffn_size: int = None,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
Expand All @@ -46,17 +47,33 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)

def _default(val, default):
"""
shorthand for explicit None check for optional arguments.
"""
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
embedding_size = _default(embedding_size, opt['embedding_size'])
ffn_size = _default(ffn_size, opt['ffn_size'])

self.opt = opt
self.dim = embedding_size
self.ffn_dim = ffn_size
self.activation = activation
self.variant = variant
self.attention = self.swappables.self_attention( # type: ignore
n_heads, embedding_size, dropout=attention_dropout # --attention-dropout
opt=self.opt,
n_heads=n_heads,
dim=embedding_size,
dropout=attention_dropout, # --attention-dropout
)
self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)
self.ffn = self.swappables.feedforward( # type: ignore
embedding_size,
ffn_size,
opt=self.opt,
dim=embedding_size,
dim_hidden=ffn_size,
relu_dropout=relu_dropout,
activation=self.activation,
)
Expand Down Expand Up @@ -145,6 +162,7 @@ def _default(val, default):
# this is --dropout, not --relu-dropout or --attention-dropout
self.dropout_frac = _default(dropout, opt.get('dropout', 0.0))
self.dropout = nn.Dropout(p=self.dropout_frac)
self.activation = _default(activation, opt.get('activation', 'relu'))
self.variant = _default(variant, opt.get('variant', 'aiayn'))
self.n_segments = _default(n_segments, opt.get('n_segments', 0))

Expand Down Expand Up @@ -203,21 +221,23 @@ def _default(val, default):
nn.init.normal_(self.segment_embeddings.weight, 0, self.dim ** -0.5)

# build the model
self.layers = nn.ModuleList()
self.layers = self.build_layers()
self.output_scaling = _default(output_scaling, opt.get('output_scaling', 1.0))

def build_layers(self) -> nn.ModuleList:
layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(
layers.append(
self.swappables.layer( # type: ignore
self.n_heads,
self.embedding_size,
self.ffn_size,
attention_dropout=opt.get('attention_dropout', 0.0),
relu_dropout=opt.get('relu_dropout', 0.0),
self.opt,
attention_dropout=self.opt.get('attention_dropout', 0.0),
relu_dropout=self.opt.get('relu_dropout', 0.0),
dropout=self.dropout_frac,
variant=self.variant,
activation=_default(activation, opt.get('activation', 'relu')),
activation=self.activation,
)
)
self.output_scaling = _default(output_scaling, opt.get('output_scaling', 1.0))
return layers

def forward_embedding(
self,
Expand Down
23 changes: 19 additions & 4 deletions parlai/agents/transformer/modules/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch.nn as nn
import torch.nn.functional as F

from parlai.core.opt import Opt


class TransformerFFN(nn.Module):
"""
Expand All @@ -19,12 +21,25 @@ class TransformerFFN(nn.Module):

def __init__(
self,
dim: int,
dim_hidden: int,
opt: Opt,
dim: int = None,
dim_hidden: int = None,
relu_dropout: float = 0,
activation: str = 'relu',
**kwargs,
):
super(TransformerFFN, self).__init__()
super(TransformerFFN, self).__init__(**kwargs)

def _default(val, default):
"""
shorthand for explicit None check for optional arguments.
"""
return val if val is not None else default

dim = _default(dim, opt['embedding_size'])
dim_hidden = _default(dim_hidden, opt['ffn_size'])

self.opt = opt
self.relu_dropout = nn.Dropout(p=relu_dropout)
if activation == 'relu':
self.nonlinear = F.relu
Expand All @@ -40,7 +55,7 @@ def __init__(
nn.init.xavier_uniform_(self.lin2.weight)
# TODO: initialize biases to 0

def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down
Loading

0 comments on commit 4014ae9

Please sign in to comment.