From 32d5df52202debff7c04c6afa88ab54095b2e696 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 16 Aug 2023 14:14:03 +0200 Subject: [PATCH] Clean up RepresentablePredictor (#2967) --- src/gluonts/model/predictor.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/gluonts/model/predictor.py b/src/gluonts/model/predictor.py index 69ff23ed84..e0d490c851 100644 --- a/src/gluonts/model/predictor.py +++ b/src/gluonts/model/predictor.py @@ -26,7 +26,7 @@ import gluonts from gluonts.core import fqname_for -from gluonts.core.component import equals, from_hyperparameters, validated +from gluonts.core.component import equals, from_hyperparameters from gluonts.core.serde import dump_json, load_json from gluonts.dataset.common import DataEntry, Dataset from gluonts.model.forecast import Forecast @@ -135,23 +135,18 @@ def from_inputs(cls, train_iter, **params): class RepresentablePredictor(Predictor): """ - An abstract predictor that can be subclassed by models that are not based - on Gluon. Subclasses should have @validated() constructors. - (De)serialization and value equality are all implemented on top of the. + An abstract predictor that can be subclassed by framework-specific models. + Subclasses should have ``@validated()`` constructors: + (de)serialization and equality test are all implemented on top of its logic. - @validated() logic. Parameters ---------- prediction_length Prediction horizon. + lead_time + Prediction lead time. """ - @validated() - def __init__(self, prediction_length: int, lead_time: int = 0) -> None: - super().__init__( - lead_time=lead_time, prediction_length=prediction_length - ) - def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]: for item in dataset: yield self.predict_item(item)