Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Doc] Add Doc of Detection Transformers #9534

Open
wants to merge 26 commits into
base: dev-3.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
97e1dd9
[Refactor]: Refactor DETR and Deformable DETR (#8763)
Li-Qingyun Oct 20, 2022
b257a29
Add unitests for detr 3.x (#9089)
LYMDLUT Oct 24, 2022
f874d5c
[Refactor] Change to BNC data flow for all DETRs (#9460)
KeiChiTse Dec 11, 2022
899d4b8
Refactor detr 3.x conditional detr (#9405)
LYMDLUT Dec 15, 2022
6e7f1e6
[Feat] Add DINO on MMDetection 3.x (#9149)
Li-Qingyun Dec 19, 2022
d170810
[Refactor]: Refactor DAB-DETR in MMDetection 3.x (#9252)
KeiChiTse Dec 22, 2022
80ce8d2
doc: add detection_transformer.md
Li-Qingyun Dec 26, 2022
fef2947
modified by the pre-commit hook
Li-Qingyun Dec 26, 2022
ca14980
modify the first paragraph and section 1
Li-Qingyun Dec 27, 2022
0a18641
modify the first section and add an outline
Li-Qingyun Dec 29, 2022
7e0e960
add temp zh doc and temp outline
Li-Qingyun Dec 31, 2022
47705c3
add feature descriptions to CN doc
Li-Qingyun Dec 31, 2022
ef95dd5
complement feature descriptions to CN doc
Li-Qingyun Jan 3, 2023
6ce2be0
complement positional encoding and set prediction to CN doc
Li-Qingyun Jan 3, 2023
d192ad3
add "implement a DETR" to CN doc
Li-Qingyun Jan 3, 2023
07b1456
update `Positional embedding of DETRs` in EN doc
Li-Qingyun Jan 12, 2023
2c4bca6
update `Object detection paradigm of set prediction` in EN doc
Li-Qingyun Jan 12, 2023
8d5b9ec
update `Appointment`-`Parameter names` in EN doc
Li-Qingyun Jan 12, 2023
d79db48
update `Customize a DETR` in EN doc
Li-Qingyun Jan 12, 2023
0302723
update `Customize a DETR` in CN doc
Li-Qingyun Jan 13, 2023
81a8239
delete unnecessary basic knowledge in `Positional embedding of DETRs`…
Li-Qingyun Jan 13, 2023
e1f84e2
add fig of DETR_mlvl_feats2seq.png in both doc
Li-Qingyun Jan 13, 2023
58a630b
Add "unified data flow" doc
KeiChiTse Jan 13, 2023
ba45638
Update docs/en/advanced_guides/detection_transformer.md as RangeKing
Li-Qingyun Jan 16, 2023
2c0587a
add `Compatibility` to EN doc
Li-Qingyun Jan 17, 2023
77cdbd9
update `Compatibility` of EN doc
Li-Qingyun Jan 18, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions configs/conditional_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Conditional DETR

> [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152)

<!-- [ALGORITHM] -->

## Abstract

The DETR approach applies the transformer encoder and decoder architecture to object detection and achieves promising performance. In this paper, we handle the critical issue, slow training convergence, and present a conditional cross-attention mechanism for fast DETR training. Our approach is motivated by that the cross-attention in DETR relies highly on the content embeddings and that the spatial embeddings make minor contributions, increasing the need for high-quality content embeddings and thus increasing the training difficulty.

<div align=center>
<img src="https://github.com/Atten4Vis/ConditionalDETR/blob/main/.github/attention-maps.png?raw=true"/>
</div>

Our conditional DETR learns a conditional spatial query from the decoder embedding for decoder multi-head cross-attention. The benefit is that through the conditional spatial query, each cross-attention head is able to attend to a band containing a distinct region, e.g., one object extremity or a region inside the object box (Figure 1). This narrows down the spatial range for localizing the distinct regions for object classification and box regression, thus relaxing the dependence on the content embeddings and easing the training. Empirical results show that conditional DETR converges 6.7x faster for the backbones R50 and R101 and 10x faster for stronger backbones DC5-R50 and DC5-R101.

<div align=center>
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/conditional-detr.png" width="48%"/>
<img src="https://github.com/Atten4Vis/ConditionalDETR/raw/main/.github/convergence-curve.png" width="48%"/>
</div>

## Results and Models

We provide the config files and models for Conditional DETR: [Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152).

| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :--------------: | :-----: | :------: | :------------: | :----: | :-----------------------------------------------: | :----------------------------------: |
| R-50 | Conditional DETR | 50e | 7.9 | | 40.9 | [config](./conditional_detr_r50_8xb2-50e_coco.py) | \[model\](# TODO) \| \[log\](# TODO) |

## Citation

```latex
@inproceedings{meng2021-CondDETR,
title = {Conditional DETR for Fast Training Convergence},
author = {Meng, Depu and Chen, Xiaokang and Fan, Zejia and Zeng, Gang and Li, Houqiang and Yuan, Yuhui and Sun, Lei and Wang, Jingdong},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
year = {2021}
}
```
42 changes: 42 additions & 0 deletions configs/conditional_detr/conditional_detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
_base_ = ['../detr/detr_r50_8xb2-150e_coco.py']
model = dict(
type='ConditionalDETR',
num_queries=300,
decoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
_delete_=True,
embed_dims=256,
num_heads=8,
attn_drop=0.1,
cross_attn=False),
cross_attn_cfg=dict(
_delete_=True,
embed_dims=256,
num_heads=8,
attn_drop=0.1,
cross_attn=True))),
bbox_head=dict(
type='ConditionalDETRHead',
loss_cls=dict(
_delete_=True,
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])))

# learning policy
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=50, val_interval=1)

param_scheduler = [dict(type='MultiStepLR', end=50, milestones=[40])]
40 changes: 40 additions & 0 deletions configs/dab_detr/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# DAB-DETR

> [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329)

<!-- [ALGORITHM] -->

## Abstract

We present in this paper a novel query formulation using dynamic anchor boxes for DETR (DEtection TRansformer) and offer a deeper understanding of the role of queries in DETR. This new formulation directly uses box coordinates as queries in Transformer decoders and dynamically updates them layer-by-layer. Using box coordinates not only helps using explicit positional priors to improve the query-to-feature similarity and eliminate the slow training convergence issue in DETR, but also allows us to modulate the positional attention map using the box width and height information. Such a design makes it clear that queries in DETR can be implemented as performing soft ROI pooling layer-by-layer in a cascade manner. As a result, it leads to the best performance on MS-COCO benchmark among the DETR-like detection models under the same setting, e.g., AP 45.7% using ResNet50-DC5 as backbone trained in 50 epochs. We also conducted extensive experiments to confirm our analysis and verify the effectiveness of our methods.

<div align=center>
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/arch.png?raw=true"/>
</div>
<div align=center>
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/model.png?raw=true"/>
</div>
<div align=center>
<img src="https://github.com/IDEA-Research/DAB-DETR/blob/main/figure/results.png?raw=true"/>
</div>

## Results and Models

We provide the config files and models for DAB-DETR: [DAB-DETR: Dynamic Anchor Boxes are Better Queries for DETR](https://arxiv.org/abs/2201.12329).

| Backbone | Model | Lr schd | Mem (GB) | Inf time (fps) | box AP | Config | Download |
| :------: | :------: | :-----: | :------: | :------------: | :----: | :---------------------------------------: | :----------------------------------: |
| R-50 | DAB-DETR | 50e | 6.4 | | 42.3 | [config](./dab-detr_r50_8xb2-50e_coco.py) | \[model\](# TODO) \| \[log\](# TODO) |

## Citation

```latex
@inproceedings{
liu2022dabdetr,
title={{DAB}-{DETR}: Dynamic Anchor Boxes are Better Queries for {DETR}},
author={Shilong Liu and Feng Li and Hao Zhang and Xiao Yang and Xianbiao Qi and Hang Su and Jun Zhu and Lei Zhang},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=oMI9PjOb9Jl}
}
```
162 changes: 162 additions & 0 deletions configs/dab_detr/dab-detr_r50_8xb2-50e_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
_base_ = [
'../_base_/datasets/coco_detection.py', '../_base_/default_runtime.py'
]
model = dict(
type='DABDETR',
num_queries=300,
with_random_refpoints=False,
num_patterns=0,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(3, ),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=None,
num_outs=1),
encoder=dict(
num_layers=6,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256, num_heads=8, dropout=0., batch_first=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='PReLU')))),
decoder=dict(
num_layers=6,
query_dim=4,
query_scale_type='cond_elewise',
with_modulated_hw_attn=True,
layer_cfg=dict(
self_attn_cfg=dict(
embed_dims=256,
num_heads=8,
attn_drop=0.,
proj_drop=0.,
cross_attn=False),
cross_attn_cfg=dict(
embed_dims=256,
num_heads=8,
attn_drop=0.,
proj_drop=0.,
cross_attn=True),
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
num_fcs=2,
ffn_drop=0.,
act_cfg=dict(type='PReLU'))),
return_intermediate=True),
positional_encoding_cfg=dict(
num_feats=128, temperature=20, normalize=True),
bbox_head=dict(
type='DABDETRHead',
num_classes=80,
embed_dims=256,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2., eps=1e-8),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300))

# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different
# from the default setting in mmdet.
train_pipeline = [
dict(
type='LoadImageFromFile',
file_client_args={{_base_.file_client_args}}),
dict(type='LoadAnnotations', with_bbox=True),
dict(type='RandomFlip', prob=0.5),
dict(
type='RandomChoice',
transforms=[[
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
(608, 1333), (640, 1333), (672, 1333), (704, 1333),
(736, 1333), (768, 1333), (800, 1333)],
keep_ratio=True)
],
[
dict(
type='RandomChoiceResize',
scales=[(400, 1333), (500, 1333), (600, 1333)],
keep_ratio=True),
dict(
type='RandomCrop',
crop_type='absolute_range',
crop_size=(384, 600),
allow_negative_crop=True),
dict(
type='RandomChoiceResize',
scales=[(480, 1333), (512, 1333), (544, 1333),
(576, 1333), (608, 1333), (640, 1333),
(672, 1333), (704, 1333), (736, 1333),
(768, 1333), (800, 1333)],
keep_ratio=True)
]]),
dict(type='PackDetInputs')
]
train_dataloader = dict(dataset=dict(pipeline=train_pipeline))

# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))

# learning policy
max_epochs = 50
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[40],
gamma=0.1)
]

# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16, enable=False)
68 changes: 28 additions & 40 deletions configs/deformable_detr/deformable-detr_r50_16xb2-50e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
]
model = dict(
type='DeformableDETR',
num_queries=300,
num_feature_levels=4,
with_box_refine=False,
as_two_stage=False,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
Expand All @@ -27,50 +31,34 @@
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
decoder=dict( # DeformableDetrTransformerDecoder
num_layers=6,
return_intermediate=True,
layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
cross_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
post_norm_cfg=None),
positional_encoding_cfg=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type='DeformableDETRHead',
num_query=300,
num_classes=80,
in_channels=2048,
sync_cls_avg_factor=True,
as_two_stage=False,
transformer=dict(
type='DeformableDetrTransformer',
encoder=dict(
type='DetrTransformerEncoder',
num_layers=6,
transformerlayers=dict(
type='BaseTransformerLayer',
attn_cfgs=dict(
type='MultiScaleDeformableAttention', embed_dims=256),
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'ffn', 'norm'))),
decoder=dict(
type='DeformableDetrTransformerDecoder',
num_layers=6,
return_intermediate=True,
transformerlayers=dict(
type='DetrTransformerDecoderLayer',
attn_cfgs=[
dict(
type='MultiheadAttention',
embed_dims=256,
num_heads=8,
dropout=0.1),
dict(
type='MultiScaleDeformableAttention',
embed_dims=256)
],
feedforward_channels=1024,
ffn_dropout=0.1,
operation_order=('self_attn', 'norm', 'cross_attn', 'norm',
'ffn', 'norm')))),
positional_encoding=dict(
type='SinePositionalEncoding',
num_feats=128,
normalize=True,
offset=-0.5),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_r50_16xb2-50e_coco.py'
model = dict(bbox_head=dict(with_box_refine=True))
model = dict(with_box_refine=True)
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
_base_ = 'deformable-detr_refine_r50_16xb2-50e_coco.py'
model = dict(bbox_head=dict(as_two_stage=True))
model = dict(as_two_stage=True)
2 changes: 1 addition & 1 deletion configs/detr/detr_r18_8xb2-500e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
backbone=dict(
depth=18,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18')),
neck=dict(in_channels=[64, 128, 256, 512]))
neck=dict(in_channels=[512]))
Loading