Skip to content

Commit

Permalink
context_length now optional::forecast_start is now used::updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Eduardo Mercado Lopez committed Jul 27, 2020
1 parent 9258b5f commit 05c38ff
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 96 deletions.
6 changes: 3 additions & 3 deletions src/gluonts/model/trivial/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from gluonts.model.forecast import Forecast, SampleForecast
from gluonts.model.predictor import FallbackPredictor, RepresentablePredictor
from gluonts.model.trivial.constant import ConstantPredictor
from gluonts.support.pandas import frequency_add
from gluonts.support.pandas import frequency_add, forecast_start


class MeanPredictor(RepresentablePredictor, FallbackPredictor):
Expand Down Expand Up @@ -105,7 +105,7 @@ class MovingAveragePredictor(RepresentablePredictor):

@validated()
def __init__(
self, prediction_length: int, freq: str, context_length: int,
self, prediction_length: int, freq: str, context_length: int = 1,
) -> None:
super().__init__(freq=freq, prediction_length=prediction_length)

Expand All @@ -122,7 +122,7 @@ def predict_item(self, item: DataEntry) -> SampleForecast:
window = target[-self.context_length :]
target.append(np.nanmean(window))

start_date = frequency_add(item["start"], len(item["target"]))
start_date = forecast_start(item)
return SampleForecast(
samples=np.array([target[-self.prediction_length :]]),
start_date=start_date,
Expand Down
140 changes: 47 additions & 93 deletions test/model/trivial/test_moving_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,109 +17,63 @@

# Third-party imports
import numpy as np
import pytest


def set_data(target, freq):
"""
Sets test data in the right format
"""

start = "2020"
ds = ListDataset([{"target": target, "start": start}], freq=freq)
data = list(ds)

return data


def get_predictions(data, prediction_length, context_length, freq):
"""
Gets predictions based on moving average
"""

def get_predictions(
target, prediction_length=1, context_length=1, freq="D", start="2020"
):
mp = MovingAveragePredictor(
prediction_length=prediction_length,
context_length=context_length,
freq=freq,
)
predictions = mp.predict_item(data[0]).samples[0]

return predictions


def check_equality_constant_sequence(
predictions, constant_value, prediction_length
):
"""
Checks if prediction values coincide with expected values. This is for the case where the input is constant:
expected = [constant_value, constant_value, ..., constant_value]
"""

expected = [constant_value] * prediction_length

if np.isnan(constant_value):
return list(np.isnan(predictions)) == list(np.isnan(expected))
else:
return list(predictions) == expected


def run_evaluations(
data,
freq,
constant_value,
context_length_values=range(1, 10),
prediction_length_values=range(1, 10),
):
"""
Executes generic tests based on settings provided by input parameters
Performs asserts on the output and shape of output.
"""

for context_length in context_length_values:
for prediction_length in prediction_length_values:
predictions = get_predictions(
data, prediction_length, context_length, freq
)
assert check_equality_constant_sequence(
predictions, constant_value, prediction_length
)
assert predictions.shape == (prediction_length,)


def test_constant_sequence():
constant_value = 1
target_length = 3
target = [constant_value] * target_length # [1, 1, 1]
freq = "D"
data = set_data(target, freq)

run_evaluations(data, freq, constant_value)


def test_length_one_sequence():
constant_value = 1
# target_length = 1
target = [constant_value]
freq = "D"
data = set_data(target, freq)

run_evaluations(data, freq, constant_value)

ds = ListDataset([{"target": target, "start": start}], freq=freq)
item = next(iter(ds))
predictions = mp.predict_item(item).samples[0]

def test_empty_sequence():
constant_value = np.nan
target_length = 0
target = []
freq = "D"
data = set_data(target, freq)
return predictions

run_evaluations(data, freq, constant_value)

@pytest.mark.parametrize(
"data, expected_output, prediction_length, context_length",
[
([1, 1, 1], [1], 1, 1),
([1, 1, 1], [1, 1], 2, 1),
([1, 1, 1], [1, 1, 1], 3, 1),
([1, 1, 1], [1], 1, 2),
([1, 1, 1], [1, 1], 2, 2),
([1, 1, 1], [1, 1, 1], 3, 2),
([1, 1, 1], [1], 1, 3),
([1, 1, 1], [1, 1], 2, 3),
([1, 1, 1], [1, 1, 1], 3, 3),
([], [np.nan] * 1, 1, 1),
([], [np.nan] * 2, 2, 1),
([], [np.nan] * 3, 3, 1),
([np.nan], [np.nan] * 1, 1, 1),
([1, 3, np.nan], [2], 1, 3),
([1, 3, np.nan], [2, 2.5], 2, 3),
([1, 3, np.nan], [2, 2.5, 2.25], 3, 3),
([1, 2, 3], [3], 1, 1),
([1, 2, 3], [3, 3], 2, 1),
([1, 2, 3], [3, 3, 3], 3, 1),
([1, 2, 3], [2.5], 1, 2),
([1, 2, 3], [2.5, 2.75], 2, 2),
([1, 2, 3], [2.5, 2.75, 2.625], 3, 2),
([1, 2, 3], [2], 1, 3),
([1, 2, 3], [2, 7 / 3], 2, 3),
([1, 2, 3], [2, 7 / 3, 22 / 9], 3, 3),
],
)
def testing(data, expected_output, prediction_length, context_length):

predictions = get_predictions(
data,
prediction_length=prediction_length,
context_length=context_length,
)

def test_nan_sequence():
constant_value = np.nan
target_length = 3
target = [constant_value] * target_length # [1, 1, 1]
freq = "D"
data = set_data(target, freq)
np.testing.assert_equal(predictions, expected_output)

run_evaluations(data, freq, constant_value)
np.testing.assert_equal(predictions.shape, (prediction_length,))

0 comments on commit 05c38ff

Please sign in to comment.