Skip to content

Adding Torchscript utility functions #3138

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Dec 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from .deviceconfig import (
USE_COMPILED,
IgniteInfo,
get_config_values,
get_gpu_info,
get_optional_config_values,
get_system_info,
print_config,
print_debug_info,
Expand Down
1 change: 1 addition & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .synthetic import create_test_image_2d, create_test_image_3d
from .test_time_augmentation import TestTimeAugmentation
from .thread_buffer import ThreadBuffer, ThreadDataLoader
from .torchscript_utils import load_net_with_metadata, save_net_with_metadata
from .utils import (
compute_importance_map,
compute_shape_offset,
Expand Down
149 changes: 149 additions & 0 deletions monai/data/torchscript_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import json
import os
from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union

import torch

from monai.config import get_config_values
from monai.utils import JITMetadataKeys
from monai.utils.module import pytorch_after

METADATA_FILENAME = "metadata.json"


def save_net_with_metadata(
jit_obj: torch.nn.Module,
filename_prefix_or_stream: Union[str, IO[Any]],
include_config_vals: bool = True,
append_timestamp: bool = False,
meta_values: Optional[Mapping[str, Any]] = None,
more_extra_files: Optional[Mapping[str, bytes]] = None,
) -> None:
"""
Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata
included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used
here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be
compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will
include information about the network applicable to some use case, such as describing the input and output format,
a network name and version, a plain language description of what the network does, and other relevant scientific
information. Clients can use this information to determine automatically how to use the network, and users can
read what the network does and keep track of versions.

Examples::

net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2]))

meta = {
"name": "Test UNet",
"used_for": "demonstration purposes",
"input_dims": 2,
"output_dims": 2
}

# save the Torchscript bundle with the above dictionary stored as an extra file
save_net_with_metadata(m, "test", meta_values=meta)

# load the network back, `loaded_meta` has same data as `meta` plus version information
loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt")


Args:
jit_obj: object to save, should be generated by `script` or `trace`.
filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`.
include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata.
append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension.
meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`.
more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`.
"""

now = datetime.datetime.now()
metadict = {}

if include_config_vals:
metadict.update(get_config_values())
metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat()

if meta_values is not None:
metadict.update(meta_values)

json_data = json.dumps(metadict)

# Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object
if pytorch_after(1, 7):
extra_files = {METADATA_FILENAME: json_data.encode()}

if more_extra_files is not None:
extra_files.update(more_extra_files)
else:
extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined]
extra_files[METADATA_FILENAME] = json_data.encode()

if more_extra_files is not None:
for k, v in more_extra_files.items():
extra_files[k] = v

if isinstance(filename_prefix_or_stream, str):
filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream)
if ext == "":
ext = ".pt"

if append_timestamp:
filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}")
else:
filename_prefix_or_stream = filename_no_ext + ext

torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files)


def load_net_with_metadata(
filename_prefix_or_stream: Union[str, IO[Any]],
map_location: Optional[torch.device] = None,
more_extra_files: Sequence[str] = (),
) -> Tuple[torch.nn.Module, dict, dict]:
"""
Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata
back to a dict object. This will produce an empty dict if the metadata file is not present.

Args:
filename_prefix_or_stream: filename or file-like stream object.
map_location: network map location as in `torch.jit.load`.
more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`.
Returns:
Triple containing loaded object, metadata dict, and extra files dict containing other file data if present
"""
# Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object
if pytorch_after(1, 7):
extra_files = {f: "" for f in more_extra_files}
extra_files[METADATA_FILENAME] = ""
else:
extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined]
extra_files[METADATA_FILENAME] = ""

for f in more_extra_files:
extra_files[f] = ""

jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files)

extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap

if METADATA_FILENAME in extra_files:
json_data = extra_files[METADATA_FILENAME]
del extra_files[METADATA_FILENAME]
else:
json_data = "{}"

json_data_dict = json.loads(json_data)

return jit_obj, json_data_dict, extra_files
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
GridSamplePadMode,
InterpolateMode,
InverseKeys,
JITMetadataKeys,
LossReduction,
Method,
MetricReduction,
Expand Down
12 changes: 12 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,15 @@ class TransformBackends(Enum):

TORCH = "torch"
NUMPY = "numpy"


class JITMetadataKeys(Enum):
"""
Keys stored in the metadata file for saved Torchscript models. Some of these are generated by the routines
and others are optionally provided by users.
"""

NAME = "name"
TIMESTAMP = "timestamp"
VERSION = "version"
DESCRIPTION = "description"
112 changes: 112 additions & 0 deletions tests/test_torchscript_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import torch

from monai.config import get_config_values
from monai.data import load_net_with_metadata, save_net_with_metadata
from monai.utils import JITMetadataKeys
from monai.utils.module import pytorch_after


class TestModule(torch.nn.Module):
def forward(self, x):
return x + 10


class TestTorchscript(unittest.TestCase):
def test_save_net_with_metadata(self):
"""Save a network without metadata to a file."""
m = torch.jit.script(TestModule())

with tempfile.TemporaryDirectory() as tempdir:
save_net_with_metadata(m, f"{tempdir}/test")

self.assertTrue(os.path.isfile(f"{tempdir}/test.pt"))

def test_save_net_with_metadata_ext(self):
"""Save a network without metadata to a file."""
m = torch.jit.script(TestModule())

with tempfile.TemporaryDirectory() as tempdir:
save_net_with_metadata(m, f"{tempdir}/test.zip")

self.assertTrue(os.path.isfile(f"{tempdir}/test.zip"))

def test_save_net_with_metadata_with_extra(self):
"""Save a network with simple metadata to a file."""
m = torch.jit.script(TestModule())

test_metadata = {"foo": [1, 2], "bar": "string"}

with tempfile.TemporaryDirectory() as tempdir:
save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata)

self.assertTrue(os.path.isfile(f"{tempdir}/test.pt"))

def test_load_net_with_metadata(self):
"""Save then load a network with no metadata or other extra files."""
m = torch.jit.script(TestModule())

with tempfile.TemporaryDirectory() as tempdir:
save_net_with_metadata(m, f"{tempdir}/test")
_, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt")

del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be

self.assertEqual(meta, get_config_values())
self.assertEqual(extra_files, {})

def test_load_net_with_metadata_with_extra(self):
"""Save then load a network with basic metadata."""
m = torch.jit.script(TestModule())

test_metadata = {"foo": [1, 2], "bar": "string"}

with tempfile.TemporaryDirectory() as tempdir:
save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata)
_, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt")

del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be

test_compare = get_config_values()
test_compare.update(test_metadata)

self.assertEqual(meta, test_compare)
self.assertEqual(extra_files, {})

def test_save_load_more_extra_files(self):
"""Save then load extra file data from a torchscript file."""
m = torch.jit.script(TestModule())

test_metadata = {"foo": [1, 2], "bar": "string"}

more_extra_files = {"test.txt": b"This is test data"}

with tempfile.TemporaryDirectory() as tempdir:
save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files)

self.assertTrue(os.path.isfile(f"{tempdir}/test.pt"))

_, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",))

if pytorch_after(1, 7):
self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"])
else:
self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"])


if __name__ == "__main__":
unittest.main()