diff --git a/docs/getting_started.md b/docs/getting_started.md index 9140435cf2..bf4980218e 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -332,3 +332,18 @@ python tools/publish_model.py work_dirs/pspnet/latest.pth psp_r50_hszhao_200ep.p ``` The final output filename will be `psp_r50_512x1024_40ki_cityscapes-{hash id}.pth`. + +### Convert to ONNX (experimental) + +We provide a script to convert model to [ONNX](https://github.com/onnx/onnx) format. The converted model could be visualized by tools like [Netron](https://github.com/lutzroeder/netron). Besides, we also support comparing the output results between Pytorch and ONNX model. + +```shell +python tools/pytorch2onnx.py ${CONFIG_FILE} --checkpoint ${CHECKPOINT_FILE} --output_file ${ONNX_FILE} [--shape ${INPUT_SHAPE} --verify] +``` + +**Note**: This tool is still experimental. Some customized operators are not supported for now. + +## Tutorials + +Currently, we provide four tutorials for users to [add new dataset](tutorials/new_dataset.md), [design data pipeline](tutorials/data_pipeline.md) and [add new modules](tutorials/new_modules.md), [use training tricks](tutorials/training_tricks.md). +We also provide a full description about the [config system](config.md). diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index d3ce17adbb..d1709e0ca3 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -1,3 +1,4 @@ +import torch import torch.nn as nn import torch.nn.functional as F @@ -171,6 +172,8 @@ def slide_inference(self, img, img_meta, rescale): h_stride, w_stride = self.test_cfg.stride h_crop, w_crop = self.test_cfg.crop_size batch_size, _, h_img, w_img = img.size() + assert h_crop <= h_img and w_crop <= w_img, ( + 'crop size should not greater than image size') num_classes = self.num_classes h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 @@ -185,14 +188,15 @@ def slide_inference(self, img, img_meta, rescale): y1 = max(y2 - h_crop, 0) x1 = max(x2 - w_crop, 0) crop_img = img[:, :, y1:y2, x1:x2] - pad_img = crop_img.new_zeros( - (crop_img.size(0), crop_img.size(1), h_crop, w_crop)) - pad_img[:, :, :y2 - y1, :x2 - x1] = crop_img - pad_seg_logit = self.encode_decode(pad_img, img_meta) - preds[:, :, y1:y2, - x1:x2] += pad_seg_logit[:, :, :y2 - y1, :x2 - x1] + crop_seg_logit = self.encode_decode(crop_img, img_meta) + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + count_mat[:, :, y1:y2, x1:x2] += 1 assert (count_mat == 0).sum() == 0 + # We want to regard count_mat as a constant while exporting to ONNX + count_mat = torch.from_numpy(count_mat.detach().numpy()) preds = preds / count_mat if rescale: preds = resize( @@ -201,7 +205,6 @@ def slide_inference(self, img, img_meta, rescale): mode='bilinear', align_corners=self.align_corners, warning=False) - return preds def whole_inference(self, img, img_meta, rescale): @@ -243,8 +246,8 @@ def inference(self, img, img_meta, rescale): seg_logit = self.whole_inference(img, img_meta, rescale) output = F.softmax(seg_logit, dim=1) flip = img_meta[0]['flip'] - flip_direction = img_meta[0]['flip_direction'] if flip: + flip_direction = img_meta[0]['flip_direction'] assert flip_direction in ['horizontal', 'vertical'] if flip_direction == 'horizontal': output = output.flip(dims=(3, )) @@ -257,6 +260,8 @@ def simple_test(self, img, img_meta, rescale=True): """Simple test with single image.""" seg_logit = self.inference(img, img_meta, rescale) seg_pred = seg_logit.argmax(dim=1) + if torch.onnx.is_in_onnx_export(): + return seg_pred seg_pred = seg_pred.cpu().numpy() # unravel batch dim seg_pred = list(seg_pred) diff --git a/setup.cfg b/setup.cfg index 2102a8ca60..9721e1c5c3 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,6 +8,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmseg -known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,pytablewriter,pytest,scipy,torch,torchvision +known_third_party = PIL,cityscapesscripts,cv2,matplotlib,mmcv,numpy,onnxruntime,pytablewriter,pytest,scipy,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/pytorch2onnx.py b/tools/pytorch2onnx.py new file mode 100644 index 0000000000..df84eeb911 --- /dev/null +++ b/tools/pytorch2onnx.py @@ -0,0 +1,198 @@ +import argparse +from functools import partial + +import mmcv +import numpy as np +import onnxruntime as rt +import torch +import torch._C +import torch.serialization +from mmcv.onnx import register_extra_symbolics +from mmcv.runner import load_checkpoint + +from mmseg.models import build_segmentor + +torch.manual_seed(3) + + +def _convert_batchnorm(module): + module_output = module + if isinstance(module, torch.nn.SyncBatchNorm): + module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, + module.momentum, module.affine, + module.track_running_stats) + if module.affine: + module_output.weight.data = module.weight.data.clone().detach() + module_output.bias.data = module.bias.data.clone().detach() + # keep requires_grad unchanged + module_output.weight.requires_grad = module.weight.requires_grad + module_output.bias.requires_grad = module.bias.requires_grad + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + for name, child in module.named_children(): + module_output.add_module(name, _convert_batchnorm(child)) + del module + return module_output + + +def _demo_mm_inputs(input_shape, num_classes): + """Create a superset of inputs needed to run test or train batches. + + Args: + input_shape (tuple): + input batch dimensions + num_classes (int): + number of semantic classes + """ + (N, C, H, W) = input_shape + rng = np.random.RandomState(0) + imgs = rng.rand(*input_shape) + segs = rng.randint( + low=0, high=num_classes - 1, size=(N, 1, H, W)).astype(np.uint8) + img_metas = [{ + 'img_shape': (H, W, C), + 'ori_shape': (H, W, C), + 'pad_shape': (H, W, C), + 'filename': '.png', + 'scale_factor': 1.0, + 'flip': False, + } for _ in range(N)] + mm_inputs = { + 'imgs': torch.FloatTensor(imgs).requires_grad_(True), + 'img_metas': img_metas, + 'gt_semantic_seg': torch.LongTensor(segs) + } + return mm_inputs + + +def pytorch2onnx(model, + input_shape, + opset_version=11, + show=False, + output_file='tmp.onnx', + verify=False): + """Export Pytorch model to ONNX model and verify the outputs are same + between Pytorch and ONNX. + + Args: + model (nn.Module): Pytorch model we want to export. + input_shape (tuple): Use this input shape to construct + the corresponding dummy input and execute the model. + opset_version (int): The onnx op version. Default: 11. + show (bool): Whether print the computation graph. Default: False. + output_file (string): The path to where we store the output ONNX model. + Default: `tmp.onnx`. + verify (bool): Whether compare the outputs between Pytorch and ONNX. + Default: False. + """ + model.cpu().eval() + + num_classes = model.decode_head.num_classes + + mm_inputs = _demo_mm_inputs(input_shape, num_classes) + + imgs = mm_inputs.pop('imgs') + img_metas = mm_inputs.pop('img_metas') + + img_list = [img[None, :] for img in imgs] + img_meta_list = [[img_meta] for img_meta in img_metas] + + # replace original forward function + origin_forward = model.forward + model.forward = partial( + model.forward, img_metas=img_meta_list, return_loss=False) + + register_extra_symbolics(opset_version) + with torch.no_grad(): + torch.onnx.export( + model, (img_list, ), + output_file, + export_params=True, + keep_initializers_as_inputs=True, + verbose=show, + opset_version=opset_version) + print(f'Successfully exported ONNX model: {output_file}') + model.forward = origin_forward + + if verify: + # check by onnx + import onnx + onnx_model = onnx.load(output_file) + onnx.checker.check_model(onnx_model) + + # check the numerical value + # get pytorch output + pytorch_result = model(img_list, img_meta_list, return_loss=False)[0] + + # get onnx output + input_all = [node.name for node in onnx_model.graph.input] + input_initializer = [ + node.name for node in onnx_model.graph.initializer + ] + net_feed_input = list(set(input_all) - set(input_initializer)) + assert (len(net_feed_input) == 1) + sess = rt.InferenceSession(output_file) + onnx_result = sess.run( + None, {net_feed_input[0]: img_list[0].detach().numpy()})[0] + if not np.allclose(pytorch_result, onnx_result): + raise ValueError( + 'The outputs are different between Pytorch and ONNX') + print('The outputs are same between Pytorch and ONNX') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert MMDet to ONNX') + parser.add_argument('config', help='test config file path') + parser.add_argument('--checkpoint', help='checkpoint file', default=None) + parser.add_argument('--show', action='store_true', help='show onnx graph') + parser.add_argument( + '--verify', action='store_true', help='verify the onnx model') + parser.add_argument('--output-file', type=str, default='tmp.onnx') + parser.add_argument('--opset-version', type=int, default=11) + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[256, 256], + help='input image size') + args = parser.parse_args() + return args + + +if __name__ == '__main__': + args = parse_args() + + if len(args.shape) == 1: + input_shape = (1, 3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = ( + 1, + 3, + ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + cfg = mmcv.Config.fromfile(args.config) + cfg.model.pretrained = None + + # build the model and load checkpoint + segmentor = build_segmentor( + cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + # convert SyncBN to BN + segmentor = _convert_batchnorm(segmentor) + + num_classes = segmentor.decode_head.num_classes + + if args.checkpoint: + checkpoint = load_checkpoint( + segmentor, args.checkpoint, map_location='cpu') + + # conver model to onnx file + pytorch2onnx( + segmentor, + input_shape, + opset_version=args.opset_version, + show=args.show, + output_file=args.output_file, + verify=args.verify)