diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 8c2b48db7eaaa9..1be9d648aa0916 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -19,8 +19,9 @@ import os import re from shutil import copyfile +from typing import List, Optional -from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils import BatchEncoding, PreTrainedTokenizer logger = logging.getLogger(__name__) @@ -96,6 +97,8 @@ class T5Tokenizer(PreTrainedTokenizer): max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["attention_mask"] + prefix_tokens: List[int] = [] + def __init__( self, vocab_file, @@ -206,3 +209,127 @@ def save_vocabulary(self, save_directory): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. The special tokens depend on calling source text or target text. + A T5 sequence has the following format, where ``X`` represents the sequence: + - ``input_ids`` (for encoder) ``X [eos]`` + - ``decoder_input_ids``: (for decoder) ``[pad] X [eos]`` + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = None, + truncation: bool = True, + **kwargs, + ) -> BatchEncoding: + r""" + Prepare a batch that can be passed directly to an instance of :class:`~transformers.T5Model`. + Args: + src_texts: (:obj:`List[str]`): + List of documents to summarize or source language texts. + tgt_texts: (:obj:`List[str]`, `optional`): + List of summaries or target language texts. + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts). + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries). + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to :obj:`self.__call__`. + Returns: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + """ + if max_length is None: + max_length = self.max_len + self.prefix_tokens = [] + model_inputs: BatchEncoding = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + # set prefix_tokens for target text + self.prefix_tokens = [self.pad_token_id] + decoder_inputs: BatchEncoding = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + ) + for k, v in decoder_inputs.items(): + model_inputs[f"decoder_{k}"] = v + + self.prefix_tokens = [] + return model_inputs diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index bee735921c7518..ee55f5b5278d28 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -17,6 +17,8 @@ import os import unittest +from transformers import BatchEncoding +from transformers.testing_utils import _torch_available from transformers.tokenization_t5 import T5Tokenizer from transformers.tokenization_xlnet import SPIECE_UNDERLINE @@ -25,6 +27,8 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") +FRAMEWORK = "pt" if _torch_available else "tf" + class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): @@ -102,3 +106,77 @@ def test_full_tokenizer(self): ".", ], ) + + def test_prepare_seq2seq_batch(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5] + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 9), batch.input_ids.shape) + self.assertEqual((2, 9), batch.attention_mask.shape) + result = list(batch.input_ids.numpy()[0]) + self.assertListEqual(expected_src_tokens, result) + # Test that special tokens are reset + self.assertEqual(tokenizer.prefix_tokens, []) + + def test_empty_target_text(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) + # check if input_ids are returned and no decoder_input_ids + self.assertIn("input_ids", batch) + self.assertIn("attention_mask", batch) + self.assertNotIn("decoder_input_ids", batch) + self.assertNotIn("decoder_attention_mask", batch) + + def test_max_target_length(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK + ) + self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) + + # test None max_target_length + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK + ) + self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) + + def test_outputs_not_longer_than_maxlen(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + batch = tokenizer.prepare_seq2seq_batch( + ["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch, BatchEncoding) + self.assertEqual(batch.input_ids.shape, (2, 512)) + + def test_eos_in_input(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization. "] + tgt_text = ["Summary of the text. "] + expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] + expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1] + + batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) + + src_ids = list(batch.input_ids.numpy()[0]) + tgt_ids = list(batch.decoder_input_ids.numpy()[0]) + + self.assertEqual(expected_src_tokens, src_ids) + self.assertEqual(expected_tgt_tokens, tgt_ids)