Skip to content

Commit

Permalink
Merge branch 'master' of github.com:awslabs/gluon-ts
Browse files Browse the repository at this point in the history
  • Loading branch information
Danielle Robinson committed Sep 16, 2019
2 parents 9307e4c + 4479857 commit 4884489
Show file tree
Hide file tree
Showing 32 changed files with 599 additions and 89 deletions.
10 changes: 10 additions & 0 deletions src/gluonts/core/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,16 @@ def encode_np_dtype(v: np.dtype) -> Any:
}


@encode.register(mx.nd.NDArray)
def encode_mx_ndarray(v: mx.nd.NDArray) -> Any:
return {
"__kind__": kind_inst,
"class": "mxnet.nd.array",
"args": encode([v.asnumpy().tolist()]),
"kwargs": {"dtype": encode(v.dtype)},
}


def decode(r: Any) -> Any:
"""
Decodes a value from an intermediate representation `r`.
Expand Down
16 changes: 9 additions & 7 deletions src/gluonts/dataset/artificial/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def generate_ts(
ts_data = dict(
start=self.start,
target=target,
item=str(i),
item_id=str(i),
feat_static_cat=[i],
feat_static_real=[i],
)
Expand Down Expand Up @@ -508,7 +508,7 @@ def train(self) -> List[DataEntry]:
dict(
start=ts["start"],
target=ts["target"][: -self.prediction_length],
item=ts["item"],
item_id=ts["item_id"],
)
for ts in self.make_timeseries()
]
Expand Down Expand Up @@ -609,7 +609,7 @@ def sigmoid(x: np.ndarray) -> np.ndarray:
dict(
start=pd.Timestamp(start, freq=self.freq_str),
target=np.array(v),
item=i,
item_id=i,
)
)
return res
Expand Down Expand Up @@ -681,7 +681,9 @@ def trim_ts_item_end(x: DataEntry, length: int) -> DataEntry:
the last prediction_length time points from the target and dynamic
features."""
y = dict(
item=x["item"], start=x["start"], target=x["target"][:-length]
item_id=x["item_id"],
start=x["start"],
target=x["target"][:-length],
)

if "feat_dynamic_cat" in x:
Expand All @@ -703,7 +705,7 @@ def trim_ts_item_front(x: DataEntry, length: int) -> DataEntry:
assert length <= len(x["target"])

y = dict(
item=x["item"],
item_id=x["item_id"],
start=x["start"] + length * x["start"].freq,
target=x["target"][length:],
)
Expand Down Expand Up @@ -799,7 +801,7 @@ def constant_dataset() -> Tuple[DatasetInfo, Dataset, Dataset]:
train_ds = ListDataset(
data_iter=[
{
"item": str(i),
"item_id": str(i),
"start": start_date,
"target": [float(i)] * 24,
"feat_static_cat": [i],
Expand All @@ -813,7 +815,7 @@ def constant_dataset() -> Tuple[DatasetInfo, Dataset, Dataset]:
test_ds = ListDataset(
data_iter=[
{
"item": str(i),
"item_id": str(i),
"start": start_date,
"target": [float(i)] * 30,
"feat_static_cat": [i],
Expand Down
Loading

0 comments on commit 4884489

Please sign in to comment.