Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix historical forecasts retraining of TFMs #1465

Merged
merged 7 commits into from
Jan 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
@@ -691,17 +691,16 @@ def historical_forecasts(
retrain_func = _retrain_wrapper(
lambda counter: counter % int(retrain) == 0 if retrain else False
)

elif isinstance(retrain, Callable):
retrain_func = _retrain_wrapper(retrain)

else:
raise_log(
ValueError(
"`retrain` argument must be either `bool`, positive `int` or `Callable` (returning `bool`)"
),
logger,
)

retrain_func_signature = tuple(
inspect.signature(retrain_func).parameters.keys()
)
@@ -728,7 +727,6 @@ def historical_forecasts(

forecasts_list = []
for idx, series_ in enumerate(outer_iterator):

past_covariates_ = past_covariates[idx] if past_covariates else None
future_covariates_ = future_covariates[idx] if future_covariates else None

@@ -765,15 +763,12 @@ def historical_forecasts(

# prepare the start parameter -> pd.Timestamp
if start is not None:

historical_forecasts_time_index = drop_before_index(
historical_forecasts_time_index,
series_.get_timestamp_at_point(start),
)

else:
if (retrain is not False) or (not self._fit_called):

if train_length:
historical_forecasts_time_index = drop_before_index(
historical_forecasts_time_index,
@@ -804,9 +799,9 @@ def historical_forecasts(
(not self._fit_called)
and (retrain is False)
and (not train_length),
" The model has not been fitted yet, and `start` and train_length are not specified. "
" The model is not retraining during the historical forecasts. Hence the "
"the first and only training would be done on 2 samples.",
"The model has not been fitted yet, and `start` and train_length are not specified. "
"The model is not retraining during the historical forecasts. Hence the "
"first and only training would be done on 2 samples.",
logger,
)

@@ -837,7 +832,6 @@ def historical_forecasts(

# iterate and forecast
for _counter, pred_time in enumerate(iterator):

# build the training series
if min_timestamp > series_.time_index[0]:
train_series = series_.drop_before(
@@ -866,13 +860,17 @@ def historical_forecasts(
if future_covariates_
else None,
):
self._fit_wrapper(
# avoid fitting the same model multiple times
model = self.untrained_model()
model._fit_wrapper(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
)
else:
model = self

forecast = self._predict_wrapper(
forecast = model._predict_wrapper(
n=forecast_horizon,
series=train_series,
past_covariates=past_covariates_,
@@ -901,7 +899,6 @@ def historical_forecasts(
hierarchy=series_.hierarchy,
)
)

else:
forecasts_list.append(forecasts)

@@ -1526,7 +1523,7 @@ def _extract_model_creation_params(self):
return model_params

def untrained_model(self):
return self.__class__(**self.model_params)
return self.__class__(**copy.deepcopy(self.model_params))

@property
def model_params(self) -> dict:
19 changes: 18 additions & 1 deletion darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
import pytest

from darts import TimeSeries
from darts.logging import get_logger
@@ -219,12 +220,28 @@ def test_backtest_forecasting(self):
self.assertEqual(pred.end_time(), linear_series.end_time())

# multivariate model + multivariate series
with self.assertRaises(ValueError):
# historical forecasts doesn't overwrite model object -> we can use different input dimensions
tcn_model.backtest(
linear_series_multi,
start=pd.Timestamp("20000125"),
forecast_horizon=3,
verbose=False,
retrain=False,
)

# univariate model
tcn_model = TCNModel(
input_chunk_length=12, output_chunk_length=1, batch_size=1, n_epochs=1
)
tcn_model.fit(linear_series, verbose=False)
# univariate fitted model + multivariate series
with pytest.raises(ValueError):
tcn_model.backtest(
linear_series_multi,
start=pd.Timestamp("20000125"),
forecast_horizon=3,
verbose=False,
retrain=False,
)

tcn_model = TCNModel(
7 changes: 7 additions & 0 deletions darts/tests/models/forecasting/test_ensemble_models.py
Original file line number Diff line number Diff line change
@@ -50,6 +50,13 @@ def test_untrained_models(self):
with self.assertRaises(ValueError):
NaiveEnsembleModel([model])

# an untrained ensemble model should also give untrained underlying models
model_ens = NaiveEnsembleModel([NaiveDrift()])
model_ens.fit(self.series1)
assert model_ens.models[0]._fit_called
new_model = model_ens.untrained_model()
assert not new_model.models[0]._fit_called

def test_input_models_local_models(self):
with self.assertRaises(ValueError):
NaiveEnsembleModel([])
390 changes: 237 additions & 153 deletions examples/00-quickstart.ipynb

Large diffs are not rendered by default.