Skip to content

Commit

Permalink
Simplify PyTorchPredictor serde (#2965)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored Aug 17, 2023
1 parent 838ba43 commit 43e78d3
Show file tree
Hide file tree
Showing 12 changed files with 24 additions and 82 deletions.
1 change: 0 additions & 1 deletion src/gluonts/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ def __eq__(self, that):
return equals(self, that)

def serialize(self, path: Path) -> None:
# call Predictor.serialize() in order to serialize the class name
super().serialize(path)
with (path / "predictor.json").open("w") as fp:
print(dump_json(self), file=fp)
Expand Down
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/d_linear/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,5 @@ def create_predictor(
),
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
2 changes: 1 addition & 1 deletion src/gluonts/torch/model/deep_npts/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def train_model(
return best_net

def get_predictor(
self, net: torch.nn.Module, device=torch.device("cpu")
self, net: torch.nn.Module, device="cpu"
) -> PyTorchPredictor:
pred_net_multi_step = DeepNPTSMultiStepNetwork(
net=net, prediction_length=self.prediction_length
Expand Down
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/deepar/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,5 @@ def create_predictor(
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/lag_tst/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,5 @@ def create_predictor(
),
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/patch_tst/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,5 @@ def create_predictor(
),
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
69 changes: 11 additions & 58 deletions src/gluonts/torch/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
# permissions and limitations under the License.

from pathlib import Path
from typing import Iterator, List, Optional
from typing import Iterator, List, Optional, Union

import numpy as np
import torch
import torch.nn as nn

from gluonts.core.serde import dump_json, load_json
from gluonts.core.component import validated
from gluonts.dataset.common import Dataset
from gluonts.dataset.loader import InferenceDataLoader
from gluonts.model.forecast import Forecast
Expand All @@ -27,9 +27,8 @@
SampleForecastGenerator,
predict_to_numpy,
)
from gluonts.model.predictor import OutputTransform, Predictor
from gluonts.model.predictor import OutputTransform, RepresentablePredictor
from gluonts.torch.batchify import batchify
from gluonts.torch.component import equals
from gluonts.transform import Transformation, SelectFields


Expand All @@ -38,7 +37,8 @@ def _(prediction_net: nn.Module, kwargs) -> np.ndarray:
return prediction_net(**kwargs).cpu().numpy()


class PyTorchPredictor(Predictor):
class PyTorchPredictor(RepresentablePredictor):
@validated()
def __init__(
self,
input_names: List[str],
Expand All @@ -49,7 +49,7 @@ def __init__(
forecast_generator: ForecastGenerator = SampleForecastGenerator(),
output_transform: Optional[OutputTransform] = None,
lead_time: int = 0,
device: Optional[torch.device] = torch.device("cpu"),
device: Optional[str] = "cpu",
) -> None:
super().__init__(prediction_length, lead_time=lead_time)
self.input_names = input_names
Expand Down Expand Up @@ -94,71 +94,24 @@ def predict(
num_samples=num_samples,
)

def __eq__(self, that):
if type(self) != type(that):
return False

if not equals(self.input_transform, that.input_transform):
return False

return equals(
self.prediction_net.state_dict(),
that.prediction_net.state_dict(),
)

def serialize(self, path: Path) -> None:
super().serialize(path)

# serialize network
with (path / "prediction_net.json").open("w") as fp:
print(dump_json(self.prediction_net), file=fp)
torch.save(
self.prediction_net.state_dict(), path / "prediction_net_state"
)

# serialize transformation chain
with (path / "input_transform.json").open("w") as fp:
print(dump_json(self.input_transform), file=fp)

# FIXME: also needs to serialize the output_transform

# serialize all remaining constructor parameters
with (path / "parameters.json").open("w") as fp:
parameters = dict(
batch_size=self.batch_size,
prediction_length=self.prediction_length,
lead_time=self.lead_time,
forecast_generator=self.forecast_generator,
input_names=self.input_names,
)
print(dump_json(parameters), file=fp)

@classmethod
def deserialize(
cls, path: Path, device: Optional[torch.device] = None
cls, path: Path, device: Optional[Union[str, torch.device]] = None
) -> "PyTorchPredictor":
predictor = super().deserialize(path)

if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

# deserialize constructor parameters
with (path / "parameters.json").open("r") as fp:
parameters = load_json(fp.read())

# deserialize transformation chain
with (path / "input_transform.json").open("r") as fp:
transformation = load_json(fp.read())

# deserialize network
with (path / "prediction_net.json").open("r") as fp:
prediction_net = load_json(fp.read())
prediction_net.load_state_dict(
predictor.prediction_net.load_state_dict(
torch.load(path / "prediction_net_state", map_location=device)
)

parameters["device"] = device

return PyTorchPredictor(
input_transform=transformation,
prediction_net=prediction_net,
**parameters,
)
return predictor
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/simple_feedforward/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,5 @@ def create_predictor(
),
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/tft/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,9 +400,7 @@ def create_predictor(
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
forecast_generator=QuantileForecastGenerator(
quantiles=[str(q) for q in self.quantiles]
),
Expand Down
4 changes: 1 addition & 3 deletions src/gluonts/torch/model/wavenet/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,5 @@ def create_predictor(
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
device=torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
device="cuda" if torch.cuda.is_available() else "cpu",
)
4 changes: 4 additions & 0 deletions test/torch/model/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def test_estimator_constant_dataset(
predictor.serialize(Path(td))
predictor_copy = Predictor.deserialize(Path(td))

assert predictor == predictor_copy

forecasts = predictor_copy.predict(constant.test)

for f in islice(forecasts, 5):
Expand Down Expand Up @@ -315,6 +317,8 @@ def test_estimator_with_features(estimator_constructor):
predictor.serialize(Path(td))
predictor_copy = Predictor.deserialize(Path(td))

assert predictor == predictor_copy

forecasts = predictor_copy.predict(prediction_dataset)

for f in islice(forecasts, 5):
Expand Down
2 changes: 1 addition & 1 deletion test/torch/model/test_torch_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_pytorch_predictor_serde():
prediction_net=pred_net,
batch_size=16,
input_transform=transformation,
device=torch.device("cpu"),
device="cpu",
)

with tempfile.TemporaryDirectory() as temp_dir:
Expand Down

0 comments on commit 43e78d3

Please sign in to comment.