From da131b75f379a9a8352a487cb133688c5acd90b4 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 9 Apr 2021 11:23:32 +0800 Subject: [PATCH 1/9] add dynamic export and visualize to pytorch2onnx --- mmseg/models/segmentors/encoder_decoder.py | 6 +- mmseg/ops/wrappers.py | 5 +- tools/pytorch2onnx.py | 144 +++++++++++++++++++-- 3 files changed, 137 insertions(+), 18 deletions(-) diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 2284906e3f..43a0be40bc 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -216,9 +216,13 @@ def whole_inference(self, img, img_meta, rescale): seg_logit = self.encode_decode(img, img_meta) if rescale: + if torch.onnx.is_in_onnx_export(): + new_size = img.shape[2:] + else: + new_size = img_meta[0]['ori_shape'][:2] seg_logit = resize( seg_logit, - size=img_meta[0]['ori_shape'][:2], + size=new_size, mode='bilinear', align_corners=self.align_corners, warning=False) diff --git a/mmseg/ops/wrappers.py b/mmseg/ops/wrappers.py index a6d755273d..365948dbc5 100644 --- a/mmseg/ops/wrappers.py +++ b/mmseg/ops/wrappers.py @@ -1,6 +1,5 @@ import warnings -import torch import torch.nn as nn import torch.nn.functional as F @@ -24,8 +23,8 @@ def resize(input, 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') - if isinstance(size, torch.Size): - size = tuple(int(x) for x in size) + # if isinstance(size, torch.Size): + # size = tuple(int(x) for x in size) return F.interpolate(input, size, scale_factor, mode, align_corners) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 2ec9feb59a..3653f2258f 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -7,10 +7,14 @@ import torch import torch._C import torch.serialization +from mmcv import DictAction from mmcv.onnx import register_extra_symbolics from mmcv.runner import load_checkpoint from torch import nn +from mmseg.apis import show_result_pyplot +from mmseg.apis.inference import LoadImage +from mmseg.datasets.pipelines import Compose from mmseg.models import build_segmentor torch.manual_seed(3) @@ -67,19 +71,54 @@ def _demo_mm_inputs(input_shape, num_classes): return mm_inputs +def _prepare_input_img(img_path, test_pipeline, shape=None): + # build the data pipeline + if shape is not None: + test_pipeline[1]['img_scale'] = shape + test_pipeline = [LoadImage()] + test_pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img_path) + data = test_pipeline(data) + imgs = data['img'] + img_metas = [i.data for i in data['img_metas']] + + mm_inputs = {'imgs': imgs, 'img_metas': img_metas} + + return mm_inputs + + +def _update_input_img(img_list, img_meta_list): + N, C, H, W = img_list[0].shape + img_meta = img_meta_list[0][0] + new_img_meta_list = [[ + { + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': img_meta['filename'], + # 'scale_factor': img_meta['scale_factor'] * 0 + 1, + 'scale_factor': 1., + 'flip': False, + } for _ in range(N) + ]] + + return img_list, new_img_meta_list + + def pytorch2onnx(model, - input_shape, + mm_inputs, opset_version=11, show=False, output_file='tmp.onnx', - verify=False): + verify=False, + dynamic_export=None): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. Args: model (nn.Module): Pytorch model we want to export. - input_shape (tuple): Use this input shape to construct - the corresponding dummy input and execute the model. + mm_inputs (dict): Contain the input tensors and img_metas infomation. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. @@ -94,28 +133,33 @@ def pytorch2onnx(model, else: num_classes = model.decode_head.num_classes - mm_inputs = _demo_mm_inputs(input_shape, num_classes) - imgs = mm_inputs.pop('imgs') img_metas = mm_inputs.pop('img_metas') + ori_shape = img_metas[0]['ori_shape'] img_list = [img[None, :] for img in imgs] img_meta_list = [[img_meta] for img_meta in img_metas] + img_list, img_meta_list = _update_input_img(img_list, img_meta_list) # replace original forward function origin_forward = model.forward model.forward = partial( model.forward, img_metas=img_meta_list, return_loss=False) + if dynamic_export: + dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}} + register_extra_symbolics(opset_version) with torch.no_grad(): torch.onnx.export( model, (img_list, ), output_file, + input_names=['input'], export_params=True, - keep_initializers_as_inputs=True, + keep_initializers_as_inputs=False, verbose=show, - opset_version=opset_version) + opset_version=opset_version, + dynamic_axes=dynamic_axes) print(f'Successfully exported ONNX model: {output_file}') model.forward = origin_forward @@ -125,9 +169,28 @@ def pytorch2onnx(model, onnx_model = onnx.load(output_file) onnx.checker.check_model(onnx_model) + if dynamic_export: + # scale image for dynamic shape test + img_list = [ + nn.functional.interpolate(_, scale_factor=1.5) + for _ in img_list + ] + # concate flip image for batch test + flip_img_list = [_.flip(-1) for _ in img_list] + img_list = [ + torch.cat((ori_img, flip_img), 0) + for ori_img, flip_img in zip(img_list, flip_img_list) + ] + + # update img_meta + img_list, img_meta_list = _update_input_img( + img_list, img_meta_list) + # check the numerical value # get pytorch output - pytorch_result = model(img_list, img_meta_list, return_loss=False)[0] + with torch.no_grad(): + pytorch_result = model(img_list, img_meta_list, return_loss=False) + pytorch_result = np.stack(pytorch_result, 0) # get onnx output input_all = [node.name for node in onnx_model.graph.input] @@ -138,17 +201,39 @@ def pytorch2onnx(model, assert (len(net_feed_input) == 1) sess = rt.InferenceSession(output_file) onnx_result = sess.run( - None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] - if not np.allclose(pytorch_result, onnx_result): + None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] + if not np.allclose( + pytorch_result / num_classes, + onnx_result / num_classes, + rtol=1e-5, + atol=1): raise ValueError( 'The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX') + if show: + import cv2 + import os.path as osp + img = img_meta_list[0][0]['filename'] + if not osp.exists(img): + img = imgs[0][:3, ...].permute(1, 2, 0) * 255 + img = img.detach().numpy().astype(np.uint8) + # resize onnx_result to ori_shape + onnx_result = cv2.resize(onnx_result[0].astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, (onnx_result, ), + palette=model.PALETTE, + opacity=0.5) + def parse_args(): parser = argparse.ArgumentParser(description='Convert MMSeg to ONNX') parser.add_argument('config', help='test config file path') parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument( + '--input-img', type=str, help='Images for input', default=None) parser.add_argument('--show', action='store_true', help='show onnx graph') parser.add_argument( '--verify', action='store_true', help='verify the onnx model') @@ -160,6 +245,20 @@ def parse_args(): nargs='+', default=[256, 256], help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--dynamic-export', + action='store_true', + help='Wether to export onnx with dynamic axis.') args = parser.parse_args() return args @@ -178,6 +277,8 @@ def parse_args(): raise ValueError('invalid input shape') cfg = mmcv.Config.fromfile(args.config) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) cfg.model.pretrained = None # build the model and load checkpoint @@ -188,13 +289,28 @@ def parse_args(): segmentor = _convert_batchnorm(segmentor) if args.checkpoint: - load_checkpoint(segmentor, args.checkpoint, map_location='cpu') + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + segmentor.CLASSES = checkpoint['meta']['CLASSES'] + segmentor.PALETTE = checkpoint['meta']['PALETTE'] + + # read input or create dummpy input + if args.input_img is not None: + mm_inputs = _prepare_input_img(args.input_img, cfg.data.test.pipeline, + (input_shape[3], input_shape[2])) + else: + if isinstance(segmentor.decode_head, nn.ModuleList): + num_classes = segmentor.decode_head[-1].num_classes + else: + num_classes = segmentor.decode_head.num_classes + mm_inputs = _demo_mm_inputs(input_shape, num_classes) # conver model to onnx file pytorch2onnx( segmentor, - input_shape, + mm_inputs, opset_version=args.opset_version, show=args.show, output_file=args.output_file, - verify=args.verify) + verify=args.verify, + dynamic_export=args.dynamic_export) From 78d4b94b4561f4c6760a4fdcb41f1f304bac1c62 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 9 Apr 2021 13:31:03 +0800 Subject: [PATCH 2/9] update document --- docs/useful_tools.md | 25 +++++++++++++++++++++++-- mmseg/ops/wrappers.py | 2 -- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 7b2e3fde1e..89665c3419 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -46,10 +46,31 @@ The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pt We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model. -```shell -python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output-file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify] +```bash +python tools/pytorch2onnx.py \ + ${CONFIG_FILE} \ + --checkpoint ${CHECKPOINT_FILE} \ + --output-file ${ONNX_FILE} \ + --input-img ${INPUT_IMG} \ + --shape ${INPUT_SHAPE} \ + --show \ + --verify \ + --dynamic-export \ + --cfg-options \ + model.test_cfg.mode="whole" ``` +Description of arguments: +- `config` : The path of a model config file. +- `--checkpoint` : The path of a model checkpoint file. +- `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`. +- `--input-img` : The path of an input image for conversion and visualize. +- `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`. +- `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`. +- `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`. +- `dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`. +- `cfg-options`:Update config options. + **Note**: This tool is still experimental. Some customized operators are not supported for now. ## Miscellaneous diff --git a/mmseg/ops/wrappers.py b/mmseg/ops/wrappers.py index 365948dbc5..0ed9a0cb8d 100644 --- a/mmseg/ops/wrappers.py +++ b/mmseg/ops/wrappers.py @@ -23,8 +23,6 @@ def resize(input, 'the output would more aligned if ' f'input size {(input_h, input_w)} is `x+1` and ' f'out size {(output_h, output_w)} is `nx+1`') - # if isinstance(size, torch.Size): - # size = tuple(int(x) for x in size) return F.interpolate(input, size, scale_factor, mode, align_corners) From 652d54ad0b98070e26e7b654aae2cd52a10dbb3f Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Fri, 9 Apr 2021 13:32:22 +0800 Subject: [PATCH 3/9] fix lint --- docs/useful_tools.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index 89665c3419..e79f5d2b76 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -61,6 +61,7 @@ python tools/pytorch2onnx.py \ ``` Description of arguments: + - `config` : The path of a model config file. - `--checkpoint` : The path of a model checkpoint file. - `--output-file`: The path of output ONNX model. If not specified, it will be set to `tmp.onnx`. From a22214ef70149a6a3463a6e70de3fcbee36f9918 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 9 Apr 2021 19:47:57 +0800 Subject: [PATCH 4/9] fix dynamic error and add visualization --- mmseg/apis/inference.py | 12 +++++++++-- tools/pytorch2onnx.py | 44 +++++++++++++++++++++++++++++------------ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index 9052cdd32a..d8b8c5d854 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -103,7 +103,9 @@ def show_result_pyplot(model, result, palette=None, fig_size=(15, 10), - opacity=0.5): + opacity=0.5, + title='', + block=True): """Visualize the segmentation results on the image. Args: @@ -117,6 +119,10 @@ def show_result_pyplot(model, opacity(float): Opacity of painted segmentation map. Default 0.5. Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + block (bool): Whether to block the pyplot figure. + Default is False. """ if hasattr(model, 'module'): model = model.module @@ -124,4 +130,6 @@ def show_result_pyplot(model, img, result, palette=palette, show=False, opacity=opacity) plt.figure(figsize=fig_size) plt.imshow(mmcv.bgr2rgb(img)) - plt.show() + plt.title(title) + plt.tight_layout() + plt.show(block=block) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 3653f2258f..f163f70c99 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -75,6 +75,7 @@ def _prepare_input_img(img_path, test_pipeline, shape=None): # build the data pipeline if shape is not None: test_pipeline[1]['img_scale'] = shape + test_pipeline[1]['transforms'][0]['keep_ratio'] = False test_pipeline = [LoadImage()] + test_pipeline[1:] test_pipeline = Compose(test_pipeline) # prepare data @@ -145,9 +146,10 @@ def pytorch2onnx(model, origin_forward = model.forward model.forward = partial( model.forward, img_metas=img_meta_list, return_loss=False) - + dynamic_axes = None if dynamic_export: - dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}} + dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}, + 'output': {1: 'batch', 2: 'height', 3: 'width'}} register_extra_symbolics(opset_version) with torch.no_grad(): @@ -155,6 +157,7 @@ def pytorch2onnx(model, model, (img_list, ), output_file, input_names=['input'], + output_names=['output'], export_params=True, keep_initializers_as_inputs=False, verbose=show, @@ -202,15 +205,7 @@ def pytorch2onnx(model, sess = rt.InferenceSession(output_file) onnx_result = sess.run( None, {net_feed_input[0]: img_list[0].detach().numpy()})[0][0] - if not np.allclose( - pytorch_result / num_classes, - onnx_result / num_classes, - rtol=1e-5, - atol=1): - raise ValueError( - 'The outputs are different between Pytorch and ONNX') - print('The outputs are same between Pytorch and ONNX') - + # show segmentation results if show: import cv2 import os.path as osp @@ -219,13 +214,36 @@ def pytorch2onnx(model, img = imgs[0][:3, ...].permute(1, 2, 0) * 255 img = img.detach().numpy().astype(np.uint8) # resize onnx_result to ori_shape - onnx_result = cv2.resize(onnx_result[0].astype(np.uint8), + onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), + (ori_shape[1], ori_shape[0])) + show_result_pyplot( + model, + img, + (onnx_result_, ), + palette=model.PALETTE, + block=False, + title='ONNXRuntime', + opacity=0.5) + + # resize pytorch_result to ori_shape + pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), (ori_shape[1], ori_shape[0])) show_result_pyplot( model, - img, (onnx_result, ), + img, + (pytorch_result_, ), + title='PyTorch', palette=model.PALETTE, opacity=0.5) + # compare results + np.testing.assert_allclose( + pytorch_result.astype(np.float32)/num_classes, + onnx_result.astype(np.float32)/num_classes, + rtol=1e-5, + atol=1e-5, + err_msg='The outputs are different between Pytorch and ONNX' + ) + print('The outputs are same between Pytorch and ONNX') def parse_args(): From 6f272b4d972b76c12e601ee8c232a29e272c346f Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 9 Apr 2021 20:10:01 +0800 Subject: [PATCH 5/9] fix lint --- mmseg/models/segmentors/encoder_decoder.py | 1 + tools/pytorch2onnx.py | 58 ++++++++++++---------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 43a0be40bc..16593a25a0 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -216,6 +216,7 @@ def whole_inference(self, img, img_meta, rescale): seg_logit = self.encode_decode(img, img_meta) if rescale: + # support dynamic shape for onnx if torch.onnx.is_in_onnx_export(): new_size = img.shape[2:] else: diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index f163f70c99..eed9269e66 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -92,17 +92,14 @@ def _prepare_input_img(img_path, test_pipeline, shape=None): def _update_input_img(img_list, img_meta_list): N, C, H, W = img_list[0].shape img_meta = img_meta_list[0][0] - new_img_meta_list = [[ - { - 'img_shape': (H, W, C), - 'ori_shape': (H, W, C), - 'pad_shape': (H, W, C), - 'filename': img_meta['filename'], - # 'scale_factor': img_meta['scale_factor'] * 0 + 1, - 'scale_factor': 1., - 'flip': False, - } for _ in range(N) - ]] + new_img_meta_list = [[{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': img_meta['filename'], + 'scale_factor': 1., + 'flip': False, + } for _ in range(N)]] return img_list, new_img_meta_list @@ -113,7 +110,7 @@ def pytorch2onnx(model, show=False, output_file='tmp.onnx', verify=False, - dynamic_export=None): + dynamic_export=False): """Export Pytorch model to ONNX model and verify the outputs are same between Pytorch and ONNX. @@ -126,6 +123,8 @@ def pytorch2onnx(model, Default: `tmp.onnx`. verify (bool): Whether compare the outputs between Pytorch and ONNX. Default: False. + dynamic_export (bool): Whether to export ONNX with dynamic axis. + Default: False. """ model.cpu().eval() @@ -148,8 +147,18 @@ def pytorch2onnx(model, model.forward, img_metas=img_meta_list, return_loss=False) dynamic_axes = None if dynamic_export: - dynamic_axes = {'input': {0: 'batch', 2: 'height', 3: 'width'}, - 'output': {1: 'batch', 2: 'height', 3: 'width'}} + dynamic_axes = { + 'input': { + 0: 'batch', + 2: 'height', + 3: 'width' + }, + 'output': { + 1: 'batch', + 2: 'height', + 3: 'width' + } + } register_extra_symbolics(opset_version) with torch.no_grad(): @@ -215,11 +224,10 @@ def pytorch2onnx(model, img = img.detach().numpy().astype(np.uint8) # resize onnx_result to ori_shape onnx_result_ = cv2.resize(onnx_result[0].astype(np.uint8), - (ori_shape[1], ori_shape[0])) + (ori_shape[1], ori_shape[0])) show_result_pyplot( model, - img, - (onnx_result_, ), + img, (onnx_result_, ), palette=model.PALETTE, block=False, title='ONNXRuntime', @@ -227,22 +235,20 @@ def pytorch2onnx(model, # resize pytorch_result to ori_shape pytorch_result_ = cv2.resize(pytorch_result[0].astype(np.uint8), - (ori_shape[1], ori_shape[0])) + (ori_shape[1], ori_shape[0])) show_result_pyplot( model, - img, - (pytorch_result_, ), + img, (pytorch_result_, ), title='PyTorch', palette=model.PALETTE, opacity=0.5) # compare results np.testing.assert_allclose( - pytorch_result.astype(np.float32)/num_classes, - onnx_result.astype(np.float32)/num_classes, - rtol=1e-5, - atol=1e-5, - err_msg='The outputs are different between Pytorch and ONNX' - ) + pytorch_result.astype(np.float32) / num_classes, + onnx_result.astype(np.float32) / num_classes, + rtol=1e-5, + atol=1e-5, + err_msg='The outputs are different between Pytorch and ONNX') print('The outputs are same between Pytorch and ONNX') From d0df8ed5759792ede30f5d55af72e336dfbabf91 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Mon, 12 Apr 2021 10:58:59 +0800 Subject: [PATCH 6/9] update docstring --- mmseg/apis/inference.py | 2 +- mmseg/models/segmentors/encoder_decoder.py | 6 +++--- tools/pytorch2onnx.py | 9 +++++---- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mmseg/apis/inference.py b/mmseg/apis/inference.py index d8b8c5d854..bf875cb262 100644 --- a/mmseg/apis/inference.py +++ b/mmseg/apis/inference.py @@ -122,7 +122,7 @@ def show_result_pyplot(model, title (str): The title of pyplot figure. Default is ''. block (bool): Whether to block the pyplot figure. - Default is False. + Default is True. """ if hasattr(model, 'module'): model = model.module diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 16593a25a0..b2d067dcbe 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -218,12 +218,12 @@ def whole_inference(self, img, img_meta, rescale): if rescale: # support dynamic shape for onnx if torch.onnx.is_in_onnx_export(): - new_size = img.shape[2:] + size = img.shape[2:] else: - new_size = img_meta[0]['ori_shape'][:2] + size = img_meta[0]['ori_shape'][:2] seg_logit = resize( seg_logit, - size=new_size, + size=size, mode='bilinear', align_corners=self.align_corners, warning=False) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index eed9269e66..bda8567db1 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -90,6 +90,7 @@ def _prepare_input_img(img_path, test_pipeline, shape=None): def _update_input_img(img_list, img_meta_list): + # update img and its meta list N, C, H, W = img_list[0].shape img_meta = img_meta_list[0][0] new_img_meta_list = [[{ @@ -116,7 +117,7 @@ def pytorch2onnx(model, Args: model (nn.Module): Pytorch model we want to export. - mm_inputs (dict): Contain the input tensors and img_metas infomation. + mm_inputs (dict): Contain the input tensors and img_metas information. opset_version (int): The onnx op version. Default: 11. show (bool): Whether print the computation graph. Default: False. output_file (string): The path to where we store the output ONNX model. @@ -273,7 +274,7 @@ def parse_args(): '--cfg-options', nargs='+', action=DictAction, - help='override some settings in the used config, the key-value pair ' + help='Override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' @@ -282,7 +283,7 @@ def parse_args(): parser.add_argument( '--dynamic-export', action='store_true', - help='Wether to export onnx with dynamic axis.') + help='Whether to export onnx with dynamic axis.') args = parser.parse_args() return args @@ -329,7 +330,7 @@ def parse_args(): num_classes = segmentor.decode_head.num_classes mm_inputs = _demo_mm_inputs(input_shape, num_classes) - # conver model to onnx file + # convert model to onnx file pytorch2onnx( segmentor, mm_inputs, From cad73f5373ed0c5c89a8c9b8ac2d02e132ce19ee Mon Sep 17 00:00:00 2001 From: maningsheng Date: Mon, 12 Apr 2021 11:05:29 +0800 Subject: [PATCH 7/9] update doc --- docs/useful_tools.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/useful_tools.md b/docs/useful_tools.md index e79f5d2b76..8286af83e5 100644 --- a/docs/useful_tools.md +++ b/docs/useful_tools.md @@ -69,8 +69,8 @@ Description of arguments: - `--shape`: The height and width of input tensor to the model. If not specified, it will be set to `256 256`. - `--show`: Determines whether to print the architecture of the exported model. If not specified, it will be set to `False`. - `--verify`: Determines whether to verify the correctness of an exported model. If not specified, it will be set to `False`. -- `dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`. -- `cfg-options`:Update config options. +- `--dynamic-export`: Determines whether to export ONNX model with dynamic input and output shapes. If not specified, it will be set to `False`. +- `--cfg-options`:Update config options. **Note**: This tool is still experimental. Some customized operators are not supported for now. From 64757fcb83b36e4a69fc98b98214659a539bef59 Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Mon, 12 Apr 2021 11:07:30 +0800 Subject: [PATCH 8/9] Update help info for --show Co-authored-by: Jerry Jiarui XU --- tools/pytorch2onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index bda8567db1..578b48fa8c 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -259,7 +259,7 @@ def parse_args(): parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument( '--input-img', type=str, help='Images for input', default=None) - parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument('--show', action='store_true', help='show onnx graph and segmentation results') parser.add_argument( '--verify', action='store_true', help='verify the onnx model') parser.add_argument('--output-file', type=str, default='tmp.onnx') From 5ff17c707b126ea8f4396a7c6eed5eb6822c2f7e Mon Sep 17 00:00:00 2001 From: maningsheng Date: Mon, 12 Apr 2021 21:18:26 +0800 Subject: [PATCH 9/9] fix lint --- tools/pytorch2onnx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py index 578b48fa8c..71f1bb7227 100644 --- a/tools/pytorch2onnx.py +++ b/tools/pytorch2onnx.py @@ -259,7 +259,10 @@ def parse_args(): parser.add_argument('--checkpoint', help='checkpoint file', default=None) parser.add_argument( '--input-img', type=str, help='Images for input', default=None) - parser.add_argument('--show', action='store_true', help='show onnx graph and segmentation results') + parser.add_argument( + '--show', + action='store_true', + help='show onnx graph and segmentation results') parser.add_argument( '--verify', action='store_true', help='verify the onnx model') parser.add_argument('--output-file', type=str, default='tmp.onnx')