Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reflect max_new_tokens in Seq2SeqTrainer #18786

Merged
merged 5 commits into from
Sep 1, 2022
Merged

reflect max_new_tokens in Seq2SeqTrainer #18786

merged 5 commits into from
Sep 1, 2022

Conversation

kumapo
Copy link
Contributor

@kumapo kumapo commented Aug 27, 2022

What does this PR do?

in most cases, VisionEncoderDecoderModel's max_length is set implicitly.
it leads to the problem if the model generates prediction given max_new_tokens.

this PR makes max_new_tokens handled as expected in Seq2SeqTrainer. prediction_step() in the case.

Fixes #18785

P.S. I can reproduce the issue if using huggingface/transformers.
but, using this PR with same codes to reproduce, no exceptions raised.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 27, 2022

The documentation is not available anymore as the PR was closed or merged.

@LysandreJik
Copy link
Member

@ydshieh, could you take a look at this when you have some time please? Thanks a lot!

@ydshieh ydshieh self-assigned this Aug 30, 2022
)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
prompt_seq_length = 1 if self.model.config.is_encoder_decoder else 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me Seq2SeqTrainer will be only used for encoder-decoder models. If this is true, we shouldn't need else 0.

(If a decoder-only model is possible, we will have to get the actual length of the prompt from the inputs, instead of just set it to 0.)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to hear from @patrickvonplaten, @patil-suraj and @sgugger for this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 I think we can be confident that the Seq2Seq trainer only works for models that have self.model.config.is_encoder_decoder = True

@ydshieh
Copy link
Collaborator

ydshieh commented Aug 30, 2022

@kumapo kumapo changed the title reflect max_new_tokens in gen_kwargs to trainer.generate() reflect max_new_tokens in Seq2SeqTrainer Aug 31, 2022
@kumapo
Copy link
Contributor Author

kumapo commented Aug 31, 2022

@ydshieh, yes. at same time I believe Seq2SeqTrainer.evaluate() needs the same change.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR! Left a comment and then we should be good to merge!

)
if gen_kwargs.get("max_length") is None and gen_kwargs.get("max_new_tokens") is None:
gen_kwargs["max_length"] = self.model.config.max_length
prompt_seq_length = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's never modified, so let's use 1 below instead of adding a new variable?

@kumapo
Copy link
Contributor Author

kumapo commented Aug 31, 2022

@sgugger, thank you for your feedback. I've updated the PR.

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @kumapo for making the Seq2Seq trainer more robust!

@LysandreJik
Copy link
Member

It seems there is an issue with your CircleCI permissions, the tests won't run.
Could you try refreshing your permissions as shown here?

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me barring the tests!

@kumapo
Copy link
Contributor Author

kumapo commented Sep 1, 2022

@LysandreJik, I've done all steps to refresh circleci permission.
but it seems that nothing happens with tests. let me know if I missed something to be known.

@sgugger
Copy link
Collaborator

sgugger commented Sep 1, 2022

Can you try pushing an empty commit on your branch to re-trigger the tests?

git commit -m "Trigger CI" --allow-empty

@ydshieh
Copy link
Collaborator

ydshieh commented Sep 1, 2022

To pass the test, you can run

make style

and commit the change.

@sgugger sgugger merged commit ab663b2 into huggingface:main Sep 1, 2022
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
* reflect max_new_tokens in gen_kwargs to `trainer.generate()`

* reflect max_new_tokens in `Seq2SeqTrainer`

* remove unnecessary variable

* Trigger CI

* fix style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Raise ValueError if given max_new_tokens to Seq2SeqTrainer.predict()
6 participants