Skip to content

Commit

Permalink
fix item_id field in provided datasets (#566)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Jan 21, 2020
1 parent 824159c commit 1171f9e
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/gluonts/dataset/repository/_gp_copula_2019.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def save_dataset(dataset_path: Path, ds_info: GPCopulaDataset):
# Handles adding categorical features of rolling
# evaluation dates
cat=[cat - ds_info.num_series * (cat // ds_info.num_series)],
item_id=cat,
)
for cat, data_entry in enumerate(dataset)
],
Expand Down
2 changes: 2 additions & 0 deletions src/gluonts/dataset/repository/_lstnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def generate_lstnet_dataset(dataset_path: Path, dataset_name: str):
target_values=sliced_ts.values,
start=sliced_ts.index[0],
cat=[cat],
item_id=cat,
)
)

Expand All @@ -192,6 +193,7 @@ def generate_lstnet_dataset(dataset_path: Path, dataset_name: str):
target_values=sliced_ts.values,
start=sliced_ts.index[0],
cat=[cat],
item_id=cat,
)
)

Expand Down
14 changes: 12 additions & 2 deletions src/gluonts/dataset/repository/_m4.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,25 @@ def generate_m4_dataset(
save_to_file(
train_file,
[
to_dict(target_values=target, start=mock_start_dataset, cat=[cat])
to_dict(
target_values=target,
start=mock_start_dataset,
cat=[cat],
item_id=cat,
)
for cat, target in enumerate(train_target_values)
],
)

save_to_file(
test_file,
[
to_dict(target_values=target, start=mock_start_dataset, cat=[cat])
to_dict(
target_values=target,
start=mock_start_dataset,
cat=[cat],
item_id=cat,
)
for cat, target in enumerate(test_target_values)
],
)
10 changes: 8 additions & 2 deletions src/gluonts/dataset/repository/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@
import json
import os
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Any

import numpy as np


def to_dict(
target_values: np.ndarray, start: str, cat: Optional[List[int]] = None
target_values: np.ndarray,
start: str,
cat: Optional[List[int]] = None,
item_id: Optional[Any] = None,
):
def serialize(x):
if np.isnan(x):
Expand All @@ -37,6 +40,9 @@ def serialize(x):
if cat is not None:
res["feat_static_cat"] = cat

if item_id is not None:
res["item_id"] = item_id

return res


Expand Down

0 comments on commit 1171f9e

Please sign in to comment.