Skip to content

Commit

Permalink
add onnx to tensorrt tools (open-mmlab#542)
Browse files Browse the repository at this point in the history
  • Loading branch information
q.yao authored May 12, 2021
1 parent 1052f8d commit 5182fa1
Show file tree
Hide file tree
Showing 2 changed files with 316 additions and 1 deletion.
42 changes: 41 additions & 1 deletion docs/useful_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ We provide `tools/ort_test.py` to evaluate ONNX model with ONNXRuntime backend.

#### Usage

```python
```bash
python tools/ort_test.py \
${CONFIG_FILE} \
${ONNX_FILE} \
Expand Down Expand Up @@ -164,6 +164,46 @@ Examples:
--shape 512 1024
```

### Convert to TensorRT (experimental)

A script to convert [ONNX](https://github.com/onnx/onnx) model to [TensorRT](https://developer.nvidia.com/tensorrt) format.

Prerequisite

- install `mmcv-full` with ONNXRuntime custom ops and TensorRT plugins follow [ONNXRuntime in mmcv](https://mmcv.readthedocs.io/en/latest/onnxruntime_op.html) and [TensorRT plugin in mmcv](https://github.com/open-mmlab/mmcv/blob/master/docs/tensorrt_plugin.md).
- Use [pytorch2onnx](#convert-to-onnx-experimental) to convert the model from PyTorch to ONNX.

Usage

```bash
python ${MMSEG_PATH}/tools/onnx2tensorrt.py \
${CFG_PATH} \
${ONNX_PATH} \
--trt-file ${OUTPUT_TRT_PATH} \
--min-shape ${MIN_SHAPE} \
--max-shape ${MAX_SHAPE} \
--input-img ${INPUT_IMG} \
--show \
--verify
```

Description of all arguments

- `config` : Config file of the model.
- `model` : Path to the input ONNX model.
- `--trt-file` : Path to the output TensorRT engine.
- `--max-shape` : Maximum shape of model input.
- `--min-shape` : Minimum shape of model input.
- `--fp16` : Enable fp16 model conversion.
- `--workspace-size` : Max workspace size in GiB.
- `--input-img` : Image for visualize.
- `--show` : Enable result visualize.
- `--dataset` : Palette provider, `CityscapesDataset` as default.
- `--verify` : Verify the outputs of ONNXRuntime and TensorRT.
- `--verbose` : Whether to verbose logging messages while creating TensorRT engine. Defaults to False.

**Note**: Only tested on whole mode.

## Miscellaneous

### Print the entire config
Expand Down
275 changes: 275 additions & 0 deletions tools/onnx2tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import argparse
import os
import os.path as osp
from typing import Iterable, Optional, Union

import matplotlib.pyplot as plt
import mmcv
import numpy as np
import onnxruntime as ort
import torch
from mmcv.ops import get_onnxruntime_op_path
from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt,
save_trt_engine)

from mmseg.apis.inference import LoadImage
from mmseg.datasets import DATASETS
from mmseg.datasets.pipelines import Compose


def get_GiB(x: int):
"""return x GiB."""
return x * (1 << 30)


def _prepare_input_img(img_path: str,
test_pipeline: Iterable[dict],
shape: Optional[Iterable] = None,
rescale_shape: Optional[Iterable] = None) -> dict:
# build the data pipeline
if shape is not None:
test_pipeline[1]['img_scale'] = (shape[1], shape[0])
test_pipeline[1]['transforms'][0]['keep_ratio'] = False
test_pipeline = [LoadImage()] + test_pipeline[1:]
test_pipeline = Compose(test_pipeline)
# prepare data
data = dict(img=img_path)
data = test_pipeline(data)
imgs = data['img']
img_metas = [i.data for i in data['img_metas']]

if rescale_shape is not None:
for img_meta in img_metas:
img_meta['ori_shape'] = tuple(rescale_shape) + (3, )

mm_inputs = {'imgs': imgs, 'img_metas': img_metas}

return mm_inputs


def _update_input_img(img_list: Iterable, img_meta_list: Iterable):
# update img and its meta list
N = img_list[0].size(0)
img_meta = img_meta_list[0][0]
img_shape = img_meta['img_shape']
ori_shape = img_meta['ori_shape']
pad_shape = img_meta['pad_shape']
new_img_meta_list = [[{
'img_shape':
img_shape,
'ori_shape':
ori_shape,
'pad_shape':
pad_shape,
'filename':
img_meta['filename'],
'scale_factor':
(img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2,
'flip':
False,
} for _ in range(N)]]

return img_list, new_img_meta_list


def show_result_pyplot(img: Union[str, np.ndarray],
result: np.ndarray,
palette: Optional[Iterable] = None,
fig_size: Iterable[int] = (15, 10),
opacity: float = 0.5,
title: str = '',
block: bool = True):
img = mmcv.imread(img)
img = img.copy()
seg = result[0]
seg = mmcv.imresize(seg, img.shape[:2][::-1])
palette = np.array(palette)
assert palette.shape[1] == 3
assert len(palette.shape) == 2
assert 0 < opacity <= 1.0
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
for label, color in enumerate(palette):
color_seg[seg == label, :] = color
# convert to BGR
color_seg = color_seg[..., ::-1]

img = img * (1 - opacity) + color_seg * opacity
img = img.astype(np.uint8)

plt.figure(figsize=fig_size)
plt.imshow(mmcv.bgr2rgb(img))
plt.title(title)
plt.tight_layout()
plt.show(block=block)


def onnx2tensorrt(onnx_file: str,
trt_file: str,
config: dict,
input_config: dict,
fp16: bool = False,
verify: bool = False,
show: bool = False,
dataset: str = 'CityscapesDataset',
workspace_size: int = 1,
verbose: bool = False):
import tensorrt as trt
min_shape = input_config['min_shape']
max_shape = input_config['max_shape']
# create trt engine and wraper
opt_shape_dict = {'input': [min_shape, min_shape, max_shape]}
max_workspace_size = get_GiB(workspace_size)
trt_engine = onnx2trt(
onnx_file,
opt_shape_dict,
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR,
fp16_mode=fp16,
max_workspace_size=max_workspace_size)
save_dir, _ = osp.split(trt_file)
if save_dir:
os.makedirs(save_dir, exist_ok=True)
save_trt_engine(trt_engine, trt_file)
print(f'Successfully created TensorRT engine: {trt_file}')

if verify:
inputs = _prepare_input_img(
input_config['input_path'],
config.data.test.pipeline,
shape=min_shape[2:])

imgs = inputs['imgs']
img_metas = inputs['img_metas']
img_list = [img[None, :] for img in imgs]
img_meta_list = [[img_meta] for img_meta in img_metas]
# update img_meta
img_list, img_meta_list = _update_input_img(img_list, img_meta_list)

if max_shape[0] > 1:
# concate flip image for batch test
flip_img_list = [_.flip(-1) for _ in img_list]
img_list = [
torch.cat((ori_img, flip_img), 0)
for ori_img, flip_img in zip(img_list, flip_img_list)
]

# Get results from ONNXRuntime
ort_custom_op_path = get_onnxruntime_op_path()
session_options = ort.SessionOptions()
if osp.exists(ort_custom_op_path):
session_options.register_custom_ops_library(ort_custom_op_path)
sess = ort.InferenceSession(onnx_file, session_options)
sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode
onnx_output = sess.run(['output'],
{'input': img_list[0].detach().numpy()})[0][0]

# Get results from TensorRT
trt_model = TRTWraper(trt_file, ['input'], ['output'])
with torch.no_grad():
trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()})
trt_output = trt_outputs['output'][0].cpu().detach().numpy()

if show:
dataset = DATASETS.get(dataset)
assert dataset is not None
palette = dataset.PALETTE

show_result_pyplot(
input_config['input_path'],
(onnx_output[0].astype(np.uint8), ),
palette=palette,
title='ONNXRuntime',
block=False)
show_result_pyplot(
input_config['input_path'], (trt_output[0].astype(np.uint8), ),
palette=palette,
title='TensorRT')

np.testing.assert_allclose(
onnx_output, trt_output, rtol=1e-03, atol=1e-05)
print('TensorRT and ONNXRuntime output all close.')


def parse_args():
parser = argparse.ArgumentParser(
description='Convert MMSegmentation models from ONNX to TensorRT')
parser.add_argument('config', help='Config file of the model')
parser.add_argument('model', help='Path to the input ONNX model')
parser.add_argument(
'--trt-file', type=str, help='Path to the output TensorRT engine')
parser.add_argument(
'--max-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Maximum shape of model input.')
parser.add_argument(
'--min-shape',
type=int,
nargs=4,
default=[1, 3, 400, 600],
help='Minimum shape of model input.')
parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode')
parser.add_argument(
'--workspace-size',
type=int,
default=1,
help='Max workspace size in GiB')
parser.add_argument(
'--input-img', type=str, default='', help='Image for test')
parser.add_argument(
'--show', action='store_true', help='Whether to show output results')
parser.add_argument(
'--dataset',
type=str,
default='CityscapesDataset',
help='Dataset name')
parser.add_argument(
'--verify',
action='store_true',
help='Verify the outputs of ONNXRuntime and TensorRT')
parser.add_argument(
'--verbose',
action='store_true',
help='Whether to verbose logging messages while creating \
TensorRT engine.')
args = parser.parse_args()
return args


if __name__ == '__main__':

assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.'
args = parse_args()

if not args.input_img:
args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png')

# check arguments
assert osp.exists(args.config), 'Config {} not found.'.format(args.config)
assert osp.exists(args.model), \
'ONNX model {} not found.'.format(args.model)
assert args.workspace_size >= 0, 'Workspace size less than 0.'
assert DATASETS.get(args.dataset) is not None, \
'Dataset {} does not found.'.format(args.dataset)
for max_value, min_value in zip(args.max_shape, args.min_shape):
assert max_value >= min_value, \
'max_shape sould be larger than min shape'

input_config = {
'min_shape': args.min_shape,
'max_shape': args.max_shape,
'input_path': args.input_img
}

cfg = mmcv.Config.fromfile(args.config)
onnx2tensorrt(
args.model,
args.trt_file,
cfg,
input_config,
fp16=args.fp16,
verify=args.verify,
show=args.show,
dataset=args.dataset,
workspace_size=args.workspace_size,
verbose=args.verbose)

0 comments on commit 5182fa1

Please sign in to comment.