Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BartTokenizer] add prepare s2s batch #6212

Merged
merged 11 commits into from
Aug 17, 2020
73 changes: 73 additions & 0 deletions src/transformers/tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,79 @@ class BartTokenizer(RobertaTokenizer):
"merges_file": {m: merges_url for m in _all_bart_models},
}

def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.

Args:
src_texts (:obj:`List[str]`):
List of input texts.
tgt_texts (:obj:`List[str]`, `optional`):
List of target texts.
max_length (:obj:`int`, `optional`):
Maximum length for the source texts. If not provided, this will use the predefined model maximum length.
max_target_length (:obj:`int`, `optional`):
Maximum length for the target texts. If not provided, this will use the predefined model maximum length.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:

* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:

* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a new docstring on master/ tokenization_utils_base.py that you may want to (a) reuse or (b) modify.

:class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
"""
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=True,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v

return model_inputs


class BartTokenizerFast(RobertaTokenizerFast):
# merges and vocab same as Roberta
Expand Down
77 changes: 76 additions & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import timeout_decorator # noqa

from transformers import is_torch_available
from transformers import BatchEncoding, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -415,6 +416,10 @@ def _long_tensor(tok_lst):

@require_torch
class BartModelIntegrationTests(unittest.TestCase):
@cached_property
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")

@slow
def test_inference_no_head(self):
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
Expand Down Expand Up @@ -559,6 +564,76 @@ def test_cnn_summarization_same_as_fairseq(self):
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length

def test_prepare_seq2seq_batch(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)

self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset

def test_empty_target_text(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)

def test_max_target_length(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

def test_outputs_not_longer_than_maxlen(self):
tokenizer = self.default_tokenizer

batch = tokenizer.prepare_seq2seq_batch(["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))

def test_special_tokens(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
decoder_input_ids = batch["decoder_input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())


@require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
Expand Down