Skip to content

Commit a81b1bb

Browse files
drcutxvjiarui
authored andcommitted
add pytorch2onnx part (open-mmlab#12)
* add pytorch2onnx part * Update according to the latest mmcv * add docstring * update docs * update docs Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
1 parent eabbccf commit a81b1bb

File tree

4 files changed

+227
-9
lines changed

4 files changed

+227
-9
lines changed

docs/getting_started.md

+15
Original file line numberDiff line numberDiff line change
@@ -332,3 +332,18 @@ python tools/publish_model.py work_dirs/pspnet/latest.pth psp_r50_hszhao_200ep.p
332332
```
333333
334334
The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`.
335+
336+
### Convert to ONNX (experimental)
337+
338+
We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model.
339+
340+
```shell
341+
python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify]
342+
```
343+
344+
**Note**: This tool is still experimental. Some customized operators are not supported for now.
345+
346+
## Tutorials
347+
348+
Currently, we provide four tutorials for users to [add new dataset](tutorials/new_dataset.md), [design data pipeline](tutorials/data_pipeline.md) and [add new modules](tutorials/new_modules.md), [use training tricks](tutorials/training_tricks.md).
349+
We also provide a full description about the [config system](config.md).

mmseg/models/segmentors/encoder_decoder.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
import torch.nn as nn
23
import torch.nn.functional as F
34

@@ -171,6 +172,8 @@ def slide_inference(self, img, img_meta, rescale):
171172
h_stride, w_stride = self.test_cfg.stride
172173
h_crop, w_crop = self.test_cfg.crop_size
173174
batch_size, _, h_img, w_img = img.size()
175+
assert h_crop <= h_img and w_crop <= w_img, (
176+
'crop size should not greater than image size')
174177
num_classes = self.num_classes
175178
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
176179
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
@@ -185,14 +188,15 @@ def slide_inference(self, img, img_meta, rescale):
185188
y1 = max(y2 - h_crop, 0)
186189
x1 = max(x2 - w_crop, 0)
187190
crop_img = img[:, :, y1:y2, x1:x2]
188-
pad_img = crop_img.new_zeros(
189-
(crop_img.size(0), crop_img.size(1), h_crop, w_crop))
190-
pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img
191-
pad_seg_logit = self.encode_decode(pad_img, img_meta)
192-
preds[:, :, y1:y2,
193-
x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1]
191+
crop_seg_logit = self.encode_decode(crop_img, img_meta)
192+
preds += F.pad(crop_seg_logit,
193+
(int(x1), int(preds.shape[3] - x2), int(y1),
194+
int(preds.shape[2] - y2)))
195+
194196
count_mat[:, :, y1:y2, x1:x2] += 1
195197
assert (count_mat == 0).sum() == 0
198+
# We want to regard count_mat as a constant while exporting to ONNX
199+
count_mat = torch.from_numpy(count_mat.detach().numpy())
196200
preds = preds / count_mat
197201
if rescale:
198202
preds = resize(
@@ -201,7 +205,6 @@ def slide_inference(self, img, img_meta, rescale):
201205
mode='bilinear',
202206
align_corners=self.align_corners,
203207
warning=False)
204-
205208
return preds
206209

207210
def whole_inference(self, img, img_meta, rescale):
@@ -243,8 +246,8 @@ def inference(self, img, img_meta, rescale):
243246
seg_logit = self.whole_inference(img, img_meta, rescale)
244247
output = F.softmax(seg_logit, dim=1)
245248
flip = img_meta[0]['flip']
246-
flip_direction = img_meta[0]['flip_direction']
247249
if flip:
250+
flip_direction = img_meta[0]['flip_direction']
248251
assert flip_direction in ['horizontal', 'vertical']
249252
if flip_direction == 'horizontal':
250253
output = output.flip(dims=(3, ))
@@ -257,6 +260,8 @@ def simple_test(self, img, img_meta, rescale=True):
257260
"""Simple test with single image."""
258261
seg_logit = self.inference(img, img_meta, rescale)
259262
seg_pred = seg_logit.argmax(dim=1)
263+
if torch.onnx.is_in_onnx_export():
264+
return seg_pred
260265
seg_pred = seg_pred.cpu().numpy()
261266
# unravel batch dim
262267
seg_pred = list(seg_pred)

setup.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,6 @@ line_length = 79
88
multi_line_output = 0
99
known_standard_library = setuptools
1010
known_first_party = mmseg
11-
known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pytablewriter,pytest,scipy,torch,torchvision
11+
known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision
1212
no_lines_before = STDLIB,LOCALFOLDER
1313
default_section = THIRDPARTY

tools/pytorch2onnx.py

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import argparse
2+
from functools import partial
3+
4+
import mmcv
5+
import numpy as np
6+
import onnxruntime as rt
7+
import torch
8+
import torch._C
9+
import torch.serialization
10+
from mmcv.onnx import register_extra_symbolics
11+
from mmcv.runner import load_checkpoint
12+
13+
from mmseg.models import build_segmentor
14+
15+
torch.manual_seed(3)
16+
17+
18+
def _convert_batchnorm(module):
19+
module_output = module
20+
if isinstance(module, torch.nn.SyncBatchNorm):
21+
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
22+
module.momentum, module.affine,
23+
module.track_running_stats)
24+
if module.affine:
25+
module_output.weight.data = module.weight.data.clone().detach()
26+
module_output.bias.data = module.bias.data.clone().detach()
27+
# keep requires_grad unchanged
28+
module_output.weight.requires_grad = module.weight.requires_grad
29+
module_output.bias.requires_grad = module.bias.requires_grad
30+
module_output.running_mean = module.running_mean
31+
module_output.running_var = module.running_var
32+
module_output.num_batches_tracked = module.num_batches_tracked
33+
for name, child in module.named_children():
34+
module_output.add_module(name, _convert_batchnorm(child))
35+
del module
36+
return module_output
37+
38+
39+
def _demo_mm_inputs(input_shape, num_classes):
40+
"""Create a superset of inputs needed to run test or train batches.
41+
42+
Args:
43+
input_shape (tuple):
44+
input batch dimensions
45+
num_classes (int):
46+
number of semantic classes
47+
"""
48+
(N, C, H, W) = input_shape
49+
rng = np.random.RandomState(0)
50+
imgs = rng.rand(*input_shape)
51+
segs = rng.randint(
52+
low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8)
53+
img_metas = [{
54+
'img_shape': (H, W, C),
55+
'ori_shape': (H, W, C),
56+
'pad_shape': (H, W, C),
57+
'filename': '<demo>.png',
58+
'scale_factor': 1.0,
59+
'flip': False,
60+
} for _ in range(N)]
61+
mm_inputs = {
62+
'imgs': torch.FloatTensor(imgs).requires_grad_(True),
63+
'img_metas': img_metas,
64+
'gt_semantic_seg': torch.LongTensor(segs)
65+
}
66+
return mm_inputs
67+
68+
69+
def pytorch2onnx(model,
70+
input_shape,
71+
opset_version=11,
72+
show=False,
73+
output_file='tmp.onnx',
74+
verify=False):
75+
"""Export Pytorch model to ONNX model and verify the outputs are same
76+
between Pytorch and ONNX.
77+
78+
Args:
79+
model (nn.Module): Pytorch model we want to export.
80+
input_shape (tuple): Use this input shape to construct
81+
the corresponding dummy input and execute the model.
82+
opset_version (int): The onnx op version. Default: 11.
83+
show (bool): Whether print the computation graph. Default: False.
84+
output_file (string): The path to where we store the output ONNX model.
85+
Default: `tmp.onnx`.
86+
verify (bool): Whether compare the outputs between Pytorch and ONNX.
87+
Default: False.
88+
"""
89+
model.cpu().eval()
90+
91+
num_classes = model.decode_head.num_classes
92+
93+
mm_inputs = _demo_mm_inputs(input_shape, num_classes)
94+
95+
imgs = mm_inputs.pop('imgs')
96+
img_metas = mm_inputs.pop('img_metas')
97+
98+
img_list = [img[None, :] for img in imgs]
99+
img_meta_list = [[img_meta] for img_meta in img_metas]
100+
101+
# replace original forward function
102+
origin_forward = model.forward
103+
model.forward = partial(
104+
model.forward, img_metas=img_meta_list, return_loss=False)
105+
106+
register_extra_symbolics(opset_version)
107+
with torch.no_grad():
108+
torch.onnx.export(
109+
model, (img_list, ),
110+
output_file,
111+
export_params=True,
112+
keep_initializers_as_inputs=True,
113+
verbose=show,
114+
opset_version=opset_version)
115+
print(f'Successfully exported ONNX model: {output_file}')
116+
model.forward = origin_forward
117+
118+
if verify:
119+
# check by onnx
120+
import onnx
121+
onnx_model = onnx.load(output_file)
122+
onnx.checker.check_model(onnx_model)
123+
124+
# check the numerical value
125+
# get pytorch output
126+
pytorch_result = model(img_list, img_meta_list, return_loss=False)[0]
127+
128+
# get onnx output
129+
input_all = [node.name for node in onnx_model.graph.input]
130+
input_initializer = [
131+
node.name for node in onnx_model.graph.initializer
132+
]
133+
net_feed_input = list(set(input_all) - set(input_initializer))
134+
assert (len(net_feed_input) == 1)
135+
sess = rt.InferenceSession(output_file)
136+
onnx_result = sess.run(
137+
None, {net_feed_input[0]: img_list[0].detach().numpy()})[0]
138+
if not np.allclose(pytorch_result, onnx_result):
139+
raise ValueError(
140+
'The outputs are different between Pytorch and ONNX')
141+
print('The outputs are same between Pytorch and ONNX')
142+
143+
144+
def parse_args():
145+
parser = argparse.ArgumentParser(description='Convert MMDet to ONNX')
146+
parser.add_argument('config', help='test config file path')
147+
parser.add_argument('--checkpoint', help='checkpoint file', default=None)
148+
parser.add_argument('--show', action='store_true', help='show onnx graph')
149+
parser.add_argument(
150+
'--verify', action='store_true', help='verify the onnx model')
151+
parser.add_argument('--output-file', type=str, default='tmp.onnx')
152+
parser.add_argument('--opset-version', type=int, default=11)
153+
parser.add_argument(
154+
'--shape',
155+
type=int,
156+
nargs='+',
157+
default=[256, 256],
158+
help='input image size')
159+
args = parser.parse_args()
160+
return args
161+
162+
163+
if __name__ == '__main__':
164+
args = parse_args()
165+
166+
if len(args.shape) == 1:
167+
input_shape = (1, 3, args.shape[0], args.shape[0])
168+
elif len(args.shape) == 2:
169+
input_shape = (
170+
1,
171+
3,
172+
) + tuple(args.shape)
173+
else:
174+
raise ValueError('invalid input shape')
175+
176+
cfg = mmcv.Config.fromfile(args.config)
177+
cfg.model.pretrained = None
178+
179+
# build the model and load checkpoint
180+
segmentor = build_segmentor(
181+
cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
182+
# convert SyncBN to BN
183+
segmentor = _convert_batchnorm(segmentor)
184+
185+
num_classes = segmentor.decode_head.num_classes
186+
187+
if args.checkpoint:
188+
checkpoint = load_checkpoint(
189+
segmentor, args.checkpoint, map_location='cpu')
190+
191+
# conver model to onnx file
192+
pytorch2onnx(
193+
segmentor,
194+
input_shape,
195+
opset_version=args.opset_version,
196+
show=args.show,
197+
output_file=args.output_file,
198+
verify=args.verify)

0 commit comments

Comments
 (0)