-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
better checking of data returned from training_step #1256
Comments
I would like to work on this. |
@rotalex cool, looking forward seeing a PR from you :] |
@Borda given that this proposal is backwards compatible, i think we should get more core contributors to weigh in on the proposed design before moving forward and implementing it. one thing that is still giving me tension is the fact that there's a lot of overlap between |
@jeremyjordan good point! we had recently an issue about "why there are two |
log and progress_bar were separated a while back because people wanted to log different things that they didn't want in the progress bar |
i don't really know what metrics is. |
the usage for
as far as i know, this isn't documented anywhere. if you look in the documentation, however, you will see references to keys which are not included in the set of
now that i think about this further, a better solution might be to add more detail to the documentation about how we collect outputs from training steps and expose the outputs at the end of epochs to make this more clear. furthermore, we should document that if you're returning a torch tensor, we expect that it is a scalar value. the second question is whether or not we want to do more explicit validation of data returned by the user. the motivation for this github issue is #1236 where we could help the user more quickly track down the source of an error. imagine if the error raised to the user was instead |
I would be in favour of more rigid structure, also may check #1277 |
Shouldn't we favor the return type to be a strong type? I've always wondered why the step return type is not a dataclass or named tuple where loss is a required argument. We could keep the flexibility using some metadata dict argument. |
i wouldn't mind stronger typing, but i don't want to start adding APIs to remember? @tullie @ashwinb @Darktex thoughts? A potential way to add structure is (something like this): def training_step(...):
output = pl.StepResult(loss=loss, logs=logs, progress_bar=progress_bar)
return output I guess the only this helps with is that the user doesn't have to remember what the keys are in the dict? ProRemoves confusion with what keys do what in the return ConAdds an API users have to remember |
My preference would be a StepResult class. I think it's the best way to document what the outputs should be. We can still support returning a dictionary and just build the StepResult class from the dict on the trainer side anyway. It'd be great if we could create a unified way for callbacks to specify which arguments are required in the StepResult too. The progress_bar, log overlap that @jeremyjordan brought up isn't idea. I'd love to hear how others think we should address this? The best I could come up with is specifying the desired keys for log and progress bar somewhere in LightningModule init or as a trainer callback argument. The user would then just put all result values in the step result dictionary and the specified keys would be found for the respective outputs (log and/or progress bar). |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
🚀 Feature
let's add more validation checks on what's returned from
training_step
and provide the user with useful error messages when they're not returning the right values.Motivation
i feel like i've seen a lot of users confused about what they're supposed to return in
training_step
andvalidation_step
. additionally, i don't think we document how we treat extra keys as "callback metrics" very well.Pitch
what you do you think about adding some structure and validation for
Trainer
'sprocess_output
method?right now, we have expectations about a set of keys {
progress_bar
,log
,loss
,hiddens
} and assume everything else is a callback metric. however, this is a silent assumption.we could instead enforce a more rigid structure:
moreover, we can leverage
pydantic
to do validation automatically and provide useful error message out of the box when data validation fails.cc @PyTorchLightning/core-contributors
Alternatives
Do nothing, keep things as they are.
Additional context
This would be a backwards incompatible change.
The text was updated successfully, but these errors were encountered: