-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[T5Tokenizer] add prepare_seq2seq_batch method #6122
[T5Tokenizer] add prepare_seq2seq_batch method #6122
Conversation
src/transformers/tokenization_t5.py
Outdated
|
||
def set_tgt_special_tokens(self) -> None: | ||
self.prefix_tokens = [self.pad_token_id] | ||
self.suffix_tokens = [self.eos_token_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not entirely sure about adding eos
automatically. What do you think @sshleifer ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
I wouldn't do
eos
in this PR. I think for that we need to either
a) get to the bottom of why it impacts zero shot translation performance
or
b) add a flag to support not adding it (for backward compatibility/ zero shot tasks). -
Do we have evidence that adding a prefix token on the decoder side is helpful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have evidence that adding a prefix token on the decoder side is helpful?
yes, the T5Model
does this in the _shift_right
method. Same is the case with the original TF T5 implementation. AFAIK in seq2seq models decoder uses special start token, in BART the tokenizer automatically adds bos
, in T5 there is no bos
instead pad
token is used as decoder start id
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get to the bottom of why it impacts zero shot translation performance
I will remove it for now, and for this issue to be solved.
src/transformers/tokenization_t5.py
Outdated
|
||
def set_tgt_special_tokens(self) -> None: | ||
self.prefix_tokens = [self.pad_token_id] | ||
self.suffix_tokens = [self.eos_token_id] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
I wouldn't do
eos
in this PR. I think for that we need to either
a) get to the bottom of why it impacts zero shot translation performance
or
b) add a flag to support not adding it (for backward compatibility/ zero shot tasks). -
Do we have evidence that adding a prefix token on the decoder side is helpful?
] | ||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] | ||
batch = tokenizer.prepare_seq2seq_batch( | ||
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More cases to test:
- test max_target_length kwarg and allow it to be passed through, affect decoder_input_ids.shape[1]
- empty tgt_texts
- empty src_texts -> Raises something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I will cover these cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
empty tgt_texts
for this can I just check if input_ids
and attention_mask
are returned and no decoder_input_ids
and decoder_attention_mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests look great now!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one nit, otherwise LGTM
src/transformers/tokenization_t5.py
Outdated
for k, v in decoder_inputs.items(): | ||
model_inputs[f"decoder_{k}"] = v | ||
|
||
self.set_src_special_tokens() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) I would stylistically, just say self.prefix_tokens = []
and self.prefix_tokens = [self.pad_token_id]
to avoid adding a layer of abstraction.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, unless you expect people to have to subclass your work to inject some custom behavior.
tests/test_tokenization_t5.py
Outdated
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."] | ||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK) | ||
# check if input_ids are returned and no decoder_input_ids | ||
self.assertIn("input_ids", batch.keys()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) dont think you need .keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aah, right. in works for dict keys by default. Thanks 😀
self.assertIsInstance(batch, BatchEncoding) | ||
self.assertEqual(batch.input_ids.shape, (2, 512)) | ||
|
||
def test_eos_in_input(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be cool to migrate one or more of the integration tests in test_modeling_t5.py to the new method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks! I have some nits on the docs.
src/transformers/tokenization_t5.py
Outdated
""" | ||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks | ||
by concatenating and adding special tokens. The special tokens depend on calling source text or target text. | ||
An T5 sequence has the following format, where ``X`` represents the sequence: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An T5 sequence has the following format, where ``X`` represents the sequence: | |
A T5 sequence has the following format, where ``X`` represents the sequence: |
src/transformers/tokenization_t5.py
Outdated
Args: | ||
token_ids_0 (:obj:`List[int]`): | ||
List of IDs to which the special tokens will be added | ||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`): | |
token_ids_1 (:obj:`List[int]`, `optional`): |
(we only indicate real default values. If something is optional, the None default value is expected).
src/transformers/tokenization_t5.py
Outdated
Optional second list of IDs for sequence pairs. | ||
|
||
Returns: | ||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. | |
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. |
src/transformers/tokenization_t5.py
Outdated
**kwargs, | ||
) -> BatchEncoding: | ||
"""Prepare a batch that can be passed directly to an instance of T5Model. | ||
Arguments: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please specify the argument types with the same STYLE as above, also make sure you document all arguments (return_tensors is not documented).
src/transformers/tokenization_t5.py
Outdated
**kwargs: passed to self.__call__ | ||
|
||
Returns: | ||
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. | |
:class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask. |
src/transformers/tokenization_t5.py
Outdated
for k, v in decoder_inputs.items(): | ||
model_inputs[f"decoder_{k}"] = v | ||
|
||
self.set_src_special_tokens() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, unless you expect people to have to subclass your work to inject some custom behavior.
@sshleifer , @sgugger I have made changes regarding the suggestions. Thanks ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
] | ||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1] | ||
batch = tokenizer.prepare_seq2seq_batch( | ||
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these tests look great now!
self.assertNotIn("decoder_attention_mask", batch) | ||
|
||
def test_max_target_length(self): | ||
tokenizer = T5Tokenizer.from_pretrained("t5-small") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tip: you can use
@cached_property
def default_tok(self):
return T5Tokenizer.from_pretrained("t5-small")
To only initialize once. This barely matters for tokenizers. More usefuls for models where __init__
can take 20 seconds.
Codecov Report
@@ Coverage Diff @@
## master #6122 +/- ##
==========================================
+ Coverage 78.51% 78.59% +0.08%
==========================================
Files 146 146
Lines 26326 26347 +21
==========================================
+ Hits 20669 20708 +39
+ Misses 5657 5639 -18
Continue to review full report at Codecov.
|
@sshleifer , @patrickvonplaten , all green :) |
)" This reverts commit 3dfafe6.
This PR adds
prepare_seq2seq_batch
method toT5Tokenizer
as per the proposal in #6080@sshleifer