Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better filtering of the model outputs in Trainer #8633

Merged
merged 3 commits into from
Nov 19, 2020
Merged

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Nov 18, 2020

What does this PR do?

As discovered since merging #8530, sometimes (e.g. when using nvidia apex with the O2 optimization) the new model outputs lose their type and become regular dictionaries. This means we can't index into them with integers and some rework in the internals of Trainer has become necessary.

This PR:

  • fixes the training by indexing in the outputs by string if they are dict, int otherwise when grabbing the loss
  • fixes the evaluation by indexing in the outputs by string if they are dict, int otherwise when grabbing the loss

but it also takes advantage of the new dict outputs to better filter the outputs at inference. We had several issues recently when using models outputing past states (such as Reformer, XLNet, GPT-2) during evaluation in Trainer. This PR introduces a new API that looks at a possible key in the config of the model to get some attributes to ignore in the ouputs during evaluation (those outputs are then discarded from the predictions returned by the function Trainer.predict or passed along to metric computation in Trainer.evaluate). Since a user might have some use cases where they want to ignore more keys or output those keys, a new argument is added to both Trainer.predict and Trainer.evaluate to fully control the keys ignored in those dictionaries.

If the model outputs tuple, this is all ignored.

Fixes #8523 among others

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a very welcome change imo, and the implementation is clean. Thank you for implementing the last test, I think it's great.

@sgugger sgugger merged commit 4208f49 into master Nov 19, 2020
@sgugger sgugger deleted the trainer_outputs branch November 19, 2020 15:43
@@ -97,3 +97,4 @@ class MarianConfig(BartConfig):
"""

model_type = "marian"
keys_to_ignore_at_inference = ["past_key_values"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit late now, but I'm not a huge fan of the name to be honest -> this seems to be very specific to training, but one might think now that past_key_values can never be passed during inference in general. Why not call it keys_to_ignore_at_training?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is not for training, only for inference. During training we only get the loss in the outputs.
And this is not ignore to pass to the model, but ignore because they are not part of the logits/scores/predictions we want to gather. Maybe output_keys_to_ignore_at_inference is clearer?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see! Yeah I think output_keys_to_ignore_at_inference would be a bit clearer to me :-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #8857

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Reformer model crashes during casual LM evaluation
3 participants