From ba5351e20948591effc75e644014923282c1fa10 Mon Sep 17 00:00:00 2001 From: Richard-mei Date: Mon, 28 Feb 2022 16:28:00 +0800 Subject: [PATCH] add gfl_trt (#124) * add gfl_trt * add gfl_head.py * add batch_integral * lint code * add gfl unit test * fix unit test * add gfl benchmark * fix unit test bug * Update gfl_head.py * Update __init__.py remove '**_forward_single' * fix lint error and ut error * fix docs and benchmark Co-authored-by: VVsssssk --- docs/en/benchmark.md | 14 ++ docs/en/codebases/mmdet.md | 1 + docs/en/supported_models.md | 1 + .../mmdet/models/dense_heads/__init__.py | 4 +- .../mmdet/models/dense_heads/gfl_head.py | 185 ++++++++++++++++++ .../test_mmdet/test_mmdet_models.py | 107 ++++++++++ 6 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py diff --git a/docs/en/benchmark.md b/docs/en/benchmark.md index 8a5035de5e..41c748c9d2 100644 --- a/docs/en/benchmark.md +++ b/docs/en/benchmark.md @@ -996,6 +996,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut - $MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_caffe_fpn_1x_coco.py + + GFL + Object Detection + COCO2017 + box AP + 40.2 + - + 40.2 + 40.2 + 40.0 + - + - + $MMDET_DIR/configs/gfl/gfl_r50_fpn_1x_coco.py + Mask R-CNN Instance Segmentation diff --git a/docs/en/codebases/mmdet.md b/docs/en/codebases/mmdet.md index e5b4f5409d..f03bf7c60f 100644 --- a/docs/en/codebases/mmdet.md +++ b/docs/en/codebases/mmdet.md @@ -22,6 +22,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/ | Cascade R-CNN | ObjectDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | Faster R-CNN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | | Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) | +| GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | | Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) | diff --git a/docs/en/supported_models.md b/docs/en/supported_models.md index 8f13fe93df..2c865bfa20 100644 --- a/docs/en/supported_models.md +++ b/docs/en/supported_models.md @@ -14,6 +14,7 @@ The table below lists the models that are guaranteed to be exportable to other b | SSD[*](#note) | MMDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) | | FoveaBox | MMDetection | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) | | ATSS | MMDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) | +| GFL | MMDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) | | Cascade R-CNN | MMDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | Cascade Mask R-CNN | MMDetection | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) | | VFNet | MMDetection | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) | diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py index 080fe6c7ba..9043d44264 100644 --- a/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py +++ b/mmdeploy/codebase/mmdet/models/dense_heads/__init__.py @@ -2,6 +2,7 @@ from .base_dense_head import (base_dense_head__get_bbox, base_dense_head__get_bboxes__ncnn) from .fovea_head import fovea_head__get_bboxes +from .gfl_head import gfl_head__get_bbox from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn from .ssd_head import ssd_head__get_bboxes__ncnn from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn @@ -12,5 +13,6 @@ 'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn', 'yolox_head__get_bboxes', 'base_dense_head__get_bbox', 'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn', - 'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn' + 'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn', + 'gfl_head__get_bbox' ] diff --git a/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py new file mode 100644 index 0000000000..8dba8b5666 --- /dev/null +++ b/mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py @@ -0,0 +1,185 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn.functional as F + +from mmdeploy.codebase.mmdet import (get_post_processing_params, + multiclass_nms, pad_with_value) +from mmdeploy.core import FUNCTION_REWRITER +from mmdeploy.utils import Backend, get_backend, is_dynamic_shape + + +@FUNCTION_REWRITER.register_rewriter( + func_name='mmdet.models.dense_heads.gfl_head.' + 'GFLHead.get_bboxes') +def gfl_head__get_bbox(ctx, + self, + cls_scores, + bbox_preds, + score_factors=None, + img_metas=None, + cfg=None, + rescale=False, + with_nms=True, + **kwargs): + """Rewrite `get_bboxes` of `GFLHead` for default backend. + + Rewrite this function to deploy model, transform network output for a + batch into bbox predictions. + + Args: + ctx (ContextCaller): The context with additional information. + self: The instance of the original class. + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], Optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Default None. + img_metas (list[dict], Optional): Image meta info. Default None. + cfg (mmcv.Config, Optional): Test / postprocessing configuration, + if None, test_cfg would be used. Default None. + rescale (bool): If True, return boxes in original image space. + Default False. + with_nms (bool): If True, do nms before return boxes. + Default True. + + Returns: + If with_nms == True: + tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels), + `dets` of shape [N, num_det, 5] and `labels` of shape + [N, num_det]. + Else: + tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes, + batch_mlvl_scores, batch_mlvl_centerness + """ + deploy_cfg = ctx.cfg + is_dynamic_flag = is_dynamic_shape(deploy_cfg) + backend = get_backend(deploy_cfg) + num_levels = len(cls_scores) + + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device) + + mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)] + mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)] + if score_factors is None: + with_score_factors = False + mlvl_score_factor = [None for _ in range(num_levels)] + else: + with_score_factors = True + mlvl_score_factor = [ + score_factors[i].detach() for i in range(num_levels) + ] + mlvl_score_factors = [] + assert img_metas is not None + img_shape = img_metas[0]['img_shape'] + + assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors) + batch_size = cls_scores[0].shape[0] + cfg = self.test_cfg + pre_topk = cfg.get('nms_pre', -1) + + mlvl_valid_bboxes = [] + mlvl_valid_scores = [] + mlvl_valid_priors = [] + + for cls_score, bbox_pred, score_factors, priors, stride in zip( + mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors, + self.prior_generator.strides): + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + assert stride[0] == stride[1] + + scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1, + self.cls_out_channels) + if self.use_sigmoid_cls: + scores = scores.sigmoid() + nms_pre_score = scores + else: + scores = scores.softmax(-1) + nms_pre_score = scores + if with_score_factors: + score_factors = score_factors.permute(0, 2, 3, + 1).reshape(batch_size, + -1).sigmoid() + score_factors = score_factors.unsqueeze(2) + bbox_pred = batched_integral(self.integral, + bbox_pred.permute(0, 2, 3, 1)) * stride[0] + if not is_dynamic_flag: + priors = priors.data + priors = priors.expand(batch_size, -1, priors.size(-1)) + if pre_topk > 0: + if with_score_factors: + nms_pre_score = nms_pre_score * score_factors + if backend == Backend.TENSORRT: + priors = pad_with_value(priors, 1, pre_topk) + bbox_pred = pad_with_value(bbox_pred, 1, pre_topk) + scores = pad_with_value(scores, 1, pre_topk, 0.) + nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.) + if with_score_factors: + score_factors = pad_with_value(score_factors, 1, pre_topk, + 0.) + + # Get maximum scores for foreground classes. + if self.use_sigmoid_cls: + max_scores, _ = nms_pre_score.max(-1) + else: + max_scores, _ = nms_pre_score[..., :-1].max(-1) + _, topk_inds = max_scores.topk(pre_topk) + batch_inds = torch.arange( + batch_size, + device=bbox_pred.device).view(-1, 1).expand_as(topk_inds) + priors = priors[batch_inds, topk_inds, :] + bbox_pred = bbox_pred[batch_inds, topk_inds, :] + scores = scores[batch_inds, topk_inds, :] + if with_score_factors: + score_factors = score_factors[batch_inds, topk_inds, :] + + mlvl_valid_bboxes.append(bbox_pred) + mlvl_valid_scores.append(scores) + priors = self.anchor_center(priors) + mlvl_valid_priors.append(priors) + if with_score_factors: + mlvl_score_factors.append(score_factors) + + batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1) + batch_scores = torch.cat(mlvl_valid_scores, dim=1) + batch_priors = torch.cat(mlvl_valid_priors, dim=1) + batch_bboxes = self.bbox_coder.decode( + batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape) + if with_score_factors: + batch_score_factors = torch.cat(mlvl_score_factors, dim=1) + + if not self.use_sigmoid_cls: + batch_scores = batch_scores[..., :self.num_classes] + + if with_score_factors: + batch_scores = batch_scores * batch_score_factors + if not with_nms: + return batch_bboxes, batch_scores + post_params = get_post_processing_params(deploy_cfg) + max_output_boxes_per_class = post_params.max_output_boxes_per_class + iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold) + score_threshold = cfg.get('score_thr', post_params.score_threshold) + pre_top_k = post_params.pre_top_k + keep_top_k = cfg.get('max_per_img', post_params.keep_top_k) + return multiclass_nms( + batch_bboxes, + batch_scores, + max_output_boxes_per_class, + iou_threshold=iou_threshold, + score_threshold=score_threshold, + pre_top_k=pre_top_k, + keep_top_k=keep_top_k) + + +def batched_integral(intergral, x): + batch_size = x.size(0) + x = F.softmax(x.reshape(batch_size, -1, intergral.reg_max + 1), dim=2) + x = F.linear(x, + intergral.project.type_as(x).unsqueeze(0)).reshape( + batch_size, -1, 4) + return x diff --git a/tests/test_codebase/test_mmdet/test_mmdet_models.py b/tests/test_codebase/test_mmdet/test_mmdet_models.py index 9cfe6e83d9..192bf72198 100644 --- a/tests/test_codebase/test_mmdet/test_mmdet_models.py +++ b/tests/test_codebase/test_mmdet/test_mmdet_models.py @@ -157,6 +157,31 @@ def get_single_roi_extractor(): return model +def get_gfl_head_model(): + test_cfg = mmcv.Config( + dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100)) + anchor_generator = dict( + type='AnchorGenerator', + scales_per_octave=1, + octave_base_scale=8, + ratios=[1.0], + strides=[8, 16, 32, 64, 128]) + from mmdet.models.dense_heads import GFLHead + model = GFLHead( + num_classes=3, + in_channels=256, + reg_max=3, + test_cfg=test_cfg, + anchor_generator=anchor_generator) + model.requires_grad_(False) + return model + + def test_focus_forward_ncnn(): backend_type = Backend.NCNN check_backend(backend_type) @@ -349,6 +374,88 @@ def test_get_bboxes_of_rpn_head(backend_type: Backend): assert rewrite_outputs is not None +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_get_bboxes_of_gfl_head(backend_type): + check_backend(backend_type) + head = get_gfl_head_model() + head.cpu().eval() + s = 4 + img_metas = [{ + 'scale_factor': np.ones(4), + 'pad_shape': (s, s, 3), + 'img_shape': (s, s, 3) + }] + output_names = ['dets'] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(output_names=output_names, input_shape=None), + codebase_config=dict( + type='mmdet', + task='ObjectDetection', + model_type='ncnn_end2end', + post_processing=dict( + score_threshold=0.05, + iou_threshold=0.5, + max_output_boxes_per_class=200, + pre_top_k=5000, + keep_top_k=100, + background_label_id=-1, + )))) + + seed_everything(1234) + cls_score = [ + torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1) + ] + seed_everything(5678) + bboxes = [torch.rand(1, 16, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] + + # to get outputs of onnx model after rewrite + img_metas[0]['img_shape'] = torch.Tensor([s, s]) + wrapped_model = WrapModel( + head, 'get_bboxes', img_metas=img_metas, with_nms=True) + rewrite_inputs = { + 'cls_scores': cls_score, + 'bbox_preds': bboxes, + } + # do not run with ncnn backend + run_with_backend = False if backend_type in [Backend.NCNN] else True + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg, + run_with_backend=run_with_backend) + assert rewrite_outputs is not None + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_forward_of_gfl_head(backend_type): + check_backend(backend_type) + head = get_gfl_head_model() + head.cpu().eval() + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict(input_shape=None))) + feats = [torch.rand(1, 256, pow(2, i), pow(2, i)) for i in range(5, 0, -1)] + model_outputs = [head.forward(feats)] + wrapped_model = WrapModel(head, 'forward') + rewrite_inputs = { + 'feats': feats, + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + model_outputs[0] = [*model_outputs[0][0], *model_outputs[0][1]] + for model_output, rewrite_output in zip(model_outputs[0], + rewrite_outputs[0]): + model_output = model_output.squeeze().cpu().numpy() + rewrite_output = rewrite_output.squeeze() + assert np.allclose( + model_output, rewrite_output, rtol=1e-03, atol=1e-05) + + def _replace_r50_with_r18(model): """Replace ResNet50 with ResNet18 in config.""" model = copy.deepcopy(model)