Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1049 - add prior_scale and mode arguments to prophet model's add_seasonality #1829

Merged
merged 22 commits into from
Jul 18, 2023

Conversation

id5h
Copy link
Contributor

@id5h id5h commented Jun 14, 2023

Fixes #1049 .

Summary

When fitting a prophet model, the add_seasonality function will now properly add the prior_scale and mode to the custom seasonality, in accordance with the behaviour described in the docs.

Other Information

@id5h id5h requested review from hrzn and dennisbader as code owners June 14, 2023 06:59
@codecov-commenter
Copy link

codecov-commenter commented Jun 14, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.95%. Comparing base (a5560cc) to head (40e9c26).
Report is 227 commits behind head on master.

Additional details and impacted files
@@           Coverage Diff           @@
##           master    #1829   +/-   ##
=======================================
  Coverage   93.95%   93.95%           
=======================================
  Files         125      125           
  Lines       11773    11782    +9     
=======================================
+ Hits        11061    11070    +9     
  Misses        712      712           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.


🚨 Try these New Features:

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @id5h,

Thank you for contributing to darts. Your PR solves the issue of the "not passed" prior_scale and mode arguments, however it does not fully address the linked issue as condition_name remain unsupported. Would it be possible to also cover it in this PR?

@id5h
Copy link
Contributor Author

id5h commented Jun 14, 2023

Hi @madtoinou ,

I pushed a simple update, where one can add conditional seasonalities as well.
I am not using the condition_name as an argument in the add_seasonality functionality. That is because in the FB prophet API that argument is used to identify the column that holds the mask - and that is not really necessary when using a univariate TimeSeries object.

@id5h id5h requested a review from madtoinou June 15, 2023 10:00
@id5h
Copy link
Contributor Author

id5h commented Jun 19, 2023

Hi @madtoinou ,

It seems like the failed check is due to an issue with codecov and not with the commited code?
Am I correct, or should I take any other action?

@madtoinou
Copy link
Collaborator

Hi @id5h,

Sorry, last week was quite intense. I hope to be able to find some time during the week to review the code and run some tests.

I just re-launched the tests, let's hope Codedev does not fail (it has nothing to do with your code).

@id5h
Copy link
Contributor Author

id5h commented Jun 19, 2023

Hi @id5h,

Sorry, last week was quite intense. I hope to be able to find some time during the week to review the code and run some tests.

I just re-launched the tests, let's hope Codedev does not fail (it has nothing to do with your code).

Alright. Thanks for the clarification and the prompt reply :)

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @id5h, thanks a lot for this and sorry for the waiting time 🚀
Looks good to me. I would only suggest to actually use condition_name similar to how Prophet does it.
At fit/predict time we can check that the user supplied the condition in the future_covariates series.
What do you think?

darts/models/forecasting/prophet_model.py Show resolved Hide resolved
@@ -318,13 +334,20 @@ def add_seasonality(
fourier_order: int,
prior_scale: Optional[float] = None,
mode: Optional[str] = None,
condition_func: Optional[types.FunctionType] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to use condition_name instead of a function and check at fit/predict time that the future_covariates series contains the boolean (binary) component component_name.

For this to work we also need to avoid adding the conditions as regressors in the following line:

self.model.add_regressor(covariate)

I did a quick test and it seems to work fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved in latest commit to use the condition_name and future_covariates instead of sending a function

@dennisbader
Copy link
Collaborator

dennisbader commented Jul 3, 2023

Also, could you add an entry to the unreleased section in CHANGELOG.md? Thanks!
And don't worry about the failing unit tests, we're working on this in another PR.

@id5h
Copy link
Contributor Author

id5h commented Jul 4, 2023

Hi @dennisbader , thanks for the review. I implemented the changes you requested and added a test for the new functionality.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks @id5h .

Just some minor comments and then we can merge 🚀

darts/models/forecasting/prophet_model.py Outdated Show resolved Hide resolved
darts/models/forecasting/prophet_model.py Outdated Show resolved Hide resolved
Comment on lines 249 to 255
raise_if(
future_covariates is None,
f"Condition name '{attributes['condition_name']}' is required by "
f"the custom seasonality '{seasonality_name}', but future_covariates is None. In addition, "
f"the model should be re-trained with future_covariates.",
logger,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be dropped as it's enforced in the parent class that user must supply future covariates to predict after having fit the model with future covariates

Suggested change
raise_if(
future_covariates is None,
f"Condition name '{attributes['condition_name']}' is required by "
f"the custom seasonality '{seasonality_name}', but future_covariates is None. In addition, "
f"the model should be re-trained with future_covariates.",
logger,
)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be raised in case a user, for whatever reason, fits a model, calls add_seasonality() after fitting and then predict(). Hence the "model should be retrained with future_covariates" message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error is now raised according to the logic in _check_seasonality_conditions()

for seasonality_name, attributes in self._add_seasonalities.items():
condition_name = attributes["condition_name"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a private method _check_seasonality_conditions() that checks if all the conditions are in the future covariates and that the conditions are binary? The method can then return the condition columns for the downstream logic.
We use this in fit and predict so it helps to have the logic in a common method.

This should loop through all seasonalities first and capture all missing/invalid conditions before raising the error. It's easier for the user to see when he has multiple missing conditions

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense 👍
Done.


model.fit(ts, future_covariates=future_covariates)

forecast = model.predict(30, future_covariates=future_covariates)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also here, predicting 7 days should already be enough

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add tests that check that missing condition columns in future covariates and non-binary columns raise an error?

an example for this:

with pytest.raises(ValueError):
    model.fit(..., future_covariates=invalid_future_covariates)
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

CHANGELOG.md Outdated
@@ -29,6 +29,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Fixed a bug when loading the weights of a `TorchForecastingModel` trained with encoders or a Likelihood. [#1744](https://github.com/unit8co/darts/pull/1744) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when using selected `target_components` with `ShapExplainer. [#1803](https://github.com/unit8co/darts/pull/#1803) by [Dennis Bader](https://github.com/dennisbader).
- Fixed `TimeSeries.__getitem__()` for series with a RangeIndex with start != 0 and freq != 1. [#1868](https://github.com/unit8co/darts/pull/#1868) by [Dennis Bader](https://github.com/dennisbader).
- Fixed an issue with `prophet_model.Prophet.add_seasonality()` to allow proper use of all passed parameters. [#1829](https://github.com/unit8co/darts/pull/#1829) by [Idan Shilon](https://github.com/id5h).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add this to the model improvements section.
"Prophet now supports conditional seasonalities, and properly handles all parameters passed to Prophet.add_seasonality() and model creation parameter add_seasonalities. ..."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice, thanks @id5h 🚀
Last comment about slightly changing the error message and then we can merge 💯

darts/models/forecasting/prophet_model.py Outdated Show resolved Hide resolved
@id5h
Copy link
Contributor Author

id5h commented Jul 7, 2023

Great, thanks @dennisbader .
I will change to raise_log today most likely.

I have one suggestion before merging.
Because prophet accepts non-integer seasonalities, it would be nice to allow it also in darts.
I think the required change is minimal - the seasonal_periods type hint in add_seasonality() (line 369) and the "dtype" value in the seasonality_properties dictionary (line 469) need to be changed to Union[int, float].

What do you think about adding this to the current PR?

Co-authored-by: Dennis Bader <dennis.bader@gmx.ch>
@id5h
Copy link
Contributor Author

id5h commented Jul 7, 2023

The raise_log() is imported now.
I also pushed the changes to accept float seasonality periods, including a small modification to the test - I hope that's fine with you.

Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @id5h 🚀 I applied the last suggestion. We're good to merge now 💯

darts/models/forecasting/prophet_model.py Outdated Show resolved Hide resolved
@dennisbader dennisbader merged commit 7f428dd into unit8co:master Jul 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add_seasonality does not expose the condition_name parameter for the prophet model
5 participants