Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backports for v0.10.5 #2252

Merged
merged 8 commits into from
Aug 26, 2022
32 changes: 16 additions & 16 deletions docs/getting_started/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,26 +51,26 @@ NPTS | Local | Uni

<!-- Links to code -->

[DeepAR_mx]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/deepar/_estimator.py
[DeepAR_mx]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/deepar/_estimator.py
[DeepAR_torch]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/torch/model/deepar/estimator.py
[DeepState]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/deepstate/_estimator.py
[DeepFactor]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/deep_factor/_estimator.py
[DeepRenewal]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/renewal/_estimator.py
[GP]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/gp_forecaster/_estimator.py
[MQDNN]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/seq2seq/_mq_dnn_estimator.py
[NBeats]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/n_beats/_estimator.py
[DeepState]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/deepstate/_estimator.py
[DeepFactor]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/deep_factor/_estimator.py
[DeepRenewal]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/renewal/_estimator.py
[GP]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/gp_forecaster/_estimator.py
[MQDNN]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/seq2seq/_mq_dnn_estimator.py
[NBeats]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/n_beats/_estimator.py
[Rotbaum]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/rotbaum/_estimator.py
[SAN]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/san/_estimator.py
[TFT]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/tft/_estimator.py
[Transformer]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/transformer/_estimator.py
[WaveNet]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/wavenet/_estimator.py
[SFF_mx]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/simple_feedforward/_estimator.py
[SAN]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/san/_estimator.py
[TFT]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/tft/_estimator.py
[Transformer]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/transformer/_estimator.py
[WaveNet]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/wavenet/_estimator.py
[SFF_mx]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/simple_feedforward/_estimator.py
[SFF_torch]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/torch/model/simple_feedforward/estimator.py
[DeepVAR]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/deepvar/_estimator.py
[DeepVAR]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/deepvar/_estimator.py
[DeepVARHierarchical]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/deepvar_hierarchical/_estimator.py
[GPVAR]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/gpvar/_estimator.py
[LSTNet]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/lstnet/_estimator.py
[DeepTPP]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/tpp/deeptpp/_estimator.py
[GPVAR]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/gpvar/_estimator.py
[LSTNet]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/lstnet/_estimator.py
[DeepTPP]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/mx/model/tpp/deeptpp/_estimator.py
[RForecast]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/r_forecast/_predictor.py
[Prophet]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/prophet/_predictor.py
[NaiveSeasonal]: https://github.com/awslabs/gluon-ts/blob/dev/src/gluonts/model/seasonal_naive/_predictor.py
Expand Down
74 changes: 20 additions & 54 deletions src/gluonts/dataset/arrow/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,71 +12,37 @@
# permissions and limitations under the License.

from dataclasses import dataclass
from functools import singledispatch
from typing import Dict
from typing import List, Tuple

import numpy as np
import pyarrow as pa


@singledispatch
def _arrow_to_py(scalar):
"""Convert arrow scalar value to python value."""

raise NotImplementedError(scalar, scalar.__class__)


@_arrow_to_py.register(pa.Scalar)
def _arrow_to_py_scalar(scalar: pa.Scalar):
return scalar.as_py()


@_arrow_to_py.register(pa.ListScalar)
def _arrow_to_py_list_scalar(scalar: pa.ListScalar):
arr = scalar.values.to_numpy(zero_copy_only=False)

if arr.dtype == object:
arr = np.array(list(arr))

return arr


@dataclass
class ArrowDecoder:
columns: Dict[str, int]
ndarray_columns: Dict[str, int]
reshape_columns: List[Tuple[str, str]]

@classmethod
def from_schema(cls, schema):
columns = {}
ndarray_columns = {}

for idx, column in enumerate(schema):
if column.name.endswith("._np_shape"):
ndarray_columns[(column.name.rsplit(".", 1)[0])] = idx
else:
columns[column.name] = idx
return cls(
[
(column.name[: -len("._np_shape")], column.name)
for column in schema
if column.name.endswith("._np_shape")
]
)

return cls(columns, ndarray_columns)

def decode(self, batch, row_number):
for row in self.decode_batch(batch.slice(row_number, row_number + 1)):
return row
def decode(self, batch, row_number: int):
yield from self.decode_batch(batch.slice(row_number, row_number + 1))

def decode_batch(self, batch):
rows = zip(*batch)

for raw_row in rows:
row = {}
for column_name, column_idx in self.columns.items():
value = _arrow_to_py(raw_row[column_idx])

shape_idx = self.ndarray_columns.get(column_name)

if shape_idx is not None:
shape = _arrow_to_py(raw_row[shape_idx])
value = value.reshape(shape)

row[column_name] = value
for row in batch.to_pandas().to_dict("records"):
for column_name, shape_column in self.reshape_columns:
row[column_name] = row[column_name].reshape(
row.pop(shape_column)
)

for name, value in row.items():
if type(value) == np.ndarray and value.dtype == object:
row[name] = np.stack(value)

yield row
1 change: 1 addition & 0 deletions src/gluonts/dataset/field_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class FieldName:
"""

ITEM_ID = "item_id"
INFO = "info"

START = "start"
TARGET = "target"
Expand Down
11 changes: 7 additions & 4 deletions src/gluonts/model/forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,11 +636,14 @@ def __repr__(self):
]
)

def to_quantile_forecast(
self, quantiles: List[Union[float, str]]
) -> "QuantileForecast":
def to_quantile_forecast(self, quantiles: List[str]) -> "QuantileForecast":
return QuantileForecast(
forecast_arrays=np.array([self.quantile(q) for q in quantiles]),
forecast_arrays=np.array(
[
self.quantile(q) if q != "mean" else self.mean()
for q in quantiles
]
),
start_date=self.start_date,
forecast_keys=quantiles,
item_id=self.item_id,
Expand Down
2 changes: 2 additions & 0 deletions src/gluonts/model/r_forecast/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def predict(

for data in dataset:
if self.trunc_length:
shift_by = max(data["target"].shape[0] - self.trunc_length, 0)
data["start"] = data["start"] + shift_by
data["target"] = data["target"][-self.trunc_length :]

params = self.params.copy()
Expand Down
10 changes: 9 additions & 1 deletion src/gluonts/model/simple_feedforward/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,15 @@ def __init__(
# transformation that includes time features, age feature, observed values
# indicator, ...
def create_transformation(self) -> Transformation:
return AddObservedValuesIndicator(
return SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
],
allow_missing=True,
) + AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
dtype=self.dtype,
Expand Down
4 changes: 3 additions & 1 deletion src/gluonts/mx/distribution/piecewise_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,9 @@ def crps(self, x: Tensor) -> Tensor:
a_tilde.expand_dims(axis=-1), knot_positions
)

knots_cubed = F.broadcast_power(self.knot_positions, F.ones(1) * 3.0)
knots_cubed = F.power(
knot_positions, F.ones_like(knot_positions) * 3.0
)

coeff = (
(1.0 - knots_cubed) / 3.0
Expand Down
12 changes: 11 additions & 1 deletion src/gluonts/mx/trainer/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,12 @@ def loop( # todo call run epoch
)

batch_iter = itertools.islice(batch_iter, num_batches_to_use)

it = tqdm(batch_iter, total=num_batches_to_use)
any_batches = False

for batch_no, batch in enumerate(it, start=1):
any_batches = True

# `batch` here is expected to be a dictionary whose fields
# should correspond 1-to-1 with the network inputs
# see below how `batch.values()` is fed into the network
Expand Down Expand Up @@ -421,6 +424,13 @@ def loop( # todo call run epoch
break
it.close()

if not any_batches:
raise GluonTSDataError(
"No training data batch could be constructed; "
"this usually indicates that the training dataset "
"is empty, or consists of too short series."
)

# mark epoch end time and log time cost of current epoch
if not self.halt:
toc = time.time()
Expand Down
10 changes: 9 additions & 1 deletion src/gluonts/torch/model/simple_feedforward/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,15 @@ def __init__(
)

def create_transformation(self) -> Transformation:
return AddObservedValuesIndicator(
return SelectFields(
[
FieldName.ITEM_ID,
FieldName.INFO,
FieldName.START,
FieldName.TARGET,
],
allow_missing=True,
) + AddObservedValuesIndicator(
target_field=FieldName.TARGET,
output_field=FieldName.OBSERVED_VALUES,
)
Expand Down
9 changes: 8 additions & 1 deletion src/gluonts/transform/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,18 @@ class SelectFields(MapTransformation):
----------
input_fields
List of fields to keep.
allow_missing
If ``True``, skip any missing field. Default: ``False``.
"""

@validated()
def __init__(self, input_fields: List[str]) -> None:
def __init__(
self, input_fields: List[str], allow_missing: bool = False
) -> None:
self.input_fields = input_fields
self.allow_missing = allow_missing

def map_transform(self, data: DataEntry, is_train: bool) -> DataEntry:
if self.allow_missing:
return {f: data[f] for f in self.input_fields if f in data}
return {f: data[f] for f in self.input_fields}
1 change: 1 addition & 0 deletions test/dataset/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def make_data(n: int):
@pytest.mark.parametrize("flatten_arrays", [True, False])
def test_arrow(writer, flatten_arrays):
data = make_data(5)
writer.flatten_arrays = flatten_arrays

with tempfile.TemporaryDirectory() as path:
path = Path(path, "data.arrow")
Expand Down
12 changes: 10 additions & 2 deletions test/model/deepar/test_deepar_smoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest

from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.distribution import PiecewiseLinearOutput, StudentTOutput
from gluonts.mx.trainer import Trainer

from gluonts.testutil.dummy_datasets import make_dummy_datasets_with_features
Expand Down Expand Up @@ -106,11 +107,18 @@
),
],
)
@pytest.mark.parametrize(
"distr_output", [StudentTOutput(), PiecewiseLinearOutput(num_pieces=5)]
)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
@pytest.mark.parametrize("impute_missing_values", [False, True])
def test_deepar_smoke(estimator, datasets, dtype, impute_missing_values):
def test_deepar_smoke(
distr_output, estimator, datasets, dtype, impute_missing_values
):
estimator = estimator(
dtype=dtype, impute_missing_values=impute_missing_values
distr_output=distr_output,
dtype=dtype,
impute_missing_values=impute_missing_values,
)
dataset_train, dataset_test = datasets
predictor = estimator.train(dataset_train)
Expand Down
12 changes: 11 additions & 1 deletion test/model/r_forecast/test_r_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from gluonts.core import serde
from gluonts.dataset.repository import datasets
from gluonts.dataset.util import forecast_start
from gluonts.dataset.util import forecast_start, to_pandas
from gluonts.evaluation import Evaluator, backtest_metrics
from gluonts.model.forecast import SampleForecast, QuantileForecast
from gluonts.model.r_forecast import (
Expand Down Expand Up @@ -98,6 +98,16 @@ def test_forecasts(method_name):
assert agg_metrics["NRMSE"] < TOLERANCE
assert agg_metrics["RMSE"] < TOLERANCE

trunc_length = prediction_length

predictor = RForecastPredictor(**params, trunc_length=trunc_length)
predictions = list(predictor.predict(train_dataset))

assert all(
prediction.start_date == to_pandas(data).index[-1] + 1
for data, prediction in zip(train_dataset, predictions)
)


def test_r_predictor_serialization():
predictor = RForecastPredictor(freq="1D", prediction_length=3)
Expand Down
File renamed without changes.
30 changes: 30 additions & 0 deletions test/mx/test_no_batches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import pytest

from gluonts.exceptions import GluonTSDataError
from gluonts.model.deepar import DeepAREstimator
from gluonts.mx.trainer import Trainer


@pytest.mark.parametrize("dataset", [[]])
def test_deepar_no_batches(dataset):
estimator = DeepAREstimator(
prediction_length=10,
freq="H",
trainer=Trainer(epochs=1, num_batches_per_epoch=1),
)

with pytest.raises(GluonTSDataError):
estimator.train(dataset)
Loading