Skip to content

Commit

Permalink
Implemented moving average (awslabs#926)
Browse files Browse the repository at this point in the history
* Add moving average

* Add moving average

* Add moving average - added tests

* context_length now optional::forecast_start is now used::updated tests

* refinments added

* added assert + forecast_start updates

* forecast_start updates

* forecast_start updates + assert

Co-authored-by: Pedro Eduardo Mercado Lopez <pedroml@amazon.com>
  • Loading branch information
2 people authored and kashif committed Oct 10, 2020
1 parent 5a1b92a commit c14b0e6
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 3 deletions.
64 changes: 61 additions & 3 deletions src/gluonts/model/trivial/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from gluonts.model.forecast import Forecast, SampleForecast
from gluonts.model.predictor import FallbackPredictor, RepresentablePredictor
from gluonts.model.trivial.constant import ConstantPredictor
from gluonts.support.pandas import frequency_add
from gluonts.support.pandas import forecast_start


class MeanPredictor(RepresentablePredictor, FallbackPredictor):
Expand Down Expand Up @@ -70,10 +70,68 @@ def predict_item(self, item: DataEntry) -> SampleForecast:
std = np.nanstd(target)
normal = np.random.standard_normal(self.shape)

start_date = frequency_add(item["start"], len(item["target"]))
return SampleForecast(
samples=std * normal + mean,
start_date=start_date,
start_date=forecast_start(item),
freq=self.freq,
item_id=item.get(FieldName.ITEM_ID),
)


class MovingAveragePredictor(RepresentablePredictor):
"""
A :class:`Predictor` that predicts the moving average based on the
last `context_length` elements of the input target.
If `prediction_length` = 1, the output is the moving average
based on the last `context_length` elements of the input target.
If `prediction_length` > 1, the output is the moving average based on the
last `context_length` elements of the input target, where
previously calculated moving averages are appended at the end of the input target.
Hence, for `prediction_length` larger than `context_length`, there will be
cases where the moving average is calculated on top of previous moving averages.
Parameters
----------
context_length
Length of the target context used to condition the predictions.
prediction_length
Length of the prediction horizon.
freq
Frequency of the predicted data.
"""

@validated()
def __init__(
self,
prediction_length: int,
freq: str,
context_length: Optional[int] = None,
) -> None:
super().__init__(freq=freq, prediction_length=prediction_length)

if context_length is not None:
assert (
context_length >= 1
), "The value of 'context_length' should be >= 1 or None"

self.context_length = context_length

def predict_item(self, item: DataEntry) -> SampleForecast:
target = item["target"].tolist()

for _ in range(self.prediction_length):
if self.context_length is not None:
window = target[-self.context_length :]
else:
window = target

target.append(np.nanmean(window))

return SampleForecast(
samples=np.array([target[-self.prediction_length :]]),
start_date=forecast_start(item),
freq=self.freq,
item_id=item.get(FieldName.ITEM_ID),
)
Expand Down
83 changes: 83 additions & 0 deletions test/model/trivial/test_moving_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# First-party imports
from gluonts.dataset.common import ListDataset
from gluonts.model.trivial.mean import MovingAveragePredictor

# Third-party imports
import numpy as np
import pytest


def get_predictions(
target, prediction_length=1, context_length=1, freq="D", start="2020"
):
mp = MovingAveragePredictor(
prediction_length=prediction_length,
context_length=context_length,
freq=freq,
)

ds = ListDataset([{"target": target, "start": start}], freq=freq)
item = next(iter(ds))
predictions = mp.predict_item(item).mean

return predictions


@pytest.mark.parametrize(
"data, expected_output, prediction_length, context_length",
[
([1, 1, 1], [1], 1, 1),
([1, 1, 1], [1, 1], 2, 1),
([1, 1, 1], [1, 1, 1], 3, 1),
([1, 1, 1], [1], 1, 2),
([1, 1, 1], [1, 1], 2, 2),
([1, 1, 1], [1, 1, 1], 3, 2),
([1, 1, 1], [1], 1, 3),
([1, 1, 1], [1, 1], 2, 3),
([1, 1, 1], [1, 1, 1], 3, 3),
([], [np.nan] * 1, 1, 1),
([], [np.nan] * 2, 2, 1),
([], [np.nan] * 3, 3, 1),
([np.nan], [np.nan] * 1, 1, 1),
([1, 3, np.nan], [2], 1, 3),
([1, 3, np.nan], [2, 2.5], 2, 3),
([1, 3, np.nan], [2, 2.5, 2.25], 3, 3),
([1, 2, 3], [3], 1, 1),
([1, 2, 3], [3, 3], 2, 1),
([1, 2, 3], [3, 3, 3], 3, 1),
([1, 2, 3], [2.5], 1, 2),
([1, 2, 3], [2.5, 2.75], 2, 2),
([1, 2, 3], [2.5, 2.75, 2.625], 3, 2),
([1, 2, 3], [2], 1, 3),
([1, 2, 3], [2, 7 / 3], 2, 3),
([1, 2, 3], [2, 7 / 3, 22 / 9], 3, 3),
([1, 1, 1], [1], 1, None),
([1, 1, 1], [1, 1], 2, None),
([1, 1, 1], [1, 1, 1], 3, None),
([1, 3, np.nan], [2], 1, None),
([1, 3, np.nan], [2, 2], 2, None),
([1, 3, np.nan], [2, 2, 2], 3, None),
],
)
def testing(data, expected_output, prediction_length, context_length):

predictions = get_predictions(
data,
prediction_length=prediction_length,
context_length=context_length,
)

np.testing.assert_equal(predictions, expected_output)

0 comments on commit c14b0e6

Please sign in to comment.