diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index cc7d04c949..67ae8d58d9 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -1052,11 +1052,10 @@ def backtest( errors = errors[0] backtest_list.append(errors) else: - errors = [ - [metric_f(series, f) for metric_f in metric] + [metric_f(target_ts, f) for metric_f in metric] if len(metric) > 1 - else metric[0](series, f) + else metric[0](target_ts, f) for f in forecasts[idx] ] diff --git a/darts/tests/models/forecasting/test_backtesting.py b/darts/tests/models/forecasting/test_backtesting.py index b57781ffd4..56e2be42a1 100644 --- a/darts/tests/models/forecasting/test_backtesting.py +++ b/darts/tests/models/forecasting/test_backtesting.py @@ -7,9 +7,17 @@ import pytest from darts import TimeSeries +from darts.datasets import AirPassengersDataset, MonthlyMilkDataset from darts.logging import get_logger from darts.metrics import mape, r2_score -from darts.models import ARIMA, FFT, ExponentialSmoothing, NaiveDrift, Theta +from darts.models import ( + ARIMA, + FFT, + ExponentialSmoothing, + NaiveDrift, + NaiveSeasonal, + Theta, +) from darts.tests.base_test_class import DartsBaseTestClass from darts.utils.timeseries_generation import gaussian_timeseries as gt from darts.utils.timeseries_generation import linear_timeseries as lt @@ -257,6 +265,25 @@ def test_backtest_forecasting(self): self.assertEqual(pred.width, 2) self.assertEqual(pred.end_time(), linear_series.end_time()) + def test_backtest_multiple_series(self): + series = [AirPassengersDataset().load(), MonthlyMilkDataset().load()] + model = NaiveSeasonal(K=1) + + error = model.backtest( + series, + train_length=30, + forecast_horizon=2, + stride=1, + retrain=True, + last_points_only=False, + verbose=False, + ) + + expected = [11.63104, 6.09458] + self.assertEqual(len(error), 2) + self.assertAlmostEqual(error[0], expected[0], places=4) + self.assertAlmostEqual(error[1], expected[1], places=4) + @unittest.skipUnless(TORCH_AVAILABLE, "requires torch") def test_backtest_regression(self): np.random.seed(4)