Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BartTokenizer] add prepare s2s batch #6212

Merged
merged 11 commits into from
Aug 17, 2020
59 changes: 59 additions & 0 deletions src/transformers/tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,65 @@ class BartTokenizer(RobertaTokenizer):
"merges_file": {m: merges_url for m in _all_bart_models},
}

def prepare_seq2seq_batch(
self,
src_texts: List[str],
tgt_texts: Optional[List[str]] = None,
max_length: Optional[int] = None,
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of BartModel.
Copy link
Collaborator

@sgugger sgugger Aug 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""Prepare a batch that can be passed directly to an instance of BartModel.
"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.

(nit)

Args:
src_texts (:obj:`List[str]`):
list of src texts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
list of src texts
List of input texts.

tgt_texts (:obj:`List[str]`, `optional`):
list of tgt texts
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
list of tgt texts
List of target texts.

max_length (:obj:`int`, `optional`):
maximum length for the source text which defers to the config value of 1024 for facebook/bart*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
maximum length for the source text which defers to the config value of 1024 for facebook/bart*
Maximum length for the source texts. If not provided, this will use the predefined model maximum length.

Don't mention a specific model here since several could be used.

max_target_length (:obj:`int`, `optional`):
maximum length for the target text which defers to the config value of 1024 for facebook/bart*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
maximum length for the target text which defers to the config value of 1024 for facebook/bart*
Maximum length for the target texts. If not provided, this will use the predefined model maximum length.

padding (:obj:`str`, `optional`, defaults to "longest"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be bool, string or PaddingStrategy I believe? See documentation of PreTrainedTokenizerBase.__call__:

            padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
                Activates and controls padding. Accepts the following values:

                * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
                  single sequence if provided).
                * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
                  maximum acceptable input length for the model if that argument is not provided.
                * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
                  different lengths).

strategy for padding `input_ids` and `decoder_input_ids`. Should be "max_length" or "longest".
return_tensors (:obj:`str`, `optional`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be string or TensorType (same as above, just copy from PreTrainedTokenizerBase.__call__):

            return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
                If set, will return tensors instead of list of python integers. Acceptable values are:

                * :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
                * :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
                * :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.

Can be set to ‘tf’, ‘pt’ or ‘np’ to return respectively TensorFlow `tf.constant`, PyTorch `torch.Tensor` or Numpy :oj: np.ndarray instead of a list of python integers.
**kwargs:
passed to self.__call__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
passed to self.__call__
Additional keyword arguments passed along to :obj:`self.__call__`.

Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a new docstring on master/ tokenization_utils_base.py that you may want to (a) reuse or (b) modify.

:class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
"""
if max_length is None:
max_length = self.model_max_length
model_inputs: BatchEncoding = self(
src_texts,
add_special_tokens=True,
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=True,
**kwargs,
)
if tgt_texts is None:
return model_inputs
# Process tgt_texts
if max_target_length is None:
max_target_length = max_length
decoder_inputs: BatchEncoding = self(
tgt_texts,
add_special_tokens=True,
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
**kwargs,
)
for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v

return model_inputs


class BartTokenizerFast(RobertaTokenizerFast):
# merges and vocab same as Roberta
Expand Down
77 changes: 76 additions & 1 deletion tests/test_modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

import timeout_decorator # noqa

from transformers import is_torch_available
from transformers import BatchEncoding, is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
Expand Down Expand Up @@ -415,6 +416,10 @@ def _long_tensor(tok_lst):

@require_torch
class BartModelIntegrationTests(unittest.TestCase):
@cached_property
def default_tokenizer(self):
return BartTokenizer.from_pretrained("facebook/bart-large")

@slow
def test_inference_no_head(self):
model = BartModel.from_pretrained("facebook/bart-large").to(torch_device)
Expand Down Expand Up @@ -559,6 +564,76 @@ def test_cnn_summarization_same_as_fairseq(self):
# TODO(SS): run fairseq again with num_beams=2, min_len=20.
# TODO(SS): add test case that hits max_length

def test_prepare_seq2seq_batch(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
expected_src_tokens = [0, 250, 251, 17818, 13, 32933, 21645, 1258, 4, 2]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors="pt"
)
self.assertIsInstance(batch, BatchEncoding)

self.assertEqual((2, 10), batch.input_ids.shape)
self.assertEqual((2, 10), batch.attention_mask.shape)
result = batch.input_ids.tolist()[0]
self.assertListEqual(expected_src_tokens, result)
# Test that special tokens are reset

def test_empty_target_text(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors="pt")
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertIn("attention_mask", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)

def test_max_target_length(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
tgt_text = [
"Summary of the text.",
"Another summary.",
]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_target_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["decoder_input_ids"].shape[1])
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])

def test_outputs_not_longer_than_maxlen(self):
tokenizer = self.default_tokenizer

batch = tokenizer.prepare_seq2seq_batch(["I am a small frog" * 1024, "I am a small frog"], return_tensors="pt")
self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 1024))

def test_special_tokens(self):
tokenizer = self.default_tokenizer
src_text = ["A long paragraph for summrization."]
tgt_text = [
"Summary of the text.",
]
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors="pt")
input_ids = batch["input_ids"]
decoder_input_ids = batch["decoder_input_ids"]
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, 0] == tokenizer.bos_token_id).all().item())
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
self.assertTrue((decoder_input_ids[:, -1] == tokenizer.eos_token_id).all().item())


@require_torch
class TestSinusoidalPositionalEmbeddings(unittest.TestCase):
Expand Down