diff --git a/src/transformers/tokenization_bart.py b/src/transformers/tokenization_bart.py index 499895e0bda666..bf456be302135f 100644 --- a/src/transformers/tokenization_bart.py +++ b/src/transformers/tokenization_bart.py @@ -45,6 +45,109 @@ class BartTokenizer(RobertaTokenizer): "merges_file": {m: merges_url for m in _all_bart_models}, } + def prepare_seq2seq_batch( + self, + src_texts: List[str], + tgt_texts: Optional[List[str]] = None, + max_length: Optional[int] = None, + max_target_length: Optional[int] = None, + padding: str = "longest", + return_tensors: str = "None", + truncation=True, + **kwargs, + ) -> BatchEncoding: + r""" + + Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`. + + Args: + src_texts: (:obj:`List[str]`): + List of documents to summarize or source language texts. + tgt_texts: (:obj:`List[str]`, `optional`): + List of summaries or target language texts. + max_length (:obj:`int`, `optional`): + Controls the maximum length for encoder inputs (documents to summarize or source language texts). + If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum + length is required by one of the truncation/padding parameters. If the model has no specific maximum + input length (like XLNet) truncation/padding to a maximum length will be deactivated. + max_target_length (:obj:`int`, `optional`): + Controls the maximum length of decoder inputs (target language texts or summaries). + If left unset or set to :obj:`None`, this will use the max_length value. + padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`): + Activates and controls padding. Accepts the following values: + + * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a + single sequence if provided). + * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the + maximum acceptable input length for the model if that argument is not provided. + * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of + different lengths). + return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"): + If set, will return tensors instead of list of python integers. Acceptable values are: + + * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects. + * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects. + * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. + truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`): + Activates and controls truncation. Accepts the following values: + + * :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument + :obj:`max_length` or to the maximum acceptable input length for the model if that argument is not + provided. This will truncate token by token, removing a token from the longest sequence in the pair + if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to + the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or + to the maximum acceptable input length for the model if that argument is not provided. This will only + truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided. + * :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with + sequence lengths greater than the model maximum admissible input size). + **kwargs: + Additional keyword arguments passed along to :obj:`self.__call__`. + + Returns: + :class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields: + + - **input_ids** -- List of token ids to be fed to the encoder. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + - **decoder_input_ids** -- List of token ids to be fed to the decoder. + - **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder. + This does not include causal mask, which is built by the model. + + The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``, + will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys. + """ + if max_length is None: + max_length = self.model_max_length + model_inputs: BatchEncoding = self( + src_texts, + add_special_tokens=True, + return_tensors=return_tensors, + max_length=max_length, + padding=padding, + truncation=truncation, + **kwargs, + ) + if tgt_texts is None: + return model_inputs + # Process tgt_texts + if max_target_length is None: + max_target_length = max_length + decoder_inputs: BatchEncoding = self( + tgt_texts, + add_special_tokens=True, + return_tensors=return_tensors, + padding=padding, + max_length=max_target_length, + truncation=truncation, + **kwargs, + ) + for k, v in decoder_inputs.items(): + model_inputs[f"decoder_{k}"] = v + + return model_inputs + class BartTokenizerFast(RobertaTokenizerFast): # merges and vocab same as Roberta diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index e86e46812e2ecd..d28e5fc3bc820c 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -18,7 +18,8 @@ import timeout_decorator # noqa -from transformers import is_torch_available +from transformers import BatchEncoding, is_torch_available +from transformers.file_utils import cached_property from transformers.testing_utils import require_torch, slow, torch_device from .test_configuration_common import ConfigTester @@ -415,6 +416,10 @@ def _long_tensor(tok_lst): @require_torch class BartModelIntegrationTests(unittest.TestCase): + @cached_property + def default_tokenizer(self): + return BartTokenizer.from_pretrained("facebook/bart-large") + @slow def test_inference_no_head(self): model = BartModel.from_pretrained("facebook/bart-large").to(torch_device) @@ -559,6 +564,76 @@ def test_cnn_summarization_same_as_fairseq(self): # TODO(SS): run fairseq again with num_beams=2, min_len=20. # TODO(SS): add test case that hits max_length + def test_prepare_seq2seq_batch(self): + tokenizer = self.default_tokenizer + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2] + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt" + ) + self.assertIsInstance(batch, BatchEncoding) + + self.assertEqual((2, 10), batch.input_ids.shape) + self.assertEqual((2, 10), batch.attention_mask.shape) + result = batch.input_ids.tolist()[0] + self.assertListEqual(expected_src_tokens, result) + # Test that special tokens are reset + + def test_empty_target_text(self): + tokenizer = self.default_tokenizer + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt") + # check if input_ids are returned and no decoder_input_ids + self.assertIn("input_ids", batch) + self.assertIn("attention_mask", batch) + self.assertNotIn("decoder_input_ids", batch) + self.assertNotIn("decoder_attention_mask", batch) + + def test_max_target_length(self): + tokenizer = self.default_tokenizer + src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + "Another summary.", + ] + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt" + ) + self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) + + # test None max_target_length + batch = tokenizer.prepare_seq2seq_batch( + src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt" + ) + self.assertEqual(32, batch["decoder_input_ids"].shape[1]) + self.assertEqual(32, batch["decoder_attention_mask"].shape[1]) + + def test_outputs_not_longer_than_maxlen(self): + tokenizer = self.default_tokenizer + + batch = tokenizer.prepare_seq2seq_batch(["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt") + self.assertIsInstance(batch, BatchEncoding) + self.assertEqual(batch.input_ids.shape, (2, 1024)) + + def test_special_tokens(self): + tokenizer = self.default_tokenizer + src_text = ["A long paragraph for summrization."] + tgt_text = [ + "Summary of the text.", + ] + batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt") + input_ids = batch["input_ids"] + decoder_input_ids = batch["decoder_input_ids"] + self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item()) + self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item()) + self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item()) + self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item()) + @require_torch class TestSinusoidalPositionalEmbeddings(unittest.TestCase):