Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't save DeepState Model; RuntimeError: Cannot serialize type gluonts.distribution.lds.ParameterBounds. #443

Closed
wxie9 opened this issue Nov 13, 2019 · 3 comments · Fixed by #445
Labels
bug Something isn't working

Comments

@wxie9
Copy link

wxie9 commented Nov 13, 2019

Hi, I'm trying to save the DeepState Model and got an error message. The details are attached.

Code

import pprint

from gluonts.dataset.repository.datasets import get_dataset, dataset_recipes
from gluonts.evaluation import Evaluator
from gluonts.evaluation.backtest import make_evaluation_predictions
from gluonts.model.deepstate import DeepStateEstimator

from pathlib import Path
from gluonts.model.predictor import Predictor

from gluonts.trainer import Trainer


if __name__ == "__main__":

    print(f"datasets available: {dataset_recipes.keys()}")

    # we pick m4_hourly as it only contains a few hundred time series
    dataset = get_dataset("m4_hourly", regenerate=False)

    estimator = DeepStateEstimator(
        prediction_length=dataset.metadata.prediction_length,
        freq=dataset.metadata.freq,
        trainer=Trainer(epochs=1, num_batches_per_epoch=3),
        # cardinality = 0,
        use_feat_dynamic_real = False,
        use_feat_static_cat = False,
        cardinality=[0],
    )

    predictor = estimator.train(dataset.train)
    predictor.serialize(Path("/tmp/"))
    forecast_it, ts_it = make_evaluation_predictions(
        dataset.test, predictor=predictor, num_samples=100
    )

    agg_metrics, item_metrics = Evaluator()(
        ts_it, forecast_it, num_series=len(dataset.test)
    )

    pprint.pprint(agg_metrics)

Error Massage:

Traceback (most recent call last):
  File "/Users/xue.w/Desktop/project_code/deep_time_prediction/gluon_test/M4_TEST.py", line 48, in <module>
    predictor.serialize(Path(f"/tmp/"))
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/model/predictor.py", line 331, in serialize
    self.serialize_prediction_net(path)
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/model/predictor.py", line 482, in serialize_prediction_net
    export_repr_block(self.prediction_net, path, "prediction_net")
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/support/util.py", line 283, in export_repr_block
    print(dump_json(rb), file=fp)
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/core/serde.py", line 134, in dump_json
    return json.dumps(encode(o), indent=indent, sort_keys=True)
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/functools.py", line 840, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/core/serde.py", line 417, in encode
    "kwargs": encode(kwargs),
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/functools.py", line 840, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/core/serde.py", line 402, in encode
    return {k: encode(v) for k, v in v.items()}
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/core/serde.py", line 402, in <dictcomp>
    return {k: encode(v) for k, v in v.items()}
  File "/usr/local/Cellar/python/3.7.4_1/Frameworks/Python.framework/Versions/3.7/lib/python3.7/functools.py", line 840, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
  File "/Users/xue.w/.virtualenvs/myvenv/lib/python3.7/site-packages/gluonts/core/serde.py", line 420, in encode
    raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__)))
RuntimeError: Cannot serialize type gluonts.distribution.lds.ParameterBounds. See the documentation of the `encode` and
`validate` functions at  http://gluon-ts.mxnet.io/api/gluonts/gluonts.html
and the Python documentation of the `__getnewargs_ex__` magic method at
https://docs.python.org/3/library/pickle.html#object.__getnewargs_ex__
for more information how to make this type serializable.
@wxie9 wxie9 added the question Further information is requested label Nov 13, 2019
@lostella
Copy link
Contributor

lostella commented Nov 13, 2019

@qcw thanks for spotting this!

TODOs here (as far as I can tell):

  • add @validated() to ParameterBounds
  • add serialization/deserialization test case for DeepState (it’s missing apparently?)

@lostella lostella added bug Something isn't working and removed question Further information is requested labels Nov 13, 2019
@lostella
Copy link
Contributor

@qcw this should be fixed now, thanks again for reporting the issue!

@wxie9
Copy link
Author

wxie9 commented Nov 13, 2019

Thank you so much for the quick response!! It helps a lot. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants