Skip to content

Commit

Permalink
Feat/statsforecasts ets (#1171)
Browse files Browse the repository at this point in the history
* update Croston to latest statsforecast

* add statsforecast ETS

* update statsforecast requirement

* update StatsForecastAutoARIMA

* proper calls to superclass methods

* add SF-ETS to models init file

* add SF ETS to tested models

* adjust settings for SF ETS test
  • Loading branch information
hrzn authored Aug 30, 2022
1 parent 4a27edd commit 2c49c27
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 56 deletions.
1 change: 1 addition & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from darts.models.forecasting.regression_ensemble_model import RegressionEnsembleModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
from darts.models.forecasting.sf_ets import StatsForecastETS
from darts.models.forecasting.tbats import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.varima import VARIMA
Expand Down
64 changes: 27 additions & 37 deletions darts/models/forecasting/croston.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
--------------
"""

import numpy as np
from numba.core import errors
from statsforecast.models import croston_classic, croston_optimized, croston_sba
from statsforecast.models import tsb as croston_tsb
from typing import Optional

from statsforecast.models import TSB as CrostonTSB
from statsforecast.models import CrostonClassic, CrostonOptimized, CrostonSBA

from darts.logging import raise_if, raise_if_not
from darts.models.forecasting.forecasting_model import ForecastingModel
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
from darts.timeseries import TimeSeries


class Croston(ForecastingModel):
class Croston(DualCovariatesForecastingModel):
def __init__(
self, version: str = "classic", alpha_d: float = None, alpha_p: float = None
):
Expand Down Expand Up @@ -56,62 +56,52 @@ def __init__(
)

if version == "classic":
self.method = croston_classic
self.model = CrostonClassic()
elif version == "optimized":
self.method = croston_optimized
self.model = CrostonOptimized()
elif version == "sba":
self.method = croston_sba
self.model = CrostonSBA()
else:
raise_if(
alpha_d is None or alpha_p is None,
'alpha_d and alpha_p must be specified when using "tsb".',
)
self.method = croston_tsb
self.alpha_d = alpha_d
self.alpha_p = alpha_p
self.model = CrostonTSB(alpha_d=self.alpha_d, alpha_p=self.alpha_p)

self.version = version

def __str__(self):
return "Croston"

def fit(self, series: TimeSeries):
super().fit(series)
def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
super()._fit(series, future_covariates)
series._assert_univariate()
series = self.training_series

if self.version == "tsb":
self.forecast_val = self.method(
series.values(copy=False),
h=1,
future_xreg=None,
alpha_d=self.alpha_d,
alpha_p=self.alpha_p,
)
elif self.version == "sba":
try:
self.forecast_val = self.method(
series.values(copy=False), h=1, future_xreg=None
)
except errors.TypingError:
raise_if(
True,
'"sba" version is not supported with this version of statsforecast.',
)
self.model.fit(
y=series.values(copy=False).flatten(),
X=future_covariates.values(copy=False).flatten()
if future_covariates is not None
else None,
)

else:
self.forecast_val = self.method(
series.values(copy=False), h=1, future_xreg=None
)
return self

def predict(
def _predict(
self,
n: int,
future_covariates: Optional[TimeSeries] = None,
num_samples: int = 1,
):
super().predict(n, num_samples)
values = np.tile(self.forecast_val, n)
super()._predict(n, future_covariates, num_samples)
values = self.model.predict(
h=n,
X=future_covariates.values(copy=False).flatten()
if future_covariates is not None
else None,
)["mean"]
return self._build_forecast_series(values)

@property
Expand Down
2 changes: 1 addition & 1 deletion darts/models/forecasting/exponential_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def fit(self, series: TimeSeries):
seasonal_periods_param = 12

hw_model = hw.ExponentialSmoothing(
series.values(),
series.values(copy=False),
trend=self.trend if self.trend is None else self.trend.value,
damped_trend=self.damped,
seasonal=self.seasonal if self.seasonal is None else self.seasonal.value,
Expand Down
25 changes: 19 additions & 6 deletions darts/models/forecasting/sf_auto_arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Optional

import numpy as np
from statsforecast.arima import AutoARIMA as SFAutoARIMA
from statsforecast.models import AutoARIMA as SFAutoARIMA

from darts import TimeSeries
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel
Expand All @@ -23,12 +23,25 @@ def __init__(self, *autoarima_args, **autoarima_kwargs):
It is probabilistic, whereas :class:`AutoARIMA` is not.
We refer to the `statsforecast AutoARIMA documentation
<https://nixtla.github.io/statsforecast/models.html#arima-methods>`_
for the documentation of the arguments.
Parameters
----------
autoarima_args
Positional arguments for ``statsforecasts.arima.AutoARIMA``.
Positional arguments for ``statsforecasts.models.AutoARIMA``.
autoarima_kwargs
Keyword arguments for ``statsforecasts.arima.AutoARIMA``.
Keyword arguments for ``statsforecasts.models.AutoARIMA``.
Examples
--------
>>> from darts.models import StatsForecastAutoARIMA
>>> from darts.datasets import AirPassengersDataset
>>> series = AirPassengersDataset().load()
>>> model = StatsForecastAutoARIMA(season_length=12)
>>> model.fit(series[:-36])
>>> pred = model.predict(36, num_samples=100)
"""
super().__init__()
self.model = SFAutoARIMA(*autoarima_args, **autoarima_kwargs)
Expand Down Expand Up @@ -56,12 +69,12 @@ def _predict(
forecast_df = self.model.predict(
h=n,
X=future_covariates.values(copy=False) if future_covariates else None,
level=68, # ask one std for the confidence interval. Note, we're limited to int...
level=(68.27,), # ask one std for the confidence interval.
)

mu = forecast_df["mean"].values
mu = forecast_df["mean"]
if num_samples > 1:
std = forecast_df["hi_68%"].values - mu
std = forecast_df["hi-68.27"] - mu
samples = np.random.normal(loc=mu, scale=std, size=(num_samples, n)).T
samples = np.expand_dims(samples, axis=1)
else:
Expand Down
89 changes: 89 additions & 0 deletions darts/models/forecasting/sf_ets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
StatsForecastETS
-----------
"""

from typing import Optional

from statsforecast.models import ETS

from darts import TimeSeries
from darts.models.forecasting.forecasting_model import DualCovariatesForecastingModel


class StatsForecastETS(DualCovariatesForecastingModel):
def __init__(self, *ets_args, **ets_kwargs):
"""ETS based on `Statsforecasts package
<https://github.com/Nixtla/statsforecast>`_.
This implementation can perform faster than the :class:`ExponentialSmoothing` model,
but typically requires more time on the first call, because it relies
on Numba and jit compilation.
This model accepts the same arguments as the `statsforecast ETS
<https://nixtla.github.io/statsforecast/models.html#ets>`_. package.
Parameters
----------
season_length
Number of observations per cycle. Default: 1.
model
Three-character string identifying method using the framework
terminology of Hyndman et al. (2002). Possible values are:
* "A" or "M" for error state,
* "N", "A" or "Ad" for trend state,
* "N", "A" or "M" for season state.
For instance, "ANN" means additive error, no trend and no seasonality.
Furthermore, the character "Z" is a placeholder telling statsforecast
to search for the best model using AICs. Default: "ZZZ".
Examples
--------
>>> from darts.datasets import AirPassengersDataset
>>> from darts.models import StatsForecastETS
>>> series = AirPassengersDataset().load()
>>> model = StatsForecastETS(season_length=12, model="AZZ")
>>> model.fit(series[:-36])
>>> pred = model.predict(36)
"""
super().__init__()
self.model = ETS(*ets_args, **ets_kwargs)

def __str__(self):
return "ETS-Statsforecasts"

def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = None):
super()._fit(series, future_covariates)
series._assert_univariate()
series = self.training_series
self.model.fit(
series.values(copy=False).flatten(),
X=future_covariates.values(copy=False) if future_covariates else None,
)
return self

def _predict(
self,
n: int,
future_covariates: Optional[TimeSeries] = None,
num_samples: int = 1,
):
super()._predict(n, future_covariates, num_samples)
forecast_df = self.model.predict(
h=n,
X=future_covariates.values(copy=False) if future_covariates else None,
)

return self._build_forecast_series(forecast_df["mean"])

@property
def min_train_series_length(self) -> int:
return 10

def _supports_range_index(self) -> bool:
return True

def _is_probabilistic(self) -> bool:
return False
25 changes: 14 additions & 11 deletions darts/tests/models/forecasting/test_local_forecasting_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Prophet,
RandomForest,
StatsForecastAutoARIMA,
StatsForecastETS,
Theta,
)
from darts.models.forecasting.forecasting_model import (
Expand All @@ -41,7 +42,8 @@
(ExponentialSmoothing(), 5.6),
(ARIMA(12, 2, 1), 10),
(ARIMA(1, 1, 1), 40),
(StatsForecastAutoARIMA(period=12), 4.8),
(StatsForecastAutoARIMA(season_length=12), 4.8),
(StatsForecastETS(season_length=12, model="AAZ"), 4.0),
(Croston(version="classic"), 34),
(Croston(version="tsb", alpha_d=0.1, alpha_p=0.1), 34),
(Theta(), 11.3),
Expand All @@ -57,6 +59,10 @@
(KalmanForecaster(dim_x=3), 17.0),
(LinearRegressionModel(lags=12), 11.0),
(RandomForest(lags=12, n_estimators=5, max_depth=3), 17.0),
(Prophet(), 13.5),
(AutoARIMA(), 12.2),
(TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0),
(BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0),
]

# forecasting models with exogenous variables support
Expand All @@ -66,16 +72,13 @@
(KalmanForecaster(dim_x=30), 30.0),
]

dual_models = [ARIMA(), StatsForecastAutoARIMA(period=12)]


models.append((Prophet(), 13.5))
dual_models.append(Prophet())

models.append((AutoARIMA(), 12.2))
models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0))
models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0))
dual_models.append(AutoARIMA())
dual_models = [
ARIMA(),
StatsForecastAutoARIMA(season_length=12),
StatsForecastETS(season_length=12),
Prophet(),
AutoARIMA(),
]


class LocalForecastingModelsTestCase(DartsBaseTestClass):
Expand Down
2 changes: 1 addition & 1 deletion requirements/core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ prophet>=1.1
requests>=2.22.0
scikit-learn>=1.0.1
scipy>=1.3.2
statsforecast==0.6.0
statsforecast>=1.0.0
statsmodels>=0.13.0
tbats>=1.1.0
tqdm>=4.60.0
Expand Down

0 comments on commit 2c49c27

Please sign in to comment.