From 10f99bcf0a7e76913ad6674ef4b97d953fc77a90 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Thu, 13 May 2021 17:20:39 +0800 Subject: [PATCH 1/8] evaluate trt models --- docs/tutorials/onnx2tensorrt.md | 30 ++++-- docs/tutorials/pytorch2onnx.md | 21 ++-- mmdet/core/export/model_wrappers.py | 142 +++++++++++++++++++--------- tools/deployment/onnx2tensorrt.py | 22 ++++- tools/deployment/test.py | 25 ++++- 5 files changed, 175 insertions(+), 65 deletions(-) diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md index 455f2de80b0..ded83145364 100644 --- a/docs/tutorials/onnx2tensorrt.md +++ b/docs/tutorials/onnx2tensorrt.md @@ -6,6 +6,7 @@ - [How to convert models from ONNX to TensorRT](#how-to-convert-models-from-onnx-to-tensorrt) - [Prerequisite](#prerequisite) - [Usage](#usage) + - [How to evaluate the exported models](#how-to-evaluate-the-exported-models) - [List of supported models convertable to TensorRT](#list-of-supported-models-convertable-to-tensorrt) - [Reminders](#reminders) - [FAQs](#faqs) @@ -28,6 +29,7 @@ python tools/deployment/onnx2tensorrt.py \ --trt-file ${TRT_FILE} \ --input-img ${INPUT_IMAGE_PATH} \ --shape ${IMAGE_SHAPE} \ + --max-shape ${MAX_IMAGE_SHAPE} \ --mean ${IMAGE_MEAN} \ --std ${IMAGE_STD} \ --dataset ${DATASET_NAME} \ @@ -42,6 +44,7 @@ Description of all arguments: - `--trt-file`: The Path of output TensorRT engine file. If not specified, it will be set to `tmp.trt`. - `--input-img` : The path of an input image for tracing and conversion. By default, it will be set to `demo/demo.jpg`. - `--shape`: The height and width of model input. If not specified, it will be set to `400 600`. +- `--max-shape`: The maximum height and width of model input. If not specified, it will be set to the same as `--shape`. - `--mean` : Three mean values for the input image. If not specified, it will be set to `123.675 116.28 103.53`. - `--std` : Three std values for the input image. If not specified, it will be set to `58.395 57.12 57.375`. - `--dataset` : The dataset name for the input model. If not specified, it will be set to `coco`. @@ -65,19 +68,28 @@ python tools/deployment/onnx2tensorrt.py \ --verify \ ``` +## How to evaluate the exported models + +We prepare a tool `tools/deplopyment/test.py` to evaluate TensorRT models. + +Please refer to following links for more information. + +- [how-to-evaluate-the-exported-models](pytorch2onnx.md#how-to-evaluate-the-exported-models) +- [results-and-models](pytorch2onnx.md#results-and-models) + ## List of supported models convertable to TensorRT The table below lists the models that are guaranteed to be convertable to TensorRT. -| Model | Config | Status | -| :----------: | :--------------------------------------------------: | :----: | -| SSD | `configs/ssd/ssd300_coco.py` | Y | -| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | -| FCOS | `configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py` | Y | -| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | -| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | -| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | -| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | +| Model | Config | Dynamic Shape | Batch Inference | Note | +| :----------: | :--------------------------------------------------: | :-----------: | :-------------: | :---: | +| SSD | `configs/ssd/ssd300_coco.py` | Y | Y | | +| FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | | +| FCOS | `configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py` | N | Y | | +| YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | | +| RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | | +| Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | | +| Mask R-CNN | `configs/mask_rcnn/mask_rcnn_r50_fpn_1x_coco.py` | Y | Y | | Notes: diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index 0b993455476..8c09b1b0037 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -7,7 +7,7 @@ - [Prerequisite](#prerequisite) - [Usage](#usage) - [Description of all arguments](#description-of-all-arguments) - - [How to evaluate ONNX models with ONNX Runtime](#how-to-evaluate-onnx-models-with-onnx-runtime) + - [How to evaluate the exported models](#how-to-evaluate-the-exported-models) - [Prerequisite](#prerequisite-1) - [Usage](#usage-1) - [Description of all arguments](#description-of-all-arguments-1) @@ -90,9 +90,9 @@ python tools/deployment/pytorch2onnx.py \ model.test_cfg.deploy_nms_pre=300 \ ``` -## How to evaluate ONNX models with ONNX Runtime +## How to evaluate the exported models -We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNX Runtime backend. +We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNXRuntime and TensorRT. ### Prerequisite @@ -107,7 +107,7 @@ We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNX ```bash python tools/deployment/test.py \ ${CONFIG_FILE} \ - ${ONNX_FILE} \ + ${MODEL_FILE} \ --out ${OUTPUT_FILE} \ --format-only ${FORMAT_ONLY} \ --eval ${EVALUATION_METRICS} \ @@ -120,7 +120,7 @@ python tools/deployment/test.py \ ### Description of all arguments - `config`: The path of a model config file. -- `model`: The path of a ONNX model file. +- `model`: The path of an input model file and it should have extension of `.onnx` or `.trt` . - `--out`: The path of output result file in pickle format. - `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. - `--eval`: Evaluation metrics, which depends on the dataset, e.g., "bbox", "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC. @@ -138,6 +138,7 @@ python tools/deployment/test.py \ Metric PyTorch ONNX Runtime + TensorRT FCOS @@ -145,6 +146,7 @@ python tools/deployment/test.py \ Box AP 36.6 36.5 + / FSAF @@ -152,6 +154,7 @@ python tools/deployment/test.py \ Box AP 36.0 36.0 + 35.9 RetinaNet @@ -159,6 +162,7 @@ python tools/deployment/test.py \ Box AP 36.5 36.4 + 36.3 SSD @@ -166,6 +170,7 @@ python tools/deployment/test.py \ Box AP 25.6 25.6 + 25.6 YOLOv3 @@ -173,6 +178,7 @@ python tools/deployment/test.py \ Box AP 33.5 33.5 + 33.5 Faster R-CNN @@ -180,6 +186,7 @@ python tools/deployment/test.py \ Box AP 37.4 37.4 + 37.0 Mask R-CNN @@ -187,11 +194,13 @@ python tools/deployment/test.py \ Box AP 38.2 38.1 + / Mask AP 34.7 33.7 + / @@ -199,7 +208,7 @@ Notes: - All ONNX models are evaluated with dynamic shape on coco dataset and images are preprocessed according to the original config file. -- Mask AP of Mask R-CNN drops by 1% for ONNXRuntime. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in ONNXRuntime. +- Mask AP of Mask R-CNN drops by 1% for ONNXRuntime. The main reason is that the predicted masks are directly interpolated to original image in PyTorch, while they are at first interpolated to the preprocessed input image of the model and then to original image in other backend. ## List of supported models exportable to ONNX diff --git a/mmdet/core/export/model_wrappers.py b/mmdet/core/export/model_wrappers.py index e9988ba4e1b..7eb11f3fbc0 100644 --- a/mmdet/core/export/model_wrappers.py +++ b/mmdet/core/export/model_wrappers.py @@ -9,11 +9,82 @@ from mmdet.models import BaseDetector -class ONNXRuntimeDetector(BaseDetector): +class DeployBaseDetector(BaseDetector): + """DeployBaseDetector.""" + + def __init__(self, class_names, device_id): + super(DeployBaseDetector, self).__init__() + self.CLASSES = class_names + self.device_id = device_id + + def simple_test(self, img, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def aug_test(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def extract_feat(self, imgs): + raise NotImplementedError('This method is not implemented.') + + def forward_train(self, imgs, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def val_step(self, data, optimizer): + raise NotImplementedError('This method is not implemented.') + + def train_step(self, data, optimizer): + raise NotImplementedError('This method is not implemented.') + + def aforward_test(self, *, img, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def async_simple_test(self, img, img_metas, **kwargs): + raise NotImplementedError('This method is not implemented.') + + def forward(self, img, img_metas, return_loss=True, **kwargs): + outputs = self.forward_test(img, img_metas, **kwargs) + batch_dets, batch_labels = outputs[:2] + batch_masks = outputs[2] if len(outputs) == 3 else None + batch_size = img[0].shape[0] + img_metas = img_metas[0] + results = [] + rescale = kwargs.get('rescale', True) + for i in range(batch_size): + dets, labels = batch_dets[i], batch_labels[i] + if rescale: + scale_factor = img_metas[i]['scale_factor'] + dets[:, :4] /= scale_factor + dets_results = bbox2result(dets, labels, len(self.CLASSES)) + if batch_masks is not None: + masks = batch_masks[i] + img_h, img_w = img_metas[i]['img_shape'][:2] + ori_h, ori_w = img_metas[i]['ori_shape'][:2] + masks = masks[:, :img_h, :img_w] + if rescale: + mask_dtype = masks.dtype + masks = masks.astype(np.float32) + masks = torch.from_numpy(masks) + masks = torch.nn.functional.interpolate( + masks.unsqueeze(0), size=(ori_h, ori_w)) + masks = masks.squeeze(0).detach().numpy() + # convert mask to range(0,1) + if mask_dtype != np.bool: + masks /= 255 + masks = masks >= 0.5 + segms_results = [[] for _ in range(len(self.CLASSES))] + for j in range(len(dets)): + segms_results[labels[j]].append(masks[j]) + results.append((dets_results, segms_results)) + else: + results.append(dets_results) + return results + + +class ONNXRuntimeDetector(DeployBaseDetector): """Wrapper for detector's inference with ONNXRuntime.""" def __init__(self, onnx_file, class_names, device_id): - super(ONNXRuntimeDetector, self).__init__() + super(ONNXRuntimeDetector, self).__init__(class_names, device_id) # get the custom op path ort_custom_op_path = '' try: @@ -37,25 +108,12 @@ def __init__(self, onnx_file, class_names, device_id): sess.set_providers(providers, options) self.sess = sess - self.CLASSES = class_names - self.device_id = device_id self.io_binding = sess.io_binding() self.output_names = [_.name for _ in sess.get_outputs()] self.is_cuda_available = is_cuda_available - def simple_test(self, img, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') - - def aug_test(self, imgs, img_metas, **kwargs): - raise NotImplementedError('This method is not implemented.') - - def extract_feat(self, imgs): - raise NotImplementedError('This method is not implemented.') - def forward_test(self, imgs, img_metas, **kwargs): input_data = imgs[0] - img_metas = img_metas[0] - batch_size = input_data.shape[0] # set io binding for inputs/outputs device_type = 'cuda' if self.is_cuda_available else 'cpu' if not self.is_cuda_available: @@ -73,34 +131,28 @@ def forward_test(self, imgs, img_metas, **kwargs): # run session to get outputs self.sess.run_with_iobinding(self.io_binding) ort_outputs = self.io_binding.copy_outputs_to_cpu() - batch_dets, batch_labels = ort_outputs[:2] - batch_masks = ort_outputs[2] if len(ort_outputs) == 3 else None + return ort_outputs - results = [] - for i in range(batch_size): - scale_factor = img_metas[i]['scale_factor'] - dets, labels = batch_dets[i], batch_labels[i] - dets[:, :4] /= scale_factor - dets_results = bbox2result(dets, labels, len(self.CLASSES)) - if batch_masks is not None: - masks = batch_masks[i] - img_h, img_w = img_metas[i]['img_shape'][:2] - ori_h, ori_w = img_metas[i]['ori_shape'][:2] - masks = masks[:, :img_h, :img_w] - mask_dtype = masks.dtype - masks = masks.astype(np.float32) - masks = torch.from_numpy(masks) - masks = torch.nn.functional.interpolate( - masks.unsqueeze(0), size=(ori_h, ori_w)) - masks = masks.squeeze(0).detach().numpy() - # convert mask to range(0,1) - if mask_dtype != np.bool: - masks /= 255 - masks = masks >= 0.5 - segms_results = [[] for _ in range(len(self.CLASSES))] - for j in range(len(dets)): - segms_results[labels[j]].append(masks[j]) - results.append((dets_results, segms_results)) - else: - results.append(dets_results) - return results + +class TensorRTDetector(DeployBaseDetector): + """Wrapper for detector's inference with TensorRT.""" + + def __init__(self, engine_file, class_names, device_id, output_names): + super(TensorRTDetector, self).__init__(class_names, device_id) + try: + from mmcv.tensorrt import TRTWraper + except (ImportError, ModuleNotFoundError): + raise RuntimeError( + 'Please install TensorRT: https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv' # noqa + ) + + self.output_names = output_names + self.model = TRTWraper(engine_file, ['input'], output_names) + + def forward_test(self, imgs, img_metas, **kwargs): + input_data = imgs[0] + with torch.cuda.device(self.device_id), torch.no_grad(): + outputs = self.model({'input': input_data}) + outputs = [outputs[name] for name in self.output_names] + outputs = [out.detach().cpu().numpy() for out in outputs] + return outputs diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py index e2ecbda452d..52ad852e5e9 100644 --- a/tools/deployment/onnx2tensorrt.py +++ b/tools/deployment/onnx2tensorrt.py @@ -31,8 +31,9 @@ def onnx2tensorrt(onnx_file, import tensorrt as trt onnx_model = onnx.load(onnx_file) input_shape = input_config['input_shape'] + max_shape = input_config['max_shape'] # create trt engine and wraper - opt_shape_dict = {'input': [input_shape, input_shape, input_shape]} + opt_shape_dict = {'input': [input_shape, max_shape, max_shape]} max_workspace_size = get_GiB(workspace_size) trt_engine = onnx2trt( onnx_model, @@ -148,6 +149,12 @@ def parse_args(): nargs='+', default=[400, 600], help='Input size of the model') + parser.add_argument( + '--max-shape', + type=int, + nargs='+', + default=None, + help='Maximum input size of the model in TensorRT') parser.add_argument( '--mean', type=float, @@ -184,6 +191,16 @@ def parse_args(): else: raise ValueError('invalid input shape') + if not args.max_shape: + max_shape = input_shape + else: + if len(args.max_shape) == 1: + max_shape = (1, 3, args.max_shape[0], args.max_shape[0]) + elif len(args.max_shape) == 2: + max_shape = (1, 3) + tuple(args.max_shape) + else: + raise ValueError('invalid input max_shape') + assert len(args.mean) == 3 assert len(args.std) == 3 @@ -191,7 +208,8 @@ def parse_args(): input_config = { 'input_shape': input_shape, 'input_path': args.input_img, - 'normalize_cfg': normalize_cfg + 'normalize_cfg': normalize_cfg, + 'max_shape': max_shape } # Create TensorRT engine diff --git a/tools/deployment/test.py b/tools/deployment/test.py index 770589d68a6..b10d583870c 100644 --- a/tools/deployment/test.py +++ b/tools/deployment/test.py @@ -1,11 +1,11 @@ import argparse +import os import mmcv from mmcv import Config, DictAction from mmcv.parallel import MMDataParallel from mmdet.apis import single_gpu_test -from mmdet.core.export.model_wrappers import ONNXRuntimeDetector from mmdet.datasets import (build_dataloader, build_dataset, replace_ImageToTensor) @@ -103,8 +103,27 @@ def main(): dist=False, shuffle=False) - model = ONNXRuntimeDetector( - args.model, class_names=dataset.CLASSES, device_id=0) + _, file_ext = os.path.splitext(args.model) + assert file_ext in ['.onnx', '.trt'] + + if file_ext == '.onnx': + from mmdet.core.export.model_wrappers import ONNXRuntimeDetector + model = ONNXRuntimeDetector( + args.model, class_names=dataset.CLASSES, device_id=0) + elif file_ext == '.trt': + from mmdet.core.export.model_wrappers import TensorRTDetector + output_names = ['dets', 'labels'] + if len(cfg.evaluation['metric']) == 2: + output_names.append('masks') + model = TensorRTDetector( + args.model, + class_names=dataset.CLASSES, + device_id=0, + output_names=output_names) + else: + raise ValueError( + f'The extension of input model file should be `.onnx` or `.trt`, but given: {file_ext}' # noqa + ) model = MMDataParallel(model, device_ids=[0]) outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, From 55b8b8d839044fcfcd1c60b429b7f601d0f49dac Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 14 May 2021 16:09:32 +0800 Subject: [PATCH 2/8] update version of onnx --- docs/tutorials/onnx2tensorrt.md | 2 +- docs/tutorials/pytorch2onnx.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md index ded83145364..7f65feb8c73 100644 --- a/docs/tutorials/onnx2tensorrt.md +++ b/docs/tutorials/onnx2tensorrt.md @@ -93,7 +93,7 @@ The table below lists the models that are guaranteed to be convertable to Tensor Notes: -- *All models above are tested with Pytorch==1.6.0 and TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0* +- *All models above are tested with Pytorch==1.6.0, onnx==1.7.0 and TensorRT-7.2.1.6.Ubuntu-16.04.x86_64-gnu.cuda-10.2.cudnn8.0* ## Reminders diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index 8c09b1b0037..671df17ab99 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -226,7 +226,7 @@ The table below lists the models that are guaranteed to be exportable to ONNX an Notes: -- *All models above are tested with Pytorch==1.6.0 and onnxruntime==1.5.1* +- *All models above are tested with Pytorch==1.6.0, onnx==1.7.0 and onnxruntime==1.5.1* - If the deployed backend platform is TensorRT, please add environment variables before running the file: From 0132b69672fc319508532d47f58a107a4750b96b Mon Sep 17 00:00:00 2001 From: maningsheng Date: Mon, 17 May 2021 16:04:02 +0800 Subject: [PATCH 3/8] update maskrcnn results --- docs/tutorials/pytorch2onnx.md | 4 ++-- mmdet/core/export/model_wrappers.py | 5 +---- mmdet/models/roi_heads/mask_heads/fcn_mask_head.py | 8 -------- tools/deployment/onnx2tensorrt.py | 3 +++ tools/deployment/pytorch2onnx.py | 2 ++ 5 files changed, 8 insertions(+), 14 deletions(-) diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index 671df17ab99..111793df27a 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -194,13 +194,13 @@ python tools/deployment/test.py \ Box AP 38.2 38.1 - / + 37.7 Mask AP 34.7 33.7 - / + 33.3 diff --git a/mmdet/core/export/model_wrappers.py b/mmdet/core/export/model_wrappers.py index 7eb11f3fbc0..8cdb19e222b 100644 --- a/mmdet/core/export/model_wrappers.py +++ b/mmdet/core/export/model_wrappers.py @@ -61,15 +61,12 @@ def forward(self, img, img_metas, return_loss=True, **kwargs): ori_h, ori_w = img_metas[i]['ori_shape'][:2] masks = masks[:, :img_h, :img_w] if rescale: - mask_dtype = masks.dtype masks = masks.astype(np.float32) masks = torch.from_numpy(masks) masks = torch.nn.functional.interpolate( masks.unsqueeze(0), size=(ori_h, ori_w)) masks = masks.squeeze(0).detach().numpy() - # convert mask to range(0,1) - if mask_dtype != np.bool: - masks /= 255 + if masks.dtype != np.bool: masks = masks >= 0.5 segms_results = [[] for _ in range(len(self.CLASSES))] for j in range(len(dets)): diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py index 4204b682902..5b45f8b8e0e 100644 --- a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py +++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py @@ -1,5 +1,3 @@ -import os - import numpy as np import torch import torch.nn as nn @@ -265,12 +263,6 @@ class label c. mask_pred, bboxes, img_h, img_w, skip_empty=False) if threshold >= 0: masks = (masks >= threshold).to(dtype=torch.bool) - else: - # TensorRT backend does not have data type of uint8 - is_trt_backend = os.environ.get( - 'ONNX_BACKEND') == 'MMCVTensorRT' - target_dtype = torch.int32 if is_trt_backend else torch.uint8 - masks = (masks * 255).to(dtype=target_dtype) return masks N = len(mask_pred) diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py index 52ad852e5e9..1cfcd3249e0 100644 --- a/tools/deployment/onnx2tensorrt.py +++ b/tools/deployment/onnx2tensorrt.py @@ -85,6 +85,9 @@ def onnx2tensorrt(onnx_file, output shapes: {trt_shapes}') trt_masks = trt_outputs[2] if with_mask else None + if trt_masks is not None and trt_masks.dtype != np.bool: + trt_masks = trt_masks >= 0.5 + ort_masks = ort_masks >= 0.5 # Show detection outputs if show: CLASSES = get_classes(dataset) diff --git a/tools/deployment/pytorch2onnx.py b/tools/deployment/pytorch2onnx.py index 343402b517d..938158b53f6 100644 --- a/tools/deployment/pytorch2onnx.py +++ b/tools/deployment/pytorch2onnx.py @@ -149,6 +149,8 @@ def pytorch2onnx(config_path, onnx_results = bbox2result(ort_dets, ort_labels, num_classes) if model.with_mask: segm_results = onnx_outputs[2] + if segm_results.dtype != np.bool: + segm_results = (segm_results * 255).astype(np.uint8) cls_segms = [[] for _ in range(num_classes)] for i in range(ort_dets.shape[0]): cls_segms[ort_labels[i]].append(segm_results[i]) From d889a12b54613ac39c569c475a0084908a4185a9 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Tue, 18 May 2021 14:22:04 +0800 Subject: [PATCH 4/8] add backend argument --- docs/tutorials/pytorch2onnx.md | 6 +++++- mmdet/core/export/model_wrappers.py | 3 ++- tools/deployment/test.py | 13 +++++++------ 3 files changed, 14 insertions(+), 8 deletions(-) diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index 111793df27a..e7396340a5c 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -102,6 +102,8 @@ We prepare a tool `tools/deplopyment/test.py` to evaluate ONNX models with ONNXR pip install onnx onnxruntime-gpu ``` +- Install TensorRT by referring to [how-to-build-tensorrt-plugins-in-mmcv](https://mmcv.readthedocs.io/en/latest/tensorrt_plugin.html#how-to-build-tensorrt-plugins-in-mmcv)(optional) + ### Usage ```bash @@ -109,6 +111,7 @@ python tools/deployment/test.py \ ${CONFIG_FILE} \ ${MODEL_FILE} \ --out ${OUTPUT_FILE} \ + --backend ${BACKEND} \ --format-only ${FORMAT_ONLY} \ --eval ${EVALUATION_METRICS} \ --show-dir ${SHOW_DIRECTORY} \ @@ -120,8 +123,9 @@ python tools/deployment/test.py \ ### Description of all arguments - `config`: The path of a model config file. -- `model`: The path of an input model file and it should have extension of `.onnx` or `.trt` . +- `model`: The path of an input model file. - `--out`: The path of output result file in pickle format. +- `--backend`: Backend for input model to run and should be `onnxruntime` or `tensorrt`. - `--format-only` : Format the output results without perform evaluation. It is useful when you want to format the result to a specific format and submit it to the test server. If not specified, it will be set to `False`. - `--eval`: Evaluation metrics, which depends on the dataset, e.g., "bbox", "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC. - `--show-dir`: Directory where painted images will be saved diff --git a/mmdet/core/export/model_wrappers.py b/mmdet/core/export/model_wrappers.py index 8cdb19e222b..844c73f7bd0 100644 --- a/mmdet/core/export/model_wrappers.py +++ b/mmdet/core/export/model_wrappers.py @@ -2,7 +2,6 @@ import warnings import numpy as np -import onnxruntime as ort import torch from mmdet.core import bbox2result @@ -82,6 +81,8 @@ class ONNXRuntimeDetector(DeployBaseDetector): def __init__(self, onnx_file, class_names, device_id): super(ONNXRuntimeDetector, self).__init__(class_names, device_id) + import onnxruntime as ort + # get the custom op path ort_custom_op_path = '' try: diff --git a/tools/deployment/test.py b/tools/deployment/test.py index b10d583870c..d8f739cb6cb 100644 --- a/tools/deployment/test.py +++ b/tools/deployment/test.py @@ -1,5 +1,4 @@ import argparse -import os import mmcv from mmcv import Config, DictAction @@ -22,6 +21,11 @@ def parse_args(): help='Format the output results without perform evaluation. It is' 'useful when you want to format the result to a specific format and ' 'submit it to the test server') + parser.add_argument( + '--backend', + required=True, + choices=['onnxruntime', 'tensorrt'], + help='Backend for input model to run. ') parser.add_argument( '--eval', type=str, @@ -103,14 +107,11 @@ def main(): dist=False, shuffle=False) - _, file_ext = os.path.splitext(args.model) - assert file_ext in ['.onnx', '.trt'] - - if file_ext == '.onnx': + if args.backend == 'onnxruntime': from mmdet.core.export.model_wrappers import ONNXRuntimeDetector model = ONNXRuntimeDetector( args.model, class_names=dataset.CLASSES, device_id=0) - elif file_ext == '.trt': + elif args.backend == 'tensorrt': from mmdet.core.export.model_wrappers import TensorRTDetector output_names = ['dets', 'labels'] if len(cfg.evaluation['metric']) == 2: From 1ca37c556d7ff6754cca3cedfbe6d5c4c609e33d Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 21 May 2021 10:20:34 +0800 Subject: [PATCH 5/8] update fcos results --- docs/tutorials/pytorch2onnx.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index e7396340a5c..82bcd08603f 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -150,7 +150,7 @@ python tools/deployment/test.py \ Box AP 36.6 36.5 - / + 36.3 FSAF From 46dcf1bcce7d6e62e44e4a5cb522bfcd6dc3c252 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Fri, 21 May 2021 18:33:20 +0800 Subject: [PATCH 6/8] update --- docs/tutorials/onnx2tensorrt.md | 2 +- tools/deployment/onnx2tensorrt.py | 2 +- tools/deployment/test.py | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/tutorials/onnx2tensorrt.md b/docs/tutorials/onnx2tensorrt.md index 7f65feb8c73..0831b4aa78b 100644 --- a/docs/tutorials/onnx2tensorrt.md +++ b/docs/tutorials/onnx2tensorrt.md @@ -85,7 +85,7 @@ The table below lists the models that are guaranteed to be convertable to Tensor | :----------: | :--------------------------------------------------: | :-----------: | :-------------: | :---: | | SSD | `configs/ssd/ssd300_coco.py` | Y | Y | | | FSAF | `configs/fsaf/fsaf_r50_fpn_1x_coco.py` | Y | Y | | -| FCOS | `configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py` | N | Y | | +| FCOS | `configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py` | Y | Y | | | YOLOv3 | `configs/yolo/yolov3_d53_mstrain-608_273e_coco.py` | Y | Y | | | RetinaNet | `configs/retinanet/retinanet_r50_fpn_1x_coco.py` | Y | Y | | | Faster R-CNN | `configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py` | Y | Y | | diff --git a/tools/deployment/onnx2tensorrt.py b/tools/deployment/onnx2tensorrt.py index 1cfcd3249e0..85dbcd49216 100644 --- a/tools/deployment/onnx2tensorrt.py +++ b/tools/deployment/onnx2tensorrt.py @@ -33,7 +33,7 @@ def onnx2tensorrt(onnx_file, input_shape = input_config['input_shape'] max_shape = input_config['max_shape'] # create trt engine and wraper - opt_shape_dict = {'input': [input_shape, max_shape, max_shape]} + opt_shape_dict = {'input': [input_shape, input_shape, max_shape]} max_workspace_size = get_GiB(workspace_size) trt_engine = onnx2trt( onnx_model, diff --git a/tools/deployment/test.py b/tools/deployment/test.py index d8f739cb6cb..c694f9404a4 100644 --- a/tools/deployment/test.py +++ b/tools/deployment/test.py @@ -121,10 +121,6 @@ def main(): class_names=dataset.CLASSES, device_id=0, output_names=output_names) - else: - raise ValueError( - f'The extension of input model file should be `.onnx` or `.trt`, but given: {file_ext}' # noqa - ) model = MMDataParallel(model, device_ids=[0]) outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, From 8ed6eec193d498999f81a38adf8387d050d7dfc5 Mon Sep 17 00:00:00 2001 From: maningsheng Date: Mon, 24 May 2021 10:04:57 +0800 Subject: [PATCH 7/8] fix bug --- mmdet/models/dense_heads/rpn_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py index 60360b7e97e..7113113ffea 100644 --- a/mmdet/models/dense_heads/rpn_head.py +++ b/mmdet/models/dense_heads/rpn_head.py @@ -218,7 +218,7 @@ def _get_bboxes(self, from mmdet.core.export import add_dummy_nms_for_onnx batch_mlvl_scores = batch_mlvl_scores.unsqueeze(2) score_threshold = cfg.nms.get('score_thr', 0.0) - nms_pre = cfg.get('deploy_nms_pre', cfg.max_per_img) + nms_pre = cfg.get('deploy_nms_pre', -1) dets, _ = add_dummy_nms_for_onnx(batch_mlvl_proposals, batch_mlvl_scores, cfg.max_per_img, From 3af71bdf01e95a8a2b91cb1bea675c39f16d3d1f Mon Sep 17 00:00:00 2001 From: maningsheng Date: Mon, 24 May 2021 16:03:23 +0800 Subject: [PATCH 8/8] update doc --- docs/tutorials/pytorch2onnx.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/tutorials/pytorch2onnx.md b/docs/tutorials/pytorch2onnx.md index 82bcd08603f..e834a2574f1 100644 --- a/docs/tutorials/pytorch2onnx.md +++ b/docs/tutorials/pytorch2onnx.md @@ -85,9 +85,7 @@ python tools/deployment/pytorch2onnx.py \ --verify \ --dynamic-export \ --cfg-options \ - model.test_cfg.nms_pre=200 \ - model.test_cfg.max_per_img=200 \ - model.test_cfg.deploy_nms_pre=300 \ + model.test_cfg.deploy_nms_pre=-1 \ ``` ## How to evaluate the exported models