Skip to content

Commit

Permalink
change order of arguments for metric in gridsearch (#989)
Browse files Browse the repository at this point in the history
* change order of arguments for metric in gridsearch

* add doctring in gridsearch

* add line break to docstring of gridsearch

Co-authored-by: Grotehans <Clara.Grotehans@EXXETA.com>
Co-authored-by: Julien Herzen <julien@unit8.co>
  • Loading branch information
3 people authored Jun 21, 2022
1 parent a1328fa commit b530dcd
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,8 @@ def gridsearch(
If `True`, uses the comparison with the fitted values.
Raises an error if ``fitted_values`` is not an attribute of `model_class`.
metric
A function that takes two TimeSeries instances as inputs and returns a float error value.
A function that takes two TimeSeries instances as inputs (actual and prediction, in this order),
and returns a float error value.
reduction
A reduction function (mapping array to float) describing how to aggregate the errors obtained
on the different validation series when backtesting. By default it'll compute the mean of errors.
Expand Down Expand Up @@ -764,7 +765,7 @@ def _evaluate_combination(param_combination) -> float:
fitted_values = TimeSeries.from_times_and_values(
series.time_index, model.fitted_values
)
error = metric(fitted_values, series)
error = metric(series, fitted_values)
elif val_series is None: # expanding window mode
error = model.backtest(
series=series,
Expand All @@ -787,7 +788,7 @@ def _evaluate_combination(param_combination) -> float:
future_covariates,
num_samples=1,
)
error = metric(pred, val_series)
error = metric(val_series, pred)

return float(error)

Expand Down

0 comments on commit b530dcd

Please sign in to comment.