Skip to content

Commit

Permalink
Merge branch 'upstream/dev' of into maskformer
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo committed Feb 22, 2022
2 parents 6a6bd29 + 0b205c6 commit bd61a83
Show file tree
Hide file tree
Showing 19 changed files with 622 additions and 23 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
<li><a href="configs/carafe">CARAFE (ICCV'2019)</a></li>
<li><a href="configs/fpg">FPG (ArXiv'2020)</a></li>
<li><a href="configs/groie">GRoIE (ICPR'2020)</a></li>
<li><a href="configs/dyhead">DyHead (CVPR'2021)</a></li>
</ul>
</td>
<td>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
<li><a href="configs/carafe">CARAFE (ICCV'2019)</a></li>
<li><a href="configs/fpg">FPG (ArXiv'2020)</a></li>
<li><a href="configs/groie">GRoIE (ICPR'2020)</a></li>
<li><a href="configs/dyhead">DyHead (CVPR'2021)</a></li>
</ul>
</td>
<td>
Expand Down
46 changes: 46 additions & 0 deletions configs/dyhead/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# DyHead

> [Dynamic Head: Unifying Object Detection Heads with Attentions](https://arxiv.org/abs/2106.08322)
<!-- [ALGORITHM] -->

## Abstract

The complex nature of combining localization and classification in object detection has resulted in the flourished development of methods. Previous works tried to improve the performance in various object detection heads but failed to present a unified view. In this paper, we present a novel dynamic head framework to unify object detection heads with attentions. By coherently combining multiple self-attention mechanisms between feature levels for scale-awareness, among spatial locations for spatial-awareness, and within output channels for task-awareness, the proposed approach significantly improves the representation ability of object detection heads without any computational overhead. Further experiments demonstrate that the effectiveness and efficiency of the proposed dynamic head on the COCO benchmark. With a standard ResNeXt-101-DCN backbone, we largely improve the performance over popular object detectors and achieve a new state-of-the-art at 54.0 AP. Furthermore, with latest transformer backbone and extra data, we can push current best COCO result to a new record at 60.6 AP.

<div align=center>
<img src="https://user-images.githubusercontent.com/42844407/149169448-fcafb6d0-b866-41cc-9422-94de9f1e1761.png" height="300"/>
</div>

## Results and Models

| Method | Backbone | Style | Setting | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
|:------:|:--------:|:-------:|:------------:|:-------:|:--------:|:--------------:|:------:|:------:|:--------:|
| ATSS | R-50 | caffe | reproduction | 1x | 5.4 | 13.2 | 42.5 | [config](./atss_r50_caffe_fpn_dyhead_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939-162888e6.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939.log.json) |
| ATSS | R-50 | pytorch | simple | 1x | 4.9 | 13.7 | 43.3 | [config](./atss_r50_fpn_dyhead_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314.log.json) |

- We trained the above models with 4 GPUs and 4 `samples_per_gpu`.
- The `reproduction` setting aims to reproduce the official implementation based on Detectron2.
- The `simple` setting serves as a minimum example to use DyHead in MMDetection. Specifically,
- it adds `DyHead` to `neck` after `FPN`
- it sets `stacked_convs=0` to `bbox_head`
- The `simple` setting achieves higher AP than the original implementation.
We have not conduct ablation study between the two settings.
`dict(type='Pad', size_divisor=128)` may further improve AP by prefer spatial alignment across pyramid levels, although large padding reduces efficiency.

## Relation to Other Methods

- DyHead can be regarded as an improved [SEPC](https://arxiv.org/abs/2005.03101) with [DyReLU modules](https://arxiv.org/abs/2003.10027) and simplified [SE blocks](https://arxiv.org/abs/1709.01507).
- Xiyang Dai et al., the author team of DyHead, adopt it for [Dynamic DETR](https://openaccess.thecvf.com/content/ICCV2021/html/Dai_Dynamic_DETR_End-to-End_Object_Detection_With_Dynamic_Attention_ICCV_2021_paper.html).
The description of Dynamic Encoder in Sec. 3.2 will help you understand DyHead.

## Citation

```latex
@inproceedings{DyHead_CVPR2021,
author = {Dai, Xiyang and Chen, Yinpeng and Xiao, Bin and Chen, Dongdong and Liu, Mengchen and Yuan, Lu and Zhang, Lei},
title = {Dynamic Head: Unifying Object Detection Heads With Attentions},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2021}
}
```
112 changes: 112 additions & 0 deletions configs/dyhead/atss_r50_caffe_fpn_dyhead_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='ATSS',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
neck=[
dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
dict(
type='DyHead',
in_channels=256,
out_channels=256,
num_blocks=6,
# disable zero_init_offset to follow official implementation
zero_init_offset=False)
],
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
pred_kernel_size=1, # follow DyHead official implementation
stacked_convs=0,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128],
center_offset=0.5), # follow DyHead official implementation
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)

# use caffe img_norm, size_divisor=128, pillow resize
img_norm_cfg = dict(
mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=(1333, 800),
keep_ratio=True,
backend='pillow'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=128),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(1333, 800),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True, backend='pillow'),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=128),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
65 changes: 65 additions & 0 deletions configs/dyhead/atss_r50_fpn_dyhead_1x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
]
model = dict(
type='ATSS',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=[
dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
add_extra_convs='on_output',
num_outs=5),
dict(type='DyHead', in_channels=256, out_channels=256, num_blocks=6)
],
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
stacked_convs=0,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128]),
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
63 changes: 63 additions & 0 deletions configs/dyhead/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
Collections:
- Name: DyHead
Metadata:
Training Data: COCO
Training Techniques:
- SGD with Momentum
- Weight Decay
Training Resources: 4x T4 GPUs
Architecture:
- ATSS
- DyHead
- FPN
- ResNet
- Deformable Convolution
- Pyramid Convolution
Paper:
URL: https://arxiv.org/abs/2106.08322
Title: 'Dynamic Head: Unifying Object Detection Heads with Attentions'
README: configs/dyhead/README.md
Code:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.22.0/mmdet/models/necks/dyhead.py#L130
Version: v2.22.0

Models:
- Name: atss_r50_caffe_fpn_dyhead_1x_coco
In Collection: DyHead
Config: configs/dyhead/atss_r50_caffe_fpn_dyhead_1x_coco.py
Metadata:
Training Memory (GB): 5.4
inference time (ms/im):
- value: 75.7
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (800, 1333)
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 42.5
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_for_reproduction_1x_coco/atss_r50_fpn_dyhead_for_reproduction_4x4_1x_coco_20220107_213939-162888e6.pth

- Name: atss_r50_fpn_dyhead_1x_coco
In Collection: DyHead
Config: configs/dyhead/atss_r50_fpn_dyhead_1x_coco.py
Metadata:
Training Memory (GB): 4.9
inference time (ms/im):
- value: 73.1
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (800, 1333)
Epochs: 12
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 43.3
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth
11 changes: 7 additions & 4 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ Or you can still install MMDetection manually:
# for LVIS dataset
pip install git+https://github.com/lvis-dataset/lvis-api.git
# for albumentations
pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
pip install -r requirements/albu.txt
```

**Note:**
Expand All @@ -155,9 +155,12 @@ you can install it before installing MMCV.
c. Some dependencies are optional. Simply running `pip install -v -e .` will
only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`.

d. If you would like to use `albumentations`, we suggest using
`pip install albumentations>=0.3.2 --no-binary imgaug,albumentations`. If you simply use
`pip install albumentations>=0.3.2`, it will install `opencv-python-headless` simultaneously (even though you have already installed `opencv-python`). We should not allow `opencv-python` and `opencv-python-headless` installed at the same time, because it might cause unexpected issues. Please refer to [official documentation](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies) for more details.
d. If you would like to use `albumentations`, we suggest using `pip install -r requirements/albu.txt` or
`pip install -U albumentations --no-binary qudida,albumentations`. If you simply use `pip install albumentations>=0.3.2`,
it will install `opencv-python-headless` simultaneously (even though you have already
installed `opencv-python`). We recommended checking the environment after installing `albumentation` to
ensure that `opencv-python` and `opencv-python-headless` are not installed at the same time, because it might cause unexpected issues if they both installed. Please refer
to [official documentation](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies) for more details.

### Install without GPU support

Expand Down
4 changes: 2 additions & 2 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ MIM 能够自动地安装 OpenMMLab 的项目以及对应的依赖包。
# 安装 LVIS 数据集依赖
pip install git+https://github.com/lvis-dataset/lvis-api.git
# 安装 albumentations 依赖
pip install albumentations>=0.3.2 --no-binary imgaug,albumentations
pip install -r requirements/albu.txt
```

**注意:**
Expand All @@ -160,7 +160,7 @@ MIM 能够自动地安装 OpenMMLab 的项目以及对应的依赖包。

(3) 一些安装依赖是可以选择的。例如只需要安装最低运行要求的版本,则可以使用 `pip install -v -e .` 命令。如果希望使用可选择的像 `albumentations``imagecorruptions` 这种依赖项,可以使用 `pip install -r requirements/optional.txt` 进行手动安装,或者在使用 `pip` 时指定所需的附加功能(例如 `pip install -v -e .[optional]`),支持附加功能的有效键值包括 `all``tests``build` 以及 `optional`

(4) 如果希望使用 `albumentations`,我们建议使用 `pip install albumentations>=0.3.2 --no-binary imgaug,albumentations` 进行安装。 如果简单地使用 `pip install albumentations>=0.3.2` 进行安装,则会同时安装 `opencv-python-headless`(即便已经安装了 `opencv-python` 也会再次安装)。我们不允许同时安装 `opencv-python``opencv-python-headless`因为这样可能会导致一些问题。更多细节请参考[官方文档](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies)。
(4) 如果希望使用 `albumentations`,我们建议使用 `pip install -r requirements/albu.txt` 或者 `pip install -U albumentations --no-binary qudida,albumentations` 进行安装。 如果简单地使用 `pip install albumentations>=0.3.2` 进行安装,则会同时安装 `opencv-python-headless`(即便已经安装了 `opencv-python` 也会再次安装)。我们建议在安装 `albumentations` 后检查环境,以确保没有同时安装 `opencv-python``opencv-python-headless`因为同时安装可能会导致一些问题。更多细节请参考[官方文档](https://albumentations.ai/docs/getting_started/installation/#note-on-opencv-dependencies)。

### 只在 CPU 安装

Expand Down
Loading

0 comments on commit bd61a83

Please sign in to comment.