Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
type annotate transformer/modules.py (#3545)
Browse files Browse the repository at this point in the history
  • Loading branch information
spencerp authored Mar 24, 2021
1 parent 7171d7c commit f11c077
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 32 deletions.
20 changes: 16 additions & 4 deletions parlai/agents/transformer/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import math
from typing import Dict, Tuple, Optional
from typing import Dict, Tuple, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -22,7 +22,13 @@ class BasicAttention(nn.Module):
Implements simple/classical attention.
"""

def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
def __init__(
self,
dim: int = 1,
attn: str = 'cosine',
residual: bool = False,
get_weights: bool = True,
):
super().__init__()
if attn == 'cosine':
self.cosine = nn.CosineSimilarity(dim=dim)
Expand All @@ -31,7 +37,13 @@ def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
self.get_weights = get_weights
self.residual = residual

def forward(self, xs, ys, mask_ys=None, values=None):
def forward(
self,
xs: torch.Tensor,
ys: torch.Tensor,
mask_ys: Optional[torch.Tensor] = None,
values: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute attention.
Expand Down Expand Up @@ -80,7 +92,7 @@ class MultiHeadAttention(nn.Module):
See Vaswani (2017) for an extensive description.
"""

def __init__(self, n_heads, dim, dropout=0):
def __init__(self, n_heads: int, dim: int, dropout: float = 0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
Expand Down
31 changes: 21 additions & 10 deletions parlai/agents/transformer/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def forward_layers(

return tensor, new_incr_state

def forward(self, input, encoder_state, incr_state=None):
def forward(
self,
input: torch.Tensor,
encoder_state,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.
Expand Down Expand Up @@ -282,14 +287,14 @@ class TransformerDecoderLayer(nn.Module):

def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
activation='relu',
variant='aiayn',
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: str = 'aiayn',
):
super().__init__()
self.dim = embedding_size
Expand All @@ -313,7 +318,13 @@ def __init__(
)
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

def forward(self, x, encoder_output, encoder_mask, incr_state=None):
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.
Expand Down
18 changes: 9 additions & 9 deletions parlai/agents/transformer/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,14 @@ class TransformerEncoderLayer(nn.Module):

def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
activation='relu',
variant=None,
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: Optional[str] = None,
):
super().__init__()
self.dim = embedding_size
Expand All @@ -342,7 +342,7 @@ def __init__(
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)
self.dropout = nn.Dropout(p=dropout)

def forward(self, tensor, mask):
def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down
11 changes: 9 additions & 2 deletions parlai/agents/transformer/modules/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Feedforward neural network, as used in transformer implementation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -16,7 +17,13 @@ class TransformerFFN(nn.Module):
Implements the FFN part of the transformer.
"""

def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'):
def __init__(
self,
dim: int,
dim_hidden: int,
relu_dropout: float = 0,
activation: str = 'relu',
):
super(TransformerFFN, self).__init__()
self.relu_dropout = nn.Dropout(p=relu_dropout)
if activation == 'relu':
Expand All @@ -33,7 +40,7 @@ def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'):
nn.init.xavier_uniform_(self.lin2.weight)
# TODO: initialize biases to 0

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down
4 changes: 3 additions & 1 deletion parlai/agents/transformer/modules/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
TransformerDecoder,
TransformerEncoder,
)
from parlai.core.opt import Opt
from parlai.core.torch_agent import DictionaryAgent
from parlai.core.torch_generator_agent import TorchGeneratorModel
from parlai.utils.torch import neginf

Expand All @@ -52,7 +54,7 @@ def build_encoder(
def build_decoder(cls, opt, embedding=None):
return TransformerDecoder(opt=opt, embedding=embedding)

def __init__(self, opt, dictionary):
def __init__(self, opt: Opt, dictionary: DictionaryAgent):
self.pad_idx = dictionary[dictionary.null_token]
self.start_idx = dictionary[dictionary.start_token]
self.end_idx = dictionary[dictionary.end_token]
Expand Down
14 changes: 12 additions & 2 deletions parlai/agents/transformer/modules/mem_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from typing import Optional, Tuple

from parlai.agents.transformer.modules import (
create_embeddings,
BasicAttention,
TransformerEncoder,
TransformerResponseWrapper,
)
from parlai.core.opt import Opt
from parlai.core.torch_agent import DictionaryAgent


class TransformerMemNetModel(nn.Module):
Expand All @@ -30,7 +34,7 @@ def build_encoder(
reduction_type=reduction_type,
)

def __init__(self, opt, dictionary):
def __init__(self, opt: Opt, dictionary: DictionaryAgent):
super().__init__()
self.opt = opt
self.pad_idx = dictionary[dictionary.null_token]
Expand Down Expand Up @@ -136,7 +140,13 @@ def encode_context_memory(self, context_w, memories_w, context_segments=None):

return weights, context_h

def forward(self, xs, mems, cands, context_segments=None):
def forward(
self,
xs: torch.LongTensor,
mems: torch.LongTensor,
cands: torch.LongTensor,
context_segments: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""
Forward pass.
Expand Down
11 changes: 7 additions & 4 deletions parlai/agents/transformer/modules/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from parlai.agents.transformer.modules import TransformerEncoder


class TransformerResponseWrapper(nn.Module):
"""
Expand All @@ -13,7 +16,7 @@ class TransformerResponseWrapper(nn.Module):
Pushes input through transformer and MLP.
"""

def __init__(self, transformer, hdim):
def __init__(self, transformer: TransformerEncoder, hdim: int):
super(TransformerResponseWrapper, self).__init__()
dim = transformer.out_dim
self.transformer = transformer
Expand All @@ -23,7 +26,7 @@ def __init__(self, transformer, hdim):
nn.Linear(hdim, dim),
)

def forward(self, *args):
def forward(self, *args) -> torch.Tensor:
"""
Forward pass.
"""
Expand All @@ -35,13 +38,13 @@ class TransformerLinearWrapper(nn.Module):
Wrap a transformer in a linear layer.
"""

def __init__(self, transformer, output_dim):
def __init__(self, transformer: TransformerEncoder, output_dim: int):
super().__init__()
self.transformer = transformer
input_dim = transformer.out_dim
self.additional_linear_layer = nn.Linear(input_dim, output_dim)

def forward(self, *args):
def forward(self, *args) -> torch.Tensor:
"""
Forward pass.
Expand Down

0 comments on commit f11c077

Please sign in to comment.