diff --git a/darts/models/forecasting/forecasting_model.py b/darts/models/forecasting/forecasting_model.py index 70319293a2..439e227207 100644 --- a/darts/models/forecasting/forecasting_model.py +++ b/darts/models/forecasting/forecasting_model.py @@ -548,7 +548,10 @@ def backtest( retrain: Union[bool, int, Callable[..., bool]] = True, overlap_end: bool = False, last_points_only: bool = False, - metric: Callable[[TimeSeries, TimeSeries], float] = metrics.mape, + metric: Union[ + Callable[[TimeSeries, TimeSeries], float], + List[Callable[[TimeSeries, TimeSeries], float]], + ] = metrics.mape, reduction: Union[Callable[[np.ndarray], float], None] = np.mean, verbose: bool = False, ) -> Union[float, List[float]]: @@ -624,9 +627,12 @@ def backtest( last_points_only Whether to use the whole historical forecasts or only the last point of each forecast to compute the error metric - A function that takes two ``TimeSeries`` instances as inputs and returns an error value. + A function or a list of function that takes two ``TimeSeries`` instances as inputs and returns an + error value. reduction A function used to combine the individual error scores obtained when `last_points_only` is set to False. + When providing several metric functions, the function will receive the argument `axis = 0` to obtain single + value for each metric function. If explicitly set to `None`, the method will return a list of the individual error scores instead. Set to ``np.mean`` by default. verbose @@ -651,14 +657,26 @@ def backtest( verbose=verbose, ) + if not isinstance(metric, list): + metric = [metric] + if last_points_only: - return metric(series, forecasts) + errors = [metric_f(series, forecasts) for metric_f in metric] - errors = [metric(series, forecast) for forecast in forecasts] - if reduction is None: - return errors + else: + # metric in columns, forecast in rows + errors = [ + [metric_f(series, forecast) for metric_f in metric] + for forecast in forecasts + ] + if reduction is not None: + # one value per metric + errors = reduction(np.array(errors), axis=0) - return reduction(np.array(errors)) + if len(metric) > 1: + return errors + else: + return errors[0] @classmethod def gridsearch( diff --git a/darts/tests/models/forecasting/test_backtesting.py b/darts/tests/models/forecasting/test_backtesting.py index 25f3f92ad4..5d0ad37f7c 100644 --- a/darts/tests/models/forecasting/test_backtesting.py +++ b/darts/tests/models/forecasting/test_backtesting.py @@ -117,6 +117,16 @@ def test_backtest_forecasting(self): ) self.assertEqual(score, 1.0) + # using several metric function should not affect the backtest + score = NaiveDrift().backtest( + linear_series, + train_length=10000, + start=pd.Timestamp("20000201"), + forecast_horizon=3, + metric=[r2_score, mape], + ) + np.testing.assert_almost_equal(score, np.array([1.0, 0.0])) + # window of size 2 is too small for naive drift with self.assertRaises(ValueError): NaiveDrift().backtest(