From bc2dc1277a90c64b49cb8274cd71821734702d88 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 13 Apr 2021 02:54:59 +0800 Subject: [PATCH] add dynamic export and visualize to pytorch2onnx (#463) * add dynamic export and visualize to pytorch2onnx * update document * fix lint * fix dynamic error and add visualization * fix lint * update docstring * update doc * Update help info for --show Co-authored-by: Jerry Jiarui XU * fix lint Co-authored-by: maningsheng Co-authored-by: Jerry Jiarui XU --- docs/useful_tools.md | 26 ++- mmseg/apis/inference.py | 12 +- mmseg/models/segmentors/encoder_decoder.py | 7 +- mmseg/ops/wrappers.py | 3 - tools/pytorch2onnx.py | 180 ++++++++++++++++++--- 5 files changed, 202 insertions(+), 26 deletions(-) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 7b2e3fde1e..8286af83e5 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -46,10 +46,32 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model. -```shell -python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify] +```bash +python tools/pytorch2onnx.py \ + ${CONFIG_FILE} \ + --checkpoint ${CHECKPOINT_FILE} \ + --output-file ${ONNX_FILE} \ + --input-img ${INPUT_IMG} \ + --shape ${INPUT_SHAPE} \ + --show \ + --verify \ + --dynamic-export \ + --cfg-options \ + model.test_cfg.mode="whole" ``` +Description of arguments: + +- `config` : The path of a model config file. +- `--checkpoint` : The path of a model checkpoint file. +- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`. +- `--input-img` : The path of an input image for conversion and visualize. +- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`. +- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`. +- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`. +- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`. +- `--cfg-options`:Update config options. + **Note**: This tool is still experimental. Some customized operators are not supported for now. ## Miscellaneous diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 9052cdd32a..bf875cb262 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -103,7 +103,9 @@ def show_result_pyplot(model, result, palette=None, fig_size=(15, 10), - opacity=0.5): + opacity=0.5, + title='', + block=True): """Visualize the segmentation results on the image. Args: @@ -117,6 +119,10 @@ def show_result_pyplot(model, opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + block (bool): Whether to block the pyplot figure. + Default is True. """ if hasattr(model, 'module'): model = model.module @@ -124,4 +130,6 @@ def show_result_pyplot(model, img, result, palette=palette, show=False, opacity=opacity) plt.figure(figsize=fig_size) plt.imshow(mmcv.bgr2rgb(img)) - plt.show() + plt.title(title) + plt.tight_layout() + plt.show(block=block) diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 2284906e3f..b2d067dcbe 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -216,9 +216,14 @@ def whole_inference(self, img, img_meta, rescale): seg_logit = self.encode_decode(img, img_meta) if rescale: + # support dynamic shape for onnx + if torch.onnx.is_in_onnx_export(): + size = img.shape[2:] + else: + size = img_meta[0]['ori_shape'][:2] seg_logit = resize( seg_logit, - size=img_meta[0]['ori_shape'][:2], + size=size, mode='bilinear', align_corners=self.align_corners, warning=False) diff --git a/mmseg/ops/wrappers.py b/mmseg/ops/wrappers.py index a6d755273d..0ed9a0cb8d 100644 --- a/mmseg/ops/wrappers.py +++ b/mmseg/ops/wrappers.py @@ -1,6 +1,5 @@ import warnings -import torch import torch.nn as nn import torch.nn.functional as F @@ -24,8 +23,6 @@ def resize(input, 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') - if isinstance(size, torch.Size): - size = tuple(int(x) for x in size) return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 2ec9feb59a..71f1bb7227 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -7,10 +7,14 @@ import torch import torch._C import torch.serialization +from mmcv import DictAction from mmcv.onnx import register_extra_symbolics from mmcv.runner import load_checkpoint from torch import nn +from mmseg.apis import show_result_pyplot +from mmseg.apis.inference import LoadImage +from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor torch.manual_seed(3) @@ -67,25 +71,61 @@ def _demo_mm_inputs(input_shape, num_classes): return mm_inputs +def _prepare_input_img(img_path, test_pipeline, shape=None): + # build the data pipeline + if shape is not None: + test_pipeline[1]['img_scale'] = shape + test_pipeline[1]['transforms'][0]['keep_ratio'] = False + test_pipeline = [LoadImage()] + test_pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img_path) + data = test_pipeline(data) + imgs = data['img'] + img_metas = [i.data for i in data['img_metas']] + + mm_inputs = {'imgs': imgs, 'img_metas': img_metas} + + return mm_inputs + + +def _update_input_img(img_list, img_meta_list): + # update img and its meta list + N, C, H, W = img_list[0].shape + img_meta = img_meta_list[0][0] + new_img_meta_list = [[{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': img_meta['filename'], + 'scale_factor': 1., + 'flip': False, + } for _ in range(N)]] + + return img_list, new_img_meta_list + + def pytorch2onnx(model, - input_shape, + mm_inputs, opset_version=11, show=False, output_file='tmp.onnx', - verify=False): + verify=False, + dynamic_export=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. - input_shape (tuple): Use this input shape to construct - the corresponding dummy input and execute the model. + mm_inputs (dict): Contain the input tensors and img_metas information. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. + dynamic_export (bool): Whether to export ONNX with dynamic axis. + Default: False. """ model.cpu().eval() @@ -94,28 +134,45 @@ def pytorch2onnx(model, else: num_classes = model.decode_head.num_classes - mm_inputs = _demo_mm_inputs(input_shape, num_classes) - imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') + ori_shape = img_metas[0]['ori_shape'] img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] + img_list, img_meta_list = _update_input_img(img_list, img_meta_list) # replace original forward function origin_forward = model.forward model.forward = partial( model.forward, img_metas=img_meta_list, return_loss=False) + dynamic_axes = None + if dynamic_export: + dynamic_axes = { + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 1: 'batch', + 2: 'height', + 3: 'width' + } + } register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export( model, (img_list, ), output_file, + input_names=['input'], + output_names=['output'], export_params=True, - keep_initializers_as_inputs=True, + keep_initializers_as_inputs=False, verbose=show, - opset_version=opset_version) + opset_version=opset_version, + dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward @@ -125,9 +182,28 @@ def pytorch2onnx(model, onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) + if dynamic_export: + # scale image for dynamic shape test + img_list = [ + nn.functional.interpolate(_, scale_factor=1.5) + for _ in img_list + ] + # concate flip image for batch test + flip_img_list = [_.flip(-1) for _ in img_list] + img_list = [ + torch.cat((ori_img, flip_img), 0) + for ori_img, flip_img in zip(img_list, flip_img_list) + ] + + # update img_meta + img_list, img_meta_list = _update_input_img( + img_list, img_meta_list) + # check the numerical value # get pytorch output - pytorch_result = model(img_list, img_meta_list, return_loss=False)[0] + with torch.no_grad(): + pytorch_result = model(img_list, img_meta_list, return_loss=False) + pytorch_result = np.stack(pytorch_result, 0) # get onnx output input_all = [node.name for node in onnx_model.graph.input] @@ -138,10 +214,42 @@ def pytorch2onnx(model, assert (len(net_feed_input) == 1) sess = rt.InferenceSession(output_file) onnx_result = sess.run( - None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] - if not np.allclose(pytorch_result, onnx_result): - raise ValueError( - 'The outputs are different between Pytorch and ONNX') + None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] + # show segmentation results + if show: + import cv2 + import os.path as osp + img = img_meta_list[0][0]['filename'] + if not osp.exists(img): + img = imgs[0][:3, ...].permute(1, 2, 0) * 255 + img = img.detach().numpy().astype(np.uint8) + # resize onnx_result to ori_shape + onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (onnx_result_, ), + palette=model.PALETTE, + block=False, + title='ONNXRuntime', + opacity=0.5) + + # resize pytorch_result to ori_shape + pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (pytorch_result_, ), + title='PyTorch', + palette=model.PALETTE, + opacity=0.5) + # compare results + np.testing.assert_allclose( + pytorch_result.astype(np.float32) / num_classes, + onnx_result.astype(np.float32) / num_classes, + rtol=1e-5, + atol=1e-5, + err_msg='The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX') @@ -149,7 +257,12 @@ def parse_args(): parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser.add_argument('config', help='test config file path') parser.add_argument('--checkpoint', help='checkpoint file', default=None) - parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument( + '--input-img', type=str, help='Images for input', default=None) + parser.add_argument( + '--show', + action='store_true', + help='show onnx graph and segmentation results') parser.add_argument( '--verify', action='store_true', help='verify the onnx model') parser.add_argument('--output-file', type=str, default='tmp.onnx') @@ -160,6 +273,20 @@ def parse_args(): nargs='+', default=[256, 256], help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='Override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--dynamic-export', + action='store_true', + help='Whether to export onnx with dynamic axis.') args = parser.parse_args() return args @@ -178,6 +305,8 @@ def parse_args(): raise ValueError('invalid input shape') cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) cfg.model.pretrained = None # build the model and load checkpoint @@ -188,13 +317,28 @@ def parse_args(): segmentor = _convert_batchnorm(segmentor) if args.checkpoint: - load_checkpoint(segmentor, args.checkpoint, map_location='cpu') + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + segmentor.CLASSES = checkpoint['meta']['CLASSES'] + segmentor.PALETTE = checkpoint['meta']['PALETTE'] + + # read input or create dummpy input + if args.input_img is not None: + mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline, + (input_shape[3], input_shape[2])) + else: + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes + mm_inputs = _demo_mm_inputs(input_shape, num_classes) - # conver model to onnx file + # convert model to onnx file pytorch2onnx( segmentor, - input_shape, + mm_inputs, opset_version=args.opset_version, show=args.show, output_file=args.output_file, - verify=args.verify) + verify=args.verify, + dynamic_export=args.dynamic_export)