-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BartTokenizer] add prepare s2s batch #6212
[BartTokenizer] add prepare s2s batch #6212
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, nice tests!
return_tensors: str = "None", | ||
**kwargs, | ||
) -> BatchEncoding: | ||
"""Prepare a batch that can be passed directly to an instance of BartModel. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"""Prepare a batch that can be passed directly to an instance of BartModel. | |
""" | |
Prepare a batch that can be passed directly to an instance of :class:`~transformers.BartModel`. |
(nit)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Lots of nits on the docs: in general if the argument you are documenting is passed along to another method, don't hesitate to copy-paste the docstring from that method. And when documenting an argument, don't use abbreviations and make full sentences :-)
maximum length for the source text which defers to the config value of 1024 for facebook/bart* | ||
max_target_length (:obj:`int`, `optional`): | ||
maximum length for the target text which defers to the config value of 1024 for facebook/bart* | ||
padding (:obj:`str`, `optional`, defaults to "longest"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be bool, string or PaddingStrategy I believe? See documentation of PreTrainedTokenizerBase.__call__
:
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`False`):
Activates and controls padding. Accepts the following values:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a
single sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
maximum length for the target text which defers to the config value of 1024 for facebook/bart* | ||
padding (:obj:`str`, `optional`, defaults to "longest"): | ||
strategy for padding `input_ids` and `decoder_input_ids`. Should be "max_length" or "longest". | ||
return_tensors (:obj:`str`, `optional`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be string or TensorType (same as above, just copy from PreTrainedTokenizerBase.__call__
):
return_tensors (:obj:`str` or :class:`~transformers.tokenization_utils_base.TensorType`, `optional`):
If set, will return tensors instead of list of python integers. Acceptable values are:
* :obj:`'tf'`: Return TensorFlow :obj:`tf.constant` objects.
* :obj:`'pt'`: Return PyTorch :obj:`torch.Tensor` objects.
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects.
return_tensors (:obj:`str`, `optional`): | ||
Can be set to ‘tf’, ‘pt’ or ‘np’ to return respectively TensorFlow `tf.constant`, PyTorch `torch.Tensor` or Numpy :oj: np.ndarray instead of a list of python integers. | ||
**kwargs: | ||
passed to self.__call__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
passed to self.__call__ | |
Additional keyword arguments passed along to :obj:`self.__call__`. |
"""Prepare a batch that can be passed directly to an instance of BartModel. | ||
Args: | ||
src_texts (:obj:`List[str]`): | ||
list of src texts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list of src texts | |
List of input texts. |
src_texts (:obj:`List[str]`): | ||
list of src texts | ||
tgt_texts (:obj:`List[str]`, `optional`): | ||
list of tgt texts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
list of tgt texts | |
List of target texts. |
tgt_texts (:obj:`List[str]`, `optional`): | ||
list of tgt texts | ||
max_length (:obj:`int`, `optional`): | ||
maximum length for the source text which defers to the config value of 1024 for facebook/bart* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maximum length for the source text which defers to the config value of 1024 for facebook/bart* | |
Maximum length for the source texts. If not provided, this will use the predefined model maximum length. |
Don't mention a specific model here since several could be used.
max_length (:obj:`int`, `optional`): | ||
maximum length for the source text which defers to the config value of 1024 for facebook/bart* | ||
max_target_length (:obj:`int`, `optional`): | ||
maximum length for the target text which defers to the config value of 1024 for facebook/bart* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maximum length for the target text which defers to the config value of 1024 for facebook/bart* | |
Maximum length for the target texts. If not provided, this will use the predefined model maximum length. |
Thanks @sgugger for theses helpful suggestions!. Will keep these in mind for future PRs. |
@sgugger , can you help me with the build_doc failure ? Thanks! |
Fixed, you needed to have the beginning of the docstrings on a new line for sphinx to understand the indentation. |
Codecov Report
@@ Coverage Diff @@
## master #6212 +/- ##
==========================================
+ Coverage 78.29% 78.35% +0.05%
==========================================
Files 146 146
Lines 26607 26619 +12
==========================================
+ Hits 20832 20856 +24
+ Misses 5775 5763 -12
Continue to review full report at Codecov.
|
Thanks @sgugger ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks useful! It would be nice to upstream it so that other sequence to sequence models may make use of it. Also, you added it to BartTokenizer
and not the fast tokenizer, is there a reason for that?
If we consider this to be a conversion to an s2s task, I think this would be better suited in a s2s processor like we have for squad_convert_examples_to_features
or glue_convert...
. I don't see any reason of having it linked to BART especially.
Pinging @thomwolf
Hi @LysandreJik
No, just forgot to add that. Upstream will be useful but we will need handle few cases differently for each seq2seq model i.e in case of t5 we manually need to add the deocder_start_token_id as T5 don't have a |
hi @sshleifer , @LysandreJik any update ? |
* :obj:`'np'`: Return Numpy :obj:`np.ndarray` objects. | ||
**kwargs: | ||
Additional keyword arguments passed along to :obj:`self.__call__`. | ||
Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a new docstring on master/ tokenization_utils_base.py that you may want to (a) reuse or (b) modify.
…ransformers into bart-tok-s2s-batch
…ransformers into bart-tok-s2s-batch
@sshleifer updated the docs. |
src_texts: (:obj:`list`): | ||
list of documents to summarize or source language texts | ||
tgt_texts: (:obj:`list`, `optional`): | ||
list of tgt language texts or summaries. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotations here were better before. The docstrings should not have abbreviations (and start with a capital and end with a full stop nit).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aah, I blindly copy pasted, will make the changes. Also can you tell me where the doc error is coming from ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're missing new lines before your lists I'd say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM as-is (after the doc building test fixes), but we really should add the same method on the Fast tokenizer. Having parity on both tokenizers is one of our goals.
@LysandreJik I will add this for fast tokenizer too once this PR is merged. |
Sounds good! |
@LysandreJik , doc error is fixed, not sure if current failure is related to this PR. |
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
This reverts commit c0b35c9.
This PR adds prepare_seq2seq_batch method to BartTokenizer as per the proposal in #6080
@sshleifer