diff --git a/src/gluonts/torch/model/predictor.py b/src/gluonts/torch/model/predictor.py index bd128a6a7e..c5afd399fd 100644 --- a/src/gluonts/torch/model/predictor.py +++ b/src/gluonts/torch/model/predictor.py @@ -83,3 +83,70 @@ def predict( output_transform=self.output_transform, num_samples=num_samples, ) + + def __eq__(self, that): + if type(self) != type(that): + return False + + # TODO: also consider equality of the pipelines + # if not equals(self.input_transform, that.input_transform): + # return False + + return equals( + self.prediction_net.state_dict(), + that.prediction_net.state_dict(), + ) + + def serialize(self, path: Path) -> None: + super().serialize(path) + + # serialize network + with (path / f"prediction_net.json").open("w") as fp: + print(dump_json(self.prediction_net), file=fp) + torch.save( + self.prediction_net.state_dict(), path / "prediction_net_state" + ) + + # serialize transformation chain + with (path / "input_transform.json").open("w") as fp: + print(dump_json(self.input_transform), file=fp) + + # FIXME: also needs to serialize the output_transform + + # serialize all remaining constructor parameters + with (path / "parameters.json").open("w") as fp: + parameters = dict( + batch_size=self.batch_size, + prediction_length=self.prediction_length, + freq=self.freq, + forecast_generator=self.forecast_generator, + input_names=self.input_names, + ) + print(dump_json(parameters), file=fp) + + @classmethod + def deserialize( + cls, path: Path, device: Optional[torch.device] = None + ) -> "PyTorchPredictor": + # deserialize constructor parameters + with (path / "parameters.json").open("r") as fp: + parameters = load_json(fp.read()) + + # deserialize transformation chain + with (path / "input_transform.json").open("r") as fp: + transformation = load_json(fp.read()) + + # deserialize network + with (path / f"prediction_net.json").open("r") as fp: + prediction_net = load_json(fp.read()) + prediction_net.load_state_dict( + torch.load(path / "prediction_net_state") + ) + + parameters["device"] = device + + return PyTorchPredictor( + input_transform=transformation, + prediction_net=prediction_net, + **parameters, + ) diff --git a/test/torch/model/test_torch_predictor.py b/test/torch/model/test_torch_predictor.py new file mode 100644 index 0000000000..96c934a4ff --- /dev/null +++ b/test/torch/model/test_torch_predictor.py @@ -0,0 +1,81 @@ +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# A copy of the License is located at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# or in the "license" file accompanying this file. This file is distributed +# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import tempfile +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn + +from gluonts.core.component import validated +from gluonts.dataset.field_names import FieldName +from gluonts.model.predictor import Predictor +from gluonts.torch.model.predictor import PyTorchPredictor +from gluonts.transform import ExpectedNumInstanceSampler, InstanceSplitter + + +class RandomNetwork(nn.Module): + @validated() + def __init__( + self, + prediction_length: int, + context_length: int, + ) -> None: + super().__init__() + assert prediction_length > 0 + assert context_length > 0 + self.prediction_length = prediction_length + self.context_length = context_length + self.net = nn.Linear(context_length, prediction_length) + torch.nn.init.uniform_(self.net.weight, -1.0, 1.0) + + def forward(self, context): + assert context.shape[-1] == self.context_length + out = self.net(context) + return out.unsqueeze(1) + + +def test_pytorch_predictor_serde(): + context_length = 20 + prediction_length = 5 + + transformation = InstanceSplitter( + target_field=FieldName.TARGET, + is_pad_field=FieldName.IS_PAD, + start_field=FieldName.START, + forecast_start_field=FieldName.FORECAST_START, + train_sampler=ExpectedNumInstanceSampler(num_instances=1), + past_length=context_length, + future_length=prediction_length, + ) + + pred_net = RandomNetwork( + prediction_length=prediction_length, context_length=context_length + ) + + predictor = PyTorchPredictor( + prediction_length=prediction_length, + freq="1H", + input_names=["past_target"], + prediction_net=pred_net, + batch_size=16, + input_transform=transformation, + device=torch.device("cpu"), + ) + + with tempfile.TemporaryDirectory() as temp_dir: + predictor.serialize(Path(temp_dir)) + predictor_exp = Predictor.deserialize(Path(temp_dir)) + assert predictor == predictor_exp