Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix scaling for MQ-(C|R)NN when distribution outputs are used #1070

Merged
merged 3 commits into from
Oct 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/gluonts/model/seq2seq/_forking_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,14 @@
from gluonts.model.estimator import GluonEstimator
from gluonts.model.forecast import Quantile
from gluonts.model.forecast_generator import QuantileForecastGenerator
from gluonts.mx.model.forecast_generator import DistributionForecastGenerator
from gluonts.model.predictor import Predictor
from gluonts.mx.model.predictor import RepresentableBlockPredictor

from gluonts.mx.block.decoder import Seq2SeqDecoder
from gluonts.mx.block.enc2dec import FutureFeatIntegratorEnc2Dec
from gluonts.mx.block.encoder import Seq2SeqEncoder
from gluonts.mx.block.quantile_output import QuantileOutput
from gluonts.mx.distribution import DistributionOutput
from gluonts.mx.model.forecast_generator import DistributionForecastGenerator
from gluonts.mx.model.predictor import RepresentableBlockPredictor
from gluonts.mx.trainer import Trainer
from gluonts.support.util import copy_parameters
from gluonts.time_feature import time_features_from_frequency_str
Expand Down Expand Up @@ -125,13 +124,15 @@ class ForkingSeq2SeqEstimator(GluonEstimator):
trainer
trainer (default: Trainer())
scaling
Whether to automatically scale the target values. (default: False)
Whether to automatically scale the target values. (default: False if quantile_output is used, True otherwise)
scaling_decoder_dynamic_feature
Whether to automatically scale the dynamic features for the decoder. (default: False)
dtype
(default: np.float32)
num_forking
Decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-C(R)NN
Decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-C(R)NN.
max_ts_len
Returns the length of the longest time series in the dataset to be used in bounding context_length.
"""

@validated()
Expand All @@ -154,7 +155,7 @@ def __init__(
enable_encoder_dynamic_feature: bool = True,
enable_decoder_dynamic_feature: bool = True,
trainer: Trainer = Trainer(),
scaling: bool = False,
scaling: Optional[bool] = None,
scaling_decoder_dynamic_feature: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to turn on scaling of the dynamic features or this should be unrelated to the distribution change right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is unrelated maybe, we could do it but maybe as a separate story

dtype: DType = np.float32,
num_forking: Optional[int] = None,
Expand Down Expand Up @@ -221,7 +222,9 @@ def __init__(
)
self.enable_encoder_dynamic_feature = enable_encoder_dynamic_feature
self.enable_decoder_dynamic_feature = enable_decoder_dynamic_feature
self.scaling = scaling
self.scaling = (
scaling if scaling is not None else (quantile_output is None)
)
self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature
self.dtype = dtype

Expand Down Expand Up @@ -409,6 +412,7 @@ def create_training_network(self) -> ForkingSeq2SeqNetworkBase:
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
scaling=self.scaling,
scaling_decoder_dynamic_feature=self.scaling_decoder_dynamic_feature,
dtype=self.dtype,
)

Expand Down Expand Up @@ -443,6 +447,7 @@ def create_predictor(
cardinality=self.cardinality,
embedding_dimension=self.embedding_dimension,
scaling=self.scaling,
scaling_decoder_dynamic_feature=self.scaling_decoder_dynamic_feature,
dtype=self.dtype,
)

Expand Down
31 changes: 13 additions & 18 deletions src/gluonts/model/seq2seq/_forking_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ class ForkingSeq2SeqNetworkBase(gluon.HybridBlock):
distribution output
context_length: int,
length of the encoding sequence.
num_forking: int,
decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-C(R)NN.
cardinality: List[int],
number of values of each categorical feature.
embedding_dimension: List[int],
dimension of the embeddings for categorical features.
scaling
Whether to automatically scale the target values. (default: False)
Whether to automatically scale the target values. (default: True)
scaling_decoder_dynamic_feature
Whether to automatically scale the dynamic features for the decoder. (default: False)
dtype
(default: np.float32)
num_forking: int,
decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-C(R)NN.
kwargs: dict
dictionary of Gluon HybridBlock parameters
"""
Expand All @@ -77,7 +77,7 @@ def __init__(
embedding_dimension: List[int],
distr_output: Optional[DistributionOutput] = None,
quantile_output: Optional[QuantileOutput] = None,
scaling: bool = False,
scaling: bool = True,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we update the default for scaling back to False here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right!

scaling_decoder_dynamic_feature: bool = False,
dtype: DType = np.float32,
num_forking: Optional[int] = None,
Expand All @@ -94,25 +94,20 @@ def __init__(
self.quantile_output = quantile_output
self.scaling = scaling
self.scaling_decoder_dynamic_feature = scaling_decoder_dynamic_feature
self.scaling_decoder_dynamic_feature_axis = 1
self.dtype = dtype
self.num_forking = (
num_forking if num_forking is not None else context_length
)

if self.scaling:
self.scaler = MeanScaler(keepdims=True)
self.scaler = MeanScaler()
else:
self.scaler = NOPScaler(keepdims=True)
self.scaler = NOPScaler()

if self.scaling_decoder_dynamic_feature:
self.scaler_decoder_dynamic_feature = MeanScaler(
keepdims=True, axis=self.scaling_decoder_dynamic_feature_axis
)
self.scaler_decoder_dynamic_feature = MeanScaler(axis=1)
else:
self.scaler_decoder_dynamic_feature = NOPScaler(
keepdims=True, axis=self.scaling_decoder_dynamic_feature_axis
)
self.scaler_decoder_dynamic_feature = NOPScaler(axis=1)

with self.name_scope():
if self.quantile_output:
Expand Down Expand Up @@ -167,9 +162,7 @@ def get_decoder_network_output(

# in addition to embedding features, use the log scale as it can help prediction too
# (batch_size, num_feat_static = sum(embedding_dimension) + 1)
feat_static_real = F.concat(
embedded_cat, F.log(scale.squeeze(axis=1)), dim=1
)
feat_static_real = F.concat(embedded_cat, F.log(scale), dim=1)

# Passing past_observed_values as a feature would allow the network to
# make that distinction and possibly ignore the masked values.
Expand Down Expand Up @@ -266,7 +259,9 @@ def hybrid_forward(
else:
assert self.distr_output is not None
distr_args = self.distr_args_proj(dec_output)
distr = self.distr_output.distribution(distr_args, scale=scale)
distr = self.distr_output.distribution(
distr_args, scale=scale.expand_dims(axis=1)
)
loss = distr.loss(future_target)

# mask the loss based on observed indicator
Expand Down Expand Up @@ -361,7 +356,7 @@ def hybrid_forward(
-------
distr_args: the parameters of distribution
loc: an array of zeros with the same shape of scale
scale:
scale:
"""

dec_output, scale = self.get_decoder_network_output(
Expand Down
12 changes: 7 additions & 5 deletions src/gluonts/model/seq2seq/_mq_dnn_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,13 @@ class MQCNNEstimator(ForkingSeq2SeqEstimator):
trainer
The GluonTS trainer to use for training. (default: Trainer())
scaling
Whether to automatically scale the target values. (default: False)
Whether to automatically scale the target values. (default: False if quantile_output is used, True otherwise)
scaling_decoder_dynamic_feature
Whether to automatically scale the dynamic features for the decoder. (default: False)
num_forking
Decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-CNN
Decides how much forking to do in the decoder. 1 reduces to seq2seq and enc_len reduces to MQ-CNN.
max_ts_len
Returns the length of the longest time series in the dataset to be used in bounding context_length.
"""

@validated()
Expand All @@ -138,7 +140,7 @@ def __init__(
quantiles: Optional[List[float]] = None,
distr_output: Optional[DistributionOutput] = None,
trainer: Trainer = Trainer(),
scaling: bool = False,
scaling: Optional[bool] = None,
scaling_decoder_dynamic_feature: bool = False,
num_forking: Optional[int] = None,
max_ts_len: Optional[int] = None,
Expand Down Expand Up @@ -315,11 +317,11 @@ def __init__(
prediction_length: int,
freq: str,
context_length: Optional[int] = None,
decoder_mlp_dim_seq: List[int] = None,
decoder_mlp_dim_seq: Optional[List[int]] = None,
trainer: Trainer = Trainer(),
quantiles: Optional[List[float]] = None,
distr_output: Optional[DistributionOutput] = None,
scaling: bool = False,
scaling: Optional[bool] = None,
scaling_decoder_dynamic_feature: bool = False,
num_forking: Optional[int] = None,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion test/model/seq2seq/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_accuracy(
)

accuracy_test(
Estimator, hyperparameters, accuracy=0.20 if quantiles else 0.50
Estimator, hyperparameters, accuracy=0.20 if quantiles else 0.70
)


Expand Down