Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #1049 - add prior_scale and mode arguments to prophet model's add_seasonality #1829

Merged
merged 22 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c3ad7ce
Fix #1049 - add prior_scale and mode arguments to prophet model's add…
Idan-QL Jun 13, 2023
a6d53f0
Add option to treat seasonality as conditional
Idan-QL Jun 14, 2023
9d80f0b
Merge branch 'master' into fix/prophet-add-seasonality
dennisbader Jul 3, 2023
5713aba
Add seasonality conditions with a condition_name and future_covariates
Idan-QL Jul 4, 2023
cf55740
Add test for custom conditional seasonality
Idan-QL Jul 4, 2023
3d5fc4a
Add entry for pr #1829
Idan-QL Jul 4, 2023
0638f1d
Merge branch 'master' into fix/prophet-add-seasonality
id5h Jul 4, 2023
c6f9ee6
Merge branch 'master' into fix/prophet-add-seasonality
dennisbader Jul 5, 2023
801e80a
Update darts/models/forecasting/prophet_model.py
id5h Jul 5, 2023
5b0321c
Update darts/models/forecasting/prophet_model.py
id5h Jul 5, 2023
50a70d4
Validate seasonality considitions through a private method when calli…
Idan-QL Jul 5, 2023
0fe11d2
Reduce predict horizon to 7. Add tests for missing and invalid condit…
Idan-QL Jul 5, 2023
a96a15f
Move entry to models improvements section
Idan-QL Jul 5, 2023
c4cdbf2
Merge branch 'master' into fix/prophet-add-seasonality
id5h Jul 6, 2023
e9418a6
Update err msg in _check_seasonality_conditions
id5h Jul 7, 2023
9057345
Import raise_log. Initialize formatted str when necessary.
Idan-QL Jul 7, 2023
157f206
Accept float seasonalities as well. Update test
Idan-QL Jul 7, 2023
729abe2
Merge branch 'master' into fix/prophet-add-seasonality
id5h Jul 7, 2023
7c03f75
Fix dtype of seasonal_periods. Update docstrings.
Idan-QL Jul 10, 2023
6fbbd4b
Merge branch 'master' into fix/prophet-add-seasonality
id5h Jul 14, 2023
61abeea
update docstring
dennisbader Jul 17, 2023
40e9c26
Merge branch 'master' into fix/prophet-add-seasonality
dennisbader Jul 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Added support for `PathLike` to the `save()` and `load()` functions of all non-deep learning based models. [#1754](https://github.com/unit8co/darts/pull/1754) by [Simon Sudrich](https://github.com/sudrich).
- Improved efficiency of `historical_forecasts()` and `backtest()` for all models giving significant process time reduction for larger number of predict iterations and series. [#1801](https://github.com/unit8co/darts/pull/1801) by [Dennis Bader](https://github.com/dennisbader).
- Added model property `ForecastingModel.supports_multivariate` to indicate whether the model supports multivariate forecasting. [#1848](https://github.com/unit8co/darts/pull/1848) by [Felix Divo](https://github.com/felixdivo).
- `Prophet` now supports conditional seasonalities, and properly handles all parameters passed to `Prophet.add_seasonality()` and model creation parameter `add_seasonalities` [#1829](https://github.com/unit8co/darts/pull/#1829) by [Idan Shilon](https://github.com/id5h).

- Improvements to `EnsembleModel`:
- Model creation parameter `forecasting_models` now supports a mix of `LocalForecastingModel` and `GlobalForecastingModel` (single `TimeSeries` training/inference only, due to the local models). [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
- Future and past covariates can now be used even if `forecasting_models` have different covariates support. The covariates passed to `fit()`/`predict()` are used only by models that support it. [#1745](https://github.com/unit8co/darts/pull/1745) by [Antoine Madrona](https://github.com/madtoinou).
Expand Down
99 changes: 95 additions & 4 deletions darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,20 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non

# add user defined seasonalities (from model creation and/or pre-fit self.add_seasonalities())
interval_length = self._freq_to_days(series.freq_str)
conditional_seasonality_covariates = self._check_seasonality_conditions(
future_covariates=future_covariates
)
for seasonality_name, attributes in self._add_seasonalities.items():
self.model.add_seasonality(
name=seasonality_name,
period=attributes["seasonal_periods"] * interval_length,
fourier_order=attributes["fourier_order"],
prior_scale=attributes["prior_scale"],
id5h marked this conversation as resolved.
Show resolved Hide resolved
mode=attributes["mode"],
condition_name=attributes["condition_name"],
)

# add covariates
# add covariates as additional regressors
if future_covariates is not None:
fit_df = fit_df.merge(
future_covariates.pd_dataframe(),
Expand All @@ -197,7 +203,8 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
how="left",
)
for covariate in future_covariates.columns:
self.model.add_regressor(covariate)
if covariate not in conditional_seasonality_covariates:
self.model.add_regressor(covariate)

# add built-in country holidays
if self.country_holidays is not None:
Expand All @@ -220,6 +227,8 @@ def _predict(
verbose: bool = False,
) -> TimeSeries:

_ = self._check_seasonality_conditions(future_covariates=future_covariates)

super()._predict(n, future_covariates, num_samples)

predict_df = self._generate_predict_df(n=n, future_covariates=future_covariates)
Expand Down Expand Up @@ -267,6 +276,73 @@ def _generate_predict_df(
)
return predict_df

def _check_seasonality_conditions(
self, future_covariates: Optional[TimeSeries] = None
) -> List[str]:
"""
Checks if the conditions for custom conditional seasonalities are met. Each custom seasonality that has a
`condition_name` other than None is checked. If the `condition_name` is not a column in the `future_covariates`
or if the values in the column are not all True or False, an error is raised.
Returns a list of the `condition_name`s of the conditional seasonalities that have been checked.

Parameters
----------
future_covariates
optionally, a TimeSeries containing the future covariates and including the columns that are used as
conditions for the conditional seasonalities when necessary

Raises
------
ValueError
if a seasonality has a `condition_name` and a column named `condition_name` is missing in
the `future_covariates`

if a seasonality has a `condition_name` and the values in the corresponding column in `future_covariates`
are not binary values (True or False, 1 or 0)
"""

conditional_seasonality_covariates = []
invalid_conditional_seasonalities = []
if future_covariates is not None:
future_covariates_columns = future_covariates.columns
else:
future_covariates_columns = []

for seasonality_name, attributes in self._add_seasonalities.items():
condition_name = attributes["condition_name"]
if condition_name is not None:
if condition_name not in future_covariates_columns:
invalid_conditional_seasonalities.append(
(seasonality_name, condition_name, "column missing")
)
continue
if (
not future_covariates[condition_name]
.pd_series()
.isin([True, False])
.all()
):
invalid_conditional_seasonalities.append(
(seasonality_name, condition_name, "invalid values")
)
continue
conditional_seasonality_covariates.append(condition_name)

formatted_issues_str = ", ".join(
f"'{name}' (condition_name: '{cond}'; issue: {reason})"
for name, cond, reason in invalid_conditional_seasonalities
)
raise_if(
len(invalid_conditional_seasonalities) > 0,
f"The following seasonalities have invalid conditions: "
f"{formatted_issues_str}. "
f"Each conditional seasonality must be accompanied by a binary component/column in the future_covariates "
f"with the same name as the condition_name. These components must only contain "
f"True or False values (or 1 or 0).",
logger,
)
id5h marked this conversation as resolved.
Show resolved Hide resolved
return conditional_seasonality_covariates

@property
def supports_multivariate(self) -> bool:
return False
Expand Down Expand Up @@ -322,14 +398,22 @@ def add_seasonality(
fourier_order: int,
prior_scale: Optional[float] = None,
mode: Optional[str] = None,
condition_name: Optional[str] = None,
) -> None:
"""Adds a custom seasonality to the model that repeats after every n `seasonal_periods` timesteps.
An example for `seasonal_periods`: If you have hourly data (frequency='H') and your seasonal cycle repeats
after 48 hours -> `seasonal_periods=48`.

Apart from `seasonal_periods`, this is very similar to how you would call Facebook Prophet's
`add_seasonality()` method. For information about the parameters see:
`The Prophet source code <https://github.com/facebook/prophet/blob/master/python/prophet/forecaster.py>`_.
`add_seasonality()` method.

To add conditional seasonalities, provide `condition_name` here, and add a boolean (binary) component/column
named `condition_name` to the `future_covariates` series passed to `fit()` and `predict()`.

For information about the parameters see:
`The Prophet source code <https://github.com/facebook/prophet/blob/master/python/prophet/forecaster.py>`.
For more details on conditional seasonalities see:
https://facebook.github.io/prophet/docs/seasonality,_holiday_effects,_and_regressors.html#seasonalities-that-depend-on-other-factors

Parameters
----------
Expand All @@ -343,13 +427,18 @@ def add_seasonality(
optionally, a prior scale for this component
mode
optionally, 'additive' or 'multiplicative'
condition_name
optionally, the name of the condition on which the seasonality depends. If not `None`, expects a
`future_covariates` time series with a component/column named `condition_name` to be passed to `fit()`
and `predict()`.
"""
function_call = {
"name": name,
"seasonal_periods": seasonal_periods,
"fourier_order": fourier_order,
"prior_scale": prior_scale,
"mode": mode,
"condition_name": condition_name,
}
self._store_add_seasonality_call(seasonality_call=function_call)

Expand Down Expand Up @@ -381,6 +470,7 @@ def _store_add_seasonality_call(
"fourier_order": {"default": None, "dtype": int},
"prior_scale": {"default": None, "dtype": float},
"mode": {"default": None, "dtype": str},
"condition_name": {"default": None, "dtype": str},
}
seasonality_default = {
kw: seasonality_properties[kw]["default"] for kw in seasonality_properties
Expand Down Expand Up @@ -430,6 +520,7 @@ def _store_add_seasonality_call(
f'of type {[seasonality_properties[kw]["dtype"] for kw in invalid_types]}.',
logger,
)

self._add_seasonalities[seasonality_name] = add_seasonality_call

@staticmethod
Expand Down
61 changes: 60 additions & 1 deletion darts/tests/models/forecasting/test_prophet.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add tests that check that missing condition columns in future covariates and non-binary columns raise an error?

an example for this:

with pytest.raises(ValueError):
    model.fit(..., future_covariates=invalid_future_covariates)
...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import pandas as pd
import pytest

from darts import TimeSeries
from darts.logging import get_logger
Expand All @@ -25,7 +26,14 @@ def test_add_seasonality_calls(self):
"seasonal_periods": 24,
"fourier_order": 1,
}
kwargs_all = dict(kwargs_mandatory, **{"prior_scale": 1.0, "mode": "additive"})
kwargs_all = dict(
kwargs_mandatory,
**{
"prior_scale": 1.0,
"mode": "additive",
"condition_name": "custom_condition",
}
)
model1 = Prophet(add_seasonalities=kwargs_all)
model2 = Prophet()
model2.add_seasonality(**kwargs_all)
Expand Down Expand Up @@ -234,3 +242,54 @@ def helper_test_prophet_model(self, period, freq, compare_all_models=False):
for pred in compare_preds:
for val_i, pred_i in zip(val.univariate_values(), pred.univariate_values()):
self.assertAlmostEqual(val_i, pred_i, delta=0.1)

def test_conditional_seasonality(self):
"""
Test that conditional seasonality is correctly incorporated by the model
"""
duration = 395
horizon = 7
df = pd.DataFrame()
df["ds"] = pd.date_range(start="2022-01-02", periods=395)
df["y"] = [i + 10 * (i % 7 == 0) for i in range(duration)]
df["is_sunday"] = df["ds"].apply(lambda x: int(x.weekday() == 6))

ts = TimeSeries.from_dataframe(
df[:-horizon], time_col="ds", value_cols="y", freq="D"
)
future_covariates = TimeSeries.from_dataframe(
df, time_col="ds", value_cols=["is_sunday"], freq="D"
)
expected_result = TimeSeries.from_dataframe(
df[-horizon:], time_col="ds", value_cols="y", freq="D"
)

model = Prophet(seasonality_mode="additive")
model.add_seasonality(
name="weekly_sun",
seasonal_periods=7,
fourier_order=2,
condition_name="is_sunday",
)

model.fit(ts, future_covariates=future_covariates)

forecast = model.predict(horizon, future_covariates=future_covariates)

for val_i, pred_i in zip(
expected_result.univariate_values(), forecast.univariate_values()
):
self.assertAlmostEqual(val_i, pred_i, delta=0.1)

invalid_future_covariates = future_covariates.with_values(
np.reshape(np.random.randint(0, 3, duration), (-1, 1, 1)).astype("float")
)

with pytest.raises(ValueError):
model.fit(ts, future_covariates=invalid_future_covariates)

with pytest.raises(ValueError):
model.fit(
ts,
future_covariates=invalid_future_covariates.drop_columns("is_sunday"),
)