Skip to content

Commit

Permalink
Rotbaum: turn to json-based serialization (#3176)
Browse files Browse the repository at this point in the history
*Issue #, if available:*

*Description of changes:*


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
lostella authored May 24, 2024
1 parent 8c194c7 commit 38b0c64
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/gluonts/ext/rotbaum/_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import concurrent.futures
import logging
import pickle
from itertools import chain
from typing import Iterator, List, Optional, Any, Dict
from toolz import first
Expand All @@ -24,6 +23,7 @@
from itertools import compress

from gluonts.core.component import validated
from gluonts.core.serde import dump_json, load_json
from gluonts.dataset.common import Dataset
from gluonts.dataset.util import forecast_start
from gluonts.model.forecast import Forecast
Expand Down Expand Up @@ -355,8 +355,8 @@ class name, version information and constructor arguments.
generated when pickling the TreePredictor.
"""
super().serialize(path)
with (path / "predictor.pkl").open("wb") as f:
pickle.dump(self.model_list, f)
with (path / "model_list.json").open("w") as fp:
print(dump_json(self.model_list), file=fp)

@classmethod
def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":
Expand All @@ -369,8 +369,8 @@ def deserialize(cls, path: Path, **kwargs: Any) -> "TreePredictor":

predictor = super().deserialize(path)
assert isinstance(predictor, cls)
with (path / "predictor.pkl").open("rb") as f:
predictor.model_list = pickle.load(f)
with (path / "model_list.json").open("r") as fp:
predictor.model_list = load_json(fp.read())
return predictor

def explain(
Expand Down

0 comments on commit 38b0c64

Please sign in to comment.