Skip to content

Commit

Permalink
Rotbaum: Add item-id to forecast. (#3049)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jasper authored Nov 10, 2023
1 parent efb5ee8 commit ef98208
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/gluonts/ext/rotbaum/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ def __init__(
featurized_data: List,
start_date: pd.Period,
prediction_length: int,
item_id: Optional[str] = None,
):
self.models = models
self.featurized_data = featurized_data
self.start_date = start_date
self.prediction_length = prediction_length
self.item_id = None
self.item_id = item_id
self.lead_time = None

def quantile(self, q: float) -> np.ndarray: # type: ignore
Expand Down Expand Up @@ -336,6 +337,7 @@ def predict( # type: ignore
[featurized_data],
start_date=forecast_start(ts),
prediction_length=self.prediction_length,
item_id=ts.get("item_id"),
)

def explain(
Expand Down

0 comments on commit ef98208

Please sign in to comment.