Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/likelihood parameters prediction #1811

Merged
merged 55 commits into from
Jul 27, 2023
Merged

Conversation

madtoinou
Copy link
Collaborator

@madtoinou madtoinou commented Jun 2, 2023

Fixes #1735, fixes #1445.

Summary

Edit: Final version:
After fitting a regression or torch-based model with a likelihood, it's possible to pass the flag predict_likelihood_parameters=True to predict() and directly predict the distribution parameters instead of sampling (which remains the best way to simulate/visualize uncertainty).

Draft version:
The distribution parameters are returned by the Likelihood.sample() method (in addition of the sampled values), and processed in the same way as the target forecast values in predict()/get_batch_prediction() which remain necessary if the number of forecasted values is great (auto-regression). If the flag is set, the output of predict() contains only the distribution parameters (no target forecast)

Other Information

Based on @hrzn comment in #1445, a warning should probably be raised if the predict TimeSeries contain many component to avoid dimension explosion. Another solution would eventually be to restrict this feature only to a few Likelihoods.

Verifying that the model converge to the good parameters is too time consuming for the unittests but I did check manually for some distributions (Gaussian, Poisson, Quantile) and it looked rather good.

@madtoinou madtoinou marked this pull request as draft June 2, 2023 08:55
@codecov-commenter
Copy link

codecov-commenter commented Jun 5, 2023

Codecov Report

Patch coverage: 89.61% and project coverage change: -0.17% ⚠️

Comparison is base (933316b) 94.00% compared to head (822d54c) 93.84%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1811      +/-   ##
==========================================
- Coverage   94.00%   93.84%   -0.17%     
==========================================
  Files         126      126              
  Lines       11905    12162     +257     
==========================================
+ Hits        11191    11413     +222     
- Misses        714      749      +35     
Files Changed Coverage Δ
darts/models/forecasting/transformer_model.py 98.95% <ø> (ø)
darts/utils/timeseries_generation.py 96.15% <ø> (ø)
darts/utils/likelihood_models.py 95.55% <82.35%> (-2.31%) ⬇️
darts/models/forecasting/catboost_model.py 95.23% <83.33%> (-4.77%) ⬇️
darts/models/forecasting/regression_model.py 95.33% <84.72%> (-2.23%) ⬇️
darts/models/forecasting/forecasting_model.py 95.14% <90.00%> (-0.43%) ⬇️
darts/models/forecasting/ensemble_model.py 95.42% <94.59%> (-0.41%) ⬇️
darts/models/forecasting/baselines.py 96.15% <96.77%> (-0.15%) ⬇️
darts/ad/anomaly_model/forecasting_am.py 93.51% <100.00%> (ø)
darts/explainability/explainability.py 96.42% <100.00%> (ø)
... and 21 more

... and 4 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@madtoinou madtoinou marked this pull request as ready for review June 5, 2023 08:07
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice @madtoinou, this will be a great feature 🚀

From some points from our discussions:

  • enforce n <= output_chunk_length when likelihood_params=True
  • enforce num_samples == 1 when likelihood_params=True (I tried with DLinear, and it didn't raise an error with num_samples>1
  • avoid sampling when likelihood_params=True

darts/models/forecasting/forecasting_model.py Outdated Show resolved Hide resolved
darts/utils/likelihood_models.py Outdated Show resolved Hide resolved
darts/utils/likelihood_models.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/models/forecasting/pl_forecasting_module.py Outdated Show resolved Hide resolved
darts/utils/likelihood_models.py Outdated Show resolved Hide resolved
darts/models/forecasting/regression_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/regression_model.py Outdated Show resolved Hide resolved
@madtoinou madtoinou requested a review from dennisbader July 3, 2023 15:35
darts/models/forecasting/ensemble_model.py Show resolved Hide resolved
darts/models/forecasting/ensemble_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/forecasting_model.py Show resolved Hide resolved
darts/models/forecasting/forecasting_model.py Outdated Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's update the self.output_chunk_length of all TFMs to self._output_chunk_length as well, and then update the property accordingly

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated, I tried to keep self.output_chunk_length when possible (available after the model/pl module is instantiated)

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats @madtoinou for this great PR 💯
Ready to be merged 🚀

@dennisbader dennisbader merged commit 3c0603e into master Jul 27, 2023
@dennisbader dennisbader deleted the feat/likelihood_parameters branch July 27, 2023 19:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
3 participants