Skip to content

Commit

Permalink
Fix r_forecast wrapper to shift start date when truncating time ser…
Browse files Browse the repository at this point in the history
…ies (awslabs#2216)
  • Loading branch information
abdulfatir authored and lostella committed Aug 26, 2022
1 parent c45b248 commit 851fe05
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/gluonts/model/r_forecast/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def predict(

for data in dataset:
if self.trunc_length:
shift_by = max(data["target"].shape[0] - self.trunc_length, 0)
data["start"] = data["start"] + shift_by
data["target"] = data["target"][-self.trunc_length :]

params = self.params.copy()
Expand Down
12 changes: 11 additions & 1 deletion test/model/r_forecast/test_r_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from gluonts.core import serde
from gluonts.dataset.repository import datasets
from gluonts.dataset.util import forecast_start
from gluonts.dataset.util import forecast_start, to_pandas
from gluonts.evaluation import Evaluator, backtest_metrics
from gluonts.model.forecast import SampleForecast, QuantileForecast
from gluonts.model.r_forecast import (
Expand Down Expand Up @@ -98,6 +98,16 @@ def test_forecasts(method_name):
assert agg_metrics["NRMSE"] < TOLERANCE
assert agg_metrics["RMSE"] < TOLERANCE

trunc_length = prediction_length

predictor = RForecastPredictor(**params, trunc_length=trunc_length)
predictions = list(predictor.predict(train_dataset))

assert all(
prediction.start_date == to_pandas(data).index[-1] + 1
for data, prediction in zip(train_dataset, predictions)
)


def test_r_predictor_serialization():
predictor = RForecastPredictor(freq="1D", prediction_length=3)
Expand Down

0 comments on commit 851fe05

Please sign in to comment.