-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Adding Torchscript utility functions #3138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
d39efef
Adding Torchscript utility functions
ericspod 2b59bfb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] aa5b29d
[MONAI] python code formatting
monai-bot 8d1a32d
Adding Torchscript utility functions
ericspod f7744db
Merge branch 'dev' into torchscript_metadata
ericspod 3dfd1f8
Added test for extra files
ericspod be984dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1c0f3c4
Update
ericspod 1c2f620
Update
ericspod dd8c033
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 0e78ad6
Update
ericspod ae80786
Update
ericspod b994604
Updates
ericspod 123c585
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 7818227
Updates
ericspod 470b556
Merge branch 'dev' into torchscript_metadata
Nic-Ma c2ec69c
Merge branch 'dev' into torchscript_metadata
ericspod b9e30ab
Updates
ericspod File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Copyright 2020 - 2021 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import datetime | ||
import json | ||
import os | ||
from typing import IO, Any, Mapping, Optional, Sequence, Tuple, Union | ||
|
||
import torch | ||
|
||
from monai.config import get_config_values | ||
from monai.utils import JITMetadataKeys | ||
from monai.utils.module import pytorch_after | ||
|
||
METADATA_FILENAME = "metadata.json" | ||
|
||
|
||
def save_net_with_metadata( | ||
jit_obj: torch.nn.Module, | ||
filename_prefix_or_stream: Union[str, IO[Any]], | ||
include_config_vals: bool = True, | ||
append_timestamp: bool = False, | ||
meta_values: Optional[Mapping[str, Any]] = None, | ||
more_extra_files: Optional[Mapping[str, bytes]] = None, | ||
Nic-Ma marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
""" | ||
Save the JIT object (script or trace produced object) `jit_obj` to the given file or stream with metadata | ||
included as a JSON file. The Torchscript format is a zip file which can contain extra file data which is used | ||
here as a mechanism for storing metadata about the network being saved. The data in `meta_values` should be | ||
compatible with conversion to JSON using the standard library function `dumps`. The intent is this metadata will | ||
include information about the network applicable to some use case, such as describing the input and output format, | ||
a network name and version, a plain language description of what the network does, and other relevant scientific | ||
information. Clients can use this information to determine automatically how to use the network, and users can | ||
read what the network does and keep track of versions. | ||
|
||
Examples:: | ||
|
||
net = torch.jit.script(monai.networks.nets.UNet(2, 1, 1, [8, 16], [2])) | ||
|
||
meta = { | ||
MMelQin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"name": "Test UNet", | ||
"used_for": "demonstration purposes", | ||
"input_dims": 2, | ||
"output_dims": 2 | ||
} | ||
|
||
# save the Torchscript bundle with the above dictionary stored as an extra file | ||
save_net_with_metadata(m, "test", meta_values=meta) | ||
|
||
# load the network back, `loaded_meta` has same data as `meta` plus version information | ||
loaded_net, loaded_meta, _ = load_net_with_metadata("test.pt") | ||
|
||
|
||
Args: | ||
jit_obj: object to save, should be generated by `script` or `trace`. | ||
filename_prefix_or_stream: filename or file-like stream object, if filename has no extension it becomes `.pt`. | ||
include_config_vals: if True, MONAI, Pytorch, and Numpy versions are included in metadata. | ||
append_timestamp: if True, a timestamp for "now" is appended to the file's name before the extension. | ||
meta_values: metadata values to store with the object, not limited just to keys in `JITMetadataKeys`. | ||
more_extra_files: other extra file data items to include in bundle, see `_extra_files` of `torch.jit.save`. | ||
""" | ||
|
||
now = datetime.datetime.now() | ||
metadict = {} | ||
|
||
if include_config_vals: | ||
metadict.update(get_config_values()) | ||
metadict[JITMetadataKeys.TIMESTAMP.value] = now.astimezone().isoformat() | ||
|
||
if meta_values is not None: | ||
metadict.update(meta_values) | ||
|
||
json_data = json.dumps(metadict) | ||
|
||
# Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object | ||
if pytorch_after(1, 7): | ||
ericspod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
extra_files = {METADATA_FILENAME: json_data.encode()} | ||
|
||
if more_extra_files is not None: | ||
extra_files.update(more_extra_files) | ||
else: | ||
extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] | ||
extra_files[METADATA_FILENAME] = json_data.encode() | ||
|
||
if more_extra_files is not None: | ||
for k, v in more_extra_files.items(): | ||
extra_files[k] = v | ||
|
||
if isinstance(filename_prefix_or_stream, str): | ||
filename_no_ext, ext = os.path.splitext(filename_prefix_or_stream) | ||
if ext == "": | ||
ext = ".pt" | ||
|
||
if append_timestamp: | ||
filename_prefix_or_stream = now.strftime(f"{filename_no_ext}_%Y%m%d%H%M%S{ext}") | ||
else: | ||
filename_prefix_or_stream = filename_no_ext + ext | ||
|
||
torch.jit.save(jit_obj, filename_prefix_or_stream, extra_files) | ||
|
||
|
||
def load_net_with_metadata( | ||
filename_prefix_or_stream: Union[str, IO[Any]], | ||
map_location: Optional[torch.device] = None, | ||
more_extra_files: Sequence[str] = (), | ||
) -> Tuple[torch.nn.Module, dict, dict]: | ||
""" | ||
Load the module object from the given Torchscript filename or stream, and convert the stored JSON metadata | ||
back to a dict object. This will produce an empty dict if the metadata file is not present. | ||
|
||
Args: | ||
filename_prefix_or_stream: filename or file-like stream object. | ||
map_location: network map location as in `torch.jit.load`. | ||
more_extra_files: other extra file data names to load from bundle, see `_extra_files` of `torch.jit.load`. | ||
Returns: | ||
Triple containing loaded object, metadata dict, and extra files dict containing other file data if present | ||
""" | ||
# Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object | ||
if pytorch_after(1, 7): | ||
ericspod marked this conversation as resolved.
Show resolved
Hide resolved
|
||
extra_files = {f: "" for f in more_extra_files} | ||
extra_files[METADATA_FILENAME] = "" | ||
else: | ||
extra_files = torch._C.ExtraFilesMap() # type:ignore[attr-defined] | ||
extra_files[METADATA_FILENAME] = "" | ||
|
||
for f in more_extra_files: | ||
extra_files[f] = "" | ||
|
||
jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files) | ||
|
||
extra_files = dict(extra_files.items()) # compatibility with ExtraFilesMap | ||
|
||
if METADATA_FILENAME in extra_files: | ||
json_data = extra_files[METADATA_FILENAME] | ||
del extra_files[METADATA_FILENAME] | ||
else: | ||
json_data = "{}" | ||
|
||
json_data_dict = json.loads(json_data) | ||
|
||
return jit_obj, json_data_dict, extra_files |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
# Copyright 2020 - 2021 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import tempfile | ||
import unittest | ||
|
||
import torch | ||
|
||
from monai.config import get_config_values | ||
from monai.data import load_net_with_metadata, save_net_with_metadata | ||
from monai.utils import JITMetadataKeys | ||
from monai.utils.module import pytorch_after | ||
|
||
|
||
class TestModule(torch.nn.Module): | ||
def forward(self, x): | ||
return x + 10 | ||
|
||
|
||
class TestTorchscript(unittest.TestCase): | ||
def test_save_net_with_metadata(self): | ||
"""Save a network without metadata to a file.""" | ||
m = torch.jit.script(TestModule()) | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
save_net_with_metadata(m, f"{tempdir}/test") | ||
|
||
self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) | ||
|
||
def test_save_net_with_metadata_ext(self): | ||
"""Save a network without metadata to a file.""" | ||
m = torch.jit.script(TestModule()) | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
save_net_with_metadata(m, f"{tempdir}/test.zip") | ||
|
||
self.assertTrue(os.path.isfile(f"{tempdir}/test.zip")) | ||
|
||
def test_save_net_with_metadata_with_extra(self): | ||
"""Save a network with simple metadata to a file.""" | ||
m = torch.jit.script(TestModule()) | ||
|
||
test_metadata = {"foo": [1, 2], "bar": "string"} | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) | ||
|
||
self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) | ||
|
||
def test_load_net_with_metadata(self): | ||
"""Save then load a network with no metadata or other extra files.""" | ||
m = torch.jit.script(TestModule()) | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
save_net_with_metadata(m, f"{tempdir}/test") | ||
_, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") | ||
|
||
del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be | ||
|
||
self.assertEqual(meta, get_config_values()) | ||
self.assertEqual(extra_files, {}) | ||
|
||
def test_load_net_with_metadata_with_extra(self): | ||
"""Save then load a network with basic metadata.""" | ||
m = torch.jit.script(TestModule()) | ||
|
||
test_metadata = {"foo": [1, 2], "bar": "string"} | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata) | ||
_, meta, extra_files = load_net_with_metadata(f"{tempdir}/test.pt") | ||
|
||
del meta[JITMetadataKeys.TIMESTAMP.value] # no way of knowing precisely what this value would be | ||
|
||
test_compare = get_config_values() | ||
test_compare.update(test_metadata) | ||
|
||
self.assertEqual(meta, test_compare) | ||
self.assertEqual(extra_files, {}) | ||
|
||
def test_save_load_more_extra_files(self): | ||
"""Save then load extra file data from a torchscript file.""" | ||
m = torch.jit.script(TestModule()) | ||
|
||
test_metadata = {"foo": [1, 2], "bar": "string"} | ||
|
||
more_extra_files = {"test.txt": b"This is test data"} | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
save_net_with_metadata(m, f"{tempdir}/test", meta_values=test_metadata, more_extra_files=more_extra_files) | ||
|
||
self.assertTrue(os.path.isfile(f"{tempdir}/test.pt")) | ||
|
||
_, _, loaded_extra_files = load_net_with_metadata(f"{tempdir}/test.pt", more_extra_files=("test.txt",)) | ||
|
||
if pytorch_after(1, 7): | ||
self.assertEqual(more_extra_files["test.txt"], loaded_extra_files["test.txt"]) | ||
else: | ||
self.assertEqual(more_extra_files["test.txt"].decode(), loaded_extra_files["test.txt"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.