diff --git a/src/gluonts/model/r_forecast/_predictor.py b/src/gluonts/model/r_forecast/_predictor.py index dd385b5a24..50ce2c4b7a 100644 --- a/src/gluonts/model/r_forecast/_predictor.py +++ b/src/gluonts/model/r_forecast/_predictor.py @@ -249,6 +249,8 @@ def predict( for data in dataset: if self.trunc_length: + shift_by = max(data["target"].shape[0] - self.trunc_length, 0) + data["start"] = data["start"] + shift_by data["target"] = data["target"][-self.trunc_length :] params = self.params.copy() diff --git a/test/model/r_forecast/test_r_predictor.py b/test/model/r_forecast/test_r_predictor.py index 49879f9622..6af26013b8 100644 --- a/test/model/r_forecast/test_r_predictor.py +++ b/test/model/r_forecast/test_r_predictor.py @@ -15,7 +15,7 @@ from gluonts.core import serde from gluonts.dataset.repository import datasets -from gluonts.dataset.util import forecast_start +from gluonts.dataset.util import forecast_start, to_pandas from gluonts.evaluation import Evaluator, backtest_metrics from gluonts.model.forecast import SampleForecast, QuantileForecast from gluonts.model.r_forecast import ( @@ -98,6 +98,16 @@ def test_forecasts(method_name): assert agg_metrics["NRMSE"] < TOLERANCE assert agg_metrics["RMSE"] < TOLERANCE + trunc_length = prediction_length + + predictor = RForecastPredictor(**params, trunc_length=trunc_length) + predictions = list(predictor.predict(train_dataset)) + + assert all( + prediction.start_date == to_pandas(data).index[-1] + 1 + for data, prediction in zip(train_dataset, predictions) + ) + def test_r_predictor_serialization(): predictor = RForecastPredictor(freq="1D", prediction_length=3)