-
Notifications
You must be signed in to change notification settings - Fork 904
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix/load weights from ckpt with encoder or likelihood #1744
Conversation
…ssociated unittests
fix: updated the PR reference in the changelog
…omponents dimensions for both past and future covariates), extended the unitests
Extended the PR:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @madtoinou, those are nice fixes 🚀
I left a couple of suggestions, mainly regarding the __eq__
methods (simplification, and clarification).
Codecov ReportPatch coverage:
❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more. Additional details and impacted files@@ Coverage Diff @@
## master #1744 +/- ##
==========================================
- Coverage 94.17% 94.10% -0.07%
==========================================
Files 125 125
Lines 11513 11567 +54
==========================================
+ Hits 10842 10885 +43
- Misses 671 682 +11
☔ View full report in Codecov by Sentry. |
…_encoders returns them instead of changing model attributes
…wing reviewer comments
…he same encoders but different fittable transformer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice, thanks :) 🚀
Soon ready to be merged!
Only had some minor suggestions, and that it might be better to avoid the __eq__()
and rather do a simple check when comparing add_encoders
for the same transformer class.
) | ||
model_same_likelihood.load_weights(model_path_manual, map_location="cpu") | ||
model_same_likelihood.to_cpu() | ||
model_same_likelihood.predict(n=4, series=self.series) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can check that load_weights_from_checkpoint
and load_weights
have equal predictions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Likelihood object do not have the same random_state
(also mentioned in #1779), making the predictions different. Should I implement a method to set it before updating this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my suggestion below
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, it's very nice 🚀
A last iteration with some minor suggestions and then we're good to go 👍
) | ||
model_same_likelihood.load_weights(model_path_manual, map_location="cpu") | ||
model_same_likelihood.to_cpu() | ||
model_same_likelihood.predict(n=4, series=self.series) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see my suggestion below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @madtoinou 💯 Thanks a lot 🚀
* fix: possible to load encoders from the checkpoint, including some sanity checks * feat: adding corresponding unittests * feat: adding eq operator to likelihood models and data_transformer, associated unittests * fix: removed typo, updated changelog * fix: remove typo in call to super().__eq__ * Update CHANGELOG.md fix: updated the PR reference in the changelog * fix: load encoders before weights so that they are included in the new model ckpt * feat: adding multi steps pipeline unittest * new encoders are instantiated if compatible with the old ones (same components dimensions for both past and future covariates), extended the unitests * fix: adding encoding_n_component property to SequentialEncoders, load_encoders returns them instead of changing model attributes * fix: simplifying equality operator between likelihood objecting following reviewer comments * doc: adding small comment about the ignore_attrs_equality attribute in Likelihood * fix: improve equality operator between data transformers * fix: equality ignore nn.Module * test: adding test for loading weights with likelihood * test: adding test when not loading encoders in a model with contain the same encoders but different fittable transformer * test: improving test for save/load with likelihood * Apply suggestions from code review Co-authored-by: Dennis Bader <dennis.bader@gmx.ch> * fix: addressing reviewer comments * fix: changed order of exception to make them more informative, fixed a typo * fix: typo * fix: helper function to compare models encoders and transformers, added test for likelihood --------- Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Fixes #1725
Summary
load_encoders
(default:True
) toload_weights()
andload_weights_from_checkpoints()
so that the model can run inference directly after loading the weights with the encoders.load_encoders=True
, the model loading the weights must be instantiated without encoder or with the same encoders, otherwise an exception is raised.load_encoders=False
and the weights were trained with encoders, the model loading the weights must already be instantiated with encoders (identical or different than the ones used during the first training, must be careful about the dimensions) otherwise an exception is raised.__eq__
operator to theLikelihoodModel
(checking their numerical parameters only) andDataTransformer
to prevent weights loading failure occurring during the hyper-parameters sanity check.Other Information
I don't know if it makes sense but we could eventually exclude the likelihood attribute from the hyper-parameter check when loading the weights in case some user want to fine-tune/retrain the model with a different likelihood.