Skip to content

Commit 805eda8

Browse files
author
q.yao
authored
add onnx to tensorrt tools (#542)
1 parent 584f5a7 commit 805eda8

File tree

2 files changed

+316
-1
lines changed

2 files changed

+316
-1
lines changed

docs/useful_tools.md

+41-1
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.
9090

9191
#### Usage
9292

93-
```python
93+
```bash
9494
python tools/ort_test.py \
9595
${CONFIG_FILE} \
9696
${ONNX_FILE} \
@@ -164,6 +164,46 @@ Examples:
164164
--shape 512 1024
165165
```
166166

167+
### Convert to TensorRT (experimental)
168+
169+
A script to convert [ONNX](https://github.com/onnx/onnx) model to [TensorRT](https://developer.nvidia.com/tensorrt) format.
170+
171+
Prerequisite
172+
173+
- install `mmcv-full` with ONNXRuntime custom ops and TensorRT plugins follow [ONNXRuntime in mmcv](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) and [TensorRT plugin in mmcv](https://github.com/open-mmlab/mmcv/blob/master/docs/tensorrt_plugin.md).
174+
- Use [pytorch2onnx](#convert-to-onnx-experimental) to convert the model from PyTorch to ONNX.
175+
176+
Usage
177+
178+
```bash
179+
python ${MMSEG_PATH}/tools/onnx2tensorrt.py \
180+
${CFG_PATH} \
181+
${ONNX_PATH} \
182+
--trt-file ${OUTPUT_TRT_PATH} \
183+
--min-shape ${MIN_SHAPE} \
184+
--max-shape ${MAX_SHAPE} \
185+
--input-img ${INPUT_IMG} \
186+
--show \
187+
--verify
188+
```
189+
190+
Description of all arguments
191+
192+
- `config` : Config file of the model.
193+
- `model` : Path to the input ONNX model.
194+
- `--trt-file` : Path to the output TensorRT engine.
195+
- `--max-shape` : Maximum shape of model input.
196+
- `--min-shape` : Minimum shape of model input.
197+
- `--fp16` : Enable fp16 model conversion.
198+
- `--workspace-size` : Max workspace size in GiB.
199+
- `--input-img` : Image for visualize.
200+
- `--show` : Enable result visualize.
201+
- `--dataset` : Palette provider, `CityscapesDataset` as default.
202+
- `--verify` : Verify the outputs of ONNXRuntime and TensorRT.
203+
- `--verbose` : Whether to verbose logging messages while creating TensorRT engine. Defaults to False.
204+
205+
**Note**: Only tested on whole mode.
206+
167207
## Miscellaneous
168208

169209
### Print the entire config

tools/onnx2tensorrt.py

+275
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
import argparse
2+
import os
3+
import os.path as osp
4+
from typing import Iterable, Optional, Union
5+
6+
import matplotlib.pyplot as plt
7+
import mmcv
8+
import numpy as np
9+
import onnxruntime as ort
10+
import torch
11+
from mmcv.ops import get_onnxruntime_op_path
12+
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
13+
save_trt_engine)
14+
15+
from mmseg.apis.inference import LoadImage
16+
from mmseg.datasets import DATASETS
17+
from mmseg.datasets.pipelines import Compose
18+
19+
20+
def get_GiB(x: int):
21+
"""return x GiB."""
22+
return x * (1 << 30)
23+
24+
25+
def _prepare_input_img(img_path: str,
26+
test_pipeline: Iterable[dict],
27+
shape: Optional[Iterable] = None,
28+
rescale_shape: Optional[Iterable] = None) -> dict:
29+
# build the data pipeline
30+
if shape is not None:
31+
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
32+
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
33+
test_pipeline = [LoadImage()] + test_pipeline[1:]
34+
test_pipeline = Compose(test_pipeline)
35+
# prepare data
36+
data = dict(img=img_path)
37+
data = test_pipeline(data)
38+
imgs = data['img']
39+
img_metas = [i.data for i in data['img_metas']]
40+
41+
if rescale_shape is not None:
42+
for img_meta in img_metas:
43+
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )
44+
45+
mm_inputs = {'imgs': imgs, 'img_metas': img_metas}
46+
47+
return mm_inputs
48+
49+
50+
def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
51+
# update img and its meta list
52+
N = img_list[0].size(0)
53+
img_meta = img_meta_list[0][0]
54+
img_shape = img_meta['img_shape']
55+
ori_shape = img_meta['ori_shape']
56+
pad_shape = img_meta['pad_shape']
57+
new_img_meta_list = [[{
58+
'img_shape':
59+
img_shape,
60+
'ori_shape':
61+
ori_shape,
62+
'pad_shape':
63+
pad_shape,
64+
'filename':
65+
img_meta['filename'],
66+
'scale_factor':
67+
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
68+
'flip':
69+
False,
70+
} for _ in range(N)]]
71+
72+
return img_list, new_img_meta_list
73+
74+
75+
def show_result_pyplot(img: Union[str, np.ndarray],
76+
result: np.ndarray,
77+
palette: Optional[Iterable] = None,
78+
fig_size: Iterable[int] = (15, 10),
79+
opacity: float = 0.5,
80+
title: str = '',
81+
block: bool = True):
82+
img = mmcv.imread(img)
83+
img = img.copy()
84+
seg = result[0]
85+
seg = mmcv.imresize(seg, img.shape[:2][::-1])
86+
palette = np.array(palette)
87+
assert palette.shape[1] == 3
88+
assert len(palette.shape) == 2
89+
assert 0 < opacity <= 1.0
90+
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
91+
for label, color in enumerate(palette):
92+
color_seg[seg == label, :] = color
93+
# convert to BGR
94+
color_seg = color_seg[..., ::-1]
95+
96+
img = img * (1 - opacity) + color_seg * opacity
97+
img = img.astype(np.uint8)
98+
99+
plt.figure(figsize=fig_size)
100+
plt.imshow(mmcv.bgr2rgb(img))
101+
plt.title(title)
102+
plt.tight_layout()
103+
plt.show(block=block)
104+
105+
106+
def onnx2tensorrt(onnx_file: str,
107+
trt_file: str,
108+
config: dict,
109+
input_config: dict,
110+
fp16: bool = False,
111+
verify: bool = False,
112+
show: bool = False,
113+
dataset: str = 'CityscapesDataset',
114+
workspace_size: int = 1,
115+
verbose: bool = False):
116+
import tensorrt as trt
117+
min_shape = input_config['min_shape']
118+
max_shape = input_config['max_shape']
119+
# create trt engine and wraper
120+
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
121+
max_workspace_size = get_GiB(workspace_size)
122+
trt_engine = onnx2trt(
123+
onnx_file,
124+
opt_shape_dict,
125+
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
126+
fp16_mode=fp16,
127+
max_workspace_size=max_workspace_size)
128+
save_dir, _ = osp.split(trt_file)
129+
if save_dir:
130+
os.makedirs(save_dir, exist_ok=True)
131+
save_trt_engine(trt_engine, trt_file)
132+
print(f'Successfully created TensorRT engine: {trt_file}')
133+
134+
if verify:
135+
inputs = _prepare_input_img(
136+
input_config['input_path'],
137+
config.data.test.pipeline,
138+
shape=min_shape[2:])
139+
140+
imgs = inputs['imgs']
141+
img_metas = inputs['img_metas']
142+
img_list = [img[None, :] for img in imgs]
143+
img_meta_list = [[img_meta] for img_meta in img_metas]
144+
# update img_meta
145+
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)
146+
147+
if max_shape[0] > 1:
148+
# concate flip image for batch test
149+
flip_img_list = [_.flip(-1) for _ in img_list]
150+
img_list = [
151+
torch.cat((ori_img, flip_img), 0)
152+
for ori_img, flip_img in zip(img_list, flip_img_list)
153+
]
154+
155+
# Get results from ONNXRuntime
156+
ort_custom_op_path = get_onnxruntime_op_path()
157+
session_options = ort.SessionOptions()
158+
if osp.exists(ort_custom_op_path):
159+
session_options.register_custom_ops_library(ort_custom_op_path)
160+
sess = ort.InferenceSession(onnx_file, session_options)
161+
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
162+
onnx_output = sess.run(['output'],
163+
{'input': img_list[0].detach().numpy()})[0][0]
164+
165+
# Get results from TensorRT
166+
trt_model = TRTWraper(trt_file, ['input'], ['output'])
167+
with torch.no_grad():
168+
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
169+
trt_output = trt_outputs['output'][0].cpu().detach().numpy()
170+
171+
if show:
172+
dataset = DATASETS.get(dataset)
173+
assert dataset is not None
174+
palette = dataset.PALETTE
175+
176+
show_result_pyplot(
177+
input_config['input_path'],
178+
(onnx_output[0].astype(np.uint8), ),
179+
palette=palette,
180+
title='ONNXRuntime',
181+
block=False)
182+
show_result_pyplot(
183+
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
184+
palette=palette,
185+
title='TensorRT')
186+
187+
np.testing.assert_allclose(
188+
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
189+
print('TensorRT and ONNXRuntime output all close.')
190+
191+
192+
def parse_args():
193+
parser = argparse.ArgumentParser(
194+
description='Convert MMSegmentation models from ONNX to TensorRT')
195+
parser.add_argument('config', help='Config file of the model')
196+
parser.add_argument('model', help='Path to the input ONNX model')
197+
parser.add_argument(
198+
'--trt-file', type=str, help='Path to the output TensorRT engine')
199+
parser.add_argument(
200+
'--max-shape',
201+
type=int,
202+
nargs=4,
203+
default=[1, 3, 400, 600],
204+
help='Maximum shape of model input.')
205+
parser.add_argument(
206+
'--min-shape',
207+
type=int,
208+
nargs=4,
209+
default=[1, 3, 400, 600],
210+
help='Minimum shape of model input.')
211+
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
212+
parser.add_argument(
213+
'--workspace-size',
214+
type=int,
215+
default=1,
216+
help='Max workspace size in GiB')
217+
parser.add_argument(
218+
'--input-img', type=str, default='', help='Image for test')
219+
parser.add_argument(
220+
'--show', action='store_true', help='Whether to show output results')
221+
parser.add_argument(
222+
'--dataset',
223+
type=str,
224+
default='CityscapesDataset',
225+
help='Dataset name')
226+
parser.add_argument(
227+
'--verify',
228+
action='store_true',
229+
help='Verify the outputs of ONNXRuntime and TensorRT')
230+
parser.add_argument(
231+
'--verbose',
232+
action='store_true',
233+
help='Whether to verbose logging messages while creating \
234+
TensorRT engine.')
235+
args = parser.parse_args()
236+
return args
237+
238+
239+
if __name__ == '__main__':
240+
241+
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
242+
args = parse_args()
243+
244+
if not args.input_img:
245+
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')
246+
247+
# check arguments
248+
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
249+
assert osp.exists(args.model), \
250+
'ONNX model {} not found.'.format(args.model)
251+
assert args.workspace_size >= 0, 'Workspace size less than 0.'
252+
assert DATASETS.get(args.dataset) is not None, \
253+
'Dataset {} does not found.'.format(args.dataset)
254+
for max_value, min_value in zip(args.max_shape, args.min_shape):
255+
assert max_value >= min_value, \
256+
'max_shape sould be larger than min shape'
257+
258+
input_config = {
259+
'min_shape': args.min_shape,
260+
'max_shape': args.max_shape,
261+
'input_path': args.input_img
262+
}
263+
264+
cfg = mmcv.Config.fromfile(args.config)
265+
onnx2tensorrt(
266+
args.model,
267+
args.trt_file,
268+
cfg,
269+
input_config,
270+
fp16=args.fp16,
271+
verify=args.verify,
272+
show=args.show,
273+
dataset=args.dataset,
274+
workspace_size=args.workspace_size,
275+
verbose=args.verbose)

0 commit comments

Comments
 (0)