diff --git a/parlai/agents/transformer/modules/attention.py b/parlai/agents/transformer/modules/attention.py index 16d5219cd53..693415e3f6b 100644 --- a/parlai/agents/transformer/modules/attention.py +++ b/parlai/agents/transformer/modules/attention.py @@ -14,6 +14,7 @@ import torch.nn as nn import torch.nn.functional as F +from parlai.core.opt import Opt from parlai.utils.torch import neginf @@ -92,8 +93,20 @@ class MultiHeadAttention(nn.Module): See Vaswani (2017) for an extensive description. """ - def __init__(self, n_heads: int, dim: int, dropout: float = 0): + def __init__( + self, opt: Opt, n_heads: int = None, dim: int = None, dropout: float = 0 + ): super(MultiHeadAttention, self).__init__() + + def _default(val, default): + """ + shorthand for explicit None check for optional arguments. + """ + return val if val is not None else default + + n_heads = _default(n_heads, opt['n_heads']) + dim = _default(dim, opt['embedding_size']) + self.n_heads = n_heads self.dim = dim @@ -120,6 +133,7 @@ def forward( # type: ignore mask: torch.Tensor = None, incr_state: Optional[Dict[str, torch.Tensor]] = None, static_kv: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """ Forward pass. diff --git a/parlai/agents/transformer/modules/decoder.py b/parlai/agents/transformer/modules/decoder.py index 4bbcb11e345..7ba4968195a 100644 --- a/parlai/agents/transformer/modules/decoder.py +++ b/parlai/agents/transformer/modules/decoder.py @@ -44,9 +44,10 @@ class TransformerDecoderLayer(nn.Module): def __init__( self, - n_heads: int, - embedding_size: int, - ffn_size: int, + opt: Opt, + n_heads: int = None, + embedding_size: int = None, + ffn_size: int = None, attention_dropout: float = 0.0, relu_dropout: float = 0.0, dropout: float = 0.0, @@ -55,6 +56,18 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + + def _default(val, default): + """ + shorthand for explicit None check for optional arguments. + """ + return val if val is not None else default + + n_heads = _default(n_heads, opt['n_heads']) + embedding_size = _default(embedding_size, opt['embedding_size']) + ffn_size = _default(ffn_size, opt['ffn_size']) + + self.opt = opt self.dim = embedding_size self.ffn_dim = ffn_size self.variant = variant @@ -62,17 +75,21 @@ def __init__( self.dropout = nn.Dropout(p=dropout) self.self_attention = self.swappables.self_attention( - n_heads, embedding_size, dropout=attention_dropout + opt=self.opt, n_heads=n_heads, dim=embedding_size, dropout=attention_dropout ) # type: ignore self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) self.encoder_attention = self.swappables.encoder_attention( - n_heads, embedding_size, dropout=attention_dropout + opt=self.opt, n_heads=n_heads, dim=embedding_size, dropout=attention_dropout ) # type: ignore self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) self.ffn = self.swappables.feedforward( - embedding_size, ffn_size, relu_dropout=relu_dropout, activation=activation + opt=self.opt, + dim=embedding_size, + dim_hidden=ffn_size, + relu_dropout=relu_dropout, + activation=activation, ) # type: ignore self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) @@ -106,6 +123,7 @@ def forward( mask=decoder_mask, incr_state=incr_state.get('self_attn'), static_kv=False, + **kwargs, )[:2] x = self.dropout(x) # --dropout x = x + residual @@ -123,6 +141,7 @@ def forward( mask=encoder_mask, incr_state=incr_state.get('encoder_attn'), static_kv=True, + **kwargs, )[:2] x = self.dropout(x) # --dropout x = residual + x @@ -133,7 +152,7 @@ def forward( residual = x if self.variant == 'prelayernorm': x = self.norm3(x) - x = self.ffn(x) + x = self.ffn(x, **kwargs) x = self.dropout(x) # --dropout x = residual + x if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': @@ -195,6 +214,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + self.opt = opt def _default(val, default): return val if val is not None else default @@ -212,8 +232,7 @@ def _default(val, default): self.variant = opt.get('variant', 'aiayn') self.embeddings_scale = opt.get('embeddings_scale', True) - dropout_frac = opt.get('dropout', 0.0) - self.dropout = nn.Dropout(p=dropout_frac) # --dropout + self.dropout = nn.Dropout(p=opt.get('dropout', 0.0)) # --dropout self.n_positions = _default(n_positions, get_n_positions_from_options(opt)) self.out_dim = self.embedding_size @@ -253,26 +272,29 @@ def _default(val, default): ) # build the model - self.layers = nn.ModuleList() + self.layers = self.build_layers() + + def build_layers(self) -> nn.ModuleList: + layers = nn.ModuleList() for _ in range(self.n_layers): - self.layers.append( + layers.append( self.swappables.layer( - self.n_heads, - self.embedding_size, - self.ffn_size, - attention_dropout=opt.get('attention_dropout', 0.0), - relu_dropout=opt.get('relu_dropout', 0.0), - dropout=dropout_frac, + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), + dropout=self.opt.get('dropout', 0.0), activation=self.activation, variant=self.variant, ) # type: ignore ) + return layers def forward_embedding( self, input: torch.LongTensor, positions: Optional[torch.LongTensor] = None, segments: Optional[torch.LongTensor] = None, + **kwargs, ): """ Embed tokens prior to feeding into transformer. @@ -311,6 +333,7 @@ def forward_layers( encoder_output: torch.Tensor, encoder_mask: torch.Tensor, incr_state: Dict[int, Dict[str, Dict[str, torch.Tensor]]], + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Forward pass of decoder layers. @@ -340,6 +363,7 @@ def forward_layers( encoder_output=encoder_output, encoder_mask=encoder_mask, incr_state=incr_state.get(idx), + **kwargs, ) return tensor, new_incr_state @@ -377,12 +401,12 @@ def forward( else: incr_state = {} - tensor = self.forward_embedding(input, positions) + tensor = self.forward_embedding(input, positions, **kwargs) tensor = self.dropout(tensor) # --dropout tensor, new_incr_state = self.forward_layers( - tensor, encoder_output, encoder_mask, incr_state + tensor, encoder_output, encoder_mask, incr_state, **kwargs ) if self.variant == 'prelayernorm': diff --git a/parlai/agents/transformer/modules/encoder.py b/parlai/agents/transformer/modules/encoder.py index 95bc70de081..b79981fd9eb 100644 --- a/parlai/agents/transformer/modules/encoder.py +++ b/parlai/agents/transformer/modules/encoder.py @@ -35,9 +35,10 @@ class TransformerEncoderLayer(nn.Module): def __init__( self, - n_heads: int, - embedding_size: int, - ffn_size: int, + opt: Opt, + n_heads: int = None, + embedding_size: int = None, + ffn_size: int = None, attention_dropout: float = 0.0, relu_dropout: float = 0.0, dropout: float = 0.0, @@ -46,17 +47,33 @@ def __init__( **kwargs, ): super().__init__(**kwargs) + + def _default(val, default): + """ + shorthand for explicit None check for optional arguments. + """ + return val if val is not None else default + + n_heads = _default(n_heads, opt['n_heads']) + embedding_size = _default(embedding_size, opt['embedding_size']) + ffn_size = _default(ffn_size, opt['ffn_size']) + + self.opt = opt self.dim = embedding_size self.ffn_dim = ffn_size self.activation = activation self.variant = variant self.attention = self.swappables.self_attention( # type: ignore - n_heads, embedding_size, dropout=attention_dropout # --attention-dropout + opt=self.opt, + n_heads=n_heads, + dim=embedding_size, + dropout=attention_dropout, # --attention-dropout ) self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) self.ffn = self.swappables.feedforward( # type: ignore - embedding_size, - ffn_size, + opt=self.opt, + dim=embedding_size, + dim_hidden=ffn_size, relu_dropout=relu_dropout, activation=self.activation, ) @@ -145,6 +162,7 @@ def _default(val, default): # this is --dropout, not --relu-dropout or --attention-dropout self.dropout_frac = _default(dropout, opt.get('dropout', 0.0)) self.dropout = nn.Dropout(p=self.dropout_frac) + self.activation = _default(activation, opt.get('activation', 'relu')) self.variant = _default(variant, opt.get('variant', 'aiayn')) self.n_segments = _default(n_segments, opt.get('n_segments', 0)) @@ -203,21 +221,23 @@ def _default(val, default): nn.init.normal_(self.segment_embeddings.weight, 0, self.dim ** -0.5) # build the model - self.layers = nn.ModuleList() + self.layers = self.build_layers() + self.output_scaling = _default(output_scaling, opt.get('output_scaling', 1.0)) + + def build_layers(self) -> nn.ModuleList: + layers = nn.ModuleList() for _ in range(self.n_layers): - self.layers.append( + layers.append( self.swappables.layer( # type: ignore - self.n_heads, - self.embedding_size, - self.ffn_size, - attention_dropout=opt.get('attention_dropout', 0.0), - relu_dropout=opt.get('relu_dropout', 0.0), + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), dropout=self.dropout_frac, variant=self.variant, - activation=_default(activation, opt.get('activation', 'relu')), + activation=self.activation, ) ) - self.output_scaling = _default(output_scaling, opt.get('output_scaling', 1.0)) + return layers def forward_embedding( self, diff --git a/parlai/agents/transformer/modules/ffn.py b/parlai/agents/transformer/modules/ffn.py index 32fe1d74393..3537c38b112 100644 --- a/parlai/agents/transformer/modules/ffn.py +++ b/parlai/agents/transformer/modules/ffn.py @@ -11,6 +11,8 @@ import torch.nn as nn import torch.nn.functional as F +from parlai.core.opt import Opt + class TransformerFFN(nn.Module): """ @@ -19,12 +21,25 @@ class TransformerFFN(nn.Module): def __init__( self, - dim: int, - dim_hidden: int, + opt: Opt, + dim: int = None, + dim_hidden: int = None, relu_dropout: float = 0, activation: str = 'relu', + **kwargs, ): - super(TransformerFFN, self).__init__() + super(TransformerFFN, self).__init__(**kwargs) + + def _default(val, default): + """ + shorthand for explicit None check for optional arguments. + """ + return val if val is not None else default + + dim = _default(dim, opt['embedding_size']) + dim_hidden = _default(dim_hidden, opt['ffn_size']) + + self.opt = opt self.relu_dropout = nn.Dropout(p=relu_dropout) if activation == 'relu': self.nonlinear = F.relu @@ -40,7 +55,7 @@ def __init__( nn.init.xavier_uniform_(self.lin2.weight) # TODO: initialize biases to 0 - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: """ Forward pass. """ diff --git a/projects/params_vs_compute/hash_ladder/hash_ladder.py b/projects/params_vs_compute/hash_ladder/hash_ladder.py index 80d4be37ea6..6a40be11ef3 100644 --- a/projects/params_vs_compute/hash_ladder/hash_ladder.py +++ b/projects/params_vs_compute/hash_ladder/hash_ladder.py @@ -10,20 +10,14 @@ from parlai.agents.transformer.modules import ( TransformerDecoder, - TransformerDecoderLayer, TransformerGeneratorModel, ) -from parlai.agents.transformer.modules import ( - create_position_codes, - get_n_positions_from_options, - LAYER_NORM_EPS, -) +from parlai.agents.transformer.modules import LAYER_NORM_EPS from parlai.agents.transformer.transformer import TransformerGeneratorAgent from parlai.core.opt import Opt from parlai.core.params import ParlaiParser -from parlai.utils.misc import warn_once import torch.nn.functional as F from torch.nn import LayerNorm @@ -43,9 +37,8 @@ class HashLadderAgent(TransformerGeneratorAgent): """ Simple implementation of Hash Layers and the Ladder model from the following papers: - https://arxiv.org/abs/2106.04426 - https://arxiv.org/abs/2106.04279 - + - https://arxiv.org/abs/2106.04426 + - https://arxiv.org/abs/2106.04279 """ @classmethod @@ -76,116 +69,30 @@ def build_model(self, states=None): return wrapped_class(self.opt, self.dict) -def _normalize(tensor, norm_layer): - """ - Broadcast layer norm. - """ - is_cpu = tensor.device == 'cpu' or tensor.device.type == 'cpu' - return norm_layer(tensor) - - class Decoder(TransformerDecoder): """ Custom Decoder with Ladder model. """ - def __init__( - self, - opt: Opt, - embedding: Optional[nn.Embedding] = None, - n_positions: Optional[int] = None, - **kwargs, - ): - super().__init__(opt, **kwargs) - - def _default(val, default): - return val if val is not None else default - - opt['dict_size'] = embedding.weight.size(0) - self.opt = opt - self.embedding_size = opt['embedding_size'] - self.ffn_size = opt['ffn_size'] - self.n_layers = ( - opt['n_decoder_layers'] - if opt.get('n_decoder_layers', -1) > 0 - else opt['n_layers'] - ) - self.n_heads = opt['n_heads'] - self.dim = self.embedding_size - self.activation = opt.get('activation', 'relu') - self.variant = opt.get('variant', 'aiayn') - - self.embeddings_scale = opt.get('embeddings_scale', True) - dropout_frac = opt.get('dropout', 0.0) - self.dropout = nn.Dropout(p=dropout_frac) # --dropout - - self.n_positions = _default(n_positions, get_n_positions_from_options(opt)) - self.out_dim = self.embedding_size - assert ( - self.embedding_size % self.n_heads == 0 - ), 'Transformer embedding size must be a multiple of n_heads' - - self.embeddings = embedding - - if ( - self.variant == 'xlm' - or self.variant == 'prelayernorm' - or self.variant == 'bart' - ): - self.norm_embeddings = torch.nn.LayerNorm(self.dim, eps=LAYER_NORM_EPS) - if self.variant == 'xlm': - warn_once( - 'DEPRECATED: XLM should only be used for backwards compatibility, ' - 'as it involves a less-stable layernorm operation.' - ) - elif self.variant == 'aiayn': - pass - else: - raise ValueError("Can't handle --variant {}".format(self.variant)) - - # create the positional embeddings - self.position_embeddings = nn.Embedding(self.n_positions, self.embedding_size) - if not opt.get('learn_positional_embeddings', False): - create_position_codes( - self.n_positions, - self.embedding_size, - out=self.position_embeddings.weight, - ) - else: - nn.init.normal_( - self.position_embeddings.weight, 0, self.embedding_size ** -0.5 - ) - - # build the model - self.layers = nn.ModuleList() + def build_layers(self) -> nn.ModuleList: + # HACK: Adding vocab size to opt for use in HashLayerFFN + self.opt['dict_size'] = self.embeddings.weight.size(0) + layers = nn.ModuleList() for i in range(self.n_layers): + layer_class = self.swappables.layer if self.opt['hash_layer'] == i: - self.layers.append( - HashLayer( - self.n_heads, - self.embedding_size, - self.ffn_size, - attention_dropout=opt.get('attention_dropout', 0.0), - relu_dropout=opt.get('relu_dropout', 0.0), - dropout=dropout_frac, - activation=self.activation, - variant=self.variant, - opt=self.opt, - ) # type: ignore - ) - else: - self.layers.append( - self.swappables.layer( - self.n_heads, - self.embedding_size, - self.ffn_size, - attention_dropout=opt.get('attention_dropout', 0.0), - relu_dropout=opt.get('relu_dropout', 0.0), - dropout=dropout_frac, - activation=self.activation, - variant=self.variant, - ) # type: ignore - ) + layer_class = layer_class.with_components(feedforward=HashLayerFFN) + layers.append( + layer_class( + self.opt, + attention_dropout=self.opt.get('attention_dropout', 0.0), + relu_dropout=self.opt.get('relu_dropout', 0.0), + dropout=self.opt.get('dropout', 0.0), + activation=self.activation, + variant=self.variant, + ) # type: ignore + ) + return layers def forward_layers( self, @@ -193,196 +100,35 @@ def forward_layers( encoder_output: torch.Tensor, encoder_mask: torch.Tensor, incr_state: Dict[int, Dict[str, Dict[str, torch.Tensor]]], - original_input: torch.Tensor, + **kwargs, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ - Forward pass of decoder layers. - - :param tensor: - embedded input tensor for the decoder - :param enc_out: - encoder outputs - :param enc_mask: - encoder output mask - :param incr_state: - Dict mapping layer_idx to incremental state - - :return (tensor, new_incr_state): - return encoding after applying decoder layers, as well - as new incremental decoding state. + Override of forward_layers of TransformerDecoder. """ new_incr_state = {} - if getattr(self.layers, 'is_model_parallel', False): - tensor, new_incr_state = self._apply_model_parallel( - tensor, encoder_output, encoder_mask, incr_state + for _s in range(0, self.opt['ladder_size']): + tensor, new_incr_state = super().forward_layers( + tensor=tensor, + encoder_output=encoder_output, + encoder_mask=encoder_mask, + incr_state=incr_state, + **kwargs, ) - else: - for _s in range(0, self.opt['ladder_size']): - for idx, layer in enumerate(self.layers): - if idx == self.opt['hash_layer']: - tensor, new_incr_state[idx] = layer( - x=tensor, - encoder_output=encoder_output, - encoder_mask=encoder_mask, - incr_state=incr_state.get(idx), - orig_input=original_input, - ) - else: - tensor, new_incr_state[idx] = layer( - x=tensor, - encoder_output=encoder_output, - encoder_mask=encoder_mask, - incr_state=incr_state.get(idx), - ) - return tensor, new_incr_state - def forward(self, input, encoder_state, incr_state=None): - """ - Forward pass. - - :param LongTensor[batch,seqlen] input: - The decoder inputs (partial or full decoded token IDs). - :param encoder_state: - Output from the encoder module forward pass. - :param incr_state: - The incremental state: a dictionary whose keys index the layers and whose - values contain the incremental state for each layer. - """ - encoder_output, encoder_mask = encoder_state - - seq_len = input.size(1) - positions = input.new(seq_len).long() - positions = torch.arange(seq_len, out=positions).unsqueeze(0) - - if incr_state is not None: - # We're doing incremental decoding, so select only the most recent position - input = input[:, -1:] - if positions is not None: - positions = positions[:, -1:] - else: - incr_state = {} - - tensor = self.forward_embedding(input, positions) - - tensor = self.dropout(tensor) # --dropout - - tensor, new_incr_state = self.forward_layers( - tensor, encoder_output, encoder_mask, incr_state, original_input=input - ) - - if self.variant == 'prelayernorm': - tensor = _normalize(tensor, self.norm_embeddings) - - return tensor, new_incr_state - - -class HashLayer(TransformerDecoderLayer): - def __init__( + def forward( self, - n_heads: int, - embedding_size: int, - ffn_size: int, - opt: Opt, - attention_dropout: float = 0.0, - relu_dropout: float = 0.0, - dropout: float = 0.0, - activation: str = 'relu', - variant: str = 'aiayn', + input: torch.Tensor, + encoder_state, + incr_state: Optional[Dict[str, torch.Tensor]] = None, **kwargs, - ): - super().__init__(n_heads, embedding_size, ffn_size, **kwargs) - self.dim = embedding_size - self.ffn_dim = ffn_size - self.variant = variant - self.activation = activation - self.dropout = nn.Dropout(p=dropout) - - self.self_attention = self.swappables.self_attention( - n_heads, embedding_size, dropout=attention_dropout - ) # type: ignore - self.norm1 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - - self.encoder_attention = self.swappables.encoder_attention( - n_heads, embedding_size, dropout=attention_dropout - ) # type: ignore - self.norm2 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - - self.ffn = HashLayerFFN( - opt, - embedding_size, - ffn_size, - relu_dropout=relu_dropout, - activation=activation, - ) # type: ignore - self.norm3 = torch.nn.LayerNorm(embedding_size, eps=LAYER_NORM_EPS) - - def forward( - self, x, encoder_output, encoder_mask, incr_state=None, orig_input=None - ): + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ - Forward pass. - - The incremental state is a dict with values for self- and encoder-attention - states. + Overrides TransformerDecoder forward. """ - - if incr_state is None: - incr_state = {} - - decoder_mask = self._create_selfattn_mask(x) - # first self attn - residual = x - if self.variant == 'prelayernorm': - x = _normalize(x, self.norm1) - - # don't peak into the future! - x, final_self_attn_incr_state = self.self_attention( - query=x, - mask=decoder_mask, - incr_state=incr_state.get('self_attn'), - static_kv=False, - )[:2] - x = self.dropout(x) # --dropout - x = x + residual - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - x = _normalize(x, self.norm1) - - residual = x - # encoder_attn_layer_norm norm 2 - if self.variant == 'prelayernorm': - x = _normalize(x, self.norm2) - x, final_encoder_attn_incr_state, dotprod = self.encoder_attention( - query=x, - key=encoder_output, - value=encoder_output, - mask=encoder_mask, - incr_state=incr_state.get('encoder_attn'), - static_kv=True, + return super().forward( + input, encoder_state, incr_state=incr_state, orig_input=input, **kwargs ) - x = self.dropout(x) # --dropout - x = residual + x - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - x = _normalize(x, self.norm2) - - # finally the ffn - residual = x - if self.variant == 'prelayernorm': - x = _normalize(x, self.norm3) - x = self.ffn(x, orig_input) - x = self.dropout(x) # --dropout - x = residual + x - if self.variant == 'aiayn' or self.variant == 'xlm' or self.variant == 'bart': - x = _normalize(x, self.norm3) - - new_incr_state = { - 'self_attn': final_self_attn_incr_state, - 'encoder_attn': final_encoder_attn_incr_state, - } - - self.output = x - - return x, new_incr_state class HashLayerFFN(nn.Module): @@ -458,7 +204,7 @@ def forward(self, x, orig_input): x1 = self.relu_dropout(x1) # --relu-dropout x1 = self.linears2[i](x1) x1 = residual + x1 - x1 = _normalize(x1, self.norms[0]) + x1 = self.norms[0](x1) final_output[index_list[i][0], index_list[i][1], :] = x1 return final_output