diff --git a/parlai/agents/transformer/modules/attention.py b/parlai/agents/transformer/modules/attention.py index d6b721683a0..4f67369c6ca 100644 --- a/parlai/agents/transformer/modules/attention.py +++ b/parlai/agents/transformer/modules/attention.py @@ -8,7 +8,7 @@ """ import math -from typing import Dict, Tuple, Optional +from typing import Dict, Tuple, Optional, Union import torch import torch.nn as nn @@ -22,7 +22,13 @@ class BasicAttention(nn.Module): Implements simple/classical attention. """ - def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True): + def __init__( + self, + dim: int = 1, + attn: str = 'cosine', + residual: bool = False, + get_weights: bool = True, + ): super().__init__() if attn == 'cosine': self.cosine = nn.CosineSimilarity(dim=dim) @@ -31,7 +37,13 @@ def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True): self.get_weights = get_weights self.residual = residual - def forward(self, xs, ys, mask_ys=None, values=None): + def forward( + self, + xs: torch.Tensor, + ys: torch.Tensor, + mask_ys: Optional[torch.Tensor] = None, + values: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Compute attention. @@ -80,7 +92,7 @@ class MultiHeadAttention(nn.Module): See Vaswani (2017) for an extensive description. """ - def __init__(self, n_heads, dim, dropout=0): + def __init__(self, n_heads: int, dim: int, dropout: float = 0): super(MultiHeadAttention, self).__init__() self.n_heads = n_heads self.dim = dim diff --git a/parlai/agents/transformer/modules/decoder.py b/parlai/agents/transformer/modules/decoder.py index e081e4df001..abd066fbb5e 100644 --- a/parlai/agents/transformer/modules/decoder.py +++ b/parlai/agents/transformer/modules/decoder.py @@ -194,7 +194,12 @@ def forward_layers( return tensor, new_incr_state - def forward(self, input, encoder_state, incr_state=None): + def forward( + self, + input: torch.Tensor, + encoder_state, + incr_state: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Forward pass. @@ -282,14 +287,14 @@ class TransformerDecoderLayer(nn.Module): def __init__( self, - n_heads, - embedding_size, - ffn_size, - attention_dropout=0.0, - relu_dropout=0.0, - dropout=0.0, - activation='relu', - variant='aiayn', + n_heads: int, + embedding_size: int, + ffn_size: int, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: str = 'aiayn', ): super().__init__() self.dim = embedding_size @@ -313,7 +318,13 @@ def __init__( ) self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - def forward(self, x, encoder_output, encoder_mask, incr_state=None): + def forward( + self, + x: torch.Tensor, + encoder_output: torch.Tensor, + encoder_mask: torch.Tensor, + incr_state: Optional[Dict[str, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Forward pass. diff --git a/parlai/agents/transformer/modules/encoder.py b/parlai/agents/transformer/modules/encoder.py index 8545ce5c5cc..ec338e4e004 100644 --- a/parlai/agents/transformer/modules/encoder.py +++ b/parlai/agents/transformer/modules/encoder.py @@ -315,14 +315,14 @@ class TransformerEncoderLayer(nn.Module): def __init__( self, - n_heads, - embedding_size, - ffn_size, - attention_dropout=0.0, - relu_dropout=0.0, - dropout=0.0, - activation='relu', - variant=None, + n_heads: int, + embedding_size: int, + ffn_size: int, + attention_dropout: float = 0.0, + relu_dropout: float = 0.0, + dropout: float = 0.0, + activation: str = 'relu', + variant: Optional[str] = None, ): super().__init__() self.dim = embedding_size @@ -342,7 +342,7 @@ def __init__( self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) self.dropout = nn.Dropout(p=dropout) - def forward(self, tensor, mask): + def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Forward pass. """ diff --git a/parlai/agents/transformer/modules/ffn.py b/parlai/agents/transformer/modules/ffn.py index 741d79f6a14..32fe1d74393 100644 --- a/parlai/agents/transformer/modules/ffn.py +++ b/parlai/agents/transformer/modules/ffn.py @@ -7,6 +7,7 @@ Feedforward neural network, as used in transformer implementation. """ +import torch import torch.nn as nn import torch.nn.functional as F @@ -16,7 +17,13 @@ class TransformerFFN(nn.Module): Implements the FFN part of the transformer. """ - def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'): + def __init__( + self, + dim: int, + dim_hidden: int, + relu_dropout: float = 0, + activation: str = 'relu', + ): super(TransformerFFN, self).__init__() self.relu_dropout = nn.Dropout(p=relu_dropout) if activation == 'relu': @@ -33,7 +40,7 @@ def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'): nn.init.xavier_uniform_(self.lin2.weight) # TODO: initialize biases to 0 - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass. """ diff --git a/parlai/agents/transformer/modules/generator.py b/parlai/agents/transformer/modules/generator.py index 37882fefd74..e6db5dc0289 100644 --- a/parlai/agents/transformer/modules/generator.py +++ b/parlai/agents/transformer/modules/generator.py @@ -27,6 +27,8 @@ TransformerDecoder, TransformerEncoder, ) +from parlai.core.opt import Opt +from parlai.core.torch_agent import DictionaryAgent from parlai.core.torch_generator_agent import TorchGeneratorModel from parlai.utils.torch import neginf @@ -52,7 +54,7 @@ def build_encoder( def build_decoder(cls, opt, embedding=None): return TransformerDecoder(opt=opt, embedding=embedding) - def __init__(self, opt, dictionary): + def __init__(self, opt: Opt, dictionary: DictionaryAgent): self.pad_idx = dictionary[dictionary.null_token] self.start_idx = dictionary[dictionary.start_token] self.end_idx = dictionary[dictionary.end_token] diff --git a/parlai/agents/transformer/modules/mem_net.py b/parlai/agents/transformer/modules/mem_net.py index 508d6a443e0..ff3ddca7f1f 100644 --- a/parlai/agents/transformer/modules/mem_net.py +++ b/parlai/agents/transformer/modules/mem_net.py @@ -3,7 +3,9 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn +from typing import Optional, Tuple from parlai.agents.transformer.modules import ( create_embeddings, @@ -11,6 +13,8 @@ TransformerEncoder, TransformerResponseWrapper, ) +from parlai.core.opt import Opt +from parlai.core.torch_agent import DictionaryAgent class TransformerMemNetModel(nn.Module): @@ -30,7 +34,7 @@ def build_encoder( reduction_type=reduction_type, ) - def __init__(self, opt, dictionary): + def __init__(self, opt: Opt, dictionary: DictionaryAgent): super().__init__() self.opt = opt self.pad_idx = dictionary[dictionary.null_token] @@ -136,7 +140,13 @@ def encode_context_memory(self, context_w, memories_w, context_segments=None): return weights, context_h - def forward(self, xs, mems, cands, context_segments=None): + def forward( + self, + xs: torch.LongTensor, + mems: torch.LongTensor, + cands: torch.LongTensor, + context_segments: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.LongTensor, torch.LongTensor]: """ Forward pass. diff --git a/parlai/agents/transformer/modules/wrappers.py b/parlai/agents/transformer/modules/wrappers.py index 9ace8014681..df728f97d95 100644 --- a/parlai/agents/transformer/modules/wrappers.py +++ b/parlai/agents/transformer/modules/wrappers.py @@ -3,8 +3,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import torch import torch.nn as nn +from parlai.agents.transformer.modules import TransformerEncoder + class TransformerResponseWrapper(nn.Module): """ @@ -13,7 +16,7 @@ class TransformerResponseWrapper(nn.Module): Pushes input through transformer and MLP. """ - def __init__(self, transformer, hdim): + def __init__(self, transformer: TransformerEncoder, hdim: int): super(TransformerResponseWrapper, self).__init__() dim = transformer.out_dim self.transformer = transformer @@ -23,7 +26,7 @@ def __init__(self, transformer, hdim): nn.Linear(hdim, dim), ) - def forward(self, *args): + def forward(self, *args) -> torch.Tensor: """ Forward pass. """ @@ -35,13 +38,13 @@ class TransformerLinearWrapper(nn.Module): Wrap a transformer in a linear layer. """ - def __init__(self, transformer, output_dim): + def __init__(self, transformer: TransformerEncoder, output_dim: int): super().__init__() self.transformer = transformer input_dim = transformer.out_dim self.additional_linear_layer = nn.Linear(input_dim, output_dim) - def forward(self, *args): + def forward(self, *args) -> torch.Tensor: """ Forward pass.