From 47964603379d0613606448936838bc8f0f0e15cb Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 29 Nov 2022 18:41:32 +0800 Subject: [PATCH 01/17] [Feature] Support RTMDet instance segmentation model. --- .../rtmdet/rtmdet-ins_l_8xb32-300e_coco.py | 102 ++ .../rtmdet/rtmdet-ins_m_8xb32-300e_coco.py | 6 + .../rtmdet/rtmdet-ins_s_8xb32-300e_coco.py | 84 ++ .../rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py | 50 + configs/rtmdet/rtmdet_l_8xb32-300e_coco.py | 4 +- mmdet/datasets/transforms/transforms.py | 46 +- mmdet/models/dense_heads/__init__.py | 3 +- mmdet/models/dense_heads/rtmdet_head.py | 9 +- mmdet/models/dense_heads/rtmdet_ins_head.py | 955 ++++++++++++++++++ .../assigners/dynamic_soft_label_assigner.py | 16 +- .../models/task_modules/samplers/__init__.py | 4 +- .../samplers/mask_box_pseudo_sampler.py | 140 +++ 12 files changed, 1409 insertions(+), 10 deletions(-) create mode 100644 configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py create mode 100644 configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py create mode 100644 configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py create mode 100644 configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py create mode 100644 mmdet/models/dense_heads/rtmdet_ins_head.py create mode 100644 mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py diff --git a/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py new file mode 100644 index 00000000000..404061f042f --- /dev/null +++ b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py @@ -0,0 +1,102 @@ +_base_ = './rtmdet_l_8xb32-300e_coco.py' +model = dict( + bbox_head=dict( + _delete_=True, + type='RTMDetInsSepBNHead', + num_classes=80, + in_channels=256, + stacked_convs=2, + share_conv=True, + pred_kernel_size=1, + feat_channels=256, + act_cfg=dict(type='SiLU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + anchor_generator=dict( + type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]), + bbox_coder=dict(type='DistancePointBBoxCoder'), + loss_cls=dict( + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='GIoULoss', loss_weight=2.0), + loss_mask=dict( + type='DiceLoss', loss_weight=1.0, eps=5e-6, reduction='mean')), + test_cfg=dict(mask_thr_binary=0.5), +) + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict( + type='RandomCrop', + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='PackDetInputs') +] + +train_dataloader = dict(pin_memory=True, dataset=dict(pipeline=train_pipeline)) + +train_pipeline_stage2 = [ + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict( + type='RandomResize', + scale=(640, 640), + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict( + type='RandomCrop', + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') +] +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=280, + switch_pipeline=train_pipeline_stage2) +] + +val_evaluator = dict(metric=['bbox', 'segm']) +test_evaluator = val_evaluator diff --git a/configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py new file mode 100644 index 00000000000..66da9148775 --- /dev/null +++ b/configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py @@ -0,0 +1,6 @@ +_base_ = './rtmdet-ins_l_8xb32-300e_coco.py' + +model = dict( + backbone=dict(deepen_factor=0.67, widen_factor=0.75), + neck=dict(in_channels=[192, 384, 768], out_channels=192, num_csp_blocks=2), + bbox_head=dict(in_channels=192, feat_channels=192)) diff --git a/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py new file mode 100644 index 00000000000..7785f2ff208 --- /dev/null +++ b/configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py @@ -0,0 +1,84 @@ +_base_ = './rtmdet-ins_l_8xb32-300e_coco.py' +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-s_imagenet_600e.pth' # noqa +model = dict( + backbone=dict( + deepen_factor=0.33, + widen_factor=0.5, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict(in_channels=[128, 256, 512], out_channels=128, num_csp_blocks=1), + bbox_head=dict(in_channels=128, feat_channels=128)) + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict(type='CachedMosaic', img_scale=(640, 640), pad_val=114.0), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict( + type='RandomCrop', + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='PackDetInputs') +] + +train_pipeline_stage2 = [ + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict( + type='RandomResize', + scale=(640, 640), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict( + type='RandomCrop', + crop_size=(640, 640), + recompute_bbox=True, + allow_negative_crop=True), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=280, + switch_pipeline=train_pipeline_stage2) +] diff --git a/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py new file mode 100644 index 00000000000..33b62878027 --- /dev/null +++ b/configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py @@ -0,0 +1,50 @@ +_base_ = './rtmdet-ins_s_8xb32-300e_coco.py' + +checkpoint = 'https://download.openmmlab.com/mmdetection/v3.0/rtmdet/cspnext_rsb_pretrain/cspnext-tiny_imagenet_600e.pth' # noqa + +model = dict( + backbone=dict( + deepen_factor=0.167, + widen_factor=0.375, + init_cfg=dict( + type='Pretrained', prefix='backbone.', checkpoint=checkpoint)), + neck=dict(in_channels=[96, 192, 384], out_channels=96, num_csp_blocks=1), + bbox_head=dict(in_channels=96, feat_channels=96)) + +train_pipeline = [ + dict( + type='LoadImageFromFile', + file_client_args={{_base_.file_client_args}}), + dict( + type='LoadAnnotations', + with_bbox=True, + with_mask=True, + poly2mask=False), + dict( + type='CachedMosaic', + img_scale=(640, 640), + pad_val=114.0, + max_cached_images=20, + random_pop=False), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.5, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(640, 640)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(640, 640), + ratio_range=(1.0, 1.0), + max_cached_images=10, + random_pop=False, + pad_val=(114, 114, 114), + prob=0.5), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1)), + dict(type='PackDetInputs') +] + +train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) diff --git a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py index 33ccb839c6c..85c66130178 100644 --- a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py +++ b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py @@ -115,8 +115,8 @@ ] train_dataloader = dict( - batch_size=32, - num_workers=10, + batch_size=2, + num_workers=1, batch_sampler=None, pin_memory=True, dataset=dict(pipeline=train_pipeline)) diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index 7ae88dcc568..791f36caffb 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -3244,6 +3244,9 @@ def transform(self, results: dict) -> dict: mosaic_bboxes = [] mosaic_bboxes_labels = [] mosaic_ignore_flags = [] + mosaic_masks = [] + with_mask = True if 'gt_masks' in results else False + if len(results['img'].shape) == 3: mosaic_img = np.full( (int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3), @@ -3298,6 +3301,20 @@ def transform(self, results: dict) -> dict: mosaic_bboxes.append(gt_bboxes_i) mosaic_bboxes_labels.append(gt_bboxes_labels_i) mosaic_ignore_flags.append(gt_ignore_flags_i) + if with_mask and results_patch.get('gt_masks', None) is not None: + gt_masks_i = results_patch['gt_masks'] + gt_masks_i = gt_masks_i.rescale(float(scale_ratio_i)) + gt_masks_i = gt_masks_i.translate( + out_shape=(int(self.img_scale[0] * 2), + int(self.img_scale[1] * 2)), + offset=padw, + direction='horizontal') + gt_masks_i = gt_masks_i.translate( + out_shape=(int(self.img_scale[0] * 2), + int(self.img_scale[1] * 2)), + offset=padh, + direction='vertical') + mosaic_masks.append(gt_masks_i) mosaic_bboxes = mosaic_bboxes[0].cat(mosaic_bboxes, 0) mosaic_bboxes_labels = np.concatenate(mosaic_bboxes_labels, 0) @@ -3317,6 +3334,10 @@ def transform(self, results: dict) -> dict: results['gt_bboxes'] = mosaic_bboxes results['gt_bboxes_labels'] = mosaic_bboxes_labels results['gt_ignore_flags'] = mosaic_ignore_flags + + if with_mask: + mosaic_masks = mosaic_masks[0].cat(mosaic_masks) + results['gt_masks'] = mosaic_masks[inside_inds] return results def __repr__(self): @@ -3481,6 +3502,7 @@ def transform(self, results: dict) -> dict: return results retrieve_img = retrieve_results['img'] + with_mask = True if 'gt_masks' in results else False jit_factor = random.uniform(*self.ratio_range) is_filp = random.uniform(0, 1) > self.flip_ratio @@ -3532,16 +3554,32 @@ def transform(self, results: dict) -> dict: # 6. adjust bbox retrieve_gt_bboxes = retrieve_results['gt_bboxes'] retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio]) + if with_mask: + retrieve_gt_masks: BitmapMasks = retrieve_results[ + 'gt_masks'].rescale(scale_ratio) + if self.bbox_clip_border: retrieve_gt_bboxes.clip_([origin_h, origin_w]) if is_filp: retrieve_gt_bboxes.flip_([origin_h, origin_w], direction='horizontal') + if with_mask: + retrieve_gt_masks = retrieve_gt_masks.flip() # 7. filter cp_retrieve_gt_bboxes = retrieve_gt_bboxes.clone() cp_retrieve_gt_bboxes.translate_([-x_offset, -y_offset]) + if with_mask: + retrieve_gt_masks = retrieve_gt_masks.translate( + out_shape=(target_h, target_w), + offset=-x_offset, + direction='horizontal') + retrieve_gt_masks = retrieve_gt_masks.translate( + out_shape=(target_h, target_w), + offset=-y_offset, + direction='vertical') + if self.bbox_clip_border: cp_retrieve_gt_bboxes.clip_([target_h, target_w]) @@ -3558,19 +3596,25 @@ def transform(self, results: dict) -> dict: (results['gt_bboxes_labels'], retrieve_gt_bboxes_labels), axis=0) mixup_gt_ignore_flags = np.concatenate( (results['gt_ignore_flags'], retrieve_gt_ignore_flags), axis=0) + if with_mask: + mixup_gt_masks = retrieve_gt_masks.cat( + [results['gt_masks'], retrieve_gt_masks]) # remove outside bbox inside_inds = mixup_gt_bboxes.is_inside([target_h, target_w]).numpy() mixup_gt_bboxes = mixup_gt_bboxes[inside_inds] mixup_gt_bboxes_labels = mixup_gt_bboxes_labels[inside_inds] mixup_gt_ignore_flags = mixup_gt_ignore_flags[inside_inds] + if with_mask: + mixup_gt_masks = mixup_gt_masks[inside_inds] results['img'] = mixup_img.astype(np.uint8) results['img_shape'] = mixup_img.shape results['gt_bboxes'] = mixup_gt_bboxes results['gt_bboxes_labels'] = mixup_gt_bboxes_labels results['gt_ignore_flags'] = mixup_gt_ignore_flags - + if with_mask: + results['gt_masks'] = mixup_gt_masks return results def __repr__(self): diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py index 0ab3ba2018e..469f5cc69d8 100644 --- a/mmdet/models/dense_heads/__init__.py +++ b/mmdet/models/dense_heads/__init__.py @@ -34,6 +34,7 @@ from .retina_sepbn_head import RetinaSepBNHead from .rpn_head import RPNHead from .rtmdet_head import RTMDetHead, RTMDetSepBNHead +from .rtmdet_ins_head import RTMDetInsHead, RTMDetInsSepBNHead from .sabl_retina_head import SABLRetinaHead from .solo_head import DecoupledSOLOHead, DecoupledSOLOLightHead, SOLOHead from .solov2_head import SOLOV2Head @@ -58,5 +59,5 @@ 'DecoupledSOLOHead', 'DecoupledSOLOLightHead', 'SOLOV2Head', 'LADHead', 'TOODHead', 'MaskFormerHead', 'Mask2FormerHead', 'DDODHead', 'CenterNetUpdateHead', 'RTMDetHead', 'RTMDetSepBNHead', 'CondInstBboxHead', - 'CondInstMaskHead' + 'CondInstMaskHead', 'RTMDetInsHead', 'RTMDetInsSepBNHead' ] diff --git a/mmdet/models/dense_heads/rtmdet_head.py b/mmdet/models/dense_heads/rtmdet_head.py index 42c15c1f6dd..3c53b68669d 100644 --- a/mmdet/models/dense_heads/rtmdet_head.py +++ b/mmdet/models/dense_heads/rtmdet_head.py @@ -266,7 +266,7 @@ def loss_by_feat(self, batch_img_metas, batch_gt_instances_ignore=batch_gt_instances_ignore) (anchor_list, labels_list, label_weights_list, bbox_targets_list, - assign_metrics_list) = cls_reg_targets + assign_metrics_list, sampling_results_list) = cls_reg_targets losses_cls, losses_bbox,\ cls_avg_factors, bbox_avg_factors = multi_apply( @@ -353,7 +353,7 @@ def get_targets(self, batch_gt_instances_ignore = [None] * num_imgs # anchor_list: list(b * [-1, 4]) (all_anchors, all_labels, all_label_weights, all_bbox_targets, - all_assign_metrics) = multi_apply( + all_assign_metrics, sampling_results_list) = multi_apply( self._get_targets_single, cls_scores.detach(), bbox_preds.detach(), @@ -378,7 +378,7 @@ def get_targets(self, num_level_anchors) return (anchors_list, labels_list, label_weights_list, - bbox_targets_list, assign_metrics_list) + bbox_targets_list, assign_metrics_list, sampling_results_list) def _get_targets_single(self, cls_scores: Tensor, @@ -486,7 +486,8 @@ def _get_targets_single(self, bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags) assign_metrics = unmap(assign_metrics, num_total_anchors, inside_flags) - return (anchors, labels, label_weights, bbox_targets, assign_metrics) + return (anchors, labels, label_weights, bbox_targets, assign_metrics, + sampling_result) def get_anchors(self, featmap_sizes: List[tuple], diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py new file mode 100644 index 00000000000..b6f66024b92 --- /dev/null +++ b/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -0,0 +1,955 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule, is_norm +from mmcv.ops import batched_nms +from mmengine.model import (BaseModule, bias_init_with_prob, constant_init, + normal_init) +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.models.utils import (filter_scores_and_topk, multi_apply, + select_single_mlvl, sigmoid_geometric_mean) +from mmdet.registry import MODELS +from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor, + get_box_wh, scale_boxes) +from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean +from ..task_modules.samplers import MaskBoxPseudoSampler +from .rtmdet_head import RTMDetHead + + +@MODELS.register_module() +class RTMDetInsHead(RTMDetHead): + """Detection Head of RTMDet-Ins. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + with_objectness (bool): Whether to add an objectness branch. + Defaults to True. + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Default: dict(type='ReLU') + """ + + def __init__(self, + *args, + num_prototypes: int = 8, + dyconv_channels: int = 8, + num_dyconvs: int = 3, + mask_loss_stride: int = 4, + use_condinst_coord=True, + loss_mask=dict( + type='DiceLoss', + loss_weight=1.0, + eps=5e-6, + reduction='mean'), + **kwargs) -> None: + self.num_prototypes = num_prototypes + self.num_dyconvs = num_dyconvs + self.dyconv_channels = dyconv_channels + self.mask_loss_stride = mask_loss_stride + self.use_condinst_coord = use_condinst_coord + super().__init__(*args, **kwargs) + self.loss_mask = MODELS.build(loss_mask) + self.sampler = MaskBoxPseudoSampler() + + def _init_layers(self): + """Initialize layers of the head.""" + super()._init_layers() + self.kernel_convs = nn.ModuleList() + # calculate num dynamic parameters + weight_nums, bias_nums = [], [] + for i in range(self.num_dyconvs): + if i == 0: + weight_nums.append( + (self.num_prototypes + 2) * self.dyconv_channels) + bias_nums.append(self.dyconv_channels) + elif i == self.num_dyconvs - 1: + weight_nums.append(self.dyconv_channels) + bias_nums.append(1) + else: + weight_nums.append(self.dyconv_channels * self.dyconv_channels) + bias_nums.append(self.dyconv_channels) + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_gen_params = sum(weight_nums) + sum(bias_nums) + + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + self.kernel_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + pred_pad_size = self.pred_kernel_size // 2 + self.rtm_kernel = nn.Conv2d( + self.feat_channels, + self.num_gen_params, + self.pred_kernel_size, + padding=pred_pad_size) + self.mask_head = MaskFeatModule( + in_channels=self.in_channels, + feat_channels=self.feat_channels, + stacked_convs=4, + num_levels=len(self.prior_generator.strides), + num_prototypes=self.num_prototypes, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg) + + def forward(self, feats: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + - cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + """ + mask_feat = self.mask_head(feats) + + cls_scores = [] + bbox_preds = [] + kernel_preds = [] + for idx, (x, scale, stride) in enumerate( + zip(feats, self.scales, self.prior_generator.strides)): + cls_feat = x + reg_feat = x + kernel_feat = x + + for cls_layer in self.cls_convs: + cls_feat = cls_layer(cls_feat) + cls_score = self.rtm_cls(cls_feat) + + for kernel_layer in self.kernel_convs: + kernel_feat = kernel_layer(kernel_feat) + kernel_pred = self.rtm_kernel(kernel_feat) + + for reg_layer in self.reg_convs: + reg_feat = reg_layer(reg_feat) + + if self.with_objectness: + objectness = self.rtm_obj(reg_feat) + cls_score = inverse_sigmoid( + sigmoid_geometric_mean(cls_score, objectness)) + + reg_dist = scale(self.rtm_reg(reg_feat)) * stride[0] + + cls_scores.append(cls_score) + bbox_preds.append(reg_dist) + kernel_preds.append(kernel_pred) + return tuple(cls_scores), tuple(bbox_preds), tuple( + kernel_preds), mask_feat + + def predict_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + kernel_preds: List[Tensor], + mask_feat: Tensor, + score_factors: Optional[List[Tensor]] = None, + batch_img_metas: Optional[List[dict]] = None, + cfg: Optional[ConfigType] = None, + rescale: bool = False, + with_nms: bool = True) -> InstanceList: + """Transform a batch of output features extracted from the head into + bbox results. + + Note: When score_factors is not None, the cls_scores are + usually multiplied by it then obtain the real score used in NMS, + such as CenterNess in FCOS, IoU branch in ATSS. + + Args: + cls_scores (list[Tensor]): Classification scores for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * num_classes, H, W). + bbox_preds (list[Tensor]): Box energies / deltas for all + scale levels, each is a 4D-tensor, has shape + (batch_size, num_priors * 4, H, W). + score_factors (list[Tensor], optional): Score factor for + all scale level, each is a 4D-tensor, has shape + (batch_size, num_priors * 1, H, W). Defaults to None. + batch_img_metas (list[dict], Optional): Batch image meta info. + Defaults to None. + cfg (ConfigDict, optional): Test / postprocessing + configuration, if None, test_cfg would be used. + Defaults to None. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + list[:obj:`InstanceData`]: Object detection results of each image + after the post process. Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_scores) == len(bbox_preds) + + if score_factors is None: + # e.g. Retina, FreeAnchor, Foveabox, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, AutoAssign, etc. + with_score_factors = True + assert len(cls_scores) == len(score_factors) + + num_levels = len(cls_scores) + + featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)] + mlvl_priors = self.prior_generator.grid_priors( + featmap_sizes, + dtype=cls_scores[0].dtype, + device=cls_scores[0].device, + with_stride=True) + + result_list = [] + + for img_id in range(len(batch_img_metas)): + img_meta = batch_img_metas[img_id] + cls_score_list = select_single_mlvl( + cls_scores, img_id, detach=True) + bbox_pred_list = select_single_mlvl( + bbox_preds, img_id, detach=True) + kernel_pred_list = select_single_mlvl( + kernel_preds, img_id, detach=True) + if with_score_factors: + score_factor_list = select_single_mlvl( + score_factors, img_id, detach=True) + else: + score_factor_list = [None for _ in range(num_levels)] + + results = self._predict_by_feat_single( + cls_score_list=cls_score_list, + bbox_pred_list=bbox_pred_list, + kernel_pred_list=kernel_pred_list, + mask_feat=mask_feat[img_id], + score_factor_list=score_factor_list, + mlvl_priors=mlvl_priors, + img_meta=img_meta, + cfg=cfg, + rescale=rescale, + with_nms=with_nms) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score_list: List[Tensor], + bbox_pred_list: List[Tensor], + kernel_pred_list, + mask_feat, + score_factor_list: List[Tensor], + mlvl_priors: List[Tensor], + img_meta: dict, + cfg: ConfigType, + rescale: bool = False, + with_nms: bool = True) -> InstanceData: + """Transform a single image's features extracted from the head into + bbox results. + + Args: + cls_score_list (list[Tensor]): Box scores from all scale + levels of a single image, each item has shape + (num_priors * num_classes, H, W). + bbox_pred_list (list[Tensor]): Box energies / deltas from + all scale levels of a single image, each item has shape + (num_priors * 4, H, W). + score_factor_list (list[Tensor]): Score factor from all scale + levels of a single image, each item has shape + (num_priors * 1, H, W). + mlvl_priors (list[Tensor]): Each element in the list is + the priors of a single level in feature pyramid. In all + anchor-based methods, it has shape (num_priors, 4). In + all anchor-free methods, it has shape (num_priors, 2) + when `with_stride=True`, otherwise it still has shape + (num_priors, 4). + img_meta (dict): Image meta info. + cfg (mmengine.Config): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Defaults to False. + with_nms (bool): If True, do nms before return boxes. + Defaults to True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + if score_factor_list[0] is None: + # e.g. Retina, FreeAnchor, etc. + with_score_factors = False + else: + # e.g. FCOS, PAA, ATSS, etc. + with_score_factors = True + + cfg = self.test_cfg if cfg is None else cfg + cfg = copy.deepcopy(cfg) + img_shape = img_meta['img_shape'] + nms_pre = cfg.get('nms_pre', -1) + + mlvl_bbox_preds = [] + mlvl_kernels = [] + mlvl_valid_priors = [] + mlvl_scores = [] + mlvl_labels = [] + if with_score_factors: + mlvl_score_factors = [] + else: + mlvl_score_factors = None + + for level_idx, (cls_score, bbox_pred, kernel_pred, + score_factor, priors) in \ + enumerate(zip(cls_score_list, bbox_pred_list, kernel_pred_list, + score_factor_list, mlvl_priors)): + + assert cls_score.size()[-2:] == bbox_pred.size()[-2:] + + dim = self.bbox_coder.encode_size + bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, dim) + if with_score_factors: + score_factor = score_factor.permute(1, 2, + 0).reshape(-1).sigmoid() + cls_score = cls_score.permute(1, 2, + 0).reshape(-1, self.cls_out_channels) + kernel_pred = kernel_pred.permute(1, 2, 0).reshape( + -1, self.num_gen_params) + if self.use_sigmoid_cls: + scores = cls_score.sigmoid() + else: + # remind that we set FG labels to [0, num_class-1] + # since mmdet v2.0 + # BG cat_id: num_class + scores = cls_score.softmax(-1)[:, :-1] + + # After https://github.com/open-mmlab/mmdetection/pull/6268/, + # this operation keeps fewer bboxes under the same `nms_pre`. + # There is no difference in performance for most models. If you + # find a slight drop in performance, you can set a larger + # `nms_pre` than before. + score_thr = cfg.get('score_thr', 0) + + results = filter_scores_and_topk( + scores, score_thr, nms_pre, + dict( + bbox_pred=bbox_pred, + priors=priors, + kernel_pred=kernel_pred)) + scores, labels, keep_idxs, filtered_results = results + + bbox_pred = filtered_results['bbox_pred'] + priors = filtered_results['priors'] + kernel_pred = filtered_results['kernel_pred'] + + if with_score_factors: + score_factor = score_factor[keep_idxs] + + mlvl_bbox_preds.append(bbox_pred) + mlvl_valid_priors.append(priors) + mlvl_scores.append(scores) + mlvl_labels.append(labels) + mlvl_kernels.append(kernel_pred) + + if with_score_factors: + mlvl_score_factors.append(score_factor) + + bbox_pred = torch.cat(mlvl_bbox_preds) + priors = cat_boxes(mlvl_valid_priors) + bboxes = self.bbox_coder.decode( + priors[..., :2], bbox_pred, max_shape=img_shape) + + results = InstanceData() + results.bboxes = bboxes + results.priors = priors + results.scores = torch.cat(mlvl_scores) + results.labels = torch.cat(mlvl_labels) + results.kernels = torch.cat(mlvl_kernels) + if with_score_factors: + results.score_factors = torch.cat(mlvl_score_factors) + + return self._bbox_post_process( + results=results, + mask_feat=mask_feat, + cfg=cfg, + rescale=rescale, + with_nms=with_nms, + img_meta=img_meta) + + def _bbox_post_process(self, + results: InstanceData, + mask_feat, + cfg: ConfigType, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> InstanceData: + """bbox post-processing method. + + The boxes would be rescaled to the original image scale and do + the nms operation. Usually `with_nms` is False is used for aug test. + + Args: + results (:obj:`InstaceData`): Detection instance results, + each item has shape (num_bboxes, ). + cfg (ConfigDict): Test / postprocessing configuration, + if None, test_cfg would be used. + rescale (bool): If True, return boxes in original image space. + Default to False. + with_nms (bool): If True, do nms before return boxes. + Default to True. + img_meta (dict, optional): Image meta info. Defaults to None. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + stride = self.prior_generator.strides[0][0] + if rescale: + assert img_meta.get('scale_factor') is not None + scale_factor = [1 / s for s in img_meta['scale_factor']] + results.bboxes = scale_boxes(results.bboxes, scale_factor) + + if hasattr(results, 'score_factors'): + # TODO: Add sqrt operation in order to be consistent with + # the paper. + score_factors = results.pop('score_factors') + results.scores = results.scores * score_factors + + # filter small size bboxes + if cfg.get('min_bbox_size', -1) >= 0: + w, h = get_box_wh(results.bboxes) + valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size) + if not valid_mask.all(): + results = results[valid_mask] + + # TODO: deal with `with_nms` and `nms_cfg=None` in test_cfg + assert with_nms, 'with_nms must be True for RTMDet-Ins' + if results.bboxes.numel() > 0: + bboxes = get_box_tensor(results.bboxes) + det_bboxes, keep_idxs = batched_nms(bboxes, results.scores, + results.labels, cfg.nms) + results = results[keep_idxs] + # some nms would reweight the score, such as softnms + results.scores = det_bboxes[:, -1] + results = results[:cfg.max_per_img] + + # process masks + h, w = img_meta['img_shape'][:2] + + mask_logits = self._mask_predict_by_feat_single( + mask_feat, results.kernels, results.priors) + + mask_logits = F.interpolate( + mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear') + # print('upsampled croped shape ', mask_logits.shape) + if rescale: + ori_h, ori_w = img_meta['ori_shape'][:2] + # print('scale_factor: ',scale_factor) + mask_logits = F.interpolate( + mask_logits, + size=[ + math.ceil(mask_logits.shape[-2] * scale_factor[0]), + math.ceil(mask_logits.shape[-1] * scale_factor[1]) + ], + mode='bilinear', + align_corners=False)[..., :ori_h, :ori_w] + # print('rescale shape ', mask_logits.shape, ori_h, ori_w) + masks = mask_logits.sigmoid().squeeze(0) + masks = masks > cfg.mask_thr_binary + results.masks = masks + else: + h, w = img_meta['ori_shape'][:2] if rescale else img_meta[ + 'img_shape'][:2] + results.masks = torch.zeros( + size=(results.bboxes.shape[0], h, w), + dtype=torch.bool, + device=results.bboxes.device) + + return results + + def parse_dynamic_params(self, flatten_kernels): + """split kernel head prediction to conv weight and bias.""" + n_inst = flatten_kernels.size(0) + n_layers = len(self.weight_nums) + params_splits = list( + torch.split_with_sizes( + flatten_kernels, self.weight_nums + self.bias_nums, dim=1)) + weight_splits = params_splits[:n_layers] + bias_splits = params_splits[n_layers:] + for i in range(n_layers): + if i < n_layers - 1: + weight_splits[i] = weight_splits[i].reshape( + n_inst * self.dyconv_channels, -1, 1, 1) + bias_splits[i] = bias_splits[i].reshape(n_inst * + self.dyconv_channels) + else: + weight_splits[i] = weight_splits[i].reshape(n_inst, -1, 1, 1) + bias_splits[i] = bias_splits[i].reshape(n_inst) + + return weight_splits, bias_splits + + def _mask_predict_by_feat_single(self, mask_feat, kernels, priors): + num_inst = priors.shape[0] + h, w = mask_feat.size()[-2:] + if num_inst < 1: + return torch.empty( + size=(num_inst, h, w), + dtype=mask_feat.dtype, + device=mask_feat.device) + if len(mask_feat.shape) < 4: + mask_feat.unsqueeze(0) + + coord = self.prior_generator.single_level_grid_priors( + (h, w), level_idx=0).reshape(1, -1, 2) + num_inst = priors.shape[0] + points = priors[:, :2].reshape(-1, 1, 2) + strides = priors[:, 2:].reshape(-1, 1, 2) + relative_coord = (points - coord).permute(0, 2, 1) / ( + strides[..., 0].reshape(-1, 1, 1) * 8) + relative_coord = relative_coord.reshape(num_inst, 2, h, w) + + mask_feat = torch.cat( + [relative_coord, + mask_feat.repeat(num_inst, 1, 1, 1)], dim=1) + weights, biases = self.parse_dynamic_params(kernels) + + n_layers = len(weights) + x = mask_feat.reshape(1, -1, h, w) + for i, (weight, bias) in enumerate(zip(weights, biases)): + x = F.conv2d( + x, weight, bias=bias, stride=1, padding=0, groups=num_inst) + if i < n_layers - 1: + x = F.relu(x) + x = x.reshape(num_inst, h, w) + return x + + def loss_mask_by_feat(self, mask_feats, flatten_kernels, + assign_metrics_list, sampling_results_list): + # import pdb; pdb.set_trace() + batch_pos_mask_logits = [] + pos_gt_masks = [] + for idx, (mask_feat, kernels, sampling_results) in enumerate( + zip(mask_feats, flatten_kernels, sampling_results_list)): + pos_priors = sampling_results.pos_priors + pos_inds = sampling_results.pos_inds + pos_kernels = kernels[pos_inds] # n_pos, num_gen_params + pos_mask_logits = self._mask_predict_by_feat_single( + mask_feat, pos_kernels, pos_priors) + + batch_pos_mask_logits.append(pos_mask_logits) + pos_gt_masks.append(sampling_results.pos_gt_masks) + + pos_gt_masks = torch.cat(pos_gt_masks, 0) + batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) + + # avg_factor + num_pos = batch_pos_mask_logits.shape[0] + num_pos = reduce_mean(mask_feats.new_tensor([num_pos + ])).clamp_(min=1).item() + + if batch_pos_mask_logits.shape[0] == 0: + return mask_feats.sum() * 0, mask_feats.sum() * 0 + + scale = self.prior_generator.strides[0][0] // self.mask_loss_stride + # upsample pred masks + batch_pos_mask_logits = F.interpolate( + batch_pos_mask_logits.unsqueeze(0), + scale_factor=scale, + mode='bilinear', + align_corners=False).squeeze(0) + # downsample gt masks + pos_gt_masks = pos_gt_masks[:, self.mask_loss_stride // + 2::self.mask_loss_stride, + self.mask_loss_stride // + 2::self.mask_loss_stride] + + loss_mask = self.loss_mask( + batch_pos_mask_logits, + pos_gt_masks, + weight=None, + avg_factor=num_pos) + + return loss_mask + + def loss_by_feat(self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + kernel_preds: List[Tensor], + mask_feat: Tensor, + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None): + """Compute losses of the head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Decoded box for each scale + level with shape (N, num_anchors * 4, H, W) in + [tl_x, tl_y, br_x, br_y] format. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + num_imgs = len(batch_img_metas) + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + flatten_cls_scores = torch.cat([ + cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.cls_out_channels) + for cls_score in cls_scores + ], 1) + flatten_kernels = torch.cat([ + kernel_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, + self.num_gen_params) + for kernel_pred in kernel_preds + ], 1) + decoded_bboxes = [] + for anchor, bbox_pred in zip(anchor_list[0], bbox_preds): + anchor = anchor.reshape(-1, 4) + bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4) + bbox_pred = distance2bbox(anchor, bbox_pred) + decoded_bboxes.append(bbox_pred) + + flatten_bboxes = torch.cat(decoded_bboxes, 1) + for gt_instances in batch_gt_instances: + gt_instances.masks = gt_instances.masks.to_tensor( + dtype=torch.bool, device=device) + + cls_reg_targets = self.get_targets( + flatten_cls_scores, + flatten_bboxes, + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + assign_metrics_list, sampling_results_list) = cls_reg_targets + + losses_cls, losses_bbox,\ + cls_avg_factors, bbox_avg_factors = multi_apply( + self.loss_by_feat_single, + cls_scores, + decoded_bboxes, + labels_list, + label_weights_list, + bbox_targets_list, + assign_metrics_list, + self.prior_generator.strides) + + cls_avg_factor = reduce_mean(sum(cls_avg_factors)).clamp_(min=1).item() + losses_cls = list(map(lambda x: x / cls_avg_factor, losses_cls)) + + bbox_avg_factor = reduce_mean( + sum(bbox_avg_factors)).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + + loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, + assign_metrics_list, + sampling_results_list) + loss = dict( + loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask) + return loss + + +class MaskFeatModule(BaseModule): + + def __init__(self, + in_channels, + feat_channels=256, + stacked_convs=4, + num_levels=3, + num_prototypes=8, + act_cfg=dict(type='SiLU'), + norm_cfg=dict(type='BN')): + super().__init__(init_cfg=None) + self.num_levels = num_levels + self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1) + convs = [] + for i in range(stacked_convs): + in_c = in_channels if i == 0 else feat_channels + convs.append( + ConvModule( + in_c, + feat_channels, + 3, + padding=1, + act_cfg=act_cfg, + norm_cfg=norm_cfg)) + self.stacked_convs = nn.Sequential(*convs) + self.projection = nn.Conv2d( + feat_channels, num_prototypes, kernel_size=1) + + def forward(self, features): + # multi-level feature fusion + fusion_feats = [features[0]] + size = features[0].shape[-2:] + for i in range(1, self.num_levels): + f = F.interpolate(features[i], size=size, mode='bilinear') + fusion_feats.append(f) + fusion_feats = torch.cat(fusion_feats, dim=1) + fusion_feats = self.fusion_conv(fusion_feats) + # pred mask feats + mask_features = self.stacked_convs(fusion_feats) + mask_features = self.projection(mask_features) + return mask_features + + +@MODELS.register_module() +class RTMDetInsSepBNHead(RTMDetInsHead): + """Detection Head of RTMDet-ins-seg with sep-bn layers. + + Args: + num_classes (int): Number of categories excluding the background + category. + in_channels (int): Number of channels in the input feature map. + with_objectness (bool): Whether to add an objectness branch. + Defaults to True. + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Default: dict(type='ReLU') + """ + + def __init__(self, + num_classes: int, + in_channels: int, + share_conv=True, + with_objectness=False, + norm_cfg=dict(type='BN', requires_grad=True), + pred_kernel_size=1, + **kwargs) -> None: + self.share_conv = share_conv + super().__init__( + num_classes, + in_channels, + norm_cfg=norm_cfg, + pred_kernel_size=pred_kernel_size, + with_objectness=with_objectness, + **kwargs) + + def _init_layers(self): + """Initialize layers of the head.""" + self.cls_convs = nn.ModuleList() + self.reg_convs = nn.ModuleList() + self.kernel_convs = nn.ModuleList() + + self.rtm_cls = nn.ModuleList() + self.rtm_reg = nn.ModuleList() + self.rtm_kernel = nn.ModuleList() + self.rtm_obj = nn.ModuleList() + + # calculate num dynamic parameters + weight_nums, bias_nums = [], [] + for i in range(self.num_dyconvs): + if i == 0: + weight_nums.append( + (self.num_prototypes + 2) * self.dyconv_channels) + bias_nums.append(self.dyconv_channels) + elif i == self.num_dyconvs - 1: + weight_nums.append(self.dyconv_channels) + bias_nums.append(1) + else: + weight_nums.append(self.dyconv_channels * self.dyconv_channels) + bias_nums.append(self.dyconv_channels) + self.weight_nums = weight_nums + self.bias_nums = bias_nums + self.num_gen_params = sum(weight_nums) + sum(bias_nums) + pred_pad_size = self.pred_kernel_size // 2 + + for n in range(len(self.prior_generator.strides)): + cls_convs = nn.ModuleList() + reg_convs = nn.ModuleList() + kernel_convs = nn.ModuleList() + for i in range(self.stacked_convs): + chn = self.in_channels if i == 0 else self.feat_channels + cls_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + reg_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + kernel_convs.append( + ConvModule( + chn, + self.feat_channels, + 3, + stride=1, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg)) + self.cls_convs.append(cls_convs) + self.reg_convs.append(cls_convs) + self.kernel_convs.append(kernel_convs) + + self.rtm_cls.append( + nn.Conv2d( + self.feat_channels, + self.num_base_priors * self.cls_out_channels, + self.pred_kernel_size, + padding=pred_pad_size)) + self.rtm_reg.append( + nn.Conv2d( + self.feat_channels, + self.num_base_priors * 4, + self.pred_kernel_size, + padding=pred_pad_size)) + self.rtm_kernel.append( + nn.Conv2d( + self.feat_channels, + self.num_gen_params, + self.pred_kernel_size, + padding=pred_pad_size)) + if self.with_objectness: + self.rtm_obj.append( + nn.Conv2d( + self.feat_channels, + 1, + self.pred_kernel_size, + padding=pred_pad_size)) + + if self.share_conv: + for n in range(len(self.prior_generator.strides)): + for i in range(self.stacked_convs): + self.cls_convs[n][i].conv = self.cls_convs[0][i].conv + self.reg_convs[n][i].conv = self.reg_convs[0][i].conv + + self.mask_head = MaskFeatModule( + in_channels=self.in_channels, + feat_channels=self.feat_channels, + stacked_convs=4, + num_levels=len(self.prior_generator.strides), + num_prototypes=self.num_prototypes, + act_cfg=self.act_cfg, + norm_cfg=self.norm_cfg) + + def init_weights(self) -> None: + """Initialize weights of the head.""" + for m in self.modules(): + if isinstance(m, nn.Conv2d): + normal_init(m, mean=0, std=0.01) + if is_norm(m): + constant_init(m, 1) + bias_cls = bias_init_with_prob(0.01) + for rtm_cls, rtm_reg, rtm_kernel in zip(self.rtm_cls, self.rtm_reg, + self.rtm_kernel): + normal_init(rtm_cls, std=0.01, bias=bias_cls) + normal_init(rtm_reg, std=0.01, bias=1) + if self.with_objectness: + for rtm_obj in self.rtm_obj: + normal_init(rtm_obj, std=0.01, bias=bias_cls) + + def forward(self, feats: Tuple[Tensor, ...]) -> tuple: + """Forward features from the upstream network. + + Args: + feats (tuple[Tensor]): Features from the upstream network, each is + a 4D-tensor. + + Returns: + tuple: Usually a tuple of classification scores and bbox prediction + - cls_scores (list[Tensor]): Classification scores for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * num_classes. + - bbox_preds (list[Tensor]): Box energies / deltas for all scale + levels, each is a 4D-tensor, the channels number is + num_base_priors * 4. + """ + mask_feat = self.mask_head(feats) + + cls_scores = [] + bbox_preds = [] + kernel_preds = [] + for idx, (x, stride) in enumerate( + zip(feats, self.prior_generator.strides)): + cls_feat = x + reg_feat = x + kernel_feat = x + + for cls_layer in self.cls_convs[idx]: + cls_feat = cls_layer(cls_feat) + cls_score = self.rtm_cls[idx](cls_feat) + + for kernel_layer in self.kernel_convs[idx]: + kernel_feat = kernel_layer(kernel_feat) + kernel_pred = self.rtm_kernel[idx](kernel_feat) + + for reg_layer in self.reg_convs[idx]: + reg_feat = reg_layer(reg_feat) + + if self.with_objectness: + objectness = self.rtm_obj[idx](reg_feat) + cls_score = inverse_sigmoid( + sigmoid_geometric_mean(cls_score, objectness)) + + reg_dist = F.relu(self.rtm_reg[idx](reg_feat)) * stride[0] + + cls_scores.append(cls_score) + bbox_preds.append(reg_dist) + kernel_preds.append(kernel_pred) + return tuple(cls_scores), tuple(bbox_preds), tuple( + kernel_preds), mask_feat diff --git a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py index 00276e05b80..0e6651b8c01 100644 --- a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py +++ b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py @@ -16,6 +16,17 @@ EPS = 1.0e-7 +def center_of_mass(masks: Tensor, eps=1e-6): + n, h, w = masks.shape + grid_h = torch.arange(h, device=masks.device)[:, None] + grid_w = torch.arange(w, device=masks.device) + normalizer = masks.sum(dim=(1, 2)).float().clamp(min=eps) + center_y = (masks * grid_h).sum(dim=(1, 2)) / normalizer + center_x = (masks * grid_w).sum(dim=(1, 2)) / normalizer + center = torch.cat([center_x[:, None], center_y[:, None]], dim=1) + return center + + @TASK_UTILS.register_module() class DynamicSoftLabelAssigner(BaseAssigner): """Computes matching between predictions and ground truth with dynamic soft @@ -118,7 +129,10 @@ def assign(self, dtype=torch.long) return AssignResult( num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) - if isinstance(gt_bboxes, BaseBoxes): + if hasattr(gt_instances, 'masks'): + gt_center = center_of_mass(gt_instances.masks, eps=EPS) + # print(gt_center, (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2.0) + elif isinstance(gt_bboxes, BaseBoxes): gt_center = gt_bboxes.centers else: # Tensor boxes will be treated as horizontal boxes by defaults diff --git a/mmdet/models/task_modules/samplers/__init__.py b/mmdet/models/task_modules/samplers/__init__.py index 3782eb898cf..87dc91540b8 100644 --- a/mmdet/models/task_modules/samplers/__init__.py +++ b/mmdet/models/task_modules/samplers/__init__.py @@ -3,6 +3,8 @@ from .combined_sampler import CombinedSampler from .instance_balanced_pos_sampler import InstanceBalancedPosSampler from .iou_balanced_neg_sampler import IoUBalancedNegSampler +from .mask_box_pseudo_sampler import (MaskBoxPseudoSampler, + MaskBoxSamplingResult) from .mask_pseudo_sampler import MaskPseudoSampler from .mask_sampling_result import MaskSamplingResult from .multi_instance_random_sampler import MultiInsRandomSampler @@ -18,5 +20,5 @@ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler', 'MaskSamplingResult', 'MultiInstanceSamplingResult', - 'MultiInsRandomSampler' + 'MultiInsRandomSampler', 'MaskBoxSamplingResult', 'MaskBoxPseudoSampler' ] diff --git a/mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py b/mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py new file mode 100644 index 00000000000..27fa267cb29 --- /dev/null +++ b/mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py @@ -0,0 +1,140 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.registry import TASK_UTILS +from mmdet.structures.bbox import BaseBoxes +from ..assigners.assign_result import AssignResult +from .base_sampler import BaseSampler +from .sampling_result import SamplingResult + + +# TODO: replace these sampler after refactor +@TASK_UTILS.register_module() +class MaskBoxPseudoSampler(BaseSampler): + """A pseudo sampler that does not do sampling actually.""" + + def __init__(self, **kwargs): + pass + + def _sample_pos(self, **kwargs): + """Sample positive samples.""" + raise NotImplementedError + + def _sample_neg(self, **kwargs): + """Sample negative samples.""" + raise NotImplementedError + + def sample(self, assign_result: AssignResult, pred_instances: InstanceData, + gt_instances: InstanceData, *args, **kwargs): + """Directly returns the positive and negative indices of samples. + + Args: + assign_result (:obj:`AssignResult`): Mask assigning results. + pred_instances (:obj:`InstanceData`): Instances of model + predictions. It includes ``scores`` and ``masks`` predicted + by the model. + gt_instances (:obj:`InstanceData`): Ground truth of instance + annotations. It usually includes ``labels`` and ``masks`` + attributes. + + Returns: + :obj:`SamplingResult`: sampler results + """ + gt_bboxes = gt_instances.bboxes + priors = pred_instances.priors + + pred_masks = pred_instances.bboxes + gt_masks = gt_instances.masks + pos_inds = torch.nonzero( + assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() + neg_inds = torch.nonzero( + assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() + gt_flags = pred_masks.new_zeros(pred_masks.shape[0], dtype=torch.uint8) + sampling_result = MaskBoxSamplingResult( + pos_inds=pos_inds, + neg_inds=neg_inds, + priors=priors, + masks=pred_masks, + gt_bboxes=gt_bboxes, + gt_masks=gt_masks, + assign_result=assign_result, + gt_flags=gt_flags, + avg_factor_with_neg=False) + return sampling_result + + +class MaskBoxSamplingResult(SamplingResult): + """Mask sampling result.""" + + def __init__(self, + pos_inds: Tensor, + neg_inds: Tensor, + priors: Tensor, + masks: Tensor, + gt_bboxes: Tensor, + gt_masks: Tensor, + assign_result: AssignResult, + gt_flags: Tensor, + avg_factor_with_neg: bool = True) -> None: + self.pos_inds = pos_inds + self.neg_inds = neg_inds + self.num_pos = max(pos_inds.numel(), 1) + self.num_neg = max(neg_inds.numel(), 1) + self.avg_factor = self.num_pos + self.num_neg \ + if avg_factor_with_neg else self.num_pos + + self.pos_masks = masks[pos_inds] + self.neg_masks = masks[neg_inds] + self.pos_is_gt = gt_flags[pos_inds] + + self.num_gts = gt_masks.shape[0] + self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 + + self.pos_priors = priors[pos_inds] + self.neg_priors = priors[neg_inds] + + self.pos_gt_labels = assign_result.labels[pos_inds] + box_dim = gt_bboxes.box_dim if isinstance(gt_bboxes, BaseBoxes) else 4 + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = gt_bboxes.view(-1, box_dim) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, box_dim) + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()] + + if gt_masks.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_masks = torch.empty_like(gt_masks) + else: + self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] + + @property + def masks(self) -> Tensor: + """torch.Tensor: concatenated positive and negative masks.""" + return torch.cat([self.pos_masks, self.neg_masks]) + + def __nice__(self) -> str: + data = self.info.copy() + data['pos_masks'] = data.pop('pos_masks').shape + data['neg_masks'] = data.pop('neg_masks').shape + parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self) -> dict: + """Returns a dictionary of info about the object.""" + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_masks': self.pos_masks, + 'neg_masks': self.neg_masks, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } From bc2203fee6078edf9a26bd88c793578647a51475 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 29 Nov 2022 20:03:11 +0800 Subject: [PATCH 02/17] remove sampler --- mmdet/models/dense_heads/rtmdet_ins_head.py | 23 +-- .../models/task_modules/samplers/__init__.py | 4 +- .../samplers/mask_box_pseudo_sampler.py | 140 ------------------ 3 files changed, 14 insertions(+), 153 deletions(-) delete mode 100644 mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py index b6f66024b92..df4f8272342 100644 --- a/mmdet/models/dense_heads/rtmdet_ins_head.py +++ b/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -20,7 +20,6 @@ from mmdet.structures.bbox import (cat_boxes, distance2bbox, get_box_tensor, get_box_wh, scale_boxes) from mmdet.utils import ConfigType, InstanceList, OptInstanceList, reduce_mean -from ..task_modules.samplers import MaskBoxPseudoSampler from .rtmdet_head import RTMDetHead @@ -58,7 +57,6 @@ def __init__(self, self.use_condinst_coord = use_condinst_coord super().__init__(*args, **kwargs) self.loss_mask = MODELS.build(loss_mask) - self.sampler = MaskBoxPseudoSampler() def _init_layers(self): """Initialize layers of the head.""" @@ -561,20 +559,25 @@ def _mask_predict_by_feat_single(self, mask_feat, kernels, priors): return x def loss_mask_by_feat(self, mask_feats, flatten_kernels, - assign_metrics_list, sampling_results_list): - # import pdb; pdb.set_trace() + sampling_results_list, batch_gt_instances): batch_pos_mask_logits = [] pos_gt_masks = [] - for idx, (mask_feat, kernels, sampling_results) in enumerate( - zip(mask_feats, flatten_kernels, sampling_results_list)): + for idx, (mask_feat, kernels, sampling_results, + gt_instances) in enumerate( + zip(mask_feats, flatten_kernels, sampling_results_list, + batch_gt_instances)): pos_priors = sampling_results.pos_priors pos_inds = sampling_results.pos_inds pos_kernels = kernels[pos_inds] # n_pos, num_gen_params pos_mask_logits = self._mask_predict_by_feat_single( mask_feat, pos_kernels, pos_priors) - + if gt_instances.masks.numel() == 0: + gt_masks = torch.empty_like(gt_instances.masks) + else: + gt_masks = gt_instances.masks[ + sampling_results.pos_assigned_gt_inds, :] batch_pos_mask_logits.append(pos_mask_logits) - pos_gt_masks.append(sampling_results.pos_gt_masks) + pos_gt_masks.append(gt_masks) pos_gt_masks = torch.cat(pos_gt_masks, 0) batch_pos_mask_logits = torch.cat(batch_pos_mask_logits, 0) @@ -696,8 +699,8 @@ def loss_by_feat(self, losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) loss_mask = self.loss_mask_by_feat(mask_feat, flatten_kernels, - assign_metrics_list, - sampling_results_list) + sampling_results_list, + batch_gt_instances) loss = dict( loss_cls=losses_cls, loss_bbox=losses_bbox, loss_mask=loss_mask) return loss diff --git a/mmdet/models/task_modules/samplers/__init__.py b/mmdet/models/task_modules/samplers/__init__.py index 87dc91540b8..3782eb898cf 100644 --- a/mmdet/models/task_modules/samplers/__init__.py +++ b/mmdet/models/task_modules/samplers/__init__.py @@ -3,8 +3,6 @@ from .combined_sampler import CombinedSampler from .instance_balanced_pos_sampler import InstanceBalancedPosSampler from .iou_balanced_neg_sampler import IoUBalancedNegSampler -from .mask_box_pseudo_sampler import (MaskBoxPseudoSampler, - MaskBoxSamplingResult) from .mask_pseudo_sampler import MaskPseudoSampler from .mask_sampling_result import MaskSamplingResult from .multi_instance_random_sampler import MultiInsRandomSampler @@ -20,5 +18,5 @@ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'MaskPseudoSampler', 'MaskSamplingResult', 'MultiInstanceSamplingResult', - 'MultiInsRandomSampler', 'MaskBoxSamplingResult', 'MaskBoxPseudoSampler' + 'MultiInsRandomSampler' ] diff --git a/mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py b/mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py deleted file mode 100644 index 27fa267cb29..00000000000 --- a/mmdet/models/task_modules/samplers/mask_box_pseudo_sampler.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from mmengine.structures import InstanceData -from torch import Tensor - -from mmdet.registry import TASK_UTILS -from mmdet.structures.bbox import BaseBoxes -from ..assigners.assign_result import AssignResult -from .base_sampler import BaseSampler -from .sampling_result import SamplingResult - - -# TODO: replace these sampler after refactor -@TASK_UTILS.register_module() -class MaskBoxPseudoSampler(BaseSampler): - """A pseudo sampler that does not do sampling actually.""" - - def __init__(self, **kwargs): - pass - - def _sample_pos(self, **kwargs): - """Sample positive samples.""" - raise NotImplementedError - - def _sample_neg(self, **kwargs): - """Sample negative samples.""" - raise NotImplementedError - - def sample(self, assign_result: AssignResult, pred_instances: InstanceData, - gt_instances: InstanceData, *args, **kwargs): - """Directly returns the positive and negative indices of samples. - - Args: - assign_result (:obj:`AssignResult`): Mask assigning results. - pred_instances (:obj:`InstanceData`): Instances of model - predictions. It includes ``scores`` and ``masks`` predicted - by the model. - gt_instances (:obj:`InstanceData`): Ground truth of instance - annotations. It usually includes ``labels`` and ``masks`` - attributes. - - Returns: - :obj:`SamplingResult`: sampler results - """ - gt_bboxes = gt_instances.bboxes - priors = pred_instances.priors - - pred_masks = pred_instances.bboxes - gt_masks = gt_instances.masks - pos_inds = torch.nonzero( - assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique() - neg_inds = torch.nonzero( - assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique() - gt_flags = pred_masks.new_zeros(pred_masks.shape[0], dtype=torch.uint8) - sampling_result = MaskBoxSamplingResult( - pos_inds=pos_inds, - neg_inds=neg_inds, - priors=priors, - masks=pred_masks, - gt_bboxes=gt_bboxes, - gt_masks=gt_masks, - assign_result=assign_result, - gt_flags=gt_flags, - avg_factor_with_neg=False) - return sampling_result - - -class MaskBoxSamplingResult(SamplingResult): - """Mask sampling result.""" - - def __init__(self, - pos_inds: Tensor, - neg_inds: Tensor, - priors: Tensor, - masks: Tensor, - gt_bboxes: Tensor, - gt_masks: Tensor, - assign_result: AssignResult, - gt_flags: Tensor, - avg_factor_with_neg: bool = True) -> None: - self.pos_inds = pos_inds - self.neg_inds = neg_inds - self.num_pos = max(pos_inds.numel(), 1) - self.num_neg = max(neg_inds.numel(), 1) - self.avg_factor = self.num_pos + self.num_neg \ - if avg_factor_with_neg else self.num_pos - - self.pos_masks = masks[pos_inds] - self.neg_masks = masks[neg_inds] - self.pos_is_gt = gt_flags[pos_inds] - - self.num_gts = gt_masks.shape[0] - self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 - - self.pos_priors = priors[pos_inds] - self.neg_priors = priors[neg_inds] - - self.pos_gt_labels = assign_result.labels[pos_inds] - box_dim = gt_bboxes.box_dim if isinstance(gt_bboxes, BaseBoxes) else 4 - if gt_bboxes.numel() == 0: - # hack for index error case - assert self.pos_assigned_gt_inds.numel() == 0 - self.pos_gt_bboxes = gt_bboxes.view(-1, box_dim) - else: - if len(gt_bboxes.shape) < 2: - gt_bboxes = gt_bboxes.view(-1, box_dim) - self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds.long()] - - if gt_masks.numel() == 0: - # hack for index error case - assert self.pos_assigned_gt_inds.numel() == 0 - self.pos_gt_masks = torch.empty_like(gt_masks) - else: - self.pos_gt_masks = gt_masks[self.pos_assigned_gt_inds, :] - - @property - def masks(self) -> Tensor: - """torch.Tensor: concatenated positive and negative masks.""" - return torch.cat([self.pos_masks, self.neg_masks]) - - def __nice__(self) -> str: - data = self.info.copy() - data['pos_masks'] = data.pop('pos_masks').shape - data['neg_masks'] = data.pop('neg_masks').shape - parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())] - body = ' ' + ',\n '.join(parts) - return '{\n' + body + '\n}' - - @property - def info(self) -> dict: - """Returns a dictionary of info about the object.""" - return { - 'pos_inds': self.pos_inds, - 'neg_inds': self.neg_inds, - 'pos_masks': self.pos_masks, - 'neg_masks': self.neg_masks, - 'pos_is_gt': self.pos_is_gt, - 'num_gts': self.num_gts, - 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, - } From 091d681b1b47133ecde1f12e7d8def7136e53f52 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 30 Nov 2022 14:56:10 +0800 Subject: [PATCH 03/17] fix batch --- configs/rtmdet/rtmdet_l_8xb32-300e_coco.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py index 85c66130178..33ccb839c6c 100644 --- a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py +++ b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py @@ -115,8 +115,8 @@ ] train_dataloader = dict( - batch_size=2, - num_workers=1, + batch_size=32, + num_workers=10, batch_sampler=None, pin_memory=True, dataset=dict(pipeline=train_pipeline)) From 4b7c5c455eb41530715e0a303feb7fb2d8620077 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 30 Nov 2022 15:13:43 +0800 Subject: [PATCH 04/17] fix act --- configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py | 4 ++-- configs/rtmdet/rtmdet_l_8xb32-300e_coco.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py index 404061f042f..783bc84184e 100644 --- a/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py +++ b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py @@ -9,7 +9,7 @@ share_conv=True, pred_kernel_size=1, feat_channels=256, - act_cfg=dict(type='SiLU'), + act_cfg=dict(type='SiLU', inplace=True), norm_cfg=dict(type='SyncBN', requires_grad=True), anchor_generator=dict( type='MlvlPointGenerator', offset=0, strides=[8, 16, 32]), @@ -21,7 +21,7 @@ loss_weight=1.0), loss_bbox=dict(type='GIoULoss', loss_weight=2.0), loss_mask=dict( - type='DiceLoss', loss_weight=1.0, eps=5e-6, reduction='mean')), + type='DiceLoss', loss_weight=2.0, eps=5e-6, reduction='mean')), test_cfg=dict(mask_thr_binary=0.5), ) diff --git a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py index 33ccb839c6c..6b2ce84e006 100644 --- a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py +++ b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py @@ -18,7 +18,7 @@ widen_factor=1, channel_attention=True, norm_cfg=dict(type='SyncBN'), - act_cfg=dict(type='SiLU')), + act_cfg=dict(type='SiLU', inplace=True)), neck=dict( type='CSPNeXtPAFPN', in_channels=[256, 512, 1024], @@ -26,7 +26,7 @@ num_csp_blocks=3, expand_ratio=0.5, norm_cfg=dict(type='SyncBN'), - act_cfg=dict(type='SiLU')), + act_cfg=dict(type='SiLU', inplace=True)), bbox_head=dict( type='RTMDetSepBNHead', num_classes=80, @@ -47,7 +47,7 @@ share_conv=True, pred_kernel_size=1, norm_cfg=dict(type='SyncBN'), - act_cfg=dict(type='SiLU')), + act_cfg=dict(type='SiLU', inplace=True)), train_cfg=dict( assigner=dict(type='DynamicSoftLabelAssigner', topk=13), allowed_border=-1, From 2f0de6eac6c102b4c1d1a163946afbee43029723 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Wed, 7 Dec 2022 19:58:34 +0800 Subject: [PATCH 05/17] fix act --- configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py | 7 +++++++ mmdet/models/dense_heads/rtmdet_ins_head.py | 2 -- 2 files changed, 7 insertions(+), 2 deletions(-) create mode 100644 configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py diff --git a/configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py new file mode 100644 index 00000000000..fc92c817e6b --- /dev/null +++ b/configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py @@ -0,0 +1,7 @@ +_base_ = './rtmdet-ins_l_8xb32-300e_coco.py' + +model = dict( + backbone=dict(deepen_factor=1.33, widen_factor=1.25), + neck=dict( + in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), + bbox_head=dict(in_channels=320, feat_channels=320)) diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py index df4f8272342..d0d51a9fb65 100644 --- a/mmdet/models/dense_heads/rtmdet_ins_head.py +++ b/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -43,7 +43,6 @@ def __init__(self, dyconv_channels: int = 8, num_dyconvs: int = 3, mask_loss_stride: int = 4, - use_condinst_coord=True, loss_mask=dict( type='DiceLoss', loss_weight=1.0, @@ -54,7 +53,6 @@ def __init__(self, self.num_dyconvs = num_dyconvs self.dyconv_channels = dyconv_channels self.mask_loss_stride = mask_loss_stride - self.use_condinst_coord = use_condinst_coord super().__init__(*args, **kwargs) self.loss_mask = MODELS.build(loss_mask) From 193d205be361b55b3aa0a1ecc8584a01e98c1e27 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Thu, 15 Dec 2022 18:54:51 +0800 Subject: [PATCH 06/17] [Enhance] Improve RTMDet AP with YOLO test config. --- configs/rtmdet/README.md | 10 +++++----- configs/rtmdet/rtmdet_l_8xb32-300e_coco.py | 11 +++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index f677baa5b0a..289393ff543 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -14,11 +14,11 @@ Our tech-report will be released soon. | Backbone | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | | :---------: | :--: | :----: | :-------: | :------: | :------------------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| RTMDet-tiny | 640 | 40.9 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) | -| RTMDet-s | 640 | 44.5 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) | -| RTMDet-m | 640 | 49.1 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) | -| RTMDet-l | 640 | 51.3 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) | -| RTMDet-x | 640 | 52.6 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) | +| RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) | +| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) | +| RTMDet-m | 640 | 49.4 | 24.71 | 39.27 | 1.62 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) | +| RTMDet-l | 640 | 51.5 | 52.3 | 80.23 | 2.44 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) | +| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) | **Note**: diff --git a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py index 33ccb839c6c..bbc2d804235 100644 --- a/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py +++ b/configs/rtmdet/rtmdet_l_8xb32-300e_coco.py @@ -54,11 +54,11 @@ pos_weight=-1, debug=False), test_cfg=dict( - nms_pre=1000, + nms_pre=30000, min_bbox_size=0, - score_thr=0.05, - nms=dict(type='nms', iou_threshold=0.6), - max_per_img=100), + score_thr=0.001, + nms=dict(type='nms', iou_threshold=0.65), + max_per_img=300), ) train_pipeline = [ @@ -134,6 +134,9 @@ val_interval=interval, dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)]) +val_evaluator = dict(proposal_nums=(100, 1, 10)) +test_evaluator = val_evaluator + # optimizer optim_wrapper = dict( _delete_=True, From f63eb189e6a32e9032760a3904c0b69138c77236 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 16 Dec 2022 15:30:31 +0800 Subject: [PATCH 07/17] update --- configs/rtmdet/README.md | 36 ++++++++++++++++--- .../rtmdet/rtmdet-ins_l_8xb32-300e_coco.py | 8 ++++- .../rtmdet/rtmdet-ins_x_8xb16-300e_coco.py | 31 ++++++++++++++++ .../rtmdet/rtmdet-ins_x_8xb32-300e_coco.py | 7 ---- 4 files changed, 69 insertions(+), 13 deletions(-) create mode 100644 configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py delete mode 100644 configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index 289393ff543..924fa2e10bf 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -1,18 +1,20 @@ -# RTMDet +# RTMDet: An Empirical Study of Designing Real-Time Object Detectors ## Abstract -Our tech-report will be released soon. +In this paper, we aim to design an efficient real-time object detector that exceeds the YOLO series and is easily extensible for many object recognition tasks such as instance segmentation and rotated object detection. To obtain a more efficient model architecture, we explore an architecture that has compatible capacities in the backbone and neck, constructed by a basic building block that consists of large-kernel depth-wise convolutions. We further introduce soft labels when calculating matching costs in the dynamic label assignment to improve accuracy. Together with better training techniques, the resulting object detector, named RTMDet, achieves 52.8% AP on COCO with 300+ FPS on an NVIDIA 3090 GPU, outperforming the current mainstream industrial detectors. RTMDet achieves the best parameter-accuracy trade-off with tiny/small/medium/large/extra-large model sizes for various application scenarios, and obtains new state-of-the-art performance on real-time instance segmentation and rotated object detection. We hope the experimental results can provide new insights into designing versatile real-time object detectors for many object recognition tasks.
- +
## Results and Models -| Backbone | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | +## Object Detection + +| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | | :---------: | :--: | :----: | :-------: | :------: | :------------------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) | | RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) | @@ -22,4 +24,28 @@ Our tech-report will be released soon. **Note**: -1. The inference speed is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS. +1. The inference speed of RTMDet is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS. + +## Instance Segmentation + +| Model | size | box AP | mask AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | +| :-------------: | :--: | :----: | :-----: | :-------: | :------: | :------------------: | :--------------------------------------------: | :----------------------: | +| RTMDet-Ins-tiny | 640 | 40.5 | 35.4 | 5.6 | 11.8 | 1.70 | [config](./rtmdet-ins_tiny_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | +| RTMDet-Ins-s | 640 | 44.0 | 38.7 | 10.18 | 21.5 | 1.93 | [config](./rtmdet-ins_s_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | +| RTMDet-Ins-m | 640 | 48.8 | 42.1 | 27.58 | 54.13 | 2.69 | [config](./rtmdet-ins_m_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | +| RTMDet-Ins-l | 640 | 51.2 | 43.7 | 57.37 | 106.56 | 3.68 | [config](./rtmdet-ins_l_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | +| RTMDet-Ins-x | 640 | 52.4 | 44.6 | 102.7 | 182.7 | 5.31 | [config](./rtmdet-ins_x_8xb16-300e_coco.py) | [model](<>) \| [log](<>) | + +**Note**: + +1. The inference speed of RTMDet-Ins is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1. Top 100 masks are kept and the post process latency is included. + +## Rotated Object Detection + +Models and configs of RTMDet-R are available in [MMRotate](https://github.com/open-mmlab/mmrotate/tree/1.x/configs/rotated_rtmdet) + +## Visualization + +
+ +
diff --git a/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py index 783bc84184e..1ecacab8044 100644 --- a/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py +++ b/configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py @@ -22,7 +22,13 @@ loss_bbox=dict(type='GIoULoss', loss_weight=2.0), loss_mask=dict( type='DiceLoss', loss_weight=2.0, eps=5e-6, reduction='mean')), - test_cfg=dict(mask_thr_binary=0.5), + test_cfg=dict( + nms_pre=1000, + min_bbox_size=0, + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100, + mask_thr_binary=0.5), ) train_pipeline = [ diff --git a/configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py b/configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py new file mode 100644 index 00000000000..daaa640edac --- /dev/null +++ b/configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py @@ -0,0 +1,31 @@ +_base_ = './rtmdet-ins_l_8xb32-300e_coco.py' + +model = dict( + backbone=dict(deepen_factor=1.33, widen_factor=1.25), + neck=dict( + in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), + bbox_head=dict(in_channels=320, feat_channels=320)) + +base_lr = 0.002 + +# optimizer +optim_wrapper = dict(optimizer=dict(lr=base_lr)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 150 to 300 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=_base_.max_epochs // 2, + end=_base_.max_epochs, + T_max=_base_.max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] diff --git a/configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py b/configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py deleted file mode 100644 index fc92c817e6b..00000000000 --- a/configs/rtmdet/rtmdet-ins_x_8xb32-300e_coco.py +++ /dev/null @@ -1,7 +0,0 @@ -_base_ = './rtmdet-ins_l_8xb32-300e_coco.py' - -model = dict( - backbone=dict(deepen_factor=1.33, widen_factor=1.25), - neck=dict( - in_channels=[320, 640, 1280], out_channels=320, num_csp_blocks=4), - bbox_head=dict(in_channels=320, feat_channels=320)) From fd887eea3e7d95faba2a441cf17e746feab15597 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 16 Dec 2022 16:11:40 +0800 Subject: [PATCH 08/17] update docstring --- mmdet/models/dense_heads/rtmdet_ins_head.py | 123 +++++++++++++++----- 1 file changed, 93 insertions(+), 30 deletions(-) diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py index d0d51a9fb65..c5213a50899 100644 --- a/mmdet/models/dense_heads/rtmdet_ins_head.py +++ b/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -28,13 +28,13 @@ class RTMDetInsHead(RTMDetHead): """Detection Head of RTMDet-Ins. Args: - num_classes (int): Number of categories excluding the background - category. - in_channels (int): Number of channels in the input feature map. - with_objectness (bool): Whether to add an objectness branch. - Defaults to True. - act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. - Default: dict(type='ReLU') + num_prototypes (int): Number of mask prototype features extracted + from the mask head. + dyconv_channels (int): Channel of the dynamic conv layers. + num_dyconvs (int): Number of the dynamic convolution layers. + mask_loss_stride (int): Down sample stride of the masks for loss + computation. + loss_mask (:obj:`ConfigDict` or dict): Config dict for mask loss. """ def __init__(self, @@ -56,7 +56,7 @@ def __init__(self, super().__init__(*args, **kwargs) self.loss_mask = MODELS.build(loss_mask) - def _init_layers(self): + def _init_layers(self) -> None: """Initialize layers of the head.""" super()._init_layers() self.kernel_convs = nn.ModuleList() @@ -119,6 +119,11 @@ def forward(self, feats: Tuple[Tensor, ...]) -> tuple: - bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * 4. + - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale + levels, each is a 4D-tensor, the channels number is + num_gen_params. + - mask_feat (Tensor): Output feature of the mask head. Each is a + 4D-tensor, the channels number is num_prototypes. """ mask_feat = self.mask_head(feats) @@ -202,6 +207,7 @@ def predict_by_feat(self, (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). """ assert len(cls_scores) == len(bbox_preds) @@ -301,6 +307,7 @@ def _predict_by_feat_single(self, (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). """ if score_factor_list[0] is None: # e.g. Retina, FreeAnchor, etc. @@ -435,6 +442,7 @@ def _bbox_post_process(self, (num_instances, ). - bboxes (Tensor): Has a shape (num_instances, 4), the last dimension 4 arrange as (x1, y1, x2, y2). + - masks (Tensor): Has a shape (num_instances, h, w). """ stride = self.prior_generator.strides[0][0] if rescale: @@ -467,17 +475,13 @@ def _bbox_post_process(self, results = results[:cfg.max_per_img] # process masks - h, w = img_meta['img_shape'][:2] - mask_logits = self._mask_predict_by_feat_single( mask_feat, results.kernels, results.priors) mask_logits = F.interpolate( mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear') - # print('upsampled croped shape ', mask_logits.shape) if rescale: ori_h, ori_w = img_meta['ori_shape'][:2] - # print('scale_factor: ',scale_factor) mask_logits = F.interpolate( mask_logits, size=[ @@ -486,7 +490,6 @@ def _bbox_post_process(self, ], mode='bilinear', align_corners=False)[..., :ori_h, :ori_w] - # print('rescale shape ', mask_logits.shape, ori_h, ori_w) masks = mask_logits.sigmoid().squeeze(0) masks = masks > cfg.mask_thr_binary results.masks = masks @@ -521,7 +524,21 @@ def parse_dynamic_params(self, flatten_kernels): return weight_splits, bias_splits - def _mask_predict_by_feat_single(self, mask_feat, kernels, priors): + def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, + priors: Tensor) -> Tensor: + """Generate mask logits from mask features with dynamic convs. + + Args: + mask_feat (Tensor): Mask prototype features. + Has shape (num_prototypes, H, W). + kernels (Tensor): Kernel parameters for each instance. + Has shape (num_instance, num_params) + priors (Tensor): Center priors for each instance. + Has shape (num_instance, 4). + Returns: + Tensor: Instance segmentation masks for each instance. + Has shape (num_instance, H, W). + """ num_inst = priors.shape[0] h, w = mask_feat.size()[-2:] if num_inst < 1: @@ -556,8 +573,25 @@ def _mask_predict_by_feat_single(self, mask_feat, kernels, priors): x = x.reshape(num_inst, h, w) return x - def loss_mask_by_feat(self, mask_feats, flatten_kernels, - sampling_results_list, batch_gt_instances): + def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, + sampling_results_list: list, + batch_gt_instances: InstanceList): + """Compute instance segmentation loss. + + Args: + mask_feats (list[Tensor]): Mask prototype features extracted from + the mask head. Has shape (N, num_prototypes, H, W) + flatten_kernels (list[Tensor]): Kernels of the dynamic conv layers. + Has shape (N, num_instances, num_params) + sampling_results_list (list[:obj:`SamplingResults`]) Batch of + assignment results. + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ batch_pos_mask_logits = [] pos_gt_masks = [] for idx, (mask_feat, kernels, sampling_results, @@ -705,15 +739,34 @@ def loss_by_feat(self, class MaskFeatModule(BaseModule): + """Mask feature head used in RTMDet-Ins. - def __init__(self, - in_channels, - feat_channels=256, - stacked_convs=4, - num_levels=3, - num_prototypes=8, - act_cfg=dict(type='SiLU'), - norm_cfg=dict(type='BN')): + Args: + in_channels (int): Number of channels in the input feature map. + feat_channels (int): Number of hidden channels of the mask feature + map branch. + num_levels (int): The starting feature map level from RPN that + will be used to predict the mask feature map. + num_prototypes (int): Number of output channel of the mask feature + map branch. This is the channel count of the mask + feature map that to be dynamically convolved with the predicted + kernel. + stacked_convs (int): Number of convs in mask feature branch. + act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. + Default: dict(type='ReLU') + norm_cfg (dict): Config dict for normalization layer. Default: None. + """ + + def __init__( + self, + in_channels: int, + feat_channels: int = 256, + stacked_convs: int = 4, + num_levels: int = 3, + num_prototypes: int = 8, + act_cfg: ConfigType = dict(type='ReLU'), + norm_cfg: ConfigType = dict(type='BN') + ) -> None: super().__init__(init_cfg=None) self.num_levels = num_levels self.fusion_conv = nn.Conv2d(num_levels * in_channels, in_channels, 1) @@ -732,7 +785,7 @@ def __init__(self, self.projection = nn.Conv2d( feat_channels, num_prototypes, kernel_size=1) - def forward(self, features): + def forward(self, features: Tuple[Tensor, ...]) -> Tensor: # multi-level feature fusion fusion_feats = [features[0]] size = features[0].shape[-2:] @@ -749,16 +802,19 @@ def forward(self, features): @MODELS.register_module() class RTMDetInsSepBNHead(RTMDetInsHead): - """Detection Head of RTMDet-ins-seg with sep-bn layers. + """Detection Head of RTMDet-Ins with sep-bn layers. Args: num_classes (int): Number of categories excluding the background category. in_channels (int): Number of channels in the input feature map. - with_objectness (bool): Whether to add an objectness branch. + share_conv (bool): Whether to share conv layers between stages. Defaults to True. - act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. - Default: dict(type='ReLU') + norm_cfg (:obj:`ConfigDict` or dict)): Config dict for normalization + layer. Defaults to dict(type='BN'). + act_cfg (:obj:`ConfigDict` or dict)): Config dict for activation layer. + Defaults to dict(type='SiLU', inplace=True). + pred_kernel_size (int): Kernel size of prediction layer. Defaults to 1. """ def __init__(self, @@ -767,6 +823,7 @@ def __init__(self, share_conv=True, with_objectness=False, norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='SiLU', inplace=True), pred_kernel_size=1, **kwargs) -> None: self.share_conv = share_conv @@ -774,11 +831,12 @@ def __init__(self, num_classes, in_channels, norm_cfg=norm_cfg, + act_cfg=act_cfg, pred_kernel_size=pred_kernel_size, with_objectness=with_objectness, **kwargs) - def _init_layers(self): + def _init_layers(self) -> None: """Initialize layers of the head.""" self.cls_convs = nn.ModuleList() self.reg_convs = nn.ModuleList() @@ -919,6 +977,11 @@ def forward(self, feats: Tuple[Tensor, ...]) -> tuple: - bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, the channels number is num_base_priors * 4. + - kernel_preds (list[Tensor]): Dynamic conv kernels for all scale + levels, each is a 4D-tensor, the channels number is + num_gen_params. + - mask_feat (Tensor): Output feature of the mask head. Each is a + 4D-tensor, the channels number is num_prototypes. """ mask_feat = self.mask_head(feats) From d2c4f8bb574d5c0674d707dada44c9cdd6ce210b Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 16 Dec 2022 16:15:17 +0800 Subject: [PATCH 09/17] clean code --- .../models/task_modules/assigners/dynamic_soft_label_assigner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py index 0e6651b8c01..d24b4006ad2 100644 --- a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py +++ b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py @@ -131,7 +131,6 @@ def assign(self, num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels) if hasattr(gt_instances, 'masks'): gt_center = center_of_mass(gt_instances.masks, eps=EPS) - # print(gt_center, (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2.0) elif isinstance(gt_bboxes, BaseBoxes): gt_center = gt_bboxes.centers else: From 5ce5e09b510e8ad8c04a08f2254cfbc57cfe375c Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Sat, 17 Dec 2022 16:47:31 +0800 Subject: [PATCH 10/17] update --- mmdet/datasets/transforms/transforms.py | 4 +- mmdet/models/dense_heads/rtmdet_ins_head.py | 73 +++++++++++-------- .../assigners/dynamic_soft_label_assigner.py | 11 ++- 3 files changed, 56 insertions(+), 32 deletions(-) diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index 791f36caffb..c9e95bd7476 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -3555,8 +3555,8 @@ def transform(self, results: dict) -> dict: retrieve_gt_bboxes = retrieve_results['gt_bboxes'] retrieve_gt_bboxes.rescale_([scale_ratio, scale_ratio]) if with_mask: - retrieve_gt_masks: BitmapMasks = retrieve_results[ - 'gt_masks'].rescale(scale_ratio) + retrieve_gt_masks = retrieve_results['gt_masks'].rescale( + scale_ratio) if self.bbox_clip_border: retrieve_gt_bboxes.clip_([origin_h, origin_w]) diff --git a/mmdet/models/dense_heads/rtmdet_ins_head.py b/mmdet/models/dense_heads/rtmdet_ins_head.py index c5213a50899..e355bdb79f8 100644 --- a/mmdet/models/dense_heads/rtmdet_ins_head.py +++ b/mmdet/models/dense_heads/rtmdet_ins_head.py @@ -29,11 +29,13 @@ class RTMDetInsHead(RTMDetHead): Args: num_prototypes (int): Number of mask prototype features extracted - from the mask head. + from the mask head. Defaults to 8. dyconv_channels (int): Channel of the dynamic conv layers. + Defaults to 8. num_dyconvs (int): Number of the dynamic convolution layers. + Defaults to 3. mask_loss_stride (int): Down sample stride of the masks for loss - computation. + computation. Defaults to 4. loss_mask (:obj:`ConfigDict` or dict): Config dict for mask loss. """ @@ -45,7 +47,7 @@ def __init__(self, mask_loss_stride: int = 4, loss_mask=dict( type='DiceLoss', - loss_weight=1.0, + loss_weight=2.0, eps=5e-6, reduction='mean'), **kwargs) -> None: @@ -59,20 +61,22 @@ def __init__(self, def _init_layers(self) -> None: """Initialize layers of the head.""" super()._init_layers() + # a branch to predict kernels of dynamic convs self.kernel_convs = nn.ModuleList() # calculate num dynamic parameters weight_nums, bias_nums = [], [] for i in range(self.num_dyconvs): if i == 0: weight_nums.append( + # mask prototype and coordinate features (self.num_prototypes + 2) * self.dyconv_channels) - bias_nums.append(self.dyconv_channels) + bias_nums.append(self.dyconv_channels * 1) elif i == self.num_dyconvs - 1: - weight_nums.append(self.dyconv_channels) + weight_nums.append(self.dyconv_channels * 1) bias_nums.append(1) else: weight_nums.append(self.dyconv_channels * self.dyconv_channels) - bias_nums.append(self.dyconv_channels) + bias_nums.append(self.dyconv_channels * 1) self.weight_nums = weight_nums self.bias_nums = bias_nums self.num_gen_params = sum(weight_nums) + sum(bias_nums) @@ -184,6 +188,11 @@ def predict_by_feat(self, bbox_preds (list[Tensor]): Box energies / deltas for all scale levels, each is a 4D-tensor, has shape (batch_size, num_priors * 4, H, W). + kernel_preds (list[Tensor]): Kernel predictions of dynamic + convs for all scale levels, each is a 4D-tensor, has shape + (batch_size, num_params, H, W). + mask_feat (Tensor): Mask prototype features extracted from the + mask head, has shape (batch_size, num_prototypes, H, W). score_factors (list[Tensor], optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, num_priors * 1, H, W). Defaults to None. @@ -261,8 +270,8 @@ def predict_by_feat(self, def _predict_by_feat_single(self, cls_score_list: List[Tensor], bbox_pred_list: List[Tensor], - kernel_pred_list, - mask_feat, + kernel_pred_list: List[Tensor], + mask_feat: Tensor, score_factor_list: List[Tensor], mlvl_priors: List[Tensor], img_meta: dict, @@ -270,7 +279,7 @@ def _predict_by_feat_single(self, rescale: bool = False, with_nms: bool = True) -> InstanceData: """Transform a single image's features extracted from the head into - bbox results. + bbox and mask results. Args: cls_score_list (list[Tensor]): Box scores from all scale @@ -279,6 +288,11 @@ def _predict_by_feat_single(self, bbox_pred_list (list[Tensor]): Box energies / deltas from all scale levels of a single image, each item has shape (num_priors * 4, H, W). + kernel_preds (list[Tensor]): Kernel predictions of dynamic + convs for all scale levels of a single image, each is a + 4D-tensor, has shape (num_params, H, W). + mask_feat (Tensor): Mask prototype features of a single image + extracted from the mask head, has shape (num_prototypes, H, W). score_factor_list (list[Tensor]): Score factor from all scale levels of a single image, each item has shape (num_priors * 1, H, W). @@ -400,7 +414,7 @@ def _predict_by_feat_single(self, if with_score_factors: results.score_factors = torch.cat(mlvl_score_factors) - return self._bbox_post_process( + return self._bbox_mask_post_process( results=results, mask_feat=mask_feat, cfg=cfg, @@ -408,14 +422,15 @@ def _predict_by_feat_single(self, with_nms=with_nms, img_meta=img_meta) - def _bbox_post_process(self, - results: InstanceData, - mask_feat, - cfg: ConfigType, - rescale: bool = False, - with_nms: bool = True, - img_meta: Optional[dict] = None) -> InstanceData: - """bbox post-processing method. + def _bbox_mask_post_process( + self, + results: InstanceData, + mask_feat, + cfg: ConfigType, + rescale: bool = False, + with_nms: bool = True, + img_meta: Optional[dict] = None) -> InstanceData: + """bbox and mask post-processing method. The boxes would be rescaled to the original image scale and do the nms operation. Usually `with_nms` is False is used for aug test. @@ -503,7 +518,7 @@ def _bbox_post_process(self, return results - def parse_dynamic_params(self, flatten_kernels): + def parse_dynamic_params(self, flatten_kernels: Tensor) -> tuple: """split kernel head prediction to conv weight and bias.""" n_inst = flatten_kernels.size(0) n_layers = len(self.weight_nums) @@ -575,7 +590,7 @@ def _mask_predict_by_feat_single(self, mask_feat: Tensor, kernels: Tensor, def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, sampling_results_list: list, - batch_gt_instances: InstanceList): + batch_gt_instances: InstanceList) -> Tensor: """Compute instance segmentation loss. Args: @@ -590,7 +605,7 @@ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, attributes. Returns: - dict[str, Tensor]: A dictionary of loss components. + Tensor: The mask loss tensor. """ batch_pos_mask_logits = [] pos_gt_masks = [] @@ -620,7 +635,7 @@ def loss_mask_by_feat(self, mask_feats: Tensor, flatten_kernels: Tensor, ])).clamp_(min=1).item() if batch_pos_mask_logits.shape[0] == 0: - return mask_feats.sum() * 0, mask_feats.sum() * 0 + return mask_feats.sum() * 0 scale = self.prior_generator.strides[0][0] // self.mask_loss_stride # upsample pred masks @@ -753,7 +768,7 @@ class MaskFeatModule(BaseModule): kernel. stacked_convs (int): Number of convs in mask feature branch. act_cfg (:obj:`ConfigDict` or dict): Config dict for activation layer. - Default: dict(type='ReLU') + Default: dict(type='ReLU', inplace=True) norm_cfg (dict): Config dict for normalization layer. Default: None. """ @@ -764,7 +779,7 @@ def __init__( stacked_convs: int = 4, num_levels: int = 3, num_prototypes: int = 8, - act_cfg: ConfigType = dict(type='ReLU'), + act_cfg: ConfigType = dict(type='ReLU', inplace=True), norm_cfg: ConfigType = dict(type='BN') ) -> None: super().__init__(init_cfg=None) @@ -820,11 +835,11 @@ class RTMDetInsSepBNHead(RTMDetInsHead): def __init__(self, num_classes: int, in_channels: int, - share_conv=True, - with_objectness=False, - norm_cfg=dict(type='BN', requires_grad=True), - act_cfg=dict(type='SiLU', inplace=True), - pred_kernel_size=1, + share_conv: bool = True, + with_objectness: bool = False, + norm_cfg: ConfigType = dict(type='BN', requires_grad=True), + act_cfg: ConfigType = dict(type='SiLU', inplace=True), + pred_kernel_size: int = 1, **kwargs) -> None: self.share_conv = share_conv super().__init__( diff --git a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py index d24b4006ad2..3fc7af39b22 100644 --- a/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py +++ b/mmdet/models/task_modules/assigners/dynamic_soft_label_assigner.py @@ -16,7 +16,16 @@ EPS = 1.0e-7 -def center_of_mass(masks: Tensor, eps=1e-6): +def center_of_mass(masks: Tensor, eps: float = 1e-7) -> Tensor: + """Compute the masks center of mass. + + Args: + masks: Mask tensor, has shape (num_masks, H, W). + eps: a small number to avoid normalizer to be zero. + Defaults to 1e-7. + Returns: + Tensor: The masks center of mass. Has shape (num_masks, 2). + """ n, h, w = masks.shape grid_h = torch.arange(h, device=masks.device)[:, None] grid_w = torch.arange(w, device=masks.device) From 3785af95e2a5364475732ca9bb36ed4d48a9fdc2 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Fri, 16 Dec 2022 17:46:18 +0800 Subject: [PATCH 11/17] update readme --- configs/rtmdet/README.md | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index 924fa2e10bf..00d28c30d20 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -7,7 +7,7 @@ In this paper, we aim to design an efficient real-time object detector that exceeds the YOLO series and is easily extensible for many object recognition tasks such as instance segmentation and rotated object detection. To obtain a more efficient model architecture, we explore an architecture that has compatible capacities in the backbone and neck, constructed by a basic building block that consists of large-kernel depth-wise convolutions. We further introduce soft labels when calculating matching costs in the dynamic label assignment to improve accuracy. Together with better training techniques, the resulting object detector, named RTMDet, achieves 52.8% AP on COCO with 300+ FPS on an NVIDIA 3090 GPU, outperforming the current mainstream industrial detectors. RTMDet achieves the best parameter-accuracy trade-off with tiny/small/medium/large/extra-large model sizes for various application scenarios, and obtains new state-of-the-art performance on real-time instance segmentation and rotated object detection. We hope the experimental results can provide new insights into designing versatile real-time object detectors for many object recognition tasks.
- +
## Results and Models @@ -25,6 +25,7 @@ In this paper, we aim to design an efficient real-time object detector that exce **Note**: 1. The inference speed of RTMDet is measured on an NVIDIA 3090 GPU with TensorRT 8.4.3, cuDNN 8.2.0, FP16, batch size=1, and without NMS. +2. For a fair comparison, the config of bbox postprocessing is changed to be consistent with YOLOv5/6/7 after [PR#9494](https://github.com/open-mmlab/mmdetection/pull/9494), bringing about 0.1~0.3% AP improvement. ## Instance Segmentation @@ -44,6 +45,19 @@ In this paper, we aim to design an efficient real-time object detector that exce Models and configs of RTMDet-R are available in [MMRotate](https://github.com/open-mmlab/mmrotate/tree/1.x/configs/rotated_rtmdet) +## Citation + +```latex +@misc{lyu2022rtmdet, + title={RTMDet: An Empirical Study of Designing Real-Time Object Detectors}, + author={Chengqi Lyu and Wenwei Zhang and Haian Huang and Yue Zhou and Yudong Wang and Yanyi Liu and Shilong Zhang and Kai Chen}, + year={2022}, + eprint={2212.07784}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + ## Visualization
From 88c251d26aef53664e34beba4ef4bd050a712660 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 19 Dec 2022 11:02:29 +0800 Subject: [PATCH 12/17] update readme --- configs/rtmdet/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index 00d28c30d20..9450577142b 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -1,5 +1,9 @@ # RTMDet: An Empirical Study of Designing Real-Time Object Detectors +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real) + ## Abstract From 3b1410aff6f08f4e8f1afb76ae44f73ff467f022 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 19 Dec 2022 13:14:37 +0800 Subject: [PATCH 13/17] release model weights --- configs/rtmdet/README.md | 14 +++--- configs/rtmdet/metafile.yml | 85 +++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 7 deletions(-) diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index 9450577142b..b7f7ccf3188 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -33,13 +33,13 @@ In this paper, we aim to design an efficient real-time object detector that exce ## Instance Segmentation -| Model | size | box AP | mask AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | -| :-------------: | :--: | :----: | :-----: | :-------: | :------: | :------------------: | :--------------------------------------------: | :----------------------: | -| RTMDet-Ins-tiny | 640 | 40.5 | 35.4 | 5.6 | 11.8 | 1.70 | [config](./rtmdet-ins_tiny_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | -| RTMDet-Ins-s | 640 | 44.0 | 38.7 | 10.18 | 21.5 | 1.93 | [config](./rtmdet-ins_s_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | -| RTMDet-Ins-m | 640 | 48.8 | 42.1 | 27.58 | 54.13 | 2.69 | [config](./rtmdet-ins_m_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | -| RTMDet-Ins-l | 640 | 51.2 | 43.7 | 57.37 | 106.56 | 3.68 | [config](./rtmdet-ins_l_8xb32-300e_coco.py) | [model](<>) \| [log](<>) | -| RTMDet-Ins-x | 640 | 52.4 | 44.6 | 102.7 | 182.7 | 5.31 | [config](./rtmdet-ins_x_8xb16-300e_coco.py) | [model](<>) \| [log](<>) | +| Model | size | box AP | mask AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | +| :-------------: | :--: | :----: | :-----: | :-------: | :------: | :------------------: | :--------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| RTMDet-Ins-tiny | 640 | 40.5 | 35.4 | 5.6 | 11.8 | 1.70 | [config](./rtmdet-ins_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727.log.json) | +| RTMDet-Ins-s | 640 | 44.0 | 38.7 | 10.18 | 21.5 | 1.93 | [config](./rtmdet-ins_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604.log.json) | +| RTMDet-Ins-m | 640 | 48.8 | 42.1 | 27.58 | 54.13 | 2.69 | [config](./rtmdet-ins_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_m_8xb32-300e_coco/rtmdet-ins_m_8xb32-300e_coco_20221123_001039-6eba602e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_m_8xb32-300e_coco/rtmdet-ins_m_8xb32-300e_coco_20221123_001039.log.json) | +| RTMDet-Ins-l | 640 | 51.2 | 43.7 | 57.37 | 106.56 | 3.68 | [config](./rtmdet-ins_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb32-300e_coco/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb32-300e_coco/rtmdet-ins_l_8xb32-300e_coco_20221124_103237.log.json) | +| RTMDet-Ins-x | 640 | 52.4 | 44.6 | 102.7 | 182.7 | 5.31 | [config](./rtmdet-ins_x_8xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_x_8xb16-300e_coco/rtmdet-ins_x_8xb16-300e_coco_20221124_111313-33d4595b.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb16-300e_coco/rtmdet-ins_x_8xb16-300e_coco_20221124_111313.log.json) | **Note**: diff --git a/configs/rtmdet/metafile.yml b/configs/rtmdet/metafile.yml index 0d854191934..9c0487f3ff1 100644 --- a/configs/rtmdet/metafile.yml +++ b/configs/rtmdet/metafile.yml @@ -79,3 +79,88 @@ Models: Metrics: box AP: 52.6 Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth + + - Name: rtmdet-ins_tiny_8xb32-300e_coco + In Collection: RTMDet + Config: configs/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco.py + Metadata: + Training Memory (GB): 18.4 + Epochs: 300 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 40.5 + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 35.4 + Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth + + - Name: rtmdet-ins_s_8xb32-300e_coco + In Collection: RTMDet + Config: configs/rtmdet/rtmdet-ins_s_8xb32-300e_coco.py + Metadata: + Training Memory (GB): 27.6 + Epochs: 300 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 44.0 + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 38.7 + Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_s_8xb32-300e_coco/rtmdet-ins_s_8xb32-300e_coco_20221121_212604-fdc5d7ec.pth + + - Name: rtmdet-ins_m_8xb32-300e_coco + In Collection: RTMDet + Config: configs/rtmdet/rtmdet-ins_m_8xb32-300e_coco.py + Metadata: + Training Memory (GB): 42.5 + Epochs: 300 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 48.8 + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 42.1 + Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_m_8xb32-300e_coco/rtmdet-ins_m_8xb32-300e_coco_20221123_001039-6eba602e.pth + + - Name: rtmdet-ins_l_8xb32-300e_coco + In Collection: RTMDet + Config: configs/rtmdet/rtmdet-ins_l_8xb32-300e_coco.py + Metadata: + Training Memory (GB): 59.8 + Epochs: 300 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 51.2 + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 43.7 + Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_l_8xb32-300e_coco/rtmdet-ins_l_8xb32-300e_coco_20221124_103237-78d1d652.pth + + - Name: rtmdet-ins_x_8xb16-300e_coco + In Collection: RTMDet + Config: configs/rtmdet/rtmdet-ins_x_8xb16-300e_coco.py + Metadata: + Training Memory (GB): 33.7 + Epochs: 300 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 52.4 + - Task: Instance Segmentation + Dataset: COCO + Metrics: + mask AP: 44.6 + Weights: https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_x_8xb16-300e_coco/rtmdet-ins_x_8xb16-300e_coco_20221124_111313-33d4595b.pth From 7c74f9edd0c25bf230e944e5558e16cd9c1b605c Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 19 Dec 2022 13:23:31 +0800 Subject: [PATCH 14/17] update readme --- configs/rtmdet/README.md | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index b7f7ccf3188..1c06812a748 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -33,6 +33,10 @@ In this paper, we aim to design an efficient real-time object detector that exce ## Instance Segmentation +RTMDet-Ins is the state-of-the-art real-time instance segmentation on coco dataset: + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) + | Model | size | box AP | mask AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms) | Config | Download | | :-------------: | :--: | :----: | :-----: | :-------: | :------: | :------------------: | :--------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | RTMDet-Ins-tiny | 640 | 40.5 | 35.4 | 5.6 | 11.8 | 1.70 | [config](./rtmdet-ins_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727-ec670f7e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet-ins_tiny_8xb32-300e_coco/rtmdet-ins_tiny_8xb32-300e_coco_20221130_151727.log.json) | @@ -47,6 +51,16 @@ In this paper, we aim to design an efficient real-time object detector that exce ## Rotated Object Detection +RTMDet-R achieves state-of-the-art on various remote sensing datasets + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/one-stage-anchor-free-oriented-object-1)](https://paperswithcode.com/sota/one-stage-anchor-free-oriented-object-1?p=rtmdet-an-empirical-study-of-designing-real) + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real) + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/one-stage-anchor-free-oriented-object-3)](https://paperswithcode.com/sota/one-stage-anchor-free-oriented-object-3?p=rtmdet-an-empirical-study-of-designing-real) + Models and configs of RTMDet-R are available in [MMRotate](https://github.com/open-mmlab/mmrotate/tree/1.x/configs/rotated_rtmdet) ## Citation From ed853b82e2669657c929a2da8f4a9376e8389a2d Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 19 Dec 2022 14:08:50 +0800 Subject: [PATCH 15/17] update readme --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3014050d6b8..4900ce8e515 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,4 +18,4 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA From d012ba1729857f0609a6cd478f9184081903071d Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 19 Dec 2022 14:10:52 +0800 Subject: [PATCH 16/17] update readme --- README.md | 18 ++++++++++++++++++ README_zh-CN.md | 18 ++++++++++++++++++ setup.cfg | 2 +- 3 files changed, 37 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 39688919559..17d38313d59 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,24 @@ Apart from MMDetection, we also released [MMEngine](https://github.com/open-mmla ## What's New +### Highlight + +We are excited to announce our latest work on real-time object recognition tasks, **RTMDet**, a family of fully convolutional single-stage detectors. RTMDet not only achieves the best parameter-accuracy trade-off on object detection from tiny to extra-large model sizes, but also becomes the state-of-the-art on instance segmentation and rotated object detection tasks. Detailed explanation of its methodologies is available in our [technical report](https://arxiv.org/abs/2212.07784). Pre-trained models can be found [here](configs/rtmdet). + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real) + +| Task | Dataset | AP | FPS(TRT FP16 BS1 3090) | +| ------------------------ | ------- | ------------------------------------ | ---------------------- | +| Object Detection | COCO | 52.8 | 322 | +| Instance Segmentation | COCO | 44.6 | 188 | +| Rotated Object Detection | DOTA | 78.9(single-scale)/81.3(multi-scale) | 121 | + +
+ +
+ **v3.0.0rc4** was released in 25/11/2022: - Support [CondInst](https://arxiv.org/abs/2003.05664) diff --git a/README_zh-CN.md b/README_zh-CN.md index 4255bfae257..fb53f2da004 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -74,6 +74,24 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope ## 最新进展 +### 亮点 + +我们很高兴向大家介绍我们在实时目标识别任务方面的最新成果 RTMDet,包含了一系列的全卷积单阶段检测模型。 RTMDet 不仅在从tiny到extra-large尺寸的目标检测模型上上实现了最佳的参数量和精度的平衡,而且在实时实例分割和旋转目标检测任务上取得了最先进的成果。 我们的[技术报告](https://arxiv.org/abs/2212.07784)中提供了其方法的详细说明。 预训练模型可以在[这里](configs/rtmdet)找到。 + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-hrsc2016)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-hrsc2016?p=rtmdet-an-empirical-study-of-designing-real) + +| Task | Dataset | AP | FPS(TRT FP16 BS1 3090) | +| ------------------------ | ------- | ------------------------------------ | ---------------------- | +| Object Detection | COCO | 52.8 | 322 | +| Instance Segmentation | COCO | 44.6 | 188 | +| Rotated Object Detection | DOTA | 78.9(single-scale)/81.3(multi-scale) | 121 | + +
+ +
+ **v3.0.0rc4** 版本已经在 2022.11.25 发布: - 支持了 [CondInst](https://arxiv.org/abs/2003.05664) diff --git a/setup.cfg b/setup.cfg index 4900ce8e515..70dd621c8f5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,4 +18,4 @@ SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true [codespell] skip = *.ipynb quiet-level = 3 -ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA +ignore-words-list = patten,nd,ty,mot,hist,formating,winn,gool,datas,wan,confids,TOOD,tood,ba,warmup,nam,DOTA,dota From d223850e8dbbf86c1ec5ad45033140e14b1549f4 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 19 Dec 2022 15:32:47 +0800 Subject: [PATCH 17/17] update readme --- README.md | 5 ++++- README_zh-CN.md | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 17d38313d59..0ce31cf8820 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,7 @@ The master branch works with **PyTorch 1.6+**. - **State of the art** The toolbox stems from the codebase developed by the *MMDet* team, who won [COCO Detection Challenge](http://cocodataset.org/#detection-leaderboard) in 2018, and we keep pushing it forward. + The newly released [RTMDet](configs/rtmdet) also obtains new state-of-the-art results on real-time instance segmentation and rotated object detection tasks and the best parameter-accuracy trade-off on object detection. @@ -77,7 +78,7 @@ Apart from MMDetection, we also released [MMEngine](https://github.com/open-mmla ### Highlight -We are excited to announce our latest work on real-time object recognition tasks, **RTMDet**, a family of fully convolutional single-stage detectors. RTMDet not only achieves the best parameter-accuracy trade-off on object detection from tiny to extra-large model sizes, but also becomes the state-of-the-art on instance segmentation and rotated object detection tasks. Detailed explanation of its methodologies is available in our [technical report](https://arxiv.org/abs/2212.07784). Pre-trained models can be found [here](configs/rtmdet). +We are excited to announce our latest work on real-time object recognition tasks, **RTMDet**, a family of fully convolutional single-stage detectors. RTMDet not only achieves the best parameter-accuracy trade-off on object detection from tiny to extra-large model sizes but also obtains new state-of-the-art performance on instance segmentation and rotated object detection tasks. Details can be found in the [technical report](https://arxiv.org/abs/2212.07784). Pre-trained models are [here](configs/rtmdet). [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) @@ -205,6 +206,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
  • Deformable DETR (ICLR'2021)
  • TOOD (ICCV'2021)
  • DDOD (ACM MM'2021)
  • +
  • RTMDet (ArXiv'2022)
  • @@ -224,6 +226,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
  • Mask2Former (ArXiv'2021)
  • CondInst (ECCV 2020)
  • SparseInst (CVPR 2022)
  • +
  • RTMDet (ArXiv'2022)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index fb53f2da004..a8359cfcefc 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -67,6 +67,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope - **性能高** MMDetection 这个算法库源自于 COCO 2018 目标检测竞赛的冠军团队 *MMDet* 团队开发的代码,我们在之后持续进行了改进和提升。 + 新发布的 [RTMDet](configs/rtmdet) 还在实时实例分割和旋转目标检测任务中取得了最先进的成果,同时也在目标检测模型中取得了最佳的的参数量和精度平衡。 @@ -76,7 +77,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope ### 亮点 -我们很高兴向大家介绍我们在实时目标识别任务方面的最新成果 RTMDet,包含了一系列的全卷积单阶段检测模型。 RTMDet 不仅在从tiny到extra-large尺寸的目标检测模型上上实现了最佳的参数量和精度的平衡,而且在实时实例分割和旋转目标检测任务上取得了最先进的成果。 我们的[技术报告](https://arxiv.org/abs/2212.07784)中提供了其方法的详细说明。 预训练模型可以在[这里](configs/rtmdet)找到。 +我们很高兴向大家介绍我们在实时目标识别任务方面的最新成果 RTMDet,包含了一系列的全卷积单阶段检测模型。 RTMDet 不仅在从 tiny 到 extra-large 尺寸的目标检测模型上上实现了最佳的参数量和精度的平衡,而且在实时实例分割和旋转目标检测任务上取得了最先进的成果。 更多细节请参阅[技术报告](https://arxiv.org/abs/2212.07784)。 预训练模型可以在[这里](configs/rtmdet)找到。 [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/real-time-instance-segmentation-on-mscoco)](https://paperswithcode.com/sota/real-time-instance-segmentation-on-mscoco?p=rtmdet-an-empirical-study-of-designing-real) [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/rtmdet-an-empirical-study-of-designing-real/object-detection-in-aerial-images-on-dota-1)](https://paperswithcode.com/sota/object-detection-in-aerial-images-on-dota-1?p=rtmdet-an-empirical-study-of-designing-real) @@ -206,6 +207,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
  • Deformable DETR (ICLR'2021)
  • TOOD (ICCV'2021)
  • DDOD (ACM MM'2021)
  • +
  • RTMDet (ArXiv'2022)
  • @@ -225,6 +227,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
  • Mask2Former (ArXiv'2021)
  • CondInst (ECCV 2020)
  • SparseInst (CVPR 2022)
  • +
  • RTMDet (ArXiv'2022)