Skip to content

Commit

Permalink
Add padding_masks and tests for T5Model (#1935)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1935

Added the following parameters to the `forward` method of the T5Model:
* `encoder_padding_mask`
* `decoder_padding_mask`

These allow users to specifically mask out the padding of input sequences. This matches the implementation of Transformers in PyTorch core.

Reviewed By: Nayef211

Differential Revision: D40252794

fbshipit-source-id: 0e0a17fdc97ae0bbcaa1aef91e9914fd6225456b
  • Loading branch information
Joe Cummings authored and facebook-github-bot committed Oct 17, 2022
1 parent bac09bf commit aa369b8
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 33 deletions.
87 changes: 75 additions & 12 deletions test/torchtext_unittest/prototype/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@

import torch
from torch.nn import functional as F
from torchtext.prototype.models import T5Bundle, T5Conf, T5Model
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase


class TestModels(TorchtextTestCase):
def test_t5_bundler_build_model(self) -> None:
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle

# case: user provides encoder checkpoint state dict
dummy_encoder_conf = T5Conf(
encoder_only=True,
Expand All @@ -21,7 +20,9 @@ def test_t5_bundler_build_model(self) -> None:
num_decoder_layers=2,
)
dummy_t5_encoder = T5Model(dummy_encoder_conf)
t5_encoder_model = T5Bundle.build_model(config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict())
t5_encoder_model = T5Bundle.build_model(
config=dummy_encoder_conf, checkpoint=dummy_t5_encoder.state_dict()
)
self.assertEqual(t5_encoder_model.state_dict(), dummy_t5_encoder.state_dict())

# case: user provides encoder-decoder checkpoint state dict
Expand All @@ -35,7 +36,9 @@ def test_t5_bundler_build_model(self) -> None:
num_decoder_layers=2,
)
dummy_t5 = T5Model(dummy_t5_conf)
t5_model = T5Bundle.build_model(config=dummy_t5_conf, checkpoint=dummy_t5.state_dict())
t5_model = T5Bundle.build_model(
config=dummy_t5_conf, checkpoint=dummy_t5.state_dict()
)
self.assertEqual(t5_model.state_dict(), dummy_t5.state_dict())

# case: user provides checkpoint state dict for encoder-decoder with generation
Expand All @@ -53,12 +56,12 @@ def test_t5_bundler_build_model(self) -> None:
t5_generation_model = T5Bundle.build_model(
config=dummy_t5_generation_conf, checkpoint=dummy_t5_generation.state_dict()
)
self.assertEqual(t5_generation_model.state_dict(), dummy_t5_generation.state_dict())
self.assertEqual(
t5_generation_model.state_dict(), dummy_t5_generation.state_dict()
)

@patch("logging.Logger.warning")
def test_t5_bundler_get_model(self, mock):
from torchtext.prototype.models import T5Conf, T5Bundle

# encoder-decoder with generation
dummy_t5_generation_conf = T5Conf(
encoder_only=False,
Expand All @@ -77,8 +80,6 @@ def test_t5_bundler_get_model(self, mock):
)

def test_t5_bundler_raise_checkpoint(self) -> None:
from torchtext.prototype.models import T5Conf, T5Bundle

# encoder-only
with self.assertRaises(TypeError):
dummy_encoder_conf = T5Conf(
Expand Down Expand Up @@ -132,8 +133,6 @@ def test_t5_bundler_raise_checkpoint(self) -> None:
)

def test_t5_bundler_conf_property(self) -> None:
from torchtext.prototype.models import T5Conf, T5Bundle

dummy_t5_conf = T5Conf(
encoder_only=False,
vocab_size=10,
Expand All @@ -148,7 +147,6 @@ def test_t5_bundler_conf_property(self) -> None:

def test_t5_bundler_train(self) -> None:
from torch.optim import SGD
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle

def _train(model):
optim = SGD(model.parameters(), lr=1)
Expand Down Expand Up @@ -181,3 +179,68 @@ def _train(model):

_train(model)
self.assertNotEqual(model.state_dict(), current_state_dict)

def test_t5_model_forward_with_encoder_mask_encoder_only(self) -> None:
dummy_conf = T5Conf(
encoder_only=True,
linear_head=True,
vocab_size=100,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
training=False,
)
dummy_model = T5Model(dummy_conf)
tokens = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0]])
mask = tokens.eq(0)

with torch.no_grad():
output_with_mask = dummy_model(
encoder_tokens=tokens, encoder_padding_mask=mask
)
output_no_mask = dummy_model(tokens)

torch.testing.assert_close(
output_with_mask["encoder_output"],
output_no_mask["encoder_output"],
atol=1e-04,
rtol=2.5e-06,
)

def test_t5_model_forward_with_encoder_mask_encoder_decoder(self) -> None:
dummy_conf = T5Conf(
encoder_only=False,
linear_head=True,
vocab_size=100,
embedding_dim=16,
ffn_dimension=64,
num_attention_heads=2,
num_encoder_layers=2,
num_decoder_layers=2,
training=False,
)
dummy_model = T5Model(dummy_conf)
enc_tokens = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0, 0, 0]])
dec_tokens = torch.tensor([[6, 7, 8, 9, 10, 11, 0, 0, 0, 0]])
enc_mask = enc_tokens.eq(0)
dec_mask = dec_tokens.eq(0)

with torch.no_grad():
output_with_mask = dummy_model(
encoder_tokens=enc_tokens,
encoder_padding_mask=enc_mask,
decoder_tokens=dec_tokens,
decoder_padding_mask=dec_mask,
)
output_no_mask = dummy_model(
encoder_tokens=enc_tokens, decoder_tokens=dec_tokens
)

torch.testing.assert_close(
output_with_mask["decoder_output"],
output_no_mask["decoder_output"],
atol=1e-04,
rtol=2.5e-06,
)
88 changes: 67 additions & 21 deletions torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Union, Callable
from typing import Callable, Dict, List, Optional, Union

import torch
import torch.nn as nn
from torch import Tensor

from .modules import T5Encoder, T5Decoder, T5LayerNorm
from .modules import T5Decoder, T5Encoder, T5LayerNorm


@dataclass
Expand Down Expand Up @@ -88,7 +88,9 @@ def __init__(
self.device = device
self.dtype = dtype

self.token_embeddings = nn.Embedding(config.vocab_size, config.embedding_dim, config.padding_idx)
self.token_embeddings = nn.Embedding(
config.vocab_size, config.embedding_dim, config.padding_idx
)
self.encoder = T5Encoder(
d_model=config.embedding_dim,
nhead=config.num_attention_heads,
Expand Down Expand Up @@ -129,7 +131,9 @@ def __init__(
self.decoder = None

if config.linear_head:
self.lm_head = nn.Linear(config.embedding_dim, config.vocab_size, bias=False)
self.lm_head = nn.Linear(
config.embedding_dim, config.vocab_size, bias=False
)
else:
self.lm_head = None

Expand All @@ -140,23 +144,31 @@ def __init__(
def forward(
self,
encoder_tokens: Tensor,
decoder_tokens: Optional[Tensor] = None,
encoder_mask: Optional[Tensor] = None,
encoder_padding_mask: Optional[Tensor] = None,
decoder_tokens: Optional[Tensor] = None,
decoder_mask: Optional[Tensor] = None,
) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]:
decoder_padding_mask: Optional[Tensor] = None,
) -> Dict[
str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]
]:
r"""Pass the inputs (and mask) through the decoder layer in turn.
Args:
encoder_tokens: Tokenized input sequence to the encoder.
Must be batch first with shape (B, Ne) where B is the batch size and Ne is the
encoder input sequence length. (required).
encoder_mask: Additive mask for the encoder input sequence.
Must have shape (Ne, Ne) (optional).
encoder_padding_mask: Padding mask for encoder input sequence.
Must have shape (B, Ne) (optional).
decoder_tokens: Tokenized input sequence to the decoder.
Must be batch first with shape (B, Nd) where B is the batch size and Nd is the
decoder input sequence length. If None and model is encoder-decoder, will initialize decoder
input sequence to begin with padding index. (optional).
encoder_mask: Self-attention mask for the encoder input sequence.
Must have shape (Ne, Ne) (optional).
decoder_mask: Self-attention mask for the decoder input sequence.
decoder_mask: Additive mask for the decoder input sequence.
Must have shape (Nd, Nd) (optional).
decoder_padding_mask: Padding mask for decoder input sequence.
Must have shape (B, Ne) (optional).
Returns:
encoder_output: Output Tensor from the final layer of the encoder
encoder_hidden_states: Tuple of output Tensors from each layer of the encoder
Expand All @@ -168,10 +180,24 @@ def forward(
encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder
encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder
"""
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
if encoder_padding_mask is None:
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)

batch_size = encoder_tokens.shape[0]
seq_len = encoder_tokens.shape[1]

assert encoder_padding_mask.shape == (batch_size, seq_len)

encoder_embeddings = self.dropout1(self.token_embeddings(encoder_tokens))
encoder_output, encoder_hidden_states, encoder_position_bias, encoder_sa = self.encoder(
encoder_embeddings, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
(
encoder_output,
encoder_hidden_states,
encoder_position_bias,
encoder_sa,
) = self.encoder(
encoder_embeddings,
tgt_mask=encoder_mask,
tgt_key_padding_mask=encoder_padding_mask,
)

encoder_output = self.norm1(encoder_output)
Expand All @@ -184,20 +210,34 @@ def forward(

# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
if decoder_tokens is None:
decoder_tokens = torch.ones((encoder_tokens.size(0), 1), dtype=torch.long) * self.padding_idx
decoder_tokens = (
torch.ones((encoder_tokens.size(0), 1), dtype=torch.long)
* self.padding_idx
)

tgt_seq_len = decoder_tokens.shape[1]

if decoder_mask is None:
assert decoder_tokens is not None and decoder_tokens.dim() == 2
tgt_len = decoder_tokens.shape[1]
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1)
decoder_mask = torch.triu(
torch.ones((tgt_seq_len, tgt_seq_len), dtype=torch.float64),
diagonal=1,
)
decoder_mask = decoder_mask.to(torch.bool)

decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
decoder_padding_mask[:, 0] = False
if decoder_padding_mask is None:
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
decoder_padding_mask[:, 0] = False

decoder_embeddings = self.dropout3(self.token_embeddings(decoder_tokens))
decoder_output, decoder_hidden_states, decoder_position_bias, decoder_sa, decoder_ca = self.decoder(
(
decoder_output,
decoder_hidden_states,
decoder_position_bias,
decoder_sa,
decoder_ca,
) = self.decoder(
decoder_embeddings,
memory=encoder_output,
tgt_mask=decoder_mask,
Expand All @@ -215,7 +255,7 @@ def forward(
# Rescale output before projecting on vocab. This happens when the encoder and decoder share the
# same word embeddings, which is always the case in our t5 implementation.
# See https://github.com/huggingface/transformers/blob/d0acc9537829e7d067edbb791473bbceb2ecf056/src/transformers/models/t5/modeling_t5.py#L1661
decoder_output = decoder_output * (self.embedding_dim ** -0.5)
decoder_output = decoder_output * (self.embedding_dim**-0.5)
decoder_output = self.lm_head(decoder_output)

t5_output = {
Expand All @@ -238,7 +278,13 @@ def forward(
}

assert torch.jit.isinstance(
t5_output, Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]
t5_output,
Dict[
str,
Union[
Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]
],
],
)

return t5_output

0 comments on commit aa369b8

Please sign in to comment.