-
Notifications
You must be signed in to change notification settings - Fork 28.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better filtering of the model outputs in Trainer #8633
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a very welcome change imo, and the implementation is clean. Thank you for implementing the last test, I think it's great.
@@ -97,3 +97,4 @@ class MarianConfig(BartConfig): | |||
""" | |||
|
|||
model_type = "marian" | |||
keys_to_ignore_at_inference = ["past_key_values"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit late now, but I'm not a huge fan of the name to be honest -> this seems to be very specific to training, but one might think now that past_key_values
can never be passed during inference in general. Why not call it keys_to_ignore_at_training
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No this is not for training, only for inference. During training we only get the loss in the outputs.
And this is not ignore to pass to the model, but ignore because they are not part of the logits/scores/predictions we want to gather. Maybe output_keys_to_ignore_at_inference
is clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see! Yeah I think output_keys_to_ignore_at_inference
would be a bit clearer to me :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #8857
What does this PR do?
As discovered since merging #8530, sometimes (e.g. when using nvidia apex with the O2 optimization) the new model outputs lose their type and become regular dictionaries. This means we can't index into them with integers and some rework in the internals of
Trainer
has become necessary.This PR:
but it also takes advantage of the new dict outputs to better filter the outputs at inference. We had several issues recently when using models outputing past states (such as Reformer, XLNet, GPT-2) during evaluation in
Trainer
. This PR introduces a new API that looks at a possible key in the config of the model to get some attributes to ignore in the ouputs during evaluation (those outputs are then discarded from the predictions returned by the functionTrainer.predict
or passed along to metric computation inTrainer.evaluate
). Since a user might have some use cases where they want to ignore more keys or output those keys, a new argument is added to bothTrainer.predict
andTrainer.evaluate
to fully control the keys ignored in those dictionaries.If the model outputs tuple, this is all ignored.
Fixes #8523 among others