Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/shifted output #2176

Merged
merged 46 commits into from
Feb 29, 2024
Merged

Feat/shifted output #2176

merged 46 commits into from
Feb 29, 2024

Conversation

dennisbader
Copy link
Collaborator

@dennisbader dennisbader commented Jan 20, 2024

Work in progress

Fixes #2139, fixes #2132

Summary

  • adds support for shifting the output chunk of global models with parameter output_chunk_shift
  • add support to regression models
  • fixes bug when using regression models with lags=None and future_covariates with positive only lags and future_covariates starting at or after the first predictable time step.
  • parametrize most of the tabularization unit tests
  • introduces new helper function darts.utils.utils.n_steps_between() to efficiently compute the number of time steps (periods) between two points/timestamps with a given frequency -> improves efficiency for regression model tabularization by avoiding pd.date_range().

@codecov-commenter
Copy link

codecov-commenter commented Jan 21, 2024

Codecov Report

Attention: Patch coverage is 97.18310% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 93.96%. Comparing base (b9e6d8b) to head (bfade00).

Files Patch % Lines
darts/utils/utils.py 84.61% 2 Missing ⚠️

❗ Your organization needs to install the Codecov GitHub app to enable full functionality.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2176      +/-   ##
==========================================
+ Coverage   93.88%   93.96%   +0.08%     
==========================================
  Files         135      135              
  Lines       13467    13487      +20     
==========================================
+ Hits        12643    12673      +30     
+ Misses        824      814      -10     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@noahvand
Copy link

noahvand commented Jan 26, 2024

Tried this version, and the shift output function works well, but the RNN models cannot work for some reason,

where I got message below ```
WARNING:darts.models.forecasting.rnn_model:ignoring user defined output_chunk_length. RNNModel uses a fixed `output_chunk_length=1`.
WARNING:darts.models.forecasting.rnn_model:ignoring user defined `output_chunk_shift`. RNNModel uses a fixed `output_chunk_shift=0`.

@madtoinou
Copy link
Collaborator

Hi @noahvand,

This is intentional because by definition RNNModel enforces output_chunk_length=1 and the hidden state mechanism is not really compatible with "gaps" between input values and forecast: introducing such a shift would require the RNNModel to still forecast output_chunk_shift values before getting to the actual value of interest.

For the sake of uniform API, we could change this but RNNModel is kind of already in its own category anyway.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some very minor comments for the tests.

Nice that this could be implemented so elegantly!

darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/torch_forecasting_model.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_dlinear_nlinear.py Outdated Show resolved Hide resolved
darts/utils/data/inference_dataset.py Show resolved Hide resolved
Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM 🚀

"""Tests shifted output for shift smaller than, equal to, and larger than output_chunk_length.
RNNModel does not support shift output chunk.
"""
# model_cls = TFTModel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# model_cls = TFTModel

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments, amazing PR @dennisbader 🚀

darts/tests/models/forecasting/test_regression_models.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_regression_models.py Outdated Show resolved Hide resolved
darts/tests/datasets/test_datasets.py Outdated Show resolved Hide resolved
darts/tests/models/forecasting/test_regression_models.py Outdated Show resolved Hide resolved
@dennisbader dennisbader merged commit ccd0d42 into master Feb 29, 2024
9 checks passed
@dennisbader dennisbader deleted the feat/shifted_output branch February 29, 2024 14:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants