From c0df26d5a410470ab52726c3217a4b83a5ea1bf2 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 26 Jun 2023 10:40:29 +0200 Subject: [PATCH 1/5] make freq optional in seasonal naive --- src/gluonts/model/seasonal_naive/_predictor.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/gluonts/model/seasonal_naive/_predictor.py b/src/gluonts/model/seasonal_naive/_predictor.py index f8372a3923..c781697ec7 100644 --- a/src/gluonts/model/seasonal_naive/_predictor.py +++ b/src/gluonts/model/seasonal_naive/_predictor.py @@ -43,23 +43,25 @@ class SeasonalNaivePredictor(RepresentablePredictor): Parameters ---------- - freq - Frequency of the input data prediction_length - Number of time points to predict + Number of time points to predict. + freq + Frequency of the input data, used to infer ``season_length`` + in case it is not provided. season_length - Length of the seasonality pattern of the input data + Length of the seasonality pattern of the input data. If not + provided, it is inferred from ``freq``. imputation_method The imputation method to use in case of missing values. - Defaults to `LastValueImputation` which replaces each missing + Defaults to :py:class:`LastValueImputation` which replaces each missing value with the last value that was not missing. """ @validated() def __init__( self, - freq: str, prediction_length: int, + freq: Optional[str] = None, season_length: Optional[int] = None, imputation_method: Optional[ MissingValueImputation @@ -67,6 +69,9 @@ def __init__( ) -> None: super().__init__(prediction_length=prediction_length) + assert (freq is not None) or ( + season_length is not None + ), "You must provide one of `freq` or `season_length`" assert ( season_length is None or season_length > 0 ), "The value of `season_length` should be > 0" From 5aceb2aae0fa4f845f5a79956c9c0c0f715c5389 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 26 Jun 2023 10:50:04 +0200 Subject: [PATCH 2/5] fixup --- src/gluonts/model/seasonal_naive/_predictor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gluonts/model/seasonal_naive/_predictor.py b/src/gluonts/model/seasonal_naive/_predictor.py index c781697ec7..0da4b340cd 100644 --- a/src/gluonts/model/seasonal_naive/_predictor.py +++ b/src/gluonts/model/seasonal_naive/_predictor.py @@ -46,11 +46,11 @@ class SeasonalNaivePredictor(RepresentablePredictor): prediction_length Number of time points to predict. freq - Frequency of the input data, used to infer ``season_length`` + Sampling frequency of the input data, used to infer ``season_length`` in case it is not provided. season_length - Length of the seasonality pattern of the input data. If not - provided, it is inferred from ``freq``. + Seasonality used to make predictions. If not provided, it is inferred + from ``freq``. imputation_method The imputation method to use in case of missing values. Defaults to :py:class:`LastValueImputation` which replaces each missing From bae94440bd7094d096ab3b403b85c0af693b03df Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 26 Jun 2023 15:51:12 +0200 Subject: [PATCH 3/5] remove freq altogether --- .../model/seasonal_naive/_predictor.py | 23 ++++--------------- .../seasonal_naive/test_seasonal_naive.py | 10 +++----- test/model/test_evaluation.py | 4 ++-- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/src/gluonts/model/seasonal_naive/_predictor.py b/src/gluonts/model/seasonal_naive/_predictor.py index 0da4b340cd..d8304ed64f 100644 --- a/src/gluonts/model/seasonal_naive/_predictor.py +++ b/src/gluonts/model/seasonal_naive/_predictor.py @@ -25,7 +25,6 @@ LastValueImputation, MissingValueImputation, ) -from gluonts.time_feature import get_seasonality class SeasonalNaivePredictor(RepresentablePredictor): @@ -45,12 +44,8 @@ class SeasonalNaivePredictor(RepresentablePredictor): ---------- prediction_length Number of time points to predict. - freq - Sampling frequency of the input data, used to infer ``season_length`` - in case it is not provided. season_length - Seasonality used to make predictions. If not provided, it is inferred - from ``freq``. + Seasonality used to make predictions. imputation_method The imputation method to use in case of missing values. Defaults to :py:class:`LastValueImputation` which replaces each missing @@ -61,27 +56,17 @@ class SeasonalNaivePredictor(RepresentablePredictor): def __init__( self, prediction_length: int, - freq: Optional[str] = None, - season_length: Optional[int] = None, + season_length: int, imputation_method: Optional[ MissingValueImputation ] = LastValueImputation(), ) -> None: super().__init__(prediction_length=prediction_length) - assert (freq is not None) or ( - season_length is not None - ), "You must provide one of `freq` or `season_length`" - assert ( - season_length is None or season_length > 0 - ), "The value of `season_length` should be > 0" + assert season_length > 0, "The value of `season_length` should be > 0" self.prediction_length = prediction_length - self.season_length = ( - season_length - if season_length is not None - else get_seasonality(freq) - ) + self.season_length = season_length self.imputation_method = imputation_method def predict_item(self, item: DataEntry) -> Forecast: diff --git a/test/model/seasonal_naive/test_seasonal_naive.py b/test/model/seasonal_naive/test_seasonal_naive.py index 0d16385d8b..57589dd907 100644 --- a/test/model/seasonal_naive/test_seasonal_naive.py +++ b/test/model/seasonal_naive/test_seasonal_naive.py @@ -29,7 +29,6 @@ def get_prediction( imputation_method=LastValueImputation(), ): pred = SeasonalNaivePredictor( - freq=FREQ, prediction_length=prediction_length, season_length=season_length, imputation_method=imputation_method, @@ -68,12 +67,9 @@ def get_prediction( ([1, 2, 3], [1], 1, 3, LastValueImputation()), ([1, 2, 3], [1, 2], 2, 3, LastValueImputation()), ([1, 2, 3], [1, 2, 3], 3, 3, LastValueImputation()), - ([1, 1, 1], [1], 1, None, LastValueImputation()), - ([1, 1, 1], [1, 1], 2, None, LastValueImputation()), - ([1, 1, 1], [1, 1, 1], 3, None, LastValueImputation()), - ([1, 3, np.nan], [3], 1, None, LastValueImputation()), - ([1, 3, np.nan], [3, 3], 2, None, LastValueImputation()), - ([1, 3, np.nan], [3, 3, 3], 3, None, LastValueImputation()), + ([1, 3, np.nan], [3], 1, 1, LastValueImputation()), + ([1, 3, np.nan], [3, 3], 2, 1, LastValueImputation()), + ([1, 3, np.nan], [3, 3, 3], 3, 1, LastValueImputation()), ([1, 3, np.nan], [np.nan], 1, 1, LeavesMissingValues()), ([1, 3, np.nan], [np.nan] * 2, 2, 1, LeavesMissingValues()), ([1, 3, np.nan], [np.nan] * 3, 3, 1, LeavesMissingValues()), diff --git a/test/model/test_evaluation.py b/test/model/test_evaluation.py index 0ae5010bf5..41332c45da 100644 --- a/test/model/test_evaluation.py +++ b/test/model/test_evaluation.py @@ -211,7 +211,7 @@ def test_evaluate_model_vs_forecasts(): ) model = SeasonalNaivePredictor( - freq="D", prediction_length=3, season_length=1 + prediction_length=3, season_length=1 ) forecasts = list(model.predict(test_data.input)) @@ -247,7 +247,7 @@ def test_data_nan(): ) model = SeasonalNaivePredictor( - freq="D", prediction_length=3, season_length=1 + prediction_length=3, season_length=1 ) forecasts = list(model.predict(test_data.input)) From fd6b7762f6bf158269ef617d4ed224603bea3f4d Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 26 Jun 2023 15:54:27 +0200 Subject: [PATCH 4/5] fix black --- test/model/test_evaluation.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/test/model/test_evaluation.py b/test/model/test_evaluation.py index 41332c45da..76df49fd23 100644 --- a/test/model/test_evaluation.py +++ b/test/model/test_evaluation.py @@ -210,9 +210,7 @@ def test_evaluate_model_vs_forecasts(): prediction_length=3, windows=4 ) - model = SeasonalNaivePredictor( - prediction_length=3, season_length=1 - ) + model = SeasonalNaivePredictor(prediction_length=3, season_length=1) forecasts = list(model.predict(test_data.input)) @@ -246,9 +244,7 @@ def test_data_nan(): prediction_length=3, windows=4 ) - model = SeasonalNaivePredictor( - prediction_length=3, season_length=1 - ) + model = SeasonalNaivePredictor(prediction_length=3, season_length=1) forecasts = list(model.predict(test_data.input)) From 99e42e99e1d6ae35b50bc2ed9c8aac9e6b9fe43b Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Mon, 26 Jun 2023 17:15:26 +0200 Subject: [PATCH 5/5] fix tests --- ...t_metrics_compared_to_previous_approach.py | 3 +- test/ext/naive_2/test_predictors.py | 62 ++++++++++++++----- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/test/ev/test_metrics_compared_to_previous_approach.py b/test/ev/test_metrics_compared_to_previous_approach.py index 859b6bd9b1..11bd15770a 100644 --- a/test/ev/test_metrics_compared_to_previous_approach.py +++ b/test/ev/test_metrics_compared_to_previous_approach.py @@ -269,7 +269,8 @@ def test_against_former_evaluator(): ) predictor = SeasonalNaivePredictor( - prediction_length=prediction_length, freq=freq + prediction_length=prediction_length, + season_length=get_seasonality(freq), ) quantile_levels = (0.1, 0.5, 0.9) diff --git a/test/ext/naive_2/test_predictors.py b/test/ext/naive_2/test_predictors.py index 67e3b41fbe..4c0a9b2ddf 100644 --- a/test/ext/naive_2/test_predictors.py +++ b/test/ext/naive_2/test_predictors.py @@ -27,6 +27,7 @@ from gluonts.ext.naive_2 import Naive2Predictor from gluonts.model.predictor import Predictor from gluonts.model.seasonal_naive import SeasonalNaivePredictor +from gluonts.time_feature import get_seasonality def generate_random_dataset( @@ -49,17 +50,24 @@ def generate_random_dataset( @pytest.mark.parametrize( - "predictor_cls", [SeasonalNaivePredictor, Naive2Predictor] + "make_predictor", + [ + lambda freq: SeasonalNaivePredictor( + prediction_length=PREDICTION_LENGTH, + season_length=SEASON_LENGTH, + ), + lambda freq: Naive2Predictor( + freq=freq, + prediction_length=PREDICTION_LENGTH, + season_length=SEASON_LENGTH, + ), + ], ) @pytest.mark.parametrize( "freq", ["1min", "15min", "30min", "1H", "2H", "12H", "7D", "1W", "1M"] ) -def test_predictor(predictor_cls, freq: str): - predictor = predictor_cls( - freq=freq, - prediction_length=PREDICTION_LENGTH, - season_length=SEASON_LENGTH, - ) +def test_predictor(make_predictor, freq: str): + predictor = make_predictor(freq) dataset = list( generate_random_dataset( num_ts=NUM_TS, @@ -87,7 +95,7 @@ def test_predictor(predictor_cls, freq: str): assert forecast.start_date == forecast_start(data) # specifically for the seasonal naive we can test the supposed result directly - if predictor_cls == SeasonalNaivePredictor: + if isinstance(predictor, SeasonalNaivePredictor): assert np.allclose(forecast.samples[0], ref) @@ -115,11 +123,25 @@ def naive_2_predictor(): @flaky(max_runs=3, min_passes=1) @pytest.mark.parametrize( - "predictor_cls, parameters, accuracy", - [seasonal_naive_predictor() + (0.0,), naive_2_predictor() + (0.0,)], + "predictor, accuracy", + [ + ( + SeasonalNaivePredictor( + prediction_length=CONSTANT_DATASET_PREDICTION_LENGTH, + season_length=get_seasonality(CONSTANT_DATASET_FREQ), + ), + 0.0, + ), + ( + Naive2Predictor( + freq=CONSTANT_DATASET_FREQ, + prediction_length=CONSTANT_DATASET_PREDICTION_LENGTH, + ), + 0.0, + ), + ], ) -def test_accuracy(predictor_cls, parameters, accuracy): - predictor = predictor_cls(freq=CONSTANT_DATASET_FREQ, **parameters) +def test_accuracy(predictor, accuracy): agg_metrics, item_metrics = backtest_metrics( test_dataset=constant_test_ds, predictor=predictor, @@ -133,11 +155,19 @@ def test_accuracy(predictor_cls, parameters, accuracy): @pytest.mark.parametrize( - "predictor_cls, parameters", - [seasonal_naive_predictor(), naive_2_predictor()], + "predictor", + [ + SeasonalNaivePredictor( + prediction_length=CONSTANT_DATASET_PREDICTION_LENGTH, + season_length=get_seasonality(CONSTANT_DATASET_FREQ), + ), + Naive2Predictor( + freq=CONSTANT_DATASET_FREQ, + prediction_length=CONSTANT_DATASET_PREDICTION_LENGTH, + ), + ], ) -def test_seriali_predictors(predictor_cls, parameters): - predictor = predictor_cls(freq=CONSTANT_DATASET_FREQ, **parameters) +def test_seriali_predictors(predictor): with tempfile.TemporaryDirectory() as temp_dir: predictor.serialize(Path(temp_dir)) predictor_exp = Predictor.deserialize(Path(temp_dir))