-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BartTokenizer] add prepare s2s batch #6212
Changes from 1 commit
4448ce6
baf4018
0301dfd
24736dd
ad1b918
f0c3039
7084cf0
fdf93d9
e1a7942
ebd603e
b05ee8f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -45,6 +45,65 @@ class BartTokenizer(RobertaTokenizer): | |||||
"merges_file": {m: merges_url for m in _all_bart_models}, | ||||||
} | ||||||
|
||||||
def prepare_seq2seq_batch( | ||||||
self, | ||||||
src_texts: List[str], | ||||||
tgt_texts: Optional[List[str]] = None, | ||||||
max_length: Optional[int] = None, | ||||||
max_target_length: Optional[int] = None, | ||||||
padding: str = "longest", | ||||||
return_tensors: str = "None", | ||||||
**kwargs, | ||||||
) -> BatchEncoding: | ||||||
"""Prepare a batch that can be passed directly to an instance of BartModel. | ||||||
Args: | ||||||
src_texts (:obj:`List[str]`): | ||||||
list of src texts | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
tgt_texts (:obj:`List[str]`, `optional`): | ||||||
list of tgt texts | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
max_length (:obj:`int`, `optional`): | ||||||
maximum length for the source text which defers to the config value of 1024 for facebook/bart* | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Don't mention a specific model here since several could be used. |
||||||
max_target_length (:obj:`int`, `optional`): | ||||||
maximum length for the target text which defers to the config value of 1024 for facebook/bart* | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
padding (:obj:`str`, `optional`, defaults to "longest"): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be bool, string or PaddingStrategy I believe? See documentation of
|
||||||
strategy for padding `input_ids` and `decoder_input_ids`. Should be "max_length" or "longest". | ||||||
return_tensors (:obj:`str`, `optional`): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be string or TensorType (same as above, just copy from
|
||||||
Can be set to ‘tf’, ‘pt’ or ‘np’ to return respectively TensorFlow `tf.constant`, PyTorch `torch.Tensor` or Numpy :oj: np.ndarray instead of a list of python integers. | ||||||
**kwargs: | ||||||
passed to self.__call__ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
Returns: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a new docstring on master/ tokenization_utils_base.py that you may want to (a) reuse or (b) modify. |
||||||
:class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. | ||||||
""" | ||||||
if max_length is None: | ||||||
max_length = self.model_max_length | ||||||
model_inputs: BatchEncoding = self( | ||||||
src_texts, | ||||||
add_special_tokens=True, | ||||||
return_tensors=return_tensors, | ||||||
max_length=max_length, | ||||||
padding=padding, | ||||||
truncation=True, | ||||||
**kwargs, | ||||||
) | ||||||
if tgt_texts is None: | ||||||
return model_inputs | ||||||
# Process tgt_texts | ||||||
if max_target_length is None: | ||||||
max_target_length = max_length | ||||||
decoder_inputs: BatchEncoding = self( | ||||||
tgt_texts, | ||||||
add_special_tokens=True, | ||||||
return_tensors=return_tensors, | ||||||
padding=padding, | ||||||
max_length=max_target_length, | ||||||
truncation=True, | ||||||
**kwargs, | ||||||
) | ||||||
for k, v in decoder_inputs.items(): | ||||||
model_inputs[f"decoder_{k}"] = v | ||||||
|
||||||
return model_inputs | ||||||
|
||||||
|
||||||
class BartTokenizerFast(RobertaTokenizerFast): | ||||||
# merges and vocab same as Roberta | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit)