Skip to content

Commit

Permalink
CenterNet(Objects as points) (#4602)
Browse files Browse the repository at this point in the history
* flip test ok

* clean commit files

* rename

* merge_master and update

* fix wrong merge

* add dla34 base

* add dla34

* fix details

* fix head and add config

* update init and config

* update config

* flip test

* Fix error merge

* clean model code

* refactor loss

* rename loss_heatmap to loss_center_heatmap

* Fix avg_factor

* update dlanet

* add find_unused_parameters

* Fix warmup_ratio

* add unittest

* add centernet head docstr.

* reduce workers_per_gpu to 4

* fix name error

* recode post process

* fix flip test

* fix flip test error

* Refactor post process

* Remove DLANet

* Remove resnet_ct.py

* Add README.md

* Add centernet_resnet18_140e_coco.py

* Update README.md

* Fix comment and docstr

* Fix unittest

* Fix name and docstr

* rename neck

* fix comments

* fix comments

* Stabilize the training process

* use default conv init weights

* fp16 enabled

* udpate readme and add comment

* add comment

Co-authored-by: hhaAndroid <1286304229@qq.com>
  • Loading branch information
kellenf and hhaAndroid authored May 21, 2021
1 parent e8b9dc7 commit 3d91b8b
Show file tree
Hide file tree
Showing 16 changed files with 1,011 additions and 98 deletions.
30 changes: 30 additions & 0 deletions configs/centernet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# CenterNet

## Introduction

<!-- [ALGORITHM] -->

```latex
@article{zhou2019objects,
title={Objects as Points},
author={Zhou, Xingyi and Wang, Dequan and Kr{\"a}henb{\"u}hl, Philipp},
booktitle={arXiv preprint arXiv:1904.07850},
year={2019}
}
```

## Results and models

| Backbone | DCN | Mem (GB) | Box AP | Flip box AP| Config | Download |
| :-------------: | :--------: |:----------------: | :------: | :------------: | :----: | :----: |
| ResNet-18 | N | 3.45 | 26.0 | 27.4 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/centernet/centernet_resnet18_140e_coco.py) | [model](http://download.openmmlab.com/mmdetection/v2.0/centernet/centernet_resnet18_140e_coco/centernet_resnet18_140e_coco_20210519_092334-eafe8ccd.pth) &#124; [log](http://download.openmmlab.com/mmdetection/v2.0/centernet/centernet_resnet18_140e_coco/centernet_resnet18_140e_coco_20210519_092334.log.json) |
| ResNet-18 | Y | 3.47 | 29.5 | 31.0 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/centernet/centernet_resnet18_dcnv2_140e_coco.py) | [model](http://download.openmmlab.com/mmdetection/v2.0/centernet/centernet_resnet18_dcnv2_140e_coco/centernet_resnet18_dcnv2_140e_coco_20210520_101209-da388ba2.pth) &#124; [log](http://download.openmmlab.com/mmdetection/v2.0/centernet/centernet_resnet18_dcnv2_140e_coco/centernet_resnet18_dcnv2_140e_coco_20210520_101209.log.json) |

Note:

- Flip box AP setting is single-scale and `flip=True`.
- Due to complex data enhancement, we find that the performance is unstable and may fluctuate by about 0.4 mAP. mAP 29.4 ~ 29.8 is acceptable in ResNet-18-DCNv2.
- Compared to the source code, we refer to [CenterNet-Better](https://github.com/FateScript/CenterNet-better), and make the following changes
- fix wrong image mean and variance in image normalization to be compatible with the pre-trained backbone.
- Use SGD rather than ADAM optimizer and add warmup and grad clip.
- Use DistributedDataParallel as other models in MMDetection rather than using DataParallel.
3 changes: 3 additions & 0 deletions configs/centernet/centernet_resnet18_140e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = './centernet_resnet18_dcnv2_140e_coco.py'

model = dict(neck=dict(use_dcn=False))
110 changes: 110 additions & 0 deletions configs/centernet/centernet_resnet18_dcnv2_140e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]

model = dict(
type='CenterNet',
pretrained='torchvision://resnet18',
backbone=dict(
type='ResNet', depth=18, norm_eval=False, norm_cfg=dict(type='BN')),
neck=dict(
type='CTResNetNeck',
in_channel=512,
num_deconv_filters=(256, 128, 64),
num_deconv_kernels=(4, 4, 4),
use_dcn=True),
bbox_head=dict(
type='CenterNetHead',
num_classes=80,
in_channel=64,
feat_channel=64,
loss_center_heatmap=dict(type='GaussianFocalLoss', loss_weight=1.0),
loss_wh=dict(type='L1Loss', loss_weight=0.1),
loss_offset=dict(type='L1Loss', loss_weight=1.0)),
train_cfg=None,
test_cfg=dict(topk=100, local_maximum_kernel=3, max_per_img=100))

# We fixed the incorrect img_norm_cfg problem in the source code.
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type='LoadImageFromFile', to_float32=True, color_type='color'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='PhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(
type='RandomCenterCropPad',
crop_size=(512, 512),
ratios=(0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3),
mean=[0, 0, 0],
std=[1, 1, 1],
to_rgb=True,
test_pad_mode=None),
dict(type='Resize', img_scale=(512, 512), keep_ratio=True),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
dict(type='LoadImageFromFile', to_float32=True),
dict(
type='MultiScaleFlipAug',
scale_factor=1.0,
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(
type='RandomCenterCropPad',
ratios=None,
border=None,
mean=[0, 0, 0],
std=[1, 1, 1],
to_rgb=True,
test_mode=True,
test_pad_mode=['logical_or', 31],
test_pad_add_pix=1),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='DefaultFormatBundle'),
dict(
type='Collect',
meta_keys=('filename', 'ori_shape', 'img_shape', 'pad_shape',
'scale_factor', 'flip', 'flip_direction',
'img_norm_cfg', 'border'),
keys=['img'])
])
]
data = dict(
samples_per_gpu=16,
workers_per_gpu=4,
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

# optimizer
# Based on the default settings of modern detectors, the SGD effect is better
# than the Adam in the source code, so we use SGD default settings and
# if you use adam+lr5e-4, the map is 29.1.
optimizer_config = dict(
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))

# learning policy
# Based on the default settings of modern detectors, we added warmup settings.
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=1000,
warmup_ratio=1.0 / 1000,
step=[90, 120])
runner = dict(max_epochs=140)

# Avoid evaluation and saving weights too frequently
evaluation = dict(interval=5, metric='bbox')
checkpoint_config = dict(interval=5)
4 changes: 2 additions & 2 deletions mmdet/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .dist_utils import DistOptimizerHook, allreduce_grads, reduce_mean
from .misc import mask2ndarray, multi_apply, unmap
from .misc import flip_tensor, mask2ndarray, multi_apply, unmap

__all__ = [
'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
'unmap', 'mask2ndarray'
'unmap', 'mask2ndarray', 'flip_tensor'
]
23 changes: 23 additions & 0 deletions mmdet/core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,26 @@ def mask2ndarray(mask):
elif not isinstance(mask, np.ndarray):
raise TypeError(f'Unsupported {type(mask)} data type')
return mask


def flip_tensor(src_tensor, flip_direction):
"""flip tensor base on flip_direction.
Args:
src_tensor (Tensor): input feature map, shape (B, C, H, W).
flip_direction (str): The flipping direction. Options are
'horizontal', 'vertical', 'diagonal'.
Returns:
out_tensor (Tensor): Flipped tensor.
"""
assert src_tensor.ndim == 4
valid_directions = ['horizontal', 'vertical', 'diagonal']
assert flip_direction in valid_directions
if flip_direction == 'horizontal':
out_tensor = torch.flip(src_tensor, [3])
elif flip_direction == 'vertical':
out_tensor = torch.flip(src_tensor, [2])
else:
out_tensor = torch.flip(src_tensor, [2, 3])
return out_tensor
8 changes: 6 additions & 2 deletions mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,7 @@ class RandomCenterCropPad:
- 'logical_or': final_shape = input_shape | padding_shape_value
- 'size_divisor': final_shape = int(
ceil(input_shape / padding_shape_value) * padding_shape_value)
test_pad_add_pix (int): Extra padding pixel in test mode. Default 0.
bbox_clip_border (bool, optional): Whether clip the objects outside
the border of the image. Defaults to True.
"""
Expand All @@ -1579,6 +1580,7 @@ def __init__(self,
to_rgb=None,
test_mode=False,
test_pad_mode=('logical_or', 127),
test_pad_add_pix=0,
bbox_clip_border=True):
if test_mode:
assert crop_size is None, 'crop_size must be None in test mode'
Expand Down Expand Up @@ -1612,6 +1614,7 @@ def __init__(self,
self.std = std
self.test_mode = test_mode
self.test_pad_mode = test_pad_mode
self.test_pad_add_pix = test_pad_add_pix
self.bbox_clip_border = bbox_clip_border

def _get_border(self, border, size):
Expand Down Expand Up @@ -1783,8 +1786,9 @@ def _test_aug(self, results):
h, w, c = img.shape
results['img_shape'] = img.shape
if self.test_pad_mode[0] in ['logical_or']:
target_h = h | self.test_pad_mode[1]
target_w = w | self.test_pad_mode[1]
# self.test_pad_add_pix is only used for centernet
target_h = (h | self.test_pad_mode[1]) + self.test_pad_add_pix
target_w = (w | self.test_pad_mode[1]) + self.test_pad_add_pix
elif self.test_pad_mode[0] in ['size_divisor']:
divisor = self.test_pad_mode[1]
target_h = int(np.ceil(h / divisor)) * divisor
Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .atss_head import ATSSHead
from .autoassign_head import AutoAssignHead
from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
from .centernet_head import CenterNetHead
from .centripetal_head import CentripetalHead
from .corner_head import CornerHead
from .deformable_detr_head import DeformableDETRHead
Expand Down Expand Up @@ -41,5 +42,6 @@
'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'StageCascadeRPNHead',
'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead', 'CascadeRPNHead',
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead'
'AutoAssignHead', 'DETRHead', 'YOLOFHead', 'DeformableDETRHead',
'CenterNetHead'
]
Loading

0 comments on commit 3d91b8b

Please sign in to comment.