From 8f750f00bb804bad33809f3121ec143204d06c5e Mon Sep 17 00:00:00 2001 From: Lorenzo Stella Date: Wed, 13 Nov 2019 13:18:39 +0100 Subject: [PATCH] fix deepstate serialization, add tests --- src/gluonts/distribution/lds.py | 1 + src/gluonts/model/deepstate/_estimator.py | 2 +- test/model/test_models.py | 30 ++++++++++++++++++++++- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/gluonts/distribution/lds.py b/src/gluonts/distribution/lds.py index e643fb34d0..18ff66b923 100644 --- a/src/gluonts/distribution/lds.py +++ b/src/gluonts/distribution/lds.py @@ -28,6 +28,7 @@ class ParameterBounds: + @validated() def __init__(self, lower, upper) -> None: assert ( lower <= upper diff --git a/src/gluonts/model/deepstate/_estimator.py b/src/gluonts/model/deepstate/_estimator.py index 4472296b4e..af79e008d3 100644 --- a/src/gluonts/model/deepstate/_estimator.py +++ b/src/gluonts/model/deepstate/_estimator.py @@ -129,7 +129,7 @@ class DeepStateEstimator(GluonEstimator): (default: False) use_feat_static_cat Whether to use the ``feat_static_cat`` field from the data - (default: False) + (default: True) embedding_dimension Dimension of the embeddings for categorical features (default: [min(50, (cat+1)//2) for cat in cardinality]) diff --git a/test/model/test_models.py b/test/model/test_models.py index 7948503824..31c111a3ca 100644 --- a/test/model/test_models.py +++ b/test/model/test_models.py @@ -26,6 +26,7 @@ from gluonts.evaluation.backtest import backtest_metrics from gluonts.model.deepar import DeepAREstimator from gluonts.model.deep_factor import DeepFactorEstimator +from gluonts.model.deepstate import DeepStateEstimator from gluonts.model.gp_forecaster import GaussianProcessEstimator from gluonts.model.npts import NPTSEstimator from gluonts.model.predictor import Predictor @@ -229,6 +230,28 @@ def seasonal_estimator(): return SeasonalNaiveEstimator, dict(prediction_length=prediction_length) +def deepstate_estimator(hybridize: bool = False, batches_per_epoch=1): + return ( + DeepStateEstimator, + dict( + ctx="cpu", + epochs=epochs, + learning_rate=1e-2, + hybridize=hybridize, + num_cells=2, + num_layers=1, + prediction_length=prediction_length, + context_length=2, + past_length=prediction_length, + cardinality=[1], + use_feat_static_cat=False, + num_batches_per_epoch=batches_per_epoch, + use_symbol_block_predictor=False, + num_parallel_samples=2, + ), + ) + + @flaky(max_runs=3, min_passes=1) @pytest.mark.timeout(10) # DeepAR occasionally fails with the 5 second timeout @pytest.mark.parametrize( @@ -252,7 +275,11 @@ def seasonal_estimator(): + (0.2,), ] ] - + [npts_estimator() + (0.0,), seasonal_estimator() + (0.0,)], + + [ + npts_estimator() + (0.0,), + seasonal_estimator() + (0.0,), + deepstate_estimator(hybridize=False, batches_per_epoch=100) + (0.5,), + ], ) def test_accuracy(Estimator, hyperparameters, accuracy): estimator = Estimator.from_hyperparameters(freq=freq, **hyperparameters) @@ -296,6 +323,7 @@ def test_repr(Estimator, hyperparameters): mqrnn_estimator(), gp_estimator(), transformer_estimator(), + deepstate_estimator(), ], ) def test_serialize(Estimator, hyperparameters):