From d84130803dfdf013b082799a6ac82bd56fc62556 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 1 Sep 2020 10:08:01 +0200 Subject: [PATCH 1/2] fix docstring, enable distr_output in MQRNN --- src/gluonts/model/seq2seq/_forking_network.py | 6 ++++-- src/gluonts/model/seq2seq/_mq_dnn_estimator.py | 14 ++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/gluonts/model/seq2seq/_forking_network.py b/src/gluonts/model/seq2seq/_forking_network.py index ed76e0745f..7a37b59688 100644 --- a/src/gluonts/model/seq2seq/_forking_network.py +++ b/src/gluonts/model/seq2seq/_forking_network.py @@ -44,8 +44,10 @@ class ForkingSeq2SeqNetworkBase(gluon.HybridBlock): encoder to decoder mapping block. decoder: Seq2SeqDecoder decoder block. - output - An instance of DistributionOutput or QuantileOutput to use + quantile_output + quantile output + distr_output + distribution output context_length: int, length of the encoding sequence. cardinality: List[int], diff --git a/src/gluonts/model/seq2seq/_mq_dnn_estimator.py b/src/gluonts/model/seq2seq/_mq_dnn_estimator.py index dd2c3c48ce..e3d037f5a8 100644 --- a/src/gluonts/model/seq2seq/_mq_dnn_estimator.py +++ b/src/gluonts/model/seq2seq/_mq_dnn_estimator.py @@ -101,6 +101,9 @@ class MQCNNEstimator(ForkingSeq2SeqEstimator): Optimizing for more quantiles than are of direct interest to you can result in improved performance due to a regularizing effect. (default: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]) + distr_output + DistributionOutput to use. Only one between `quantile` and `distr_output` + can be set. (Default: None) trainer The GluonTS trainer to use for training. (default: Trainer()) scaling @@ -175,7 +178,6 @@ def __init__( if (quantiles is not None) or (distr_output is not None) else [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] ) - self.distr_output = distr_output assert ( len(self.channels_seq) @@ -308,7 +310,8 @@ def __init__( context_length: Optional[int] = None, decoder_mlp_dim_seq: List[int] = None, trainer: Trainer = Trainer(), - quantiles: List[float] = None, + quantiles: Optional[List[float]] = None, + distr_output: Optional[DistributionOutput] = None, scaling: bool = False, scaling_decoder_dynamic_feature: bool = False, ) -> None: @@ -328,7 +331,7 @@ def __init__( ) self.quantiles = ( quantiles - if quantiles is not None + if (quantiles is not None) or (distr_output is not None) else [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] ) @@ -351,12 +354,15 @@ def __init__( prefix="decoder_", ) - quantile_output = QuantileOutput(self.quantiles) + quantile_output = ( + QuantileOutput(self.quantiles) if self.quantiles else None + ) super().__init__( encoder=encoder, decoder=decoder, quantile_output=quantile_output, + distr_output=distr_output, freq=freq, prediction_length=prediction_length, context_length=context_length, From 24e6c6f9e977f2a3a0795afc5658b14996f80868 Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Tue, 1 Sep 2020 11:36:53 +0200 Subject: [PATCH 2/2] bump threshold --- test/model/seq2seq/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/model/seq2seq/test_model.py b/test/model/seq2seq/test_model.py index 0764cfc1de..f6949d125a 100644 --- a/test/model/seq2seq/test_model.py +++ b/test/model/seq2seq/test_model.py @@ -63,7 +63,7 @@ def test_accuracy( ) accuracy_test( - Estimator, hyperparameters, accuracy=0.20 if quantiles else 0.40 + Estimator, hyperparameters, accuracy=0.20 if quantiles else 0.50 )