Skip to content

Commit

Permalink
Fix/backtest multiple series (#1517)
Browse files Browse the repository at this point in the history
* add a unit test capturing a bug

* fix a bug in backtesting with multiple series
  • Loading branch information
hrzn authored Jan 26, 2023
1 parent 7753a6b commit 33d8a33
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
5 changes: 2 additions & 3 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,11 +1052,10 @@ def backtest(
errors = errors[0]
backtest_list.append(errors)
else:

errors = [
[metric_f(series, f) for metric_f in metric]
[metric_f(target_ts, f) for metric_f in metric]
if len(metric) > 1
else metric[0](series, f)
else metric[0](target_ts, f)
for f in forecasts[idx]
]

Expand Down
29 changes: 28 additions & 1 deletion darts/tests/models/forecasting/test_backtesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,17 @@
import pytest

from darts import TimeSeries
from darts.datasets import AirPassengersDataset, MonthlyMilkDataset
from darts.logging import get_logger
from darts.metrics import mape, r2_score
from darts.models import ARIMA, FFT, ExponentialSmoothing, NaiveDrift, Theta
from darts.models import (
ARIMA,
FFT,
ExponentialSmoothing,
NaiveDrift,
NaiveSeasonal,
Theta,
)
from darts.tests.base_test_class import DartsBaseTestClass
from darts.utils.timeseries_generation import gaussian_timeseries as gt
from darts.utils.timeseries_generation import linear_timeseries as lt
Expand Down Expand Up @@ -257,6 +265,25 @@ def test_backtest_forecasting(self):
self.assertEqual(pred.width, 2)
self.assertEqual(pred.end_time(), linear_series.end_time())

def test_backtest_multiple_series(self):
series = [AirPassengersDataset().load(), MonthlyMilkDataset().load()]
model = NaiveSeasonal(K=1)

error = model.backtest(
series,
train_length=30,
forecast_horizon=2,
stride=1,
retrain=True,
last_points_only=False,
verbose=False,
)

expected = [11.63104, 6.09458]
self.assertEqual(len(error), 2)
self.assertAlmostEqual(error[0], expected[0], places=4)
self.assertAlmostEqual(error[1], expected[1], places=4)

@unittest.skipUnless(TORCH_AVAILABLE, "requires torch")
def test_backtest_regression(self):
np.random.seed(4)
Expand Down

0 comments on commit 33d8a33

Please sign in to comment.