-
Notifications
You must be signed in to change notification settings - Fork 8
/
pytorch2onnx.py
143 lines (120 loc) · 4.84 KB
/
pytorch2onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import os
import sys
import warnings
from typing import Dict, List, Tuple
import numpy as np
import onnx
import torch
from torch import Tensor
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from models.detectors.base_detector import BaseDetector, EvalResize
from util import utils
from util.lazy_load import Config
class ONNXDetector:
def __init__(self, onnx_file):
import onnxruntime
self.session = onnxruntime.InferenceSession(
onnx_file, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
)
self.io_binding = self.session.io_binding()
self.is_cuda_available = onnxruntime.get_device() == "GPU"
def __call__(self, images: List[Tensor], targets: List[Dict] = None):
if targets is not None:
warnings.warn("Currently ONNXDetector only support inference, targets will be ignored")
assert len(images) == 1, "Currently ONNXDetector only support batch_size=1 for inference"
assert images[0].ndim == 3, "Each image must be with three dimensions of C, H, W"
if isinstance(images, (List, Tuple)):
images = torch.stack(images)
# set io binding for inputs/outputs
device_type = images.device.type if self.is_cuda_available else "cpu"
if not self.is_cuda_available:
images = images.cpu()
self.io_binding.bind_input(
name="images",
device_type=device_type,
device_id=0,
element_type=np.float32,
shape=images.shape,
buffer_ptr=images.data_ptr(),
)
for output in self.session.get_outputs():
self.io_binding.bind_output(output.name)
# run session to get outputs
self.session.run_with_iobinding(self.io_binding)
detections = self.io_binding.copy_outputs_to_cpu()
return detections
def parse_args():
parser = argparse.ArgumentParser(description="Convert a pytorch model to ONNX model")
# model parameters
parser.add_argument("--model-config", type=str, default=None)
parser.add_argument("--checkpoint", type=str, default=None)
parser.add_argument("--shape", type=int, nargs="+", default=(1333, 800))
# save parameters
parser.add_argument("--save-file", type=str, required=True)
# onnx parameters
parser.add_argument("--opset-version", type=int, default=17)
parser.add_argument("--dynamic-export", type=bool, default=True)
parser.add_argument("--simplify", action="store_true")
parser.add_argument("--verify", action="store_true")
args = parser.parse_args()
return args
def set_antialias_to_false(model: BaseDetector):
for transform in model.eval_transform:
if isinstance(transform, EvalResize):
transform.antialias = False
def pytorch2onnx():
# get args from parser
args = parse_args()
model = Config(args.model_config).model
set_antialias_to_false(model)
model.eval()
if args.checkpoint:
checkpoint = torch.load(args.checkpoint, map_location="cpu")
utils.load_state_dict(model, checkpoint["model"] if "model" in checkpoint else checkpoint)
image = torch.randn(1, 3, args.shape[0], args.shape[1])
if args.dynamic_export:
dynamic_axes = {
"images": {
0: "batch",
2: "height",
3: "width",
},
}
else:
dynamic_axes = None
torch.onnx.export(
model=model,
args=image,
f=args.save_file,
input_names=["images"],
output_names=["scores", "labels", "boxes"],
dynamic_axes=dynamic_axes,
opset_version=args.opset_version,
)
if args.simplify:
import onnxsim
model_ops, check_ok = onnxsim.simplify(args.save_file)
if check_ok:
onnx.save(model_ops, args.save_file)
print(f"Successfully simplified ONNX model: {args.save_file}")
else:
warnings.warn("Failed to simplify ONNX model.")
print(f"Successfully exported ONNX model: {args.save_file}")
if args.verify:
# check by onnx
onnx_model = onnx.load(args.save_file)
onnx.checker.check_model(onnx_model)
# check onnx results and pytorch results
onnx_model = ONNXDetector(args.save_file)
onnx_results = onnx_model(image)
pytorch_results = list(model(image)[0].values())
err_msg = "The numerical values are different between Pytorch and ONNX"
err_msg += "But it does not necessarily mean the exported ONNX is problematic."
for onnx_res, pytorch_res in zip(onnx_results, pytorch_results):
np.testing.assert_allclose(
onnx_res, pytorch_res, rtol=1e-3, atol=1e-5, err_msg=err_msg
)
print("The numerical values are the same between Pytorch and ONNX")
if __name__ == "__main__":
pytorch2onnx()