Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx export to support older pytorch with example_outputs argument #6309

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,10 @@ def convert_to_onnx(
inputs: input sample data used by pytorch.onnx.export. It is also used in ONNX model verification.
input_names: optional input names of the ONNX model.
output_names: optional output names of the ONNX model.
opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and <= 16, for more
details: https://github.com/onnx/onnx/blob/main/docs/Operators.md.
opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and not exceed
the latest opset version supported by PyTorch, for more details:
https://github.com/onnx/onnx/blob/main/docs/Operators.md and
https://github.com/pytorch/pytorch/blob/master/torch/onnx/_constants.py
dynamic_axes: specifies axes of tensors as dynamic (i.e. known only at run-time). If set to None,
the exported model will have the shapes of all input and output tensors set to match given
ones, for more details: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export.
Expand All @@ -603,31 +605,45 @@ def convert_to_onnx(
"""
model.eval()
with torch.no_grad():
torch_versioned_kwargs = {}
if use_trace:
script_module = torch.jit.trace(model, example_inputs=inputs)
# let torch.onnx.export to trace the model.
mode_to_export = model
else:
script_module = torch.jit.script(model, **kwargs)
if not pytorch_after(1, 10):
if "example_outputs" not in kwargs:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
raise TypeError(
"example_outputs is required in scripting mode before PyTorch 1.10."
"Please provide example outputs or use trace mode to export onnx model."
)
torch_versioned_kwargs["example_outputs"] = kwargs["example_outputs"]
del kwargs["example_outputs"]
mode_to_export = torch.jit.script(model, **kwargs)

if filename is None:
f = io.BytesIO()
torch.onnx.export(
script_module,
inputs,
mode_to_export,
tuple(inputs),
f=f,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
**torch_versioned_kwargs,
)
onnx_model = onnx.load_model_from_string(f.getvalue())
else:
torch.onnx.export(
script_module,
inputs,
mode_to_export,
tuple(inputs),
f=filename,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
**torch_versioned_kwargs,
)
onnx_model = onnx.load(filename)

Expand Down
48 changes: 33 additions & 15 deletions tests/test_convert_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,42 +19,60 @@

from monai.networks import convert_to_onnx
from monai.networks.nets import SegResNet, UNet
from monai.utils.module import pytorch_after
from tests.utils import SkipIfBeforePyTorchVersion, SkipIfNoModule, optional_import

if torch.cuda.is_available():
TORCH_DEVICE_OPTIONS = ["cpu", "cuda"]
else:
TORCH_DEVICE_OPTIONS = ["cpu"]
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False]))
TESTS = list(itertools.product(TORCH_DEVICE_OPTIONS, [True, False], [True, False]))
TESTS_ORT = list(itertools.product(TORCH_DEVICE_OPTIONS, [True]))

onnx, _ = optional_import("onnx")


@SkipIfNoModule("onnx")
@SkipIfBeforePyTorchVersion((1, 10))
@SkipIfBeforePyTorchVersion((1, 9))
class TestConvertToOnnx(unittest.TestCase):
@parameterized.expand(TESTS)
def test_unet(self, device, use_ort):
def test_unet(self, device, use_trace, use_ort):
if use_ort:
_, has_onnxruntime = optional_import("onnxruntime")
if not has_onnxruntime:
self.skipTest("onnxruntime is not installed probably due to python version >= 3.11.")
model = UNet(
spatial_dims=2, in_channels=1, out_channels=3, channels=(16, 32, 64), strides=(2, 2), num_res_units=0
)
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
verify=True,
device=device,
use_ort=use_ort,
use_trace=True,
rtol=1e-3,
atol=1e-4,
)
if pytorch_after(1, 10) or use_trace:
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
verify=True,
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
)
else:
# https://github.com/pytorch/pytorch/blob/release/1.9/torch/onnx/__init__.py#L182
# example_outputs is required in scripting mode before PyTorch 3.10
onnx_model = convert_to_onnx(
model=model,
inputs=[torch.randn((16, 1, 32, 32), requires_grad=False)],
input_names=["x"],
output_names=["y"],
example_outputs=[torch.randn((16, 3, 32, 32), requires_grad=False)],
verify=True,
device=device,
use_ort=use_ort,
use_trace=use_trace,
rtol=1e-3,
atol=1e-4,
)
self.assertTrue(isinstance(onnx_model, onnx.ModelProto))

@parameterized.expand(TESTS_ORT)
Expand Down