|
| 1 | +import argparse |
| 2 | +import os |
| 3 | +import os.path as osp |
| 4 | +from typing import Iterable, Optional, Union |
| 5 | + |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import mmcv |
| 8 | +import numpy as np |
| 9 | +import onnxruntime as ort |
| 10 | +import torch |
| 11 | +from mmcv.ops import get_onnxruntime_op_path |
| 12 | +from mmcv.tensorrt import (TRTWraper, is_tensorrt_plugin_loaded, onnx2trt, |
| 13 | + save_trt_engine) |
| 14 | + |
| 15 | +from mmseg.apis.inference import LoadImage |
| 16 | +from mmseg.datasets import DATASETS |
| 17 | +from mmseg.datasets.pipelines import Compose |
| 18 | + |
| 19 | + |
| 20 | +def get_GiB(x: int): |
| 21 | + """return x GiB.""" |
| 22 | + return x * (1 << 30) |
| 23 | + |
| 24 | + |
| 25 | +def _prepare_input_img(img_path: str, |
| 26 | + test_pipeline: Iterable[dict], |
| 27 | + shape: Optional[Iterable] = None, |
| 28 | + rescale_shape: Optional[Iterable] = None) -> dict: |
| 29 | + # build the data pipeline |
| 30 | + if shape is not None: |
| 31 | + test_pipeline[1]['img_scale'] = (shape[1], shape[0]) |
| 32 | + test_pipeline[1]['transforms'][0]['keep_ratio'] = False |
| 33 | + test_pipeline = [LoadImage()] + test_pipeline[1:] |
| 34 | + test_pipeline = Compose(test_pipeline) |
| 35 | + # prepare data |
| 36 | + data = dict(img=img_path) |
| 37 | + data = test_pipeline(data) |
| 38 | + imgs = data['img'] |
| 39 | + img_metas = [i.data for i in data['img_metas']] |
| 40 | + |
| 41 | + if rescale_shape is not None: |
| 42 | + for img_meta in img_metas: |
| 43 | + img_meta['ori_shape'] = tuple(rescale_shape) + (3, ) |
| 44 | + |
| 45 | + mm_inputs = {'imgs': imgs, 'img_metas': img_metas} |
| 46 | + |
| 47 | + return mm_inputs |
| 48 | + |
| 49 | + |
| 50 | +def _update_input_img(img_list: Iterable, img_meta_list: Iterable): |
| 51 | + # update img and its meta list |
| 52 | + N = img_list[0].size(0) |
| 53 | + img_meta = img_meta_list[0][0] |
| 54 | + img_shape = img_meta['img_shape'] |
| 55 | + ori_shape = img_meta['ori_shape'] |
| 56 | + pad_shape = img_meta['pad_shape'] |
| 57 | + new_img_meta_list = [[{ |
| 58 | + 'img_shape': |
| 59 | + img_shape, |
| 60 | + 'ori_shape': |
| 61 | + ori_shape, |
| 62 | + 'pad_shape': |
| 63 | + pad_shape, |
| 64 | + 'filename': |
| 65 | + img_meta['filename'], |
| 66 | + 'scale_factor': |
| 67 | + (img_shape[1] / ori_shape[1], img_shape[0] / ori_shape[0]) * 2, |
| 68 | + 'flip': |
| 69 | + False, |
| 70 | + } for _ in range(N)]] |
| 71 | + |
| 72 | + return img_list, new_img_meta_list |
| 73 | + |
| 74 | + |
| 75 | +def show_result_pyplot(img: Union[str, np.ndarray], |
| 76 | + result: np.ndarray, |
| 77 | + palette: Optional[Iterable] = None, |
| 78 | + fig_size: Iterable[int] = (15, 10), |
| 79 | + opacity: float = 0.5, |
| 80 | + title: str = '', |
| 81 | + block: bool = True): |
| 82 | + img = mmcv.imread(img) |
| 83 | + img = img.copy() |
| 84 | + seg = result[0] |
| 85 | + seg = mmcv.imresize(seg, img.shape[:2][::-1]) |
| 86 | + palette = np.array(palette) |
| 87 | + assert palette.shape[1] == 3 |
| 88 | + assert len(palette.shape) == 2 |
| 89 | + assert 0 < opacity <= 1.0 |
| 90 | + color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) |
| 91 | + for label, color in enumerate(palette): |
| 92 | + color_seg[seg == label, :] = color |
| 93 | + # convert to BGR |
| 94 | + color_seg = color_seg[..., ::-1] |
| 95 | + |
| 96 | + img = img * (1 - opacity) + color_seg * opacity |
| 97 | + img = img.astype(np.uint8) |
| 98 | + |
| 99 | + plt.figure(figsize=fig_size) |
| 100 | + plt.imshow(mmcv.bgr2rgb(img)) |
| 101 | + plt.title(title) |
| 102 | + plt.tight_layout() |
| 103 | + plt.show(block=block) |
| 104 | + |
| 105 | + |
| 106 | +def onnx2tensorrt(onnx_file: str, |
| 107 | + trt_file: str, |
| 108 | + config: dict, |
| 109 | + input_config: dict, |
| 110 | + fp16: bool = False, |
| 111 | + verify: bool = False, |
| 112 | + show: bool = False, |
| 113 | + dataset: str = 'CityscapesDataset', |
| 114 | + workspace_size: int = 1, |
| 115 | + verbose: bool = False): |
| 116 | + import tensorrt as trt |
| 117 | + min_shape = input_config['min_shape'] |
| 118 | + max_shape = input_config['max_shape'] |
| 119 | + # create trt engine and wraper |
| 120 | + opt_shape_dict = {'input': [min_shape, min_shape, max_shape]} |
| 121 | + max_workspace_size = get_GiB(workspace_size) |
| 122 | + trt_engine = onnx2trt( |
| 123 | + onnx_file, |
| 124 | + opt_shape_dict, |
| 125 | + log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, |
| 126 | + fp16_mode=fp16, |
| 127 | + max_workspace_size=max_workspace_size) |
| 128 | + save_dir, _ = osp.split(trt_file) |
| 129 | + if save_dir: |
| 130 | + os.makedirs(save_dir, exist_ok=True) |
| 131 | + save_trt_engine(trt_engine, trt_file) |
| 132 | + print(f'Successfully created TensorRT engine: {trt_file}') |
| 133 | + |
| 134 | + if verify: |
| 135 | + inputs = _prepare_input_img( |
| 136 | + input_config['input_path'], |
| 137 | + config.data.test.pipeline, |
| 138 | + shape=min_shape[2:]) |
| 139 | + |
| 140 | + imgs = inputs['imgs'] |
| 141 | + img_metas = inputs['img_metas'] |
| 142 | + img_list = [img[None, :] for img in imgs] |
| 143 | + img_meta_list = [[img_meta] for img_meta in img_metas] |
| 144 | + # update img_meta |
| 145 | + img_list, img_meta_list = _update_input_img(img_list, img_meta_list) |
| 146 | + |
| 147 | + if max_shape[0] > 1: |
| 148 | + # concate flip image for batch test |
| 149 | + flip_img_list = [_.flip(-1) for _ in img_list] |
| 150 | + img_list = [ |
| 151 | + torch.cat((ori_img, flip_img), 0) |
| 152 | + for ori_img, flip_img in zip(img_list, flip_img_list) |
| 153 | + ] |
| 154 | + |
| 155 | + # Get results from ONNXRuntime |
| 156 | + ort_custom_op_path = get_onnxruntime_op_path() |
| 157 | + session_options = ort.SessionOptions() |
| 158 | + if osp.exists(ort_custom_op_path): |
| 159 | + session_options.register_custom_ops_library(ort_custom_op_path) |
| 160 | + sess = ort.InferenceSession(onnx_file, session_options) |
| 161 | + sess.set_providers(['CPUExecutionProvider'], [{}]) # use cpu mode |
| 162 | + onnx_output = sess.run(['output'], |
| 163 | + {'input': img_list[0].detach().numpy()})[0][0] |
| 164 | + |
| 165 | + # Get results from TensorRT |
| 166 | + trt_model = TRTWraper(trt_file, ['input'], ['output']) |
| 167 | + with torch.no_grad(): |
| 168 | + trt_outputs = trt_model({'input': img_list[0].contiguous().cuda()}) |
| 169 | + trt_output = trt_outputs['output'][0].cpu().detach().numpy() |
| 170 | + |
| 171 | + if show: |
| 172 | + dataset = DATASETS.get(dataset) |
| 173 | + assert dataset is not None |
| 174 | + palette = dataset.PALETTE |
| 175 | + |
| 176 | + show_result_pyplot( |
| 177 | + input_config['input_path'], |
| 178 | + (onnx_output[0].astype(np.uint8), ), |
| 179 | + palette=palette, |
| 180 | + title='ONNXRuntime', |
| 181 | + block=False) |
| 182 | + show_result_pyplot( |
| 183 | + input_config['input_path'], (trt_output[0].astype(np.uint8), ), |
| 184 | + palette=palette, |
| 185 | + title='TensorRT') |
| 186 | + |
| 187 | + np.testing.assert_allclose( |
| 188 | + onnx_output, trt_output, rtol=1e-03, atol=1e-05) |
| 189 | + print('TensorRT and ONNXRuntime output all close.') |
| 190 | + |
| 191 | + |
| 192 | +def parse_args(): |
| 193 | + parser = argparse.ArgumentParser( |
| 194 | + description='Convert MMSegmentation models from ONNX to TensorRT') |
| 195 | + parser.add_argument('config', help='Config file of the model') |
| 196 | + parser.add_argument('model', help='Path to the input ONNX model') |
| 197 | + parser.add_argument( |
| 198 | + '--trt-file', type=str, help='Path to the output TensorRT engine') |
| 199 | + parser.add_argument( |
| 200 | + '--max-shape', |
| 201 | + type=int, |
| 202 | + nargs=4, |
| 203 | + default=[1, 3, 400, 600], |
| 204 | + help='Maximum shape of model input.') |
| 205 | + parser.add_argument( |
| 206 | + '--min-shape', |
| 207 | + type=int, |
| 208 | + nargs=4, |
| 209 | + default=[1, 3, 400, 600], |
| 210 | + help='Minimum shape of model input.') |
| 211 | + parser.add_argument('--fp16', action='store_true', help='Enable fp16 mode') |
| 212 | + parser.add_argument( |
| 213 | + '--workspace-size', |
| 214 | + type=int, |
| 215 | + default=1, |
| 216 | + help='Max workspace size in GiB') |
| 217 | + parser.add_argument( |
| 218 | + '--input-img', type=str, default='', help='Image for test') |
| 219 | + parser.add_argument( |
| 220 | + '--show', action='store_true', help='Whether to show output results') |
| 221 | + parser.add_argument( |
| 222 | + '--dataset', |
| 223 | + type=str, |
| 224 | + default='CityscapesDataset', |
| 225 | + help='Dataset name') |
| 226 | + parser.add_argument( |
| 227 | + '--verify', |
| 228 | + action='store_true', |
| 229 | + help='Verify the outputs of ONNXRuntime and TensorRT') |
| 230 | + parser.add_argument( |
| 231 | + '--verbose', |
| 232 | + action='store_true', |
| 233 | + help='Whether to verbose logging messages while creating \ |
| 234 | + TensorRT engine.') |
| 235 | + args = parser.parse_args() |
| 236 | + return args |
| 237 | + |
| 238 | + |
| 239 | +if __name__ == '__main__': |
| 240 | + |
| 241 | + assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' |
| 242 | + args = parse_args() |
| 243 | + |
| 244 | + if not args.input_img: |
| 245 | + args.input_img = osp.join(osp.dirname(__file__), '../demo/demo.png') |
| 246 | + |
| 247 | + # check arguments |
| 248 | + assert osp.exists(args.config), 'Config {} not found.'.format(args.config) |
| 249 | + assert osp.exists(args.model), \ |
| 250 | + 'ONNX model {} not found.'.format(args.model) |
| 251 | + assert args.workspace_size >= 0, 'Workspace size less than 0.' |
| 252 | + assert DATASETS.get(args.dataset) is not None, \ |
| 253 | + 'Dataset {} does not found.'.format(args.dataset) |
| 254 | + for max_value, min_value in zip(args.max_shape, args.min_shape): |
| 255 | + assert max_value >= min_value, \ |
| 256 | + 'max_shape sould be larger than min shape' |
| 257 | + |
| 258 | + input_config = { |
| 259 | + 'min_shape': args.min_shape, |
| 260 | + 'max_shape': args.max_shape, |
| 261 | + 'input_path': args.input_img |
| 262 | + } |
| 263 | + |
| 264 | + cfg = mmcv.Config.fromfile(args.config) |
| 265 | + onnx2tensorrt( |
| 266 | + args.model, |
| 267 | + args.trt_file, |
| 268 | + cfg, |
| 269 | + input_config, |
| 270 | + fp16=args.fp16, |
| 271 | + verify=args.verify, |
| 272 | + show=args.show, |
| 273 | + dataset=args.dataset, |
| 274 | + workspace_size=args.workspace_size, |
| 275 | + verbose=args.verbose) |
0 commit comments