diff --git a/src/gluonts/model/trivial/mean.py b/src/gluonts/model/trivial/mean.py index ad51c1a176..a543f9455f 100644 --- a/src/gluonts/model/trivial/mean.py +++ b/src/gluonts/model/trivial/mean.py @@ -26,7 +26,7 @@ from gluonts.model.forecast import Forecast, SampleForecast from gluonts.model.predictor import FallbackPredictor, RepresentablePredictor from gluonts.model.trivial.constant import ConstantPredictor -from gluonts.support.pandas import frequency_add +from gluonts.support.pandas import forecast_start class MeanPredictor(RepresentablePredictor, FallbackPredictor): @@ -70,10 +70,68 @@ def predict_item(self, item: DataEntry) -> SampleForecast: std = np.nanstd(target) normal = np.random.standard_normal(self.shape) - start_date = frequency_add(item["start"], len(item["target"])) return SampleForecast( samples=std * normal + mean, - start_date=start_date, + start_date=forecast_start(item), + freq=self.freq, + item_id=item.get(FieldName.ITEM_ID), + ) + + +class MovingAveragePredictor(RepresentablePredictor): + """ + A :class:`Predictor` that predicts the moving average based on the + last `context_length` elements of the input target. + + If `prediction_length` = 1, the output is the moving average + based on the last `context_length` elements of the input target. + + If `prediction_length` > 1, the output is the moving average based on the + last `context_length` elements of the input target, where + previously calculated moving averages are appended at the end of the input target. + Hence, for `prediction_length` larger than `context_length`, there will be + cases where the moving average is calculated on top of previous moving averages. + + Parameters + ---------- + context_length + Length of the target context used to condition the predictions. + prediction_length + Length of the prediction horizon. + freq + Frequency of the predicted data. + """ + + @validated() + def __init__( + self, + prediction_length: int, + freq: str, + context_length: Optional[int] = None, + ) -> None: + super().__init__(freq=freq, prediction_length=prediction_length) + + if context_length is not None: + assert ( + context_length >= 1 + ), "The value of 'context_length' should be >= 1 or None" + + self.context_length = context_length + + def predict_item(self, item: DataEntry) -> SampleForecast: + target = item["target"].tolist() + + for _ in range(self.prediction_length): + if self.context_length is not None: + window = target[-self.context_length :] + else: + window = target + + target.append(np.nanmean(window)) + + return SampleForecast( + samples=np.array([target[-self.prediction_length :]]), + start_date=forecast_start(item), freq=self.freq, item_id=item.get(FieldName.ITEM_ID), ) diff --git a/test/model/trivial/test_moving_average.py b/test/model/trivial/test_moving_average.py new file mode 100644 index 0000000000..11c7a63b93 --- /dev/null +++ b/test/model/trivial/test_moving_average.py @@ -0,0 +1,83 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +# First-party imports +from gluonts.dataset.common import ListDataset +from gluonts.model.trivial.mean import MovingAveragePredictor + +# Third-party imports +import numpy as np +import pytest + + +def get_predictions( + target, prediction_length=1, context_length=1, freq="D", start="2020" +): + mp = MovingAveragePredictor( + prediction_length=prediction_length, + context_length=context_length, + freq=freq, + ) + + ds = ListDataset([{"target": target, "start": start}], freq=freq) + item = next(iter(ds)) + predictions = mp.predict_item(item).mean + + return predictions + + +@pytest.mark.parametrize( + "data, expected_output, prediction_length, context_length", + [ + ([1, 1, 1], [1], 1, 1), + ([1, 1, 1], [1, 1], 2, 1), + ([1, 1, 1], [1, 1, 1], 3, 1), + ([1, 1, 1], [1], 1, 2), + ([1, 1, 1], [1, 1], 2, 2), + ([1, 1, 1], [1, 1, 1], 3, 2), + ([1, 1, 1], [1], 1, 3), + ([1, 1, 1], [1, 1], 2, 3), + ([1, 1, 1], [1, 1, 1], 3, 3), + ([], [np.nan] * 1, 1, 1), + ([], [np.nan] * 2, 2, 1), + ([], [np.nan] * 3, 3, 1), + ([np.nan], [np.nan] * 1, 1, 1), + ([1, 3, np.nan], [2], 1, 3), + ([1, 3, np.nan], [2, 2.5], 2, 3), + ([1, 3, np.nan], [2, 2.5, 2.25], 3, 3), + ([1, 2, 3], [3], 1, 1), + ([1, 2, 3], [3, 3], 2, 1), + ([1, 2, 3], [3, 3, 3], 3, 1), + ([1, 2, 3], [2.5], 1, 2), + ([1, 2, 3], [2.5, 2.75], 2, 2), + ([1, 2, 3], [2.5, 2.75, 2.625], 3, 2), + ([1, 2, 3], [2], 1, 3), + ([1, 2, 3], [2, 7 / 3], 2, 3), + ([1, 2, 3], [2, 7 / 3, 22 / 9], 3, 3), + ([1, 1, 1], [1], 1, None), + ([1, 1, 1], [1, 1], 2, None), + ([1, 1, 1], [1, 1, 1], 3, None), + ([1, 3, np.nan], [2], 1, None), + ([1, 3, np.nan], [2, 2], 2, None), + ([1, 3, np.nan], [2, 2, 2], 3, None), + ], +) +def testing(data, expected_output, prediction_length, context_length): + + predictions = get_predictions( + data, + prediction_length=prediction_length, + context_length=context_length, + ) + + np.testing.assert_equal(predictions, expected_output)