Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5Tokenizer] add prepare_seq2seq_batch method #6122

Merged
merged 9 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 128 additions & 1 deletion src/transformers/tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import os
import re
from shutil import copyfile
from typing import List, Optional

from .tokenization_utils import PreTrainedTokenizer
from .tokenization_utils import BatchEncoding, PreTrainedTokenizer


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,6 +97,8 @@ class T5Tokenizer(PreTrainedTokenizer):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
model_input_names = ["attention_mask"]

prefix_tokens: List[int] = []
sshleifer marked this conversation as resolved.
Show resolved Hide resolved

def __init__(
self,
vocab_file,
Expand Down Expand Up @@ -206,3 +209,127 @@ def save_vocabulary(self, save_directory):
copyfile(self.vocab_file, out_vocab_file)

return (out_vocab_file,)

def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
) -> List[int]:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
A T5 sequence has the following format, where ``X`` represents the sequence:
- ``input_ids`` (for encoder) ``X [eos]``
- ``decoder_input_ids``: (for decoder) ``[pad] X [eos]``
Pairs of sequences are not the expected use case, but they will be handled without a separator.

Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`):
Optional second list of IDs for sequence pairs.

Returns:
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
"""
if token_ids_1 is None:
return self.prefix_tokens + token_ids_0
# We don't expect to process pairs, but leave the pair logic for API consistency
return self.prefix_tokens + token_ids_0 + token_ids_1

def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = None,
truncation: bool = True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.T5Model`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.max_len
self.prefix_tokens = []
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
# set prefix_tokens for target text
self.prefix_tokens = [self.pad_token_id]
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v

self.prefix_tokens = []
return model_inputs
78 changes: 78 additions & 0 deletions tests/test_tokenization_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import os
import unittest

from transformers import BatchEncoding
from transformers.testing_utils import _torch_available
from transformers.tokenization_t5 import T5Tokenizer
from transformers.tokenization_xlnet import SPIECE_UNDERLINE

Expand All @@ -25,6 +27,8 @@

SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")

FRAMEWORK = "pt" if _torch_available else "tf"


class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):

Expand Down Expand Up @@ -102,3 +106,77 @@ def test_full_tokenizer(self):
".",
],
)

def test_prepare_seq2seq_batch(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More cases to test:

  • test max_target_length kwarg and allow it to be passed through, affect decoder_input_ids.shape[1]
  • empty tgt_texts
  • empty src_texts -> Raises something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will cover these cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty tgt_texts

for this can I just check if input_ids and attention_mask are returned and no decoder_input_ids and decoder_attention_mask ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests look great now!

)
self.assertIsInstance(batch, BatchEncoding)

self.assertEqual((2, 9), batch.input_ids.shape)
self.assertEqual((2, 9), batch.attention_mask.shape)
result = list(batch.input_ids.numpy()[0])
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset
self.assertEqual(tokenizer.prefix_tokens, [])

def test_empty_target_text(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)

def test_max_target_length(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tip: you can use

@cached_property
def default_tok(self):
    return T5Tokenizer.from_pretrained("t5-small")

To only initialize once. This barely matters for tokenizers. More usefuls for models where __init__ can take 20 seconds.

src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

def test_outputs_not_longer_than_maxlen(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")

batch = tokenizer.prepare_seq2seq_batch(
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
)
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))

def test_eos_in_input(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be cool to migrate one or more of the integration tests in test_modeling_t5.py to the new method.

tokenizer = T5Tokenizer.from_pretrained("t5-small")
src_text = ["A long paragraph for summrization. </s>"]
tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1]

batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK)

src_ids = list(batch.input_ids.numpy()[0])
tgt_ids = list(batch.decoder_input_ids.numpy()[0])

self.assertEqual(expected_src_tokens, src_ids)
self.assertEqual(expected_tgt_tokens, tgt_ids)