Skip to content

Commit

Permalink
[BartTokenizer] add prepare s2s batch (#6212)
Browse files Browse the repository at this point in the history
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
  • Loading branch information
patil-suraj and sgugger authored Aug 17, 2020
1 parent 84d3331 commit 2a77813
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 1 deletion.
103 changes: 103 additions & 0 deletions src/transformers/tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,109 @@ class BartTokenizer(RobertaTokenizer):
"merges_file": {m: merges_url for m in _all_bart_models},
}

def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.
Args:
src_texts: (:obj:`List[str]`):
List of documents to summarize or source language texts.
tgt_texts: (:obj:`List[str]`, `optional`):
List of summaries or target language texts.
max_length (:obj:`int`, `optional`):
Controls the maximum length for encoder inputs (documents to summarize or source language texts).
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Controls the maximum length of decoder inputs (target language texts or summaries).
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.
The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=truncation,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v

return model_inputs


class BartTokenizerFast(RobertaTokenizerFast):
# merges and vocab same as Roberta
Expand Down
77 changes: 76 additions & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import timeout_decorator # noqa

from transformers import is_torch_available
from transformers import BatchEncoding, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -416,6 +417,10 @@ def _long_tensor(tok_lst):

@require_torch
class BartModelIntegrationTests(unittest.TestCase):
@cached_property
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")

@slow
def test_inference_no_head(self):
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
Expand Down Expand Up @@ -559,6 +564,76 @@ def test_cnn_summarization_same_as_fairseq(self):
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length

def test_prepare_seq2seq_batch(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)

self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset

def test_empty_target_text(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)

def test_max_target_length(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

def test_outputs_not_longer_than_maxlen(self):
tokenizer = self.default_tokenizer

batch = tokenizer.prepare_seq2seq_batch(["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))

def test_special_tokens(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
decoder_input_ids = batch["decoder_input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())


@require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
Expand Down

0 comments on commit 2a77813

Please sign in to comment.