Skip to content

Commit

Permalink
Clean up RepresentablePredictor (#2967)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Aug 16, 2023
1 parent 0ffcaa6 commit 32d5df5
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import gluonts
from gluonts.core import fqname_for
from gluonts.core.component import equals, from_hyperparameters, validated
from gluonts.core.component import equals, from_hyperparameters
from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import DataEntry, Dataset
from gluonts.model.forecast import Forecast
Expand Down Expand Up @@ -135,23 +135,18 @@ def from_inputs(cls, train_iter, **params):

class RepresentablePredictor(Predictor):
"""
An abstract predictor that can be subclassed by models that are not based
on Gluon. Subclasses should have @validated() constructors.
(De)serialization and value equality are all implemented on top of the.
An abstract predictor that can be subclassed by framework-specific models.
Subclasses should have ``@validated()`` constructors:
(de)serialization and equality test are all implemented on top of its logic.
@validated() logic.
Parameters
----------
prediction_length
Prediction horizon.
lead_time
Prediction lead time.
"""

@validated()
def __init__(self, prediction_length: int, lead_time: int = 0) -> None:
super().__init__(
lead_time=lead_time, prediction_length=prediction_length
)

def predict(self, dataset: Dataset, **kwargs) -> Iterator[Forecast]:
for item in dataset:
yield self.predict_item(item)
Expand Down

0 comments on commit 32d5df5

Please sign in to comment.