Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix historical forecasts retraining of TFMs #1465

Merged
merged 7 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,17 +691,16 @@ def historical_forecasts(
retrain_func = _retrain_wrapper(
lambda counter: counter % int(retrain) == 0 if retrain else False
)

elif isinstance(retrain, Callable):
retrain_func = _retrain_wrapper(retrain)

else:
raise_log(
ValueError(
"`retrain` argument must be either `bool`, positive `int` or `Callable` (returning `bool`)"
),
logger,
)

retrain_func_signature = tuple(
inspect.signature(retrain_func).parameters.keys()
)
Expand All @@ -728,7 +727,6 @@ def historical_forecasts(

forecasts_list = []
for idx, series_ in enumerate(outer_iterator):

past_covariates_ = past_covariates[idx] if past_covariates else None
future_covariates_ = future_covariates[idx] if future_covariates else None

Expand Down Expand Up @@ -765,15 +763,12 @@ def historical_forecasts(

# prepare the start parameter -> pd.Timestamp
if start is not None:

historical_forecasts_time_index = drop_before_index(
historical_forecasts_time_index,
series_.get_timestamp_at_point(start),
)

else:
if (retrain is not False) or (not self._fit_called):

if train_length:
historical_forecasts_time_index = drop_before_index(
historical_forecasts_time_index,
Expand Down Expand Up @@ -804,9 +799,9 @@ def historical_forecasts(
(not self._fit_called)
and (retrain is False)
and (not train_length),
" The model has not been fitted yet, and `start` and train_length are not specified. "
" The model is not retraining during the historical forecasts. Hence the "
"the first and only training would be done on 2 samples.",
"The model has not been fitted yet, and `start` and train_length are not specified. "
"The model is not retraining during the historical forecasts. Hence the "
"first and only training would be done on 2 samples.",
logger,
)

Expand Down Expand Up @@ -837,7 +832,6 @@ def historical_forecasts(

# iterate and forecast
for _counter, pred_time in enumerate(iterator):

# build the training series
if min_timestamp > series_.time_index[0]:
train_series = series_.drop_before(
Expand Down Expand Up @@ -866,13 +860,17 @@ def historical_forecasts(
if future_covariates_
else None,
):
self._fit_wrapper(
# avoid fitting the same model multiple times
model = self.untrained_model()
model._fit_wrapper(
series=train_series,
past_covariates=past_covariates_,
future_covariates=future_covariates_,
)
else:
model = self

forecast = self._predict_wrapper(
forecast = model._predict_wrapper(
n=forecast_horizon,
series=train_series,
past_covariates=past_covariates_,
Expand Down Expand Up @@ -901,7 +899,6 @@ def historical_forecasts(
hierarchy=series_.hierarchy,
)
)

else:
forecasts_list.append(forecasts)

Expand Down Expand Up @@ -1526,7 +1523,7 @@ def _extract_model_creation_params(self):
return model_params

def untrained_model(self):
return self.__class__(**self.model_params)
return self.__class__(**copy.deepcopy(self.model_params))

@property
def model_params(self) -> dict:
Expand Down
19 changes: 18 additions & 1 deletion darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import pandas as pd
import pytest

from darts import TimeSeries
from darts.logging import get_logger
Expand Down Expand Up @@ -219,12 +220,28 @@ def test_backtest_forecasting(self):
self.assertEqual(pred.end_time(), linear_series.end_time())

# multivariate model + multivariate series
with self.assertRaises(ValueError):
# historical forecasts doesn't overwrite model object -> we can use different input dimensions
tcn_model.backtest(
linear_series_multi,
start=pd.Timestamp("20000125"),
forecast_horizon=3,
verbose=False,
retrain=False,
)

# univariate model
tcn_model = TCNModel(
input_chunk_length=12, output_chunk_length=1, batch_size=1, n_epochs=1
)
tcn_model.fit(linear_series, verbose=False)
# univariate fitted model + multivariate series
with pytest.raises(ValueError):
tcn_model.backtest(
linear_series_multi,
start=pd.Timestamp("20000125"),
forecast_horizon=3,
verbose=False,
retrain=False,
)

tcn_model = TCNModel(
Expand Down
7 changes: 7 additions & 0 deletions darts/tests/models/forecasting/test_ensemble_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def test_untrained_models(self):
with self.assertRaises(ValueError):
NaiveEnsembleModel([model])

# an untrained ensemble model should also give untrained underlying models
model_ens = NaiveEnsembleModel([NaiveDrift()])
model_ens.fit(self.series1)
assert model_ens.models[0]._fit_called
new_model = model_ens.untrained_model()
assert not new_model.models[0]._fit_called

def test_input_models_local_models(self):
with self.assertRaises(ValueError):
NaiveEnsembleModel([])
Expand Down
Loading