Skip to content

Commit

Permalink
Optimize T5 for sequence generation (#2054)
Browse files Browse the repository at this point in the history
* Separate encoding/decoding logic for T5 model in preparation for generation

* Implement incremental decoding capabilities for T5

* Remove T5Wrapper

* TorchScript changes

* Linting fixes

* Define reusable types

* Add docstring to reorder cache

* Update licenses

* Linting fixes (round 2)

* Add license to t5_transform
  • Loading branch information
joecummings authored Feb 17, 2023
1 parent cbaabea commit 19f8bc9
Show file tree
Hide file tree
Showing 7 changed files with 516 additions and 587 deletions.
65 changes: 2 additions & 63 deletions test/integration_tests/prototype/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest # noqa: F401
import torch
from parameterized import parameterized, parameterized_class
from parameterized import parameterized_class
from torchtext.prototype.models import (
T5_BASE,
T5_BASE_ENCODER,
Expand All @@ -13,11 +13,8 @@
T5_SMALL,
T5_SMALL_ENCODER,
T5_SMALL_GENERATION,
T5Conf,
T5Transform,
)
from torchtext.prototype.models.t5.bundler import T5Bundle
from torchtext.prototype.models.t5.wrapper import T5Wrapper
from torchtext_unittest.common.assets import get_asset_path
from torchtext_unittest.common.parameterized_utils import nested_params
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
Expand Down Expand Up @@ -79,7 +76,7 @@ def _t5_get_encoder(self, model, model_input, encoder_output):
encoder = model.get_encoder()
# Need to set the tgt_key_padding_mask to ensure the same results
encoder_padding_mask = model_input.eq(model.padding_idx)
output_from_get_encoder = encoder(tgt=model_input, tgt_key_padding_mask=encoder_padding_mask)["encoder_output"]
output_from_get_encoder = encoder(model_input, src_key_padding_mask=encoder_padding_mask)["encoder_output"]
assert torch.all(output_from_get_encoder.eq(encoder_output))

@nested_params(["jit", "not_jit"])
Expand All @@ -93,64 +90,6 @@ def test_t5_model(self, name) -> None:
self._t5_model(is_jit=is_jit, t5_model=t5_model, expected_asset_name=expected_asset_name, test_text=test_text)


@parameterized_class(
("configuration",),
[
("small",),
("base",),
("large",),
],
)
class TestT5Wrapper(TorchtextTestCase):
# No longer Torchscriptable
@parameterized.expand(["no_jit"])
def test_t5_wrapper(self, name) -> None:
configuration = self.configuration
test_text = ["translate English to French: I want to eat pizza for dinner."]
if configuration == "small":
expected_text = ["Je veux manger la pizza pour le dîner."]
else:
expected_text = ["Je veux manger de la pizza pour le dîner."]

beam_size = 3
max_seq_len = 512
model = T5Wrapper(configuration=configuration, strict=False)
if name == "jit":
model = torch.jit.script(model)

output_text = model(test_text, beam_size, max_seq_len)
self.assertEqual(output_text, expected_text)


class TestT5WrapperCheckpoint(TorchtextTestCase):
# No longer Torchscriptable
@parameterized.expand(["no_jit"])
def test_t5_wrapper_checkpoint(self, name) -> None:
test_text = ["translate English to French: I want to eat pizza for dinner."]
expected_text = ["Je veux manger de la pizza pour le dîner."]
beam_size = 3
max_seq_len = 512
config = T5Conf(encoder_only=False, linear_head=True)
transform = T5Transform(
"https://download.pytorch.org/models/text/t5_tokenizer_base.model",
max_seq_len=512,
eos_idx=1,
padding_idx=0,
)
model = T5Wrapper(
checkpoint="https://download.pytorch.org/models/text/t5.base.generation.v2.pt",
t5_config=config,
transform=transform,
freeze_model=True,
strict=True,
)
if name == "jit":
model = torch.jit.script(model)

output_text = model(test_text, beam_size, max_seq_len)
self.assertEqual(output_text, expected_text)


class TestLoadFromHFCheckpoints(TorchtextTestCase):
def setUp(self) -> None:
super().setUp()
Expand Down
2 changes: 2 additions & 0 deletions test/torchtext_unittest/prototype/models/models_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,8 @@ def test_t5_bundler_train(self) -> None:
from torch.optim import SGD
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle

torch.manual_seed(123)

def _train(model):
optim = SGD(model.parameters(), lr=1)
model_input = torch.tensor([[1, 2, 3, 4, 5]]).to(device=self.device)
Expand Down
10 changes: 10 additions & 0 deletions torchtext/prototype/models/t5/bundler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# /* Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. */
import json
import logging
import os
Expand Down
119 changes: 86 additions & 33 deletions torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# logging library is not automatically supported by Torchscript
# /* Portions Copyright (c) Meta Platforms, Inc. and affiliates.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. */
import warnings
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Union
Expand All @@ -7,7 +16,7 @@
import torch.nn as nn
from torch import Tensor

from .modules import T5Decoder, T5Encoder
from .modules import DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE, PAST_KEY_VALUES_TYPE, T5Decoder, T5Encoder


@dataclass
Expand Down Expand Up @@ -36,9 +45,8 @@ def __post_init__(self):
"""The following is modified from:
https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/configuration_t5.py
It's to support T5 1.1 and FLAN-T5.
Supports T5 1.1 and FLAN-T5.
"""

if self.feed_forward_proj:
act_info = self.feed_forward_proj.split("-")
self.activation = act_info[-1]
Expand All @@ -56,13 +64,13 @@ def __post_init__(self):
self.activation = "gelu_new"


# NOTE: Comparable HuggingFace implentation can be found at https://github.com/huggingface/transformers/blob/8581a798c0a48fca07b29ce2ca2ef55adcae8c7e/src/transformers/models/t5/modeling_t5.py#L1269
class T5Model(nn.Module):
r"""A T5 model. User is able to modify the attributes as needed. The architecture
is based on the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer".
Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena,
Yanqi Zhou, Wei Li, and Peter J. Liu. 2020. Journal of Machine Learning Research.
Volume 21 Issue 140 pages 1-67. http://jmlr.org/papers/v21/20-074.html
Args:
config.encoder_only: Whether or not model should consist of only the encoder as opposed to encoder-decoder (default=False).
config.linear_head: Whether or not a linear layer should be used to project the output of the decoder's last layer to the vocab (default=False).
Expand All @@ -84,8 +92,9 @@ class T5Model(nn.Module):
config.vocab_size: Size of vocabulary (default: 32128)
config.training: Whether or not to apply dropout (default: False)
freeze: Indicates whether or not to freeze the model weights. (default: False)
Examples:
>>> from torchtext.prototype.models import T5Conf, T5Model
>>> from torchtext.models import T5Conf, T5Model
>>> t5_config = T5Conf(encoder_only=False, linear_head=True)
>>> t5_model = T5Model(t5_config)
>>> encoder_input = torch.randint(0, t5_config.vocab_size, (32, 512))
Expand Down Expand Up @@ -160,14 +169,52 @@ def __init__(
for p in self.parameters():
p.requires_grad = False

def prepare_inputs_for_generation(self, input_ids, encoder_outputs):
return {"decoder_tokens": input_ids, "encoder_outputs": encoder_outputs}
@torch.jit.export
def _reorder_cache(self, past: List[PAST_KEY_VALUES_TYPE], beam_idx: Tensor) -> List[PAST_KEY_VALUES_TYPE]:
"""Reorder past key value pairs in cache. Only relevant in incremental decoding with beam search generation."""
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past is None:
return past

reordered_decoder_past: List[PAST_KEY_VALUES_TYPE] = []
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)

assert len(reordered_layer_past_states) == len(layer_past_states)

reordered_decoder_past.append(reordered_layer_past_states)
return reordered_decoder_past

@torch.jit.export
def prepare_inputs_for_generation(
self,
input_ids: Tensor,
encoder_outputs: ENCODER_OUTPUTS_TYPE,
past: Optional[List[PAST_KEY_VALUES_TYPE]] = None,
return_past_key_values: bool = True,
) -> Dict[str, Union[Tensor, ENCODER_OUTPUTS_TYPE, Optional[List[PAST_KEY_VALUES_TYPE]], bool]]:
# Incremental decoding if past key values are provided
if past is not None:
input_ids = input_ids[:, -1:]

return {
"decoder_tokens": input_ids,
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"return_past_key_values": return_past_key_values,
}

@torch.jit.ignore
def get_encoder(self) -> T5Encoder:
return self.encoder

@torch.jit.ignore
def get_decoder(self) -> Optional[T5Decoder]:
if self.decoder is None:
warnings.warn("Decoder is not set on this model.")
Expand All @@ -181,15 +228,16 @@ def forward(
decoder_mask: Optional[Tensor] = None,
encoder_padding_mask: Optional[Tensor] = None,
decoder_padding_mask: Optional[Tensor] = None,
encoder_outputs: Optional[
Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]
] = None,
) -> Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]]:
r"""Pass the inputs (and mask) through the decoder layer in turn.
encoder_outputs: Optional[ENCODER_OUTPUTS_TYPE] = None,
past_key_values: Optional[List[PAST_KEY_VALUES_TYPE]] = None,
return_past_key_values: bool = False,
) -> Union[DECODER_OUTPUTS_TYPE, ENCODER_OUTPUTS_TYPE]:
r"""Pass the inputs (and mask) through the T5Encoder/T5Decoder in turn.
Args:
encoder_tokens: Tokenized input sequence to the encoder.
Must be batch first with shape (B, Ne) where B is the batch size and Ne is the
encoder input sequence length. (required).
encoder input sequence length. (optional if `encoder_outputs` is provided)
decoder_tokens: Tokenized input sequence to the decoder.
Must be batch first with shape (B, Nd) where B is the batch size and Nd is the
decoder input sequence length. If None and model is encoder-decoder, will initialize decoder
Expand All @@ -198,6 +246,14 @@ def forward(
Must have shape (Ne, Ne) (optional).
decoder_mask: Self-attention mask for the decoder input sequence.
Must have shape (Nd, Nd) (optional).
encoder_padding_mask: Padding mask for encoder input sequence.
Must have shape (B, Ne) (optional).
decoder_padding_mask: Padding mask for decoder input sequence.
Must have shape (B, Nd) (optional).
encoder_outputs: Outputs from previous run of T5Encoder. (optional)
past_key_values: Previously calculated key values, used in incremental decoding. (optional)
return_past_key_values: Boolean indicating whether to return key values to user. (default: False)
Returns:
encoder_output: Output Tensor from the final layer of the encoder
encoder_hidden_states: Tuple of output Tensors from each layer of the encoder
Expand All @@ -208,15 +264,16 @@ def forward(
decoder_position_bias: Tensor of relative attention bias computed for input sequence to decoder
encoder_sa_scores: Tuple of self-attention scores computed at each layer of the decoder
encoder_ca_scores: Tuple of cross-attention scores computed at each layer of the decoder
past_key_values: List of Tuples of key values calculated during this run, or None.
"""
if encoder_outputs is None:
assert encoder_tokens is not None, "If `encoder_outputs` is not specified, must provide `encoder_tokens`"

if encoder_padding_mask is None:
encoder_padding_mask = encoder_tokens.eq(self.padding_idx)
encoder_padding_mask = encoder_tokens.eq(self.padding_idx).to(device=encoder_tokens.device)

encoder_outputs = self.encoder(
tgt=encoder_tokens, tgt_mask=encoder_mask, tgt_key_padding_mask=encoder_padding_mask
src=encoder_tokens, mask=encoder_mask, src_key_padding_mask=encoder_padding_mask
)

if not self.encoder_only:
Expand All @@ -226,20 +283,15 @@ def forward(
encoder_output = encoder_outputs.get("encoder_output")
assert torch.jit.isinstance(encoder_output, Tensor)

batch_size = encoder_output.size(0)
encoder_output_device = encoder_output.device

# decoder_tokens is None means at start of inference, in which case decoder sequence should begin with padding idx.
if decoder_tokens is None:
batch_size = encoder_output.size()[0]
encoder_output_device = encoder_output.device
decoder_tokens = (
torch.ones((batch_size, 1), device=encoder_output_device, dtype=torch.long) * self.padding_idx
)

if decoder_mask is None:
assert decoder_tokens is not None and decoder_tokens.dim() == 2
tgt_len = decoder_tokens.shape[1]
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1)
decoder_mask = decoder_mask.to(decoder_tokens.device, dtype=torch.bool)

if decoder_padding_mask is None:
decoder_padding_mask = decoder_tokens.eq(self.padding_idx)
# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
Expand All @@ -253,6 +305,8 @@ def forward(
memory_mask=encoder_mask,
tgt_key_padding_mask=decoder_padding_mask,
memory_key_padding_mask=encoder_padding_mask,
past_key_values=past_key_values,
return_past_key_values=return_past_key_values,
)

decoder_output = decoder_outputs.get("decoder_output")
Expand All @@ -267,13 +321,12 @@ def forward(
decoder_output = self.lm_head(decoder_output)
decoder_outputs["decoder_output"] = decoder_output

encoder_outputs.update(decoder_outputs)
encoder_decoder_outputs = encoder_outputs

assert torch.jit.isinstance(
encoder_decoder_outputs,
Dict[str, Union[Tensor, List[Tensor], Optional[Tensor], List[Optional[Tensor]]]],
)
# Make TorchScript pick up the correct types
encoder_decoder_outputs: DECODER_OUTPUTS_TYPE = {}
for key, val in encoder_outputs.items():
encoder_decoder_outputs[key] = val
for key, val in decoder_outputs.items():
encoder_decoder_outputs[key] = val

return encoder_decoder_outputs

Expand Down
Loading

0 comments on commit 19f8bc9

Please sign in to comment.