Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BartTokenizer] add prepare s2s batch #6212

Merged
merged 11 commits into from
Aug 17, 2020
52 changes: 38 additions & 14 deletions src/transformers/tokenization_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,40 +53,64 @@ def prepare_seq2seq_batch(
max_target_length: Optional[int] = None,
padding: str = "longest",
return_tensors: str = "None",
truncation=True,
**kwargs,
) -> BatchEncoding:
r"""
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`.

Args:
src_texts (:obj:`List[str]`):
List of input texts.
tgt_texts (:obj:`List[str]`, `optional`):
List of target texts.
src_texts: (:obj:`list`):
list of documents to summarize or source language texts
tgt_texts: (:obj:`list`, `optional`):
list of tgt language texts or summaries.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotations here were better before. The docstrings should not have abbreviations (and start with a capital and end with a full stop nit).

Copy link
Contributor Author

@patil-suraj patil-suraj Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aah, I blindly copy pasted, will make the changes. Also can you tell me where the doc error is coming from ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're missing new lines before your lists I'd say.

max_length (:obj:`int`, `optional`):
Maximum length for the source texts. If not provided, this will use the predefined model maximum length.
Controls the maximum length for encoder inputs (documents to summarize or source language texts)
If left unset or set to :obj:`None`, this will use the predefined model maximum length if a maximum
length is required by one of the truncation/padding parameters. If the model has no specific maximum
input length (like XLNet) truncation/padding to a maximum length will be deactivated.
max_target_length (:obj:`int`, `optional`):
Maximum length for the target texts. If not provided, this will use the predefined model maximum length.
Controls the maximum length of decoder inputs (target language texts or summaries)
If left unset or set to :obj:`None`, this will use the max_length value.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:

Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`, defaults to "pt"):
If set, will return tensors instead of list of python integers. Acceptable values are:

* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
truncation (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.TruncationStrategy`, `optional`, defaults to :obj:`True`):
Activates and controls truncation. Accepts the following values:
* :obj:`True` or :obj:`'longest_first'`: Truncate to a maximum length specified with the argument
:obj:`max_length` or to the maximum acceptable input length for the model if that argument is not
provided. This will truncate token by token, removing a token from the longest sequence in the pair
if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_first'`: Truncate to a maximum length specified with the argument :obj:`max_length` or to
the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`'only_second'`: Truncate to a maximum length specified with the argument :obj:`max_length` or
to the maximum acceptable input length for the model if that argument is not provided. This will only
truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
* :obj:`False` or :obj:`'do_not_truncate'` (default): No truncation (i.e., can output batch with
sequence lengths greater than the model maximum admissible input size).
**kwargs:
Additional keyword arguments passed along to :obj:`self.__call__`.
Returns:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a new docstring on master/ tokenization_utils_base.py that you may want to (a) reuse or (b) modify.

:class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
:class:`~transformers.BatchEncoding`: A :class:`~transformers.BatchEncoding` with the following fields:
- **input_ids** -- List of token ids to be fed to the encoder.
- **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
- **decoder_input_ids** -- List of token ids to be fed to the decoder.
- **decoder_attention_mask** -- List of indices specifying which tokens should be attended to by the decoder.
This does not include causal mask, which is built by the model.

The full set of keys ``[input_ids, attention_mask, decoder_input_ids, decoder_attention_mask]``,
will only be returned if tgt_texts is passed. Otherwise, input_ids, attention_mask will be the only keys.
"""
if max_length is None:
max_length = self.model_max_length
Expand All @@ -96,7 +120,7 @@ def prepare_seq2seq_batch(
return_tensors=return_tensors,
max_length=max_length,
padding=padding,
truncation=True,
truncation=truncation,
**kwargs,
)
if tgt_texts is None:
Expand All @@ -110,7 +134,7 @@ def prepare_seq2seq_batch(
return_tensors=return_tensors,
padding=padding,
max_length=max_target_length,
truncation=True,
truncation=truncation,
**kwargs,
)
for k, v in decoder_inputs.items():
Expand Down