-
Notifications
You must be signed in to change notification settings - Fork 908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/ loading metrics and loss in load_from_checkpoint #1759
Conversation
…ding_from_checkpoint()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @madtoinou for this, Looks good :)
I would be interested to see if we can let PL handle the saving/loading of these parameters by adapting PLForecastingModule.on_save_checkpoint
and PLForecastingModule.on_load_checkpoint
.
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #1759 +/- ##
==========================================
- Coverage 94.19% 94.06% -0.13%
==========================================
Files 125 125
Lines 11505 11495 -10
==========================================
- Hits 10837 10813 -24
- Misses 668 682 +14
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome that it worked with the checkpointing :)
Actually when you mentioned that we ignore loss_fn, and torch_metrics when saving the hyperparameters, I tested if we can achieve the same thing by removing the ignore, and it works :) I left a comment.
After this change we can merge 🚀
|
||
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | ||
# by default our models are initialized as float32. For other dtypes, we need to cast to the correct precision | ||
# before parameters are loaded by PyTorch-Lightning | ||
dtype = checkpoint["model_dtype"] | ||
self.to_dtype(dtype) | ||
|
||
# restoring attributes necessary to resume from training properly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw I just saw that we don't load the "train_sample_shape" from checkpoint. I think we should add this here as well, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked, it's already loaded when calling load_weights_from_checkpoint()
. My guess is that since it's one of the constructor argument and that it does not require any processing, the de-serializing of the checkpoint by Pytorch Lightning does the job.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, looks great! Thanks a lot @madtoinou 💯 🚀
* fix: loss_fn and torch_metrics are properly restored when calling laoding_from_checkpoint() * fix: moved fix to the PL on_save/on_load methods instead of load_from_checkpoint() * fix: address reviewer comments, loss and metrics objects are saved in the constructor * update changelog --------- Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Fixes #1758.
Summary
Since
loss_fn
andtorch_metrics
are not saved inPLForecastingModule
checkpoints, they must be re-created using themodel.model_params
values so that the training continue with the proper loss (and continue to report the desired torch metrics).Other Information
Added the corresponding unittests