diff --git a/src/gluonts/model/deepstate/issm.py b/src/gluonts/model/deepstate/issm.py index 8135e365b3..f7c967920f 100644 --- a/src/gluonts/model/deepstate/issm.py +++ b/src/gluonts/model/deepstate/issm.py @@ -201,8 +201,7 @@ def transition_coeff( F = getF(seasonal_indicators) _transition_coeff = ( - # TODO: this won't work if F == mx.sym - F.array([[1, 1], [0, 1]]) + (F.diag(F.ones(shape=(2,)), k=0) + F.diag(F.ones(shape=(1,)), k=1)) .expand_dims(axis=0) .expand_dims(axis=0) ) diff --git a/test/model/test_deepstate_smoke.py b/test/model/test_deepstate_smoke.py index 1f5b09adf5..4e25b5859a 100644 --- a/test/model/test_deepstate_smoke.py +++ b/test/model/test_deepstate_smoke.py @@ -26,8 +26,11 @@ common_estimator_hps = dict( freq="D", prediction_length=3, - trainer=Trainer(epochs=3, num_batches_per_epoch=2, batch_size=1), + trainer=Trainer( + epochs=3, num_batches_per_epoch=2, batch_size=1, hybridize=True + ), past_length=10, + add_trend=True, )