Skip to content

Commit

Permalink
fix deepstate serialization, add tests (awslabs#445)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored and Ayed committed Nov 29, 2019
1 parent 59cad51 commit f8b3ef0
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
1 change: 1 addition & 0 deletions src/gluonts/distribution/lds.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@


class ParameterBounds:
@validated()
def __init__(self, lower, upper) -> None:
assert (
lower <= upper
Expand Down
2 changes: 1 addition & 1 deletion src/gluonts/model/deepstate/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class DeepStateEstimator(GluonEstimator):
(default: False)
use_feat_static_cat
Whether to use the ``feat_static_cat`` field from the data
(default: False)
(default: True)
embedding_dimension
Dimension of the embeddings for categorical features
(default: [min(50, (cat+1)//2) for cat in cardinality])
Expand Down
30 changes: 29 additions & 1 deletion test/model/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from gluonts.evaluation.backtest import backtest_metrics
from gluonts.model.deepar import DeepAREstimator
from gluonts.model.deep_factor import DeepFactorEstimator
from gluonts.model.deepstate import DeepStateEstimator
from gluonts.model.gp_forecaster import GaussianProcessEstimator
from gluonts.model.npts import NPTSEstimator
from gluonts.model.predictor import Predictor
Expand Down Expand Up @@ -229,6 +230,28 @@ def seasonal_estimator():
return SeasonalNaiveEstimator, dict(prediction_length=prediction_length)


def deepstate_estimator(hybridize: bool = False, batches_per_epoch=1):
return (
DeepStateEstimator,
dict(
ctx="cpu",
epochs=epochs,
learning_rate=1e-2,
hybridize=hybridize,
num_cells=2,
num_layers=1,
prediction_length=prediction_length,
context_length=2,
past_length=prediction_length,
cardinality=[1],
use_feat_static_cat=False,
num_batches_per_epoch=batches_per_epoch,
use_symbol_block_predictor=False,
num_parallel_samples=2,
),
)


@flaky(max_runs=3, min_passes=1)
@pytest.mark.timeout(10) # DeepAR occasionally fails with the 5 second timeout
@pytest.mark.parametrize(
Expand All @@ -252,7 +275,11 @@ def seasonal_estimator():
+ (0.2,),
]
]
+ [npts_estimator() + (0.0,), seasonal_estimator() + (0.0,)],
+ [
npts_estimator() + (0.0,),
seasonal_estimator() + (0.0,),
deepstate_estimator(hybridize=False, batches_per_epoch=100) + (0.5,),
],
)
def test_accuracy(Estimator, hyperparameters, accuracy):
estimator = Estimator.from_hyperparameters(freq=freq, **hyperparameters)
Expand Down Expand Up @@ -296,6 +323,7 @@ def test_repr(Estimator, hyperparameters):
mqrnn_estimator(),
gp_estimator(),
transformer_estimator(),
deepstate_estimator(),
],
)
def test_serialize(Estimator, hyperparameters):
Expand Down

0 comments on commit f8b3ef0

Please sign in to comment.