-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
reflect max_new_tokens in Seq2SeqTrainer
#18786
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@ydshieh, could you take a look at this when you have some time please? Thanks a lot! |
src/transformers/trainer_seq2seq.py
Outdated
) | ||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: | ||
gen_kwargs["max_length"] = self.model.config.max_length | ||
prompt_seq_length = 1 if self.model.config.is_encoder_decoder else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me Seq2SeqTrainer
will be only used for encoder-decoder models. If this is true, we shouldn't need else 0
.
(If a decoder-only model is possible, we will have to get the actual length of the prompt from the inputs, instead of just set it to 0
.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to hear from @patrickvonplaten, @patil-suraj and @sgugger for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 I think we can be confident that the Seq2Seq trainer only works for models that have self.model.config.is_encoder_decoder = True
Hi, @kumapo I believe it also requires a change in right? |
trainer.generate()
Seq2SeqTrainer
@ydshieh, yes. at same time I believe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR! Left a comment and then we should be good to merge!
src/transformers/trainer_seq2seq.py
Outdated
) | ||
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None: | ||
gen_kwargs["max_length"] = self.model.config.max_length | ||
prompt_seq_length = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's never modified, so let's use 1 below instead of adding a new variable?
@sgugger, thank you for your feedback. I've updated the PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @kumapo for making the Seq2Seq trainer more robust!
It seems there is an issue with your CircleCI permissions, the tests won't run. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me barring the tests!
@LysandreJik, I've done all steps to refresh circleci permission. |
Can you try pushing an empty commit on your branch to re-trigger the tests?
|
To pass the test, you can run make style and commit the change. |
* reflect max_new_tokens in gen_kwargs to `trainer.generate()` * reflect max_new_tokens in `Seq2SeqTrainer` * remove unnecessary variable * Trigger CI * fix style
What does this PR do?
in most cases, VisionEncoderDecoderModel's
max_length
is set implicitly.it leads to the problem if the model generates prediction given
max_new_tokens
.this PR makes
max_new_tokens
handled as expected inSeq2SeqTrainer. prediction_step()
in the case.Fixes #18785
P.S. I can reproduce the issue if using
huggingface/transformers
.but, using this PR with same codes to reproduce, no exceptions raised.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?