Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Type annotate transformer/modules.py #3545

Merged
merged 2 commits into from
Mar 24, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 108 additions & 72 deletions parlai/agents/transformer/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch.nn.functional as F

from parlai.core.opt import Opt
from parlai.core.torch_agent import DictionaryAgent
from parlai.core.torch_generator_agent import TorchGeneratorModel
from parlai.utils.misc import warn_once
from parlai.utils.torch import neginf, PipelineHelper
Expand Down Expand Up @@ -82,7 +83,7 @@ def build_encoder(
reduction_type=reduction_type,
)

def __init__(self, opt, dictionary):
def __init__(self, opt: Opt, dictionary: DictionaryAgent):
super().__init__()
self.opt = opt
self.pad_idx = dictionary[dictionary.null_token]
Expand Down Expand Up @@ -188,7 +189,13 @@ def encode_context_memory(self, context_w, memories_w, context_segments=None):

return weights, context_h

def forward(self, xs, mems, cands, context_segments=None):
def forward(
self,
xs: torch.LongTensor,
mems: torch.LongTensor,
cands: torch.LongTensor,
context_segments: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""
Forward pass.

Expand Down Expand Up @@ -230,51 +237,6 @@ def create_position_codes(n_pos, dim, out):
out[:, 1::2] = torch.FloatTensor(np.cos(position_enc)).type_as(out)


class TransformerResponseWrapper(nn.Module):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to move these down so TransformerEncoder is declared before using it as a type. If we want to avoid moving these, I could just wrap the type in quotes "TransformerEncoder", but that disables static checking of these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah the secret to that is to add at the top:

from __future__ import annotations

This causes python to delay typechecking

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ended up being resolved anyway by the breaking up of modules.py

"""
Wrap transformer response.

Pushes input through transformer and MLP.
"""

def __init__(self, transformer, hdim):
super(TransformerResponseWrapper, self).__init__()
dim = transformer.out_dim
self.transformer = transformer
self.mlp = nn.Sequential(
nn.Linear(dim, hdim),
nn.ReLU(), # TODO: should this also be gelu?
nn.Linear(hdim, dim),
)

def forward(self, *args):
"""
Forward pass.
"""
return self.mlp(self.transformer(*args))


class TransformerLinearWrapper(nn.Module):
"""
Wrap a transformer in a linear layer.
"""

def __init__(self, transformer, output_dim):
super().__init__()
self.transformer = transformer
input_dim = transformer.out_dim
self.additional_linear_layer = nn.Linear(input_dim, output_dim)

def forward(self, *args):
"""
Forward pass.

Apply transformer, then additional linear layer.
"""
context_h = self.transformer(*args)
return self.additional_linear_layer(context_h)


class TransformerEncoder(nn.Module):
"""
Transformer encoder module.
Expand Down Expand Up @@ -558,21 +520,66 @@ def _apply_model_parallel(self, tensor, mask):
return tensor_out


class TransformerResponseWrapper(nn.Module):
"""
Wrap transformer response.

Pushes input through transformer and MLP.
"""

def __init__(self, transformer: TransformerEncoder, hdim: int):
super(TransformerResponseWrapper, self).__init__()
dim = transformer.out_dim
self.transformer = transformer
self.mlp = nn.Sequential(
nn.Linear(dim, hdim),
nn.ReLU(), # TODO: should this also be gelu?
nn.Linear(hdim, dim),
)

def forward(self, *args) -> torch.Tensor:
"""
Forward pass.
"""
return self.mlp(self.transformer(*args))


class TransformerLinearWrapper(nn.Module):
"""
Wrap a transformer in a linear layer.
"""

def __init__(self, transformer: TransformerEncoder, output_dim: int):
super().__init__()
self.transformer = transformer
input_dim = transformer.out_dim
self.additional_linear_layer = nn.Linear(input_dim, output_dim)

def forward(self, *args) -> torch.Tensor:
"""
Forward pass.

Apply transformer, then additional linear layer.
"""
context_h = self.transformer(*args)
return self.additional_linear_layer(context_h)


class TransformerEncoderLayer(nn.Module):
"""
Implements a single Transformer encoder layer.
"""

def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
activation='relu',
variant=None,
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: Optional[str] = None,
):
super().__init__()
self.dim = embedding_size
Expand All @@ -592,7 +599,7 @@ def __init__(
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)
self.dropout = nn.Dropout(p=dropout)

def forward(self, tensor, mask):
def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down Expand Up @@ -782,7 +789,12 @@ def forward_layers(

return tensor, new_incr_state

def forward(self, input, encoder_state, incr_state=None):
def forward(
self,
input: torch.Tensor,
encoder_state,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.

Expand Down Expand Up @@ -870,14 +882,14 @@ class TransformerDecoderLayer(nn.Module):

def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
activation='relu',
variant='aiayn',
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: str = 'aiayn',
):
super().__init__()
self.dim = embedding_size
Expand All @@ -901,7 +913,13 @@ def __init__(
)
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

def forward(self, x, encoder_output, encoder_mask, incr_state=None):
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.

Expand Down Expand Up @@ -1012,7 +1030,7 @@ def build_encoder(
def build_decoder(cls, opt, embedding=None):
return TransformerDecoder(opt=opt, embedding=embedding)

def __init__(self, opt, dictionary):
def __init__(self, opt: Opt, dictionary: DictionaryAgent):
self.pad_idx = dictionary[dictionary.null_token]
self.start_idx = dictionary[dictionary.start_token]
self.end_idx = dictionary[dictionary.end_token]
Expand Down Expand Up @@ -1072,7 +1090,13 @@ class BasicAttention(nn.Module):
Implements simple/classical attention.
"""

def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
def __init__(
self,
dim: int = 1,
attn: str = 'cosine',
residual: bool = False,
get_weights: bool = True,
):
super().__init__()
if attn == 'cosine':
self.cosine = nn.CosineSimilarity(dim=dim)
Expand All @@ -1081,7 +1105,13 @@ def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
self.get_weights = get_weights
self.residual = residual

def forward(self, xs, ys, mask_ys=None, values=None):
def forward(
self,
xs: torch.Tensor,
ys: torch.Tensor,
mask_ys: Optional[torch.Tensor] = None,
values: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute attention.

Expand Down Expand Up @@ -1130,7 +1160,7 @@ class MultiHeadAttention(nn.Module):
See Vaswani (2017) for an extensive description.
"""

def __init__(self, n_heads, dim, dropout=0):
def __init__(self, n_heads: int, dim: int, dropout: float = 0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
Expand Down Expand Up @@ -1306,7 +1336,13 @@ class TransformerFFN(nn.Module):
Implements the FFN part of the transformer.
"""

def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'):
def __init__(
self,
dim: int,
dim_hidden: int,
relu_dropout: float = 0,
activation: str = 'relu',
):
super(TransformerFFN, self).__init__()
self.relu_dropout = nn.Dropout(p=relu_dropout)
if activation == 'relu':
Expand All @@ -1323,7 +1359,7 @@ def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'):
nn.init.xavier_uniform_(self.lin2.weight)
# TODO: initialize biases to 0

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down