Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/load weights from ckpt with encoder or likelihood #1744

Merged
merged 29 commits into from
May 26, 2023

Conversation

madtoinou
Copy link
Collaborator

Fixes #1725

Summary

  • Add the new argument load_encoders (default: True) to load_weights() and load_weights_from_checkpoints() so that the model can run inference directly after loading the weights with the encoders.
    • If load_encoders=True, the model loading the weights must be instantiated without encoder or with the same encoders, otherwise an exception is raised.
    • If load_encoders=False and the weights were trained with encoders, the model loading the weights must already be instantiated with encoders (identical or different than the ones used during the first training, must be careful about the dimensions) otherwise an exception is raised.
  • Add __eq__ operator to the LikelihoodModel (checking their numerical parameters only) and DataTransformer to prevent weights loading failure occurring during the hyper-parameters sanity check.

Other Information

I don't know if it makes sense but we could eventually exclude the likelihood attribute from the hyper-parameter check when loading the weights in case some user want to fine-tune/retrain the model with a different likelihood.

@madtoinou madtoinou requested review from hrzn and dennisbader as code owners May 4, 2023 10:56
@madtoinou
Copy link
Collaborator Author

Extended the PR:

  • the new encoders are instantiated in load_weights_from_checkpoint() to allow for direct inference if they don't need to be fitted (no transformer for example).
  • some sanity check about the number of components generated by the encoders to generate an informative exception if the new encoders don't match the checkpoint ones (in case user want to use different encoders during fine-tuning).

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @madtoinou, those are nice fixes 🚀

I left a couple of suggestions, mainly regarding the __eq__ methods (simplification, and clarification).

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/utils/likelihood_models.py Outdated Show resolved Hide resolved
darts/dataprocessing/transformers/base_data_transformer.py Outdated Show resolved Hide resolved
@codecov-commenter
Copy link

codecov-commenter commented May 15, 2023

Codecov Report

Patch coverage: 91.54% and project coverage change: -0.07 ⚠️

Comparison is base (31da6d3) 94.17% compared to head (14c38ed) 94.10%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1744      +/-   ##
==========================================
- Coverage   94.17%   94.10%   -0.07%     
==========================================
  Files         125      125              
  Lines       11513    11567      +54     
==========================================
+ Hits        10842    10885      +43     
- Misses        671      682      +11     
Impacted Files Coverage Δ
darts/dataprocessing/encoders/encoder_base.py 95.65% <75.00%> (-0.37%) ⬇️
darts/dataprocessing/encoders/encoders.py 97.87% <82.35%> (-1.00%) ⬇️
...arts/models/forecasting/torch_forecasting_model.py 90.46% <95.23%> (+0.09%) ⬆️
darts/utils/likelihood_models.py 97.85% <100.00%> (+1.81%) ⬆️

... and 10 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks :) 🚀

Soon ready to be merged!
Only had some minor suggestions, and that it might be better to avoid the __eq__() and rather do a simple check when comparing add_encoders for the same transformer class.

darts/dataprocessing/encoders/encoders.py Outdated Show resolved Hide resolved
)
model_same_likelihood.load_weights(model_path_manual, map_location="cpu")
model_same_likelihood.to_cpu()
model_same_likelihood.predict(n=4, series=self.series)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can check that load_weights_from_checkpoint and load_weights have equal predictions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Likelihood object do not have the same random_state (also mentioned in #1779), making the predictions different. Should I implement a method to set it before updating this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my suggestion below

darts/dataprocessing/transformers/base_data_transformer.py Outdated Show resolved Hide resolved
madtoinou and others added 2 commits May 22, 2023 08:49
Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, it's very nice 🚀
A last iteration with some minor suggestions and then we're good to go 👍

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
)
model_same_likelihood.load_weights(model_path_manual, map_location="cpu")
model_same_likelihood.to_cpu()
model_same_likelihood.predict(n=4, series=self.series)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my suggestion below

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @madtoinou 💯 Thanks a lot 🚀

@dennisbader dennisbader merged commit 52f4004 into master May 26, 2023
@dennisbader dennisbader deleted the fix/load-weights-ckpt branch May 26, 2023 11:59
alexcolpitts96 pushed a commit to alexcolpitts96/darts that referenced this pull request May 31, 2023
* fix: possible to load encoders from the checkpoint, including some sanity checks

* feat: adding corresponding unittests

* feat: adding eq operator to likelihood models and data_transformer, associated unittests

* fix: removed typo, updated changelog

* fix: remove typo in call to super().__eq__

* Update CHANGELOG.md

fix: updated the PR reference in the changelog

* fix: load encoders before weights so that they are included in the new model ckpt

* feat: adding multi steps pipeline unittest

* new encoders are instantiated if compatible with the old ones (same components dimensions for both past and future covariates), extended the unitests

* fix: adding encoding_n_component property to SequentialEncoders, load_encoders returns them instead of changing model attributes

* fix: simplifying equality operator between likelihood objecting following reviewer comments

* doc: adding small comment about the ignore_attrs_equality attribute in Likelihood

* fix: improve equality operator between data transformers

* fix: equality ignore nn.Module

* test: adding test for loading weights with likelihood

* test: adding test when not loading encoders in a model with contain the same encoders but different fittable transformer

* test: improving test for save/load with likelihood

* Apply suggestions from code review

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>

* fix: addressing reviewer comments

* fix: changed order of exception to make them more informative, fixed a typo

* fix: typo

* fix: helper function to compare models encoders and transformers, added test for likelihood

---------

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

[BUG] TorchForecastingModel load_weights_from_checkpoint fails to load
3 participants