Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Type annotate transformer/modules.py #3545

Merged
merged 2 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions parlai/agents/transformer/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

import math
from typing import Dict, Tuple, Optional
from typing import Dict, Tuple, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -22,7 +22,13 @@ class BasicAttention(nn.Module):
Implements simple/classical attention.
"""

def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
def __init__(
self,
dim: int = 1,
attn: str = 'cosine',
residual: bool = False,
get_weights: bool = True,
):
super().__init__()
if attn == 'cosine':
self.cosine = nn.CosineSimilarity(dim=dim)
Expand All @@ -31,7 +37,13 @@ def __init__(self, dim=1, attn='cosine', residual=False, get_weights=True):
self.get_weights = get_weights
self.residual = residual

def forward(self, xs, ys, mask_ys=None, values=None):
def forward(
self,
xs: torch.Tensor,
ys: torch.Tensor,
mask_ys: Optional[torch.Tensor] = None,
values: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Compute attention.

Expand Down Expand Up @@ -80,7 +92,7 @@ class MultiHeadAttention(nn.Module):
See Vaswani (2017) for an extensive description.
"""

def __init__(self, n_heads, dim, dropout=0):
def __init__(self, n_heads: int, dim: int, dropout: float = 0):
super(MultiHeadAttention, self).__init__()
self.n_heads = n_heads
self.dim = dim
Expand Down
31 changes: 21 additions & 10 deletions parlai/agents/transformer/modules/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def forward_layers(

return tensor, new_incr_state

def forward(self, input, encoder_state, incr_state=None):
def forward(
self,
input: torch.Tensor,
encoder_state,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.

Expand Down Expand Up @@ -282,14 +287,14 @@ class TransformerDecoderLayer(nn.Module):

def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
activation='relu',
variant='aiayn',
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: str = 'aiayn',
):
super().__init__()
self.dim = embedding_size
Expand All @@ -313,7 +318,13 @@ def __init__(
)
self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)

def forward(self, x, encoder_output, encoder_mask, incr_state=None):
def forward(
self,
x: torch.Tensor,
encoder_output: torch.Tensor,
encoder_mask: torch.Tensor,
incr_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Forward pass.

Expand Down
18 changes: 9 additions & 9 deletions parlai/agents/transformer/modules/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,14 +315,14 @@ class TransformerEncoderLayer(nn.Module):

def __init__(
self,
n_heads,
embedding_size,
ffn_size,
attention_dropout=0.0,
relu_dropout=0.0,
dropout=0.0,
activation='relu',
variant=None,
n_heads: int,
embedding_size: int,
ffn_size: int,
attention_dropout: float = 0.0,
relu_dropout: float = 0.0,
dropout: float = 0.0,
activation: str = 'relu',
variant: Optional[str] = None,
):
super().__init__()
self.dim = embedding_size
Expand All @@ -342,7 +342,7 @@ def __init__(
self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS)
self.dropout = nn.Dropout(p=dropout)

def forward(self, tensor, mask):
def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down
11 changes: 9 additions & 2 deletions parlai/agents/transformer/modules/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Feedforward neural network, as used in transformer implementation.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

Expand All @@ -16,7 +17,13 @@ class TransformerFFN(nn.Module):
Implements the FFN part of the transformer.
"""

def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'):
def __init__(
self,
dim: int,
dim_hidden: int,
relu_dropout: float = 0,
activation: str = 'relu',
):
super(TransformerFFN, self).__init__()
self.relu_dropout = nn.Dropout(p=relu_dropout)
if activation == 'relu':
Expand All @@ -33,7 +40,7 @@ def __init__(self, dim, dim_hidden, relu_dropout=0, activation='relu'):
nn.init.xavier_uniform_(self.lin2.weight)
# TODO: initialize biases to 0

def forward(self, x):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass.
"""
Expand Down
4 changes: 3 additions & 1 deletion parlai/agents/transformer/modules/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
TransformerDecoder,
TransformerEncoder,
)
from parlai.core.opt import Opt
from parlai.core.torch_agent import DictionaryAgent
from parlai.core.torch_generator_agent import TorchGeneratorModel
from parlai.utils.torch import neginf

Expand All @@ -52,7 +54,7 @@ def build_encoder(
def build_decoder(cls, opt, embedding=None):
return TransformerDecoder(opt=opt, embedding=embedding)

def __init__(self, opt, dictionary):
def __init__(self, opt: Opt, dictionary: DictionaryAgent):
self.pad_idx = dictionary[dictionary.null_token]
self.start_idx = dictionary[dictionary.start_token]
self.end_idx = dictionary[dictionary.end_token]
Expand Down
14 changes: 12 additions & 2 deletions parlai/agents/transformer/modules/mem_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
from typing import Optional, Tuple

from parlai.agents.transformer.modules import (
create_embeddings,
BasicAttention,
TransformerEncoder,
TransformerResponseWrapper,
)
from parlai.core.opt import Opt
from parlai.core.torch_agent import DictionaryAgent


class TransformerMemNetModel(nn.Module):
Expand All @@ -30,7 +34,7 @@ def build_encoder(
reduction_type=reduction_type,
)

def __init__(self, opt, dictionary):
def __init__(self, opt: Opt, dictionary: DictionaryAgent):
super().__init__()
self.opt = opt
self.pad_idx = dictionary[dictionary.null_token]
Expand Down Expand Up @@ -136,7 +140,13 @@ def encode_context_memory(self, context_w, memories_w, context_segments=None):

return weights, context_h

def forward(self, xs, mems, cands, context_segments=None):
def forward(
self,
xs: torch.LongTensor,
mems: torch.LongTensor,
cands: torch.LongTensor,
context_segments: Optional[torch.LongTensor] = None,
) -> Tuple[torch.LongTensor, torch.LongTensor]:
"""
Forward pass.

Expand Down
11 changes: 7 additions & 4 deletions parlai/agents/transformer/modules/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn

from parlai.agents.transformer.modules import TransformerEncoder


class TransformerResponseWrapper(nn.Module):
"""
Expand All @@ -13,7 +16,7 @@ class TransformerResponseWrapper(nn.Module):
Pushes input through transformer and MLP.
"""

def __init__(self, transformer, hdim):
def __init__(self, transformer: TransformerEncoder, hdim: int):
super(TransformerResponseWrapper, self).__init__()
dim = transformer.out_dim
self.transformer = transformer
Expand All @@ -23,7 +26,7 @@ def __init__(self, transformer, hdim):
nn.Linear(hdim, dim),
)

def forward(self, *args):
def forward(self, *args) -> torch.Tensor:
"""
Forward pass.
"""
Expand All @@ -35,13 +38,13 @@ class TransformerLinearWrapper(nn.Module):
Wrap a transformer in a linear layer.
"""

def __init__(self, transformer, output_dim):
def __init__(self, transformer: TransformerEncoder, output_dim: int):
super().__init__()
self.transformer = transformer
input_dim = transformer.out_dim
self.additional_linear_layer = nn.Linear(input_dim, output_dim)

def forward(self, *args):
def forward(self, *args) -> torch.Tensor:
"""
Forward pass.

Expand Down