Skip to content

Commit

Permalink
add gfl_trt (#124)
Browse files Browse the repository at this point in the history
* add gfl_trt

* add gfl_head.py

* add batch_integral

* lint code

* add gfl unit test

* fix unit test

* add gfl benchmark

* fix unit test bug

* Update gfl_head.py

* Update __init__.py

remove '**_forward_single'

* fix lint error and ut error

* fix docs and benchmark

Co-authored-by: VVsssssk <shenkun@pjlab.org.cn>
  • Loading branch information
Richard-mei and VVsssssk authored Feb 28, 2022
1 parent e89becd commit ba5351e
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 1 deletion.
14 changes: 14 additions & 0 deletions docs/en/benchmark.md
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,20 @@ Users can directly test the performance through [how_to_evaluate_a_model.md](tut
<td align="center">-</td>
<td>$MMDET_DIR/configs/cascade_rcnn/cascade_rcnn_r50_caffe_fpn_1x_coco.py</td>
</tr>
<tr>
<td align="center">GFL</td>
<td align="center">Object Detection</td>
<td align="center">COCO2017</td>
<td align="center">box AP</td>
<td align="center">40.2</td>
<td align="center">-</td>
<td align="center">40.2</td>
<td align="center">40.2</td>
<td align="center">40.0</td>
<td align="center">-</td>
<td align="center">-</td>
<td>$MMDET_DIR/configs/gfl/gfl_r50_fpn_1x_coco.py</td>
</tr>
<tr>
<td align="center" rowspan="2">Mask R-CNN</td>
<td align="center" rowspan="2">Instance Segmentation</td>
Expand Down
1 change: 1 addition & 0 deletions docs/en/codebases/mmdet.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Please refer to [get_started.md](https://github.com/open-mmlab/mmdetection/blob/
| Cascade R-CNN | ObjectDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Faster R-CNN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| Faster R-CNN + DCN | ObjectDetection | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/faster_rcnn) |
| GFL | ObjectDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
| Cascade Mask R-CNN | InstanceSegmentation | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Mask R-CNN | InstanceSegmentation | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/mask_rcnn) |

Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| SSD[*](#note) | MMDetection | Y | Y | Y | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/ssd) |
| FoveaBox | MMDetection | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/foveabox) |
| ATSS | MMDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/atss) |
| GFL | MMDetection | Y | Y | N | ? | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/gfl) |
| Cascade R-CNN | MMDetection | Y | Y | N | Y | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| Cascade Mask R-CNN | MMDetection | Y | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/cascade_rcnn) |
| VFNet | MMDetection | N | N | N | N | Y | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/vfnet) |
Expand Down
4 changes: 3 additions & 1 deletion mmdeploy/codebase/mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .base_dense_head import (base_dense_head__get_bbox,
base_dense_head__get_bboxes__ncnn)
from .fovea_head import fovea_head__get_bboxes
from .gfl_head import gfl_head__get_bbox
from .rpn_head import rpn_head__get_bboxes, rpn_head__get_bboxes__ncnn
from .ssd_head import ssd_head__get_bboxes__ncnn
from .yolo_head import yolov3_head__get_bboxes, yolov3_head__get_bboxes__ncnn
Expand All @@ -12,5 +13,6 @@
'yolov3_head__get_bboxes', 'yolov3_head__get_bboxes__ncnn',
'yolox_head__get_bboxes', 'base_dense_head__get_bbox',
'fovea_head__get_bboxes', 'base_dense_head__get_bboxes__ncnn',
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn'
'ssd_head__get_bboxes__ncnn', 'yolox_head__get_bboxes__ncnn',
'gfl_head__get_bbox'
]
185 changes: 185 additions & 0 deletions mmdeploy/codebase/mmdet/models/dense_heads/gfl_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.gfl_head.'
'GFLHead.get_bboxes')
def gfl_head__get_bbox(ctx,
self,
cls_scores,
bbox_preds,
score_factors=None,
img_metas=None,
cfg=None,
rescale=False,
with_nms=True,
**kwargs):
"""Rewrite `get_bboxes` of `GFLHead` for default backend.
Rewrite this function to deploy model, transform network output for a
batch into bbox predictions.
Args:
ctx (ContextCaller): The context with additional information.
self: The instance of the original class.
cls_scores (list[Tensor]): Classification scores for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for all
scale levels, each is a 4D-tensor, has shape
(batch_size, num_priors * 4, H, W).
score_factors (list[Tensor], Optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Default None.
img_metas (list[dict], Optional): Image meta info. Default None.
cfg (mmcv.Config, Optional): Test / postprocessing configuration,
if None, test_cfg would be used. Default None.
rescale (bool): If True, return boxes in original image space.
Default False.
with_nms (bool): If True, do nms before return boxes.
Default True.
Returns:
If with_nms == True:
tuple[Tensor, Tensor]: tuple[Tensor, Tensor]: (dets, labels),
`dets` of shape [N, num_det, 5] and `labels` of shape
[N, num_det].
Else:
tuple[Tensor, Tensor, Tensor]: batch_mlvl_bboxes,
batch_mlvl_scores, batch_mlvl_centerness
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
backend = get_backend(deploy_cfg)
num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
mlvl_priors = self.prior_generator.grid_priors(
featmap_sizes, dtype=bbox_preds[0].dtype, device=bbox_preds[0].device)

mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
if score_factors is None:
with_score_factors = False
mlvl_score_factor = [None for _ in range(num_levels)]
else:
with_score_factors = True
mlvl_score_factor = [
score_factors[i].detach() for i in range(num_levels)
]
mlvl_score_factors = []
assert img_metas is not None
img_shape = img_metas[0]['img_shape']

assert len(cls_scores) == len(bbox_preds) == len(mlvl_priors)
batch_size = cls_scores[0].shape[0]
cfg = self.test_cfg
pre_topk = cfg.get('nms_pre', -1)

mlvl_valid_bboxes = []
mlvl_valid_scores = []
mlvl_valid_priors = []

for cls_score, bbox_pred, score_factors, priors, stride in zip(
mlvl_cls_scores, mlvl_bbox_preds, mlvl_score_factor, mlvl_priors,
self.prior_generator.strides):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
assert stride[0] == stride[1]

scores = cls_score.permute(0, 2, 3, 1).reshape(batch_size, -1,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
nms_pre_score = scores
else:
scores = scores.softmax(-1)
nms_pre_score = scores
if with_score_factors:
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
-1).sigmoid()
score_factors = score_factors.unsqueeze(2)
bbox_pred = batched_integral(self.integral,
bbox_pred.permute(0, 2, 3, 1)) * stride[0]
if not is_dynamic_flag:
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
if with_score_factors:
nms_pre_score = nms_pre_score * score_factors
if backend == Backend.TENSORRT:
priors = pad_with_value(priors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value(score_factors, 1, pre_topk,
0.)

# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
max_scores, _ = nms_pre_score.max(-1)
else:
max_scores, _ = nms_pre_score[..., :-1].max(-1)
_, topk_inds = max_scores.topk(pre_topk)
batch_inds = torch.arange(
batch_size,
device=bbox_pred.device).view(-1, 1).expand_as(topk_inds)
priors = priors[batch_inds, topk_inds, :]
bbox_pred = bbox_pred[batch_inds, topk_inds, :]
scores = scores[batch_inds, topk_inds, :]
if with_score_factors:
score_factors = score_factors[batch_inds, topk_inds, :]

mlvl_valid_bboxes.append(bbox_pred)
mlvl_valid_scores.append(scores)
priors = self.anchor_center(priors)
mlvl_valid_priors.append(priors)
if with_score_factors:
mlvl_score_factors.append(score_factors)

batch_mlvl_bboxes_pred = torch.cat(mlvl_valid_bboxes, dim=1)
batch_scores = torch.cat(mlvl_valid_scores, dim=1)
batch_priors = torch.cat(mlvl_valid_priors, dim=1)
batch_bboxes = self.bbox_coder.decode(
batch_priors, batch_mlvl_bboxes_pred, max_shape=img_shape)
if with_score_factors:
batch_score_factors = torch.cat(mlvl_score_factors, dim=1)

if not self.use_sigmoid_cls:
batch_scores = batch_scores[..., :self.num_classes]

if with_score_factors:
batch_scores = batch_scores * batch_score_factors
if not with_nms:
return batch_bboxes, batch_scores
post_params = get_post_processing_params(deploy_cfg)
max_output_boxes_per_class = post_params.max_output_boxes_per_class
iou_threshold = cfg.nms.get('iou_threshold', post_params.iou_threshold)
score_threshold = cfg.get('score_thr', post_params.score_threshold)
pre_top_k = post_params.pre_top_k
keep_top_k = cfg.get('max_per_img', post_params.keep_top_k)
return multiclass_nms(
batch_bboxes,
batch_scores,
max_output_boxes_per_class,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
pre_top_k=pre_top_k,
keep_top_k=keep_top_k)


def batched_integral(intergral, x):
batch_size = x.size(0)
x = F.softmax(x.reshape(batch_size, -1, intergral.reg_max + 1), dim=2)
x = F.linear(x,
intergral.project.type_as(x).unsqueeze(0)).reshape(
batch_size, -1, 4)
return x
107 changes: 107 additions & 0 deletions tests/test_codebase/test_mmdet/test_mmdet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,31 @@ def get_single_roi_extractor():
return model


def get_gfl_head_model():
test_cfg = mmcv.Config(
dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
anchor_generator = dict(
type='AnchorGenerator',
scales_per_octave=1,
octave_base_scale=8,
ratios=[1.0],
strides=[8, 16, 32, 64, 128])
from mmdet.models.dense_heads import GFLHead
model = GFLHead(
num_classes=3,
in_channels=256,
reg_max=3,
test_cfg=test_cfg,
anchor_generator=anchor_generator)
model.requires_grad_(False)
return model


def test_focus_forward_ncnn():
backend_type = Backend.NCNN
check_backend(backend_type)
Expand Down Expand Up @@ -349,6 +374,88 @@ def test_get_bboxes_of_rpn_head(backend_type: Backend):
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_get_bboxes_of_gfl_head(backend_type):
check_backend(backend_type)
head = get_gfl_head_model()
head.cpu().eval()
s = 4
img_metas = [{
'scale_factor': np.ones(4),
'pad_shape': (s, s, 3),
'img_shape': (s, s, 3)
}]
output_names = ['dets']
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(output_names=output_names, input_shape=None),
codebase_config=dict(
type='mmdet',
task='ObjectDetection',
model_type='ncnn_end2end',
post_processing=dict(
score_threshold=0.05,
iou_threshold=0.5,
max_output_boxes_per_class=200,
pre_top_k=5000,
keep_top_k=100,
background_label_id=-1,
))))

seed_everything(1234)
cls_score = [
torch.rand(1, 3, pow(2, i), pow(2, i)) for i in range(5, 0, -1)
]
seed_everything(5678)
bboxes = [torch.rand(1, 16, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]

# to get outputs of onnx model after rewrite
img_metas[0]['img_shape'] = torch.Tensor([s, s])
wrapped_model = WrapModel(
head, 'get_bboxes', img_metas=img_metas, with_nms=True)
rewrite_inputs = {
'cls_scores': cls_score,
'bbox_preds': bboxes,
}
# do not run with ncnn backend
run_with_backend = False if backend_type in [Backend.NCNN] else True
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg,
run_with_backend=run_with_backend)
assert rewrite_outputs is not None


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_forward_of_gfl_head(backend_type):
check_backend(backend_type)
head = get_gfl_head_model()
head.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(input_shape=None)))
feats = [torch.rand(1, 256, pow(2, i), pow(2, i)) for i in range(5, 0, -1)]
model_outputs = [head.forward(feats)]
wrapped_model = WrapModel(head, 'forward')
rewrite_inputs = {
'feats': feats,
}
rewrite_outputs, is_backend_output = get_rewrite_outputs(
wrapped_model=wrapped_model,
model_inputs=rewrite_inputs,
deploy_cfg=deploy_cfg)
model_outputs[0] = [*model_outputs[0][0], *model_outputs[0][1]]
for model_output, rewrite_output in zip(model_outputs[0],
rewrite_outputs[0]):
model_output = model_output.squeeze().cpu().numpy()
rewrite_output = rewrite_output.squeeze()
assert np.allclose(
model_output, rewrite_output, rtol=1e-03, atol=1e-05)


def _replace_r50_with_r18(model):
"""Replace ResNet50 with ResNet18 in config."""
model = copy.deepcopy(model)
Expand Down

0 comments on commit ba5351e

Please sign in to comment.