From 1931f91ddd50f349c48c38de4c24503f9b4df2dd Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 29 Jul 2020 14:36:14 +0530 Subject: [PATCH 1/9] add prepare_seq2seq_batch method --- src/transformers/tokenization_t5.py | 93 ++++++++++++++++++++++++++++- tests/test_tokenization_t5.py | 26 ++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 8c2b48db7ea..8aec181858f 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -19,8 +19,9 @@ import os import re from shutil import copyfile +from typing import List, Optional -from .tokenization_utils import PreTrainedTokenizer +from .tokenization_utils import BatchEncoding, PreTrainedTokenizer logger = logging.getLogger(__name__) @@ -96,6 +97,9 @@ class T5Tokenizer(PreTrainedTokenizer): max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES model_input_names = ["attention_mask"] + prefix_tokens: List[int] = [] + suffix_tokens: List[int] = [] + def __init__( self, vocab_file, @@ -206,3 +210,90 @@ def save_vocabulary(self, save_directory): copyfile(self.vocab_file, out_vocab_file) return (out_vocab_file,) + + def build_inputs_with_special_tokens( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Build model inputs from a sequence or a pair of sequence for sequence classification tasks + by concatenating and adding special tokens. The special tokens depend on calling source text or target text. + An T5 sequence has the following format, where ``X`` represents the sequence: + - ``input_ids`` (for encoder) ``X [eos]`` + - ``decoder_input_ids``: (for decoder) ``[pad] X [eos]`` + Pairs of sequences are not the expected use case, but they will be handled without a separator. + + Args: + token_ids_0 (:obj:`List[int]`): + List of IDs to which the special tokens will be added + token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + Optional second list of IDs for sequence pairs. + + Returns: + :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + """ + if token_ids_1 is None: + return self.prefix_tokens + token_ids_0 + self.suffix_tokens + # We don't expect to process pairs, but leave the pair logic for API consistency + return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = None, + **kwargs, + ) -> BatchEncoding: + """Prepare a batch that can be passed directly to an instance of T5Model. + Arguments: + src_texts: list of src language texts + tgt_texts: list of tgt language texts + max_length: (default=None, which defers to the config value of 512 for t5* + padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. + **kwargs: passed to self.__call__ + + Returns: + :obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. + """ + if max_length is None: + max_length = self.max_len + self.set_src_special_tokens() + model_inputs: BatchEncoding = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=True, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + self.set_tgt_special_tokens() + decoder_inputs: BatchEncoding = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=True, + **kwargs, + ) + for k, v in decoder_inputs.items(): + model_inputs[f"decoder_{k}"] = v + + self.set_src_special_tokens() + return model_inputs + + def set_src_special_tokens(self) -> None: + self.prefix_tokens = [] + self.suffix_tokens = [self.eos_token_id] + + def set_tgt_special_tokens(self) -> None: + self.prefix_tokens = [self.pad_token_id] + self.suffix_tokens = [self.eos_token_id] diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index bee735921c7..5d1eeafd0a2 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -17,6 +17,8 @@ import os import unittest +from transformers import BatchEncoding +from transformers.testing_utils import _torch_available from transformers.tokenization_t5 import T5Tokenizer from transformers.tokenization_xlnet import SPIECE_UNDERLINE @@ -25,6 +27,8 @@ SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model") +FRAMEWORK = "pt" if _torch_available else "tf" + class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase): @@ -102,3 +106,25 @@ def test_full_tokenizer(self): ".", ], ) + + def test_prepare_seq2seq_batch(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 10), batch.input_ids.shape) + self.assertEqual((2, 10), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(expected_src_tokens, result) + self.assertEqual(1, batch.decoder_input_ids[0, -1]) # EOS + # Test that special tokens are reset + self.assertEqual(tokenizer.prefix_tokens, []) + self.assertEqual(tokenizer.suffix_tokens, [tokenizer.eos_token_id]) From ab3ce4eeae0acdd24021ca8857c3840b858aac26 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 29 Jul 2020 21:04:08 +0530 Subject: [PATCH 2/9] remove suffix_tokens --- src/transformers/tokenization_t5.py | 7 ++----- tests/test_tokenization_t5.py | 9 ++++----- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 8aec181858f..79d3aea2f35 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -98,7 +98,6 @@ class T5Tokenizer(PreTrainedTokenizer): model_input_names = ["attention_mask"] prefix_tokens: List[int] = [] - suffix_tokens: List[int] = [] def __init__( self, @@ -232,9 +231,9 @@ def build_inputs_with_special_tokens( :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. """ if token_ids_1 is None: - return self.prefix_tokens + token_ids_0 + self.suffix_tokens + return self.prefix_tokens + token_ids_0 # We don't expect to process pairs, but leave the pair logic for API consistency - return self.prefix_tokens + token_ids_0 + token_ids_1 + self.suffix_tokens + return self.prefix_tokens + token_ids_0 + token_ids_1 def prepare_seq2seq_batch( self, @@ -292,8 +291,6 @@ def prepare_seq2seq_batch( def set_src_special_tokens(self) -> None: self.prefix_tokens = [] - self.suffix_tokens = [self.eos_token_id] def set_tgt_special_tokens(self) -> None: self.prefix_tokens = [self.pad_token_id] - self.suffix_tokens = [self.eos_token_id] diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 5d1eeafd0a2..d5bda218d38 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -114,17 +114,16 @@ def test_prepare_seq2seq_batch(self): "Summary of the text.", "Another summary.", ] - expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] + expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5] batch = tokenizer.prepare_seq2seq_batch( src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK ) self.assertIsInstance(batch, BatchEncoding) - self.assertEqual((2, 10), batch.input_ids.shape) - self.assertEqual((2, 10), batch.attention_mask.shape) + self.assertEqual((2, 9), batch.input_ids.shape) + self.assertEqual((2, 9), batch.attention_mask.shape) result = batch.input_ids.tolist()[0] self.assertListEqual(expected_src_tokens, result) - self.assertEqual(1, batch.decoder_input_ids[0, -1]) # EOS # Test that special tokens are reset self.assertEqual(tokenizer.prefix_tokens, []) - self.assertEqual(tokenizer.suffix_tokens, [tokenizer.eos_token_id]) + From 7fcc4a797694a6c7cf2d7e057a42cd7c015c73ce Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Wed, 29 Jul 2020 21:59:51 +0530 Subject: [PATCH 3/9] more tests --- tests/test_tokenization_t5.py | 53 +++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index d5bda218d38..8344429295c 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -127,3 +127,56 @@ def test_prepare_seq2seq_batch(self): # Test that special tokens are reset self.assertEqual(tokenizer.prefix_tokens, []) + def test_empty_target_text(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) + # check if input_ids are returned and no decoder_input_ids + self.assertIn("input_ids", batch.keys()) + self.assertIn("attention_mask", batch.keys()) + self.assertNotIn("decoder_input_ids", batch.keys()) + self.assertNotIn("decoder_attention_mask", batch.keys()) + + def test_max_target_length(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors=FRAMEWORK + ) + self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) + + # test None max_target_length + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors=FRAMEWORK + ) + self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) + + def test_outputs_not_longer_than_maxlen(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + + batch = tokenizer.prepare_seq2seq_batch( + ["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK + ) + self.assertIsInstance(batch, BatchEncoding) + self.assertEqual(batch.input_ids.shape, (2, 512)) + + def test_eos_in_input(self): + tokenizer = T5Tokenizer.from_pretrained("t5-small") + src_text = ["A long paragraph for summrization. "] + tgt_text = ["Summary of the text. "] + expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] + expected_tgt_tokens = [0, 20698, 13, 8, 1499, 5, 1] + + batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) + + src_ids = batch.input_ids.tolist()[0] + tgt_ids = batch.decoder_input_ids.tolist()[0] + + self.assertEqual(expected_src_tokens, src_ids) + self.assertEqual(expected_tgt_tokens, tgt_ids) From 10f5898d7cfb845d727796ca3677621d6883e39c Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Fri, 31 Jul 2020 23:03:41 +0530 Subject: [PATCH 4/9] better assertIn --- tests/test_tokenization_t5.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 8344429295c..40e80222c07 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -132,10 +132,10 @@ def test_empty_target_text(self): src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) # check if input_ids are returned and no decoder_input_ids - self.assertIn("input_ids", batch.keys()) - self.assertIn("attention_mask", batch.keys()) - self.assertNotIn("decoder_input_ids", batch.keys()) - self.assertNotIn("decoder_attention_mask", batch.keys()) + self.assertIn("input_ids", batch) + self.assertIn("attention_mask", batch) + self.assertNotIn("decoder_input_ids", batch) + self.assertNotIn("decoder_attention_mask", batch) def test_max_target_length(self): tokenizer = T5Tokenizer.from_pretrained("t5-small") From f8cd95f2aac02a9b1b1f030dd2f329b5d7233d92 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sat, 1 Aug 2020 11:52:12 +0530 Subject: [PATCH 5/9] fix docs --- src/transformers/tokenization_t5.py | 30 +++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index 79d3aea2f35..c8130ce6b6e 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -216,7 +216,7 @@ def build_inputs_with_special_tokens( """ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and adding special tokens. The special tokens depend on calling source text or target text. - An T5 sequence has the following format, where ``X`` represents the sequence: + A T5 sequence has the following format, where ``X`` represents the sequence: - ``input_ids`` (for encoder) ``X [eos]`` - ``decoder_input_ids``: (for decoder) ``[pad] X [eos]`` Pairs of sequences are not the expected use case, but they will be handled without a separator. @@ -224,11 +224,11 @@ def build_inputs_with_special_tokens( Args: token_ids_0 (:obj:`List[int]`): List of IDs to which the special tokens will be added - token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): + token_ids_1 (:obj:`List[int]`, `optional`): Optional second list of IDs for sequence pairs. Returns: - :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. + :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. """ if token_ids_1 is None: return self.prefix_tokens + token_ids_0 @@ -246,15 +246,25 @@ def prepare_seq2seq_batch( **kwargs, ) -> BatchEncoding: """Prepare a batch that can be passed directly to an instance of T5Model. - Arguments: - src_texts: list of src language texts - tgt_texts: list of tgt language texts - max_length: (default=None, which defers to the config value of 512 for t5* - padding: strategy for padding input_ids and decoder_input_ids. Should be max_length or longest. - **kwargs: passed to self.__call__ + + Args: + src_texts (:obj:`List[str]`): + list of src texts + tgt_texts (:obj:`List[str]`, `optional`): + list of tgt texts + max_length (:obj:`int`, `optional`): + maximum length for the source text which defers to the config value of 512 for t5* + max_target_length (:obj:`int`, `optional`): + maximum length for the target text which defers to the config value of 512 for t5* + padding (:obj:`str`, `optional`, defaults to "longest"): + strategy for padding `input_ids` and `decoder_input_ids`. Should be "max_length" or "longest". + return_tensors (:obj:`str`, `optional`): + Can be set to ‘tf’, ‘pt’ or ‘np’ to return respectively TensorFlow `tf.constant`, PyTorch `torch.Tensor` or Numpy :oj: np.ndarray instead of a list of python integers. + **kwargs: + passed to self.__call__ Returns: - :obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. + :class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. """ if max_length is None: max_length = self.max_len From fb788430e01a51da922eefbd70b2f32c3dc4db61 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sat, 1 Aug 2020 11:53:36 +0530 Subject: [PATCH 6/9] cleanup --- src/transformers/tokenization_t5.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index c8130ce6b6e..b37728094b8 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -268,7 +268,7 @@ def prepare_seq2seq_batch( """ if max_length is None: max_length = self.max_len - self.set_src_special_tokens() + self.prefix_tokens = [] model_inputs: BatchEncoding = self( src_texts, add_special_tokens=True, @@ -283,7 +283,8 @@ def prepare_seq2seq_batch( # Process tgt_texts if max_target_length is None: max_target_length = max_length - self.set_tgt_special_tokens() + # set prefix_tokens for target text + self.prefix_tokens = [self.pad_token_id] decoder_inputs: BatchEncoding = self( tgt_texts, add_special_tokens=True, @@ -296,11 +297,6 @@ def prepare_seq2seq_batch( for k, v in decoder_inputs.items(): model_inputs[f"decoder_{k}"] = v - self.set_src_special_tokens() - return model_inputs - - def set_src_special_tokens(self) -> None: self.prefix_tokens = [] + return model_inputs - def set_tgt_special_tokens(self) -> None: - self.prefix_tokens = [self.pad_token_id] From fbe7ea16367c3c9d02eed8c3b13bda095f1f9582 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Sat, 1 Aug 2020 12:02:38 +0530 Subject: [PATCH 7/9] fix style --- src/transformers/tokenization_t5.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index b37728094b8..b90e0c85f0d 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -246,7 +246,7 @@ def prepare_seq2seq_batch( **kwargs, ) -> BatchEncoding: """Prepare a batch that can be passed directly to an instance of T5Model. - + Args: src_texts (:obj:`List[str]`): list of src texts @@ -299,4 +299,3 @@ def prepare_seq2seq_batch( self.prefix_tokens = [] return model_inputs - From cda6984fdae89e13e142b386903b0994ba6d2fb7 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 17 Aug 2020 22:30:14 +0530 Subject: [PATCH 8/9] fix tests --- tests/test_tokenization_t5.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_tokenization_t5.py b/tests/test_tokenization_t5.py index 40e80222c07..ee55f5b5278 100644 --- a/tests/test_tokenization_t5.py +++ b/tests/test_tokenization_t5.py @@ -122,7 +122,7 @@ def test_prepare_seq2seq_batch(self): self.assertEqual((2, 9), batch.input_ids.shape) self.assertEqual((2, 9), batch.attention_mask.shape) - result = batch.input_ids.tolist()[0] + result = list(batch.input_ids.numpy()[0]) self.assertListEqual(expected_src_tokens, result) # Test that special tokens are reset self.assertEqual(tokenizer.prefix_tokens, []) @@ -175,8 +175,8 @@ def test_eos_in_input(self): batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK) - src_ids = batch.input_ids.tolist()[0] - tgt_ids = batch.decoder_input_ids.tolist()[0] + src_ids = list(batch.input_ids.numpy()[0]) + tgt_ids = list(batch.decoder_input_ids.numpy()[0]) self.assertEqual(expected_src_tokens, src_ids) self.assertEqual(expected_tgt_tokens, tgt_ids) From a84bb5b9e98104e6398271ccc93a23aad7a1e2dc Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 17 Aug 2020 23:05:05 +0530 Subject: [PATCH 9/9] better doc --- src/transformers/tokenization_t5.py | 68 +++++++++++++++++++++-------- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/src/transformers/tokenization_t5.py b/src/transformers/tokenization_t5.py index b90e0c85f0d..1be9d648aa0 100644 --- a/src/transformers/tokenization_t5.py +++ b/src/transformers/tokenization_t5.py @@ -243,28 +243,62 @@ def prepare_seq2seq_batch( max_target_length: Optional[int] = None, padding: str = "longest", return_tensors: str = None, + truncation: bool = True, **kwargs, ) -> BatchEncoding: - """Prepare a batch that can be passed directly to an instance of T5Model. - + r""" + Prepare a batch that can be passed directly to an instance of :class:`~transformers.T5Model`. Args: - src_texts (:obj:`List[str]`): - list of src texts - tgt_texts (:obj:`List[str]`, `optional`): - list of tgt texts + src_texts: (:obj:`List[str]`): + List of documents to summarize or source language texts. + tgt_texts: (:obj:`List[str]`, `optional`): + List of summaries or target language texts. max_length (:obj:`int`, `optional`): - maximum length for the source text which defers to the config value of 512 for t5* + Controls the maximum length for encoder inputs (documents to summarize or source language texts). + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. max_target_length (:obj:`int`, `optional`): - maximum length for the target text which defers to the config value of 512 for t5* - padding (:obj:`str`, `optional`, defaults to "longest"): - strategy for padding `input_ids` and `decoder_input_ids`. Should be "max_length" or "longest". - return_tensors (:obj:`str`, `optional`): - Can be set to ‘tf’, ‘pt’ or ‘np’ to return respectively TensorFlow `tf.constant`, PyTorch `torch.Tensor` or Numpy :oj: np.ndarray instead of a list of python integers. + Controls the maximum length of decoder inputs (target language texts or summaries). + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). **kwargs: - passed to self.__call__ - + Additional keyword arguments passed along to :obj:`self.__call__`. Returns: - :class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. """ if max_length is None: max_length = self.max_len @@ -275,7 +309,7 @@ def prepare_seq2seq_batch( return_tensors=return_tensors, max_length=max_length, padding=padding, - truncation=True, + truncation=truncation, **kwargs, ) if tgt_texts is None: @@ -291,7 +325,7 @@ def prepare_seq2seq_batch( return_tensors=return_tensors, padding=padding, max_length=max_target_length, - truncation=True, + truncation=truncation, **kwargs, ) for k, v in decoder_inputs.items():