Skip to content

Commit

Permalink
fix pytorch predictor serde (#1194)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella authored and Eduard Uffelmann committed Dec 8, 2020
1 parent 71af81f commit 728da97
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 0 deletions.
67 changes: 67 additions & 0 deletions src/gluonts/torch/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,70 @@ def predict(
output_transform=self.output_transform,
num_samples=num_samples,
)

def __eq__(self, that):
if type(self) != type(that):
return False

# TODO: also consider equality of the pipelines
# if not equals(self.input_transform, that.input_transform):
# return False

return equals(
self.prediction_net.state_dict(),
that.prediction_net.state_dict(),
)

def serialize(self, path: Path) -> None:
super().serialize(path)

# serialize network
with (path / f"prediction_net.json").open("w") as fp:
print(dump_json(self.prediction_net), file=fp)
torch.save(
self.prediction_net.state_dict(), path / "prediction_net_state"
)

# serialize transformation chain
with (path / "input_transform.json").open("w") as fp:
print(dump_json(self.input_transform), file=fp)

# FIXME: also needs to serialize the output_transform

# serialize all remaining constructor parameters
with (path / "parameters.json").open("w") as fp:
parameters = dict(
batch_size=self.batch_size,
prediction_length=self.prediction_length,
freq=self.freq,
forecast_generator=self.forecast_generator,
input_names=self.input_names,
)
print(dump_json(parameters), file=fp)

@classmethod
def deserialize(
cls, path: Path, device: Optional[torch.device] = None
) -> "PyTorchPredictor":
# deserialize constructor parameters
with (path / "parameters.json").open("r") as fp:
parameters = load_json(fp.read())

# deserialize transformation chain
with (path / "input_transform.json").open("r") as fp:
transformation = load_json(fp.read())

# deserialize network
with (path / f"prediction_net.json").open("r") as fp:
prediction_net = load_json(fp.read())
prediction_net.load_state_dict(
torch.load(path / "prediction_net_state")
)

parameters["device"] = device

return PyTorchPredictor(
input_transform=transformation,
prediction_net=prediction_net,
**parameters,
)
81 changes: 81 additions & 0 deletions test/torch/model/test_torch_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

import tempfile
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from gluonts.core.component import validated
from gluonts.dataset.field_names import FieldName
from gluonts.model.predictor import Predictor
from gluonts.torch.model.predictor import PyTorchPredictor
from gluonts.transform import ExpectedNumInstanceSampler, InstanceSplitter


class RandomNetwork(nn.Module):
@validated()
def __init__(
self,
prediction_length: int,
context_length: int,
) -> None:
super().__init__()
assert prediction_length > 0
assert context_length > 0
self.prediction_length = prediction_length
self.context_length = context_length
self.net = nn.Linear(context_length, prediction_length)
torch.nn.init.uniform_(self.net.weight, -1.0, 1.0)

def forward(self, context):
assert context.shape[-1] == self.context_length
out = self.net(context)
return out.unsqueeze(1)


def test_pytorch_predictor_serde():
context_length = 20
prediction_length = 5

transformation = InstanceSplitter(
target_field=FieldName.TARGET,
is_pad_field=FieldName.IS_PAD,
start_field=FieldName.START,
forecast_start_field=FieldName.FORECAST_START,
train_sampler=ExpectedNumInstanceSampler(num_instances=1),
past_length=context_length,
future_length=prediction_length,
)

pred_net = RandomNetwork(
prediction_length=prediction_length, context_length=context_length
)

predictor = PyTorchPredictor(
prediction_length=prediction_length,
freq="1H",
input_names=["past_target"],
prediction_net=pred_net,
batch_size=16,
input_transform=transformation,
device=torch.device("cpu"),
)

with tempfile.TemporaryDirectory() as temp_dir:
predictor.serialize(Path(temp_dir))
predictor_exp = Predictor.deserialize(Path(temp_dir))
assert predictor == predictor_exp

0 comments on commit 728da97

Please sign in to comment.