Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mqcnn rts #668

Merged
merged 56 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
a104dea
more features for MQDNN, and Refactoring, remove of ts-fields from da…
lovvge Feb 11, 2020
56b150b
fix the future target calculation
lovvge Feb 18, 2020
356efaa
Added derive_auto_fields method.
Feb 25, 2020
eefce49
add use_dynamic_feat option
lovvge Feb 25, 2020
67d94eb
Added checks for dyn features.
Feb 25, 2020
2e8dc06
Fix from_hyperparameters for GluonEstimator.
Feb 25, 2020
70f3a25
enable date and age features, and rts
lovvge Feb 25, 2020
4aaf099
Merge branch 'master' into mqcnn-rts
Mar 27, 2020
1f66ac1
Fixup.
Feb 26, 2020
de08785
xx
Mar 31, 2020
2a6d737
Fixup.
Mar 31, 2020
592eb67
Another fixup.
Apr 1, 2020
9a1c1ac
Merge branch 'master' into mqcnn-rts
Apr 8, 2020
d19f5b5
Merge branch 'master' into mqcnn-rts
AaronSpieler Apr 14, 2020
db57304
Fixing formatting and tests.
Apr 14, 2020
ecf31e4
A lot of TODOs and comments added.
Apr 15, 2020
dab925c
Merge from production.
Apr 16, 2020
923510b
Fixing mq_dnn single quantile error and type errors.
Apr 16, 2020
e553095
Refactoring dnn_estimator file.
Apr 17, 2020
1ba5871
Adding additional tests, minor bugfix.
Apr 17, 2020
8ba66e0
Major refactoring that allows for disabling inputs at will. All tests…
Apr 17, 2020
ae683df
Removed print
Apr 17, 2020
bd733d8
Ensuring backward compatibility, some refactoring.
Apr 17, 2020
b17c663
Mainly argument refactoring, but also some legibility refactoring.
Apr 17, 2020
5b2cfc5
Merge branch 'master' into mqcnn-rts
AaronSpieler Apr 17, 2020
62395f6
Added use_feat_static_cat support and observed_values support.
Apr 20, 2020
cf30266
Merge branch 'master' into mqcnn-rts
AaronSpieler Apr 20, 2020
8d7b87d
Minor refactoring.
Apr 20, 2020
c8221a0
Merge branch 'mqcnn-rts' of https://github.com/lovvge/gluon-ts into m…
Apr 20, 2020
2be272b
Merge branch 'master' into mqcnn-rts
Apr 21, 2020
a3bf607
Addressing Jaspers Review
Apr 21, 2020
11209c3
Merge branch 'mqcnn-rts' of https://github.com/lovvge/gluon-ts into m…
Apr 21, 2020
ed594ac
Update src/gluonts/model/estimator.py
jaheba Apr 22, 2020
a25cd6a
Merge branch 'master' into mqcnn-rts
AaronSpieler Apr 22, 2020
edc17fb
Backwards compatibility and minor fixes.
Apr 23, 2020
3d8c73e
Merge branch 'mqcnn-rts' of https://github.com/lovvge/gluon-ts into m…
Apr 23, 2020
32271e7
Improvements to model thoughput.
Apr 30, 2020
7920a03
allow decoding features
lovvge May 2, 2020
23566bd
Merge branch 'master' into mqcnn-rts
AaronSpieler May 4, 2020
fdb011d
Temprorariliy added unconditional caching.
May 4, 2020
b280046
Enabled multiprocessing by default.
May 4, 2020
2ed3c19
Standartized comments.
May 4, 2020
1c57e45
Small bug fixes.
May 4, 2020
4fce760
making caching and multiprocessing always on a local change
May 5, 2020
12673bc
mend
May 5, 2020
a020a47
Backwards compatibility fix.
May 8, 2020
e978781
Removing deepstate noise.
May 8, 2020
af5eb10
Removing deepstate noise.
May 8, 2020
e3ad554
Adjusting read speed baseline for windows.
May 8, 2020
a7973a4
Added dynamic input to MQCNN decoder.
May 12, 2020
63b2565
Added toggle option for dynamic future feat.
May 12, 2020
18e45bc
Changed default of future dynamic to disabled.
May 12, 2020
c636cd8
Turning user specified arguments into implications.
May 13, 2020
442edf5
Adding documentation for MQCNN parameters, removing non gluonts code.
May 18, 2020
d5699fa
Removing backwards compatibility.
May 18, 2020
29064f0
Merge branch 'master' into mqcnn-rts
lostella May 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/gluonts/block/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def hybrid_forward(
pass


# TODO: add support for static variables
class ForkingMLPDecoder(Seq2SeqDecoder):
"""
Multilayer perceptron decoder for sequence-to-sequence models.
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
)
self.model.add(layer)

# TODO: add support for static input
def hybrid_forward(
self, F, dynamic_input: Tensor, static_input: Tensor = None
) -> Tensor:
Expand Down
36 changes: 19 additions & 17 deletions src/gluonts/block/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ class Seq2SeqEncoder(nn.HybridBlock):
a dynamic latent code with the same length as the `target` sequence.
"""

@validated()
def __init__(self, **kwargs):
super().__init__(**kwargs)

# noinspection PyMethodOverriding
def hybrid_forward(
self,
Expand Down Expand Up @@ -77,9 +73,12 @@ def hybrid_forward(
"""
raise NotImplementedError

@staticmethod
def _assemble_inputs(
F, target: Tensor, static_features: Tensor, dynamic_features: Tensor
self,
F,
target: Tensor,
static_features: Tensor,
dynamic_features: Tensor,
) -> Tensor:
"""
Assemble features from target, static features, and the dynamic
Expand All @@ -93,7 +92,7 @@ def _assemble_inputs(

target
target time series,
shape (batch_size, sequence_length)
shape (batch_size, sequence_length, 1)

static_features
static features,
Expand All @@ -111,7 +110,6 @@ def _assemble_inputs(
num_static_features + num_dynamic_features + 1)

"""
target = target.expand_dims(axis=-1) # (N, T, 1)

helper_ones = F.ones_like(target) # Ones of (N, T, 1)
tiled_static_features = F.batch_dot(
Expand Down Expand Up @@ -156,7 +154,8 @@ def __init__(
kernel_size_seq: List[int],
channels_seq: List[int],
use_residual: bool = False,
use_covariates: bool = False,
use_static_feat: bool = False,
use_dynamic_feat: bool = False,
**kwargs,
) -> None:
assert all(
Expand All @@ -172,7 +171,8 @@ def __init__(
super().__init__(**kwargs)

self.use_residual = use_residual
self.use_covariates = use_covariates
self.use_static_feat = use_static_feat
self.use_dynamic_feat = use_dynamic_feat
self.cnn = nn.HybridSequential()

it = zip(channels_seq, kernel_size_seq, dilation_seq)
Expand Down Expand Up @@ -203,7 +203,7 @@ def hybrid_forward(

target
target time series,
shape (batch_size, sequence_length)
shape (batch_size, sequence_length, 1)

static_features
static features,
Expand All @@ -224,13 +224,15 @@ def hybrid_forward(
shape (batch_size, sequence_length, num_dynamic_features)
"""

if self.use_covariates:
inputs = Seq2SeqEncoder._assemble_inputs(
if self.use_dynamic_feat and self.use_static_feat:
inputs = self._assemble_inputs(
F,
target=target,
static_features=static_features,
dynamic_features=dynamic_features,
)
elif self.use_dynamic_feat:
inputs = F.concat(target, dynamic_features, dim=2) # (N, T, C)
else:
inputs = target
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved

Expand Down Expand Up @@ -302,7 +304,7 @@ def hybrid_forward(

target
target time series,
shape (batch_size, sequence_length)
shape (batch_size, sequence_length, 1)

static_features
static features,
Expand Down Expand Up @@ -380,7 +382,7 @@ def hybrid_forward(
shape (batch_size, sequence_length, num_dynamic_features)
"""

inputs = Seq2SeqEncoder._assemble_inputs(
inputs = self._assemble_inputs(
F, target, static_features, dynamic_features
)
static_code = self.model(inputs)
Expand Down Expand Up @@ -442,7 +444,7 @@ def hybrid_forward(

target
target time series,
shape (batch_size, sequence_length)
shape (batch_size, sequence_length, 1)

static_features
static features,
Expand All @@ -462,7 +464,7 @@ def hybrid_forward(
dynamic code,
shape (batch_size, sequence_length, num_dynamic_features)
"""
inputs = Seq2SeqEncoder._assemble_inputs(
inputs = self._assemble_inputs(
F, target, static_features, dynamic_features
)
dynamic_code = self.rnn(inputs)
Expand Down
9 changes: 6 additions & 3 deletions src/gluonts/block/quantile_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,12 @@ def hybrid_forward(
Tensor
weighted sum of the quantile losses, shape N1 x N1 x ... Nk
"""
y_pred_all = F.split(
y_pred, axis=-1, num_outputs=self.num_quantiles, squeeze_axis=1
)
if self.num_quantiles > 1:
y_pred_all = F.split(
y_pred, axis=-1, num_outputs=self.num_quantiles, squeeze_axis=1
)
else:
y_pred_all = [F.squeeze(y_pred, axis=-1)]

qt_loss = []
for i, y_pred_q in enumerate(y_pred_all):
Expand Down
15 changes: 13 additions & 2 deletions src/gluonts/model/deepar/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# First-party imports
from gluonts.core.component import DType, validated
from gluonts.dataset.field_names import FieldName
from gluonts.dataset.stat import calculate_dataset_statistics
from gluonts.distribution import DistributionOutput, StudentTOutput
from gluonts.model.estimator import GluonEstimator
from gluonts.model.predictor import Predictor, RepresentableBlockPredictor
Expand Down Expand Up @@ -146,8 +147,8 @@ def __init__(
assert num_layers > 0, "The value of `num_layers` should be > 0"
assert num_cells > 0, "The value of `num_cells` should be > 0"
assert dropout_rate >= 0, "The value of `dropout_rate` should be >= 0"
assert (cardinality is not None and use_feat_static_cat) or (
cardinality is None and not use_feat_static_cat
assert (cardinality and use_feat_static_cat) or (
not (cardinality or use_feat_static_cat)
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
), "You should set `cardinality` if and only if `use_feat_static_cat=True`"
assert cardinality is None or all(
[c > 0 for c in cardinality]
Expand Down Expand Up @@ -197,6 +198,16 @@ def __init__(

self.num_parallel_samples = num_parallel_samples

@classmethod
def derive_auto_fields(cls, train_iter):
stats = calculate_dataset_statistics(train_iter)

return {
"use_feat_dynamic_real": stats.num_feat_dynamic_real > 0,
"use_feat_static_cat": bool(stats.feat_static_cat),
"cardinality": [len(cats) for cats in stats.feat_static_cat],
}

AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
def create_transformation(self) -> Transformation:
remove_field_names = [FieldName.FEAT_DYNAMIC_CAT]
if not self.use_feat_static_real:
Expand Down
22 changes: 20 additions & 2 deletions src/gluonts/model/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# permissions and limitations under the License.

# Standard library imports
from typing import NamedTuple, Optional
from typing import NamedTuple, Optional, Iterator

# Third-party imports
import numpy as np
Expand Down Expand Up @@ -69,6 +69,21 @@ def train(
def from_hyperparameters(cls, **hyperparameters):
return from_hyperparameters(cls, **hyperparameters)

@classmethod
def derive_auto_fields(cls, train_iter):
return {}

@classmethod
def from_inputs(cls, train_iter, **params):
# auto_params usually include `use_feat_dynamic_real`, `use_feat_static_cat` and `cardinality`
auto_params = cls.derive_auto_fields(train_iter)
# FIXME: probably params should take precedence over auto_params, since they were deliberately set,
# however, on that case this method does not make sense, since if params says `use_feat_dynamic_real`=True
# but `auto_params`=False, then this will lead to an error, since the appropriate data does not exist.
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
# This the only context in which this method makes sense is when auto_params take precedence, which could
# lead to overwriting of explicit parameters. In this case a warning should be issued.
return cls.from_hyperparameters(**auto_params, **params)
AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved


class DummyEstimator(Estimator):
"""
Expand Down Expand Up @@ -126,7 +141,10 @@ def from_hyperparameters(cls, **hyperparameters) -> "GluonEstimator":
)

try:
trainer = from_hyperparameters(Trainer, **hyperparameters)
trainer = hyperparameters.get("trainer")
if not isinstance(trainer, Trainer):
trainer = from_hyperparameters(Trainer, **hyperparameters)

AaronSpieler marked this conversation as resolved.
Show resolved Hide resolved
return cls(
**Model(**{**hyperparameters, "trainer": trainer}).__dict__
)
Expand Down
9 changes: 9 additions & 0 deletions src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,15 @@ def deserialize(
def from_hyperparameters(cls, **hyperparameters):
return from_hyperparameters(cls, **hyperparameters)

@classmethod
def derive_auto_fields(cls, train_iter):
return {}

@classmethod
def from_inputs(cls, train_iter, **params):
auto_params = cls.derive_auto_fields(train_iter)
return cls.from_hyperparameters(**auto_params, **params)


class RepresentablePredictor(Predictor):
"""
Expand Down
Loading