Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Update pad logic in detection heads #168

Merged
merged 5 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mmdeploy/codebase/mmdet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .core import * # noqa: F401,F403
from .deploy import (MMDetection, ObjectDetection, clip_bboxes,
get_post_processing_params, pad_with_value)
get_post_processing_params, pad_with_value,
pad_with_value_if_necessary)
from .models import * # noqa: F401,F403

__all__ = [
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
'MMDetection', 'ObjectDetection'
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
]
5 changes: 3 additions & 2 deletions mmdeploy/codebase/mmdet/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmdetection import MMDetection
from .object_detection import ObjectDetection
from .utils import clip_bboxes, get_post_processing_params, pad_with_value
from .utils import (clip_bboxes, get_post_processing_params, pad_with_value,
pad_with_value_if_necessary)

__all__ = [
'get_post_processing_params', 'clip_bboxes', 'pad_with_value',
'MMDetection', 'ObjectDetection'
'pad_with_value_if_necessary', 'MMDetection', 'ObjectDetection'
]
61 changes: 60 additions & 1 deletion mmdeploy/codebase/mmdet/deploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.core.rewriters.rewriter_utils import LibVersionChecker
from mmdeploy.utils import load_config
from mmdeploy.utils import Backend, load_config


def get_post_processing_params(deploy_cfg: Union[str, mmcv.Config]):
Expand Down Expand Up @@ -127,3 +127,62 @@ def pad_with_value(x: Tensor,
x_pad = x_pad.repeat(*repeat_size)
x = torch.cat([x, x_pad], dim=pad_dim)
return x


def pad_with_value_if_necessary(x: Tensor,
pad_dim: int,
pad_size: int,
pad_value: Optional[Any] = None):
"""Pad a tensor with a value along some dim if necessary.

Args:
x (Tensor): Input tensor.
pad_dim (int): Along which dim to pad.
pad_size (int): To which size to pad.
pad_value (Any): Filled value for padding. Defaults to `None`.

Returns:
Tensor: Padded tensor.
"""
return __pad_with_value_if_necessary(
x, pad_dim, pad_size=pad_size, pad_value=pad_value)


def __pad_with_value_if_necessary(x: Tensor,
pad_dim: int,
pad_size: int,
pad_value: Optional[Any] = None):
"""Pad a tensor with a value along some dim, do nothing on default.

Args:
x (Tensor): Input tensor.
pad_dim (int): Along which dim to pad.
pad_size (int): To which size to pad.
pad_value (Any): Filled value for padding. Defaults to `None`.

Returns:
Tensor: Padded tensor.
"""
return x


@FUNCTION_REWRITER.register_rewriter(
'mmdeploy.codebase.mmdet.deploy.utils.__pad_with_value_if_necessary',
backend=Backend.TENSORRT.value)
def __pad_with_value_if_necessary__tensorrt(ctx,
x: Tensor,
pad_dim: int,
pad_size: int,
pad_value: Optional[Any] = None):
"""Pad a tensor with a value along some dim.

Args:
x (Tensor): Input tensor.
pad_dim (int): Along which dim to pad.
pad_size (int): To which size to pad.
pad_value (Any): Filled value for padding. Defaults to `None`.

Returns:
Tensor: Padded tensor.
"""
return pad_with_value(x, pad_dim, pad_size=pad_size, pad_value=pad_value)
26 changes: 12 additions & 14 deletions mmdeploy/codebase/mmdet/models/dense_heads/base_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from mmdet.core.bbox.transforms import distance2bbox

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.codebase.mmdet.core.ops import ncnn_detection_output_forward
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -60,7 +61,6 @@ def base_dense_head__get_bbox(ctx,
"""
deploy_cfg = ctx.cfg
is_dynamic_flag = is_dynamic_shape(deploy_cfg)
backend = get_backend(deploy_cfg)
num_levels = len(cls_scores)

featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
Expand Down Expand Up @@ -98,10 +98,8 @@ def base_dense_head__get_bbox(ctx,
self.cls_out_channels)
if self.use_sigmoid_cls:
scores = scores.sigmoid()
nms_pre_score = scores
else:
scores = scores.softmax(-1)
nms_pre_score = scores
if with_score_factors:
score_factors = score_factors.permute(0, 2, 3,
1).reshape(batch_size,
Expand All @@ -112,16 +110,16 @@ def base_dense_head__get_bbox(ctx,
priors = priors.data
priors = priors.expand(batch_size, -1, priors.size(-1))
if pre_topk > 0:
priors = pad_with_value_if_necessary(priors, 1, pre_topk)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value_if_necessary(
score_factors, 1, pre_topk, 0.)

nms_pre_score = scores
if with_score_factors:
nms_pre_score = nms_pre_score * score_factors
if backend == Backend.TENSORRT:
priors = pad_with_value(priors, 1, pre_topk)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
scores = pad_with_value(scores, 1, pre_topk, 0.)
nms_pre_score = pad_with_value(nms_pre_score, 1, pre_topk, 0.)
if with_score_factors:
score_factors = pad_with_value(score_factors, 1, pre_topk,
0.)

# Get maximum scores for foreground classes.
if self.use_sigmoid_cls:
Expand Down Expand Up @@ -180,7 +178,7 @@ def base_dense_head__get_bbox(ctx,
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.base_dense_head.BaseDenseHead'
'.get_bboxes',
backend='ncnn')
backend=Backend.NCNN.value)
def base_dense_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
Expand Down
15 changes: 7 additions & 8 deletions mmdeploy/codebase/mmdet/models/dense_heads/rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
import torch

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -95,13 +96,11 @@ def rpn_head__get_bboxes(ctx,

anchors = anchors.expand_as(bbox_pred)

backend = get_backend(deploy_cfg)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
if backend == Backend.TENSORRT:
scores = pad_with_value(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
anchors = pad_with_value(anchors, 1, pre_topk)
scores = pad_with_value_if_necessary(scores, 1, pre_topk, 0.)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
anchors = pad_with_value_if_necessary(anchors, 1, pre_topk)

if pre_topk > 0:
_, topk_inds = scores.squeeze(2).topk(pre_topk)
Expand Down Expand Up @@ -145,7 +144,7 @@ def rpn_head__get_bboxes(ctx,


@FUNCTION_REWRITER.register_rewriter(
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend='ncnn')
'mmdet.models.dense_heads.RPNHead.get_bboxes', backend=Backend.NCNN.value)
def rpn_head__get_bboxes__ncnn(ctx,
self,
cls_scores,
Expand Down
16 changes: 8 additions & 8 deletions mmdeploy/codebase/mmdet/models/dense_heads/yolo_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import torch

from mmdeploy.codebase.mmdet import (get_post_processing_params,
multiclass_nms, pad_with_value)
multiclass_nms,
pad_with_value_if_necessary)
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils import Backend, get_backend, is_dynamic_shape
from mmdeploy.utils import Backend, is_dynamic_shape


@FUNCTION_REWRITER.register_rewriter(
Expand Down Expand Up @@ -90,13 +91,11 @@ def yolov3_head__get_bboxes(ctx,
conf_pred = torch.sigmoid(pred_map[..., 4])
cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
batch_size, -1, self.num_classes) # Cls pred one-hot.
backend = get_backend(ctx.cfg)
# topk in tensorrt does not support shape<k
# concate zero to enable topk,
if backend == Backend.TENSORRT:
bbox_pred = pad_with_value(bbox_pred, 1, pre_topk)
conf_pred = pad_with_value(conf_pred, 1, pre_topk, 0.)
cls_pred = pad_with_value(cls_pred, 1, pre_topk, 0.)
bbox_pred = pad_with_value_if_necessary(bbox_pred, 1, pre_topk)
conf_pred = pad_with_value_if_necessary(conf_pred, 1, pre_topk, 0.)
cls_pred = pad_with_value_if_necessary(cls_pred, 1, pre_topk, 0.)

if pre_topk > 0:
_, topk_inds = conf_pred.topk(pre_topk)
Expand Down Expand Up @@ -161,7 +160,8 @@ def yolov3_head__get_bboxes(ctx,


@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes', backend='ncnn')
func_name='mmdet.models.dense_heads.YOLOV3Head.get_bboxes',
backend=Backend.NCNN.value)
def yolov3_head__get_bboxes__ncnn(ctx,
self,
pred_maps,
Expand Down
12 changes: 11 additions & 1 deletion tests/test_codebase/test_mmdet/test_mmdet_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from mmdeploy.codebase import import_codebase
from mmdeploy.codebase.mmdet import (clip_bboxes, get_post_processing_params,
pad_with_value)
pad_with_value,
pad_with_value_if_necessary)
from mmdeploy.utils import Codebase

import_codebase(Codebase.MMDET)
Expand All @@ -29,6 +30,15 @@ def test_pad_with_value():
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)


def test_pad_with_value_if_necessary():
x = torch.rand(3, 2)
padded_x = pad_with_value_if_necessary(
x, pad_dim=1, pad_size=4, pad_value=0)
assert np.allclose(
padded_x.shape, torch.Size([3, 2]), rtol=1e-03, atol=1e-05)
assert np.allclose(padded_x.sum(), x.sum(), rtol=1e-03, atol=1e-05)


config_with_mmdet_params = mmcv.Config(
dict(
codebase_config=dict(
Expand Down