-
Notifications
You must be signed in to change notification settings - Fork 2
/
align_detr-900q_4scale_1x_r50.py
174 lines (167 loc) · 5.73 KB
/
align_detr-900q_4scale_1x_r50.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
_base_ = [
'./_base_/datasets/data_re_aug_coco_detection.py',
'./_base_/default_runtime.py'
]
randomness=dict(seed=681328528)
model = dict(
type='HPRAlignDETR',
num_queries=900,
ckpt_backbone=False,
ckpt_neck=False,
aux_weights=[0.5,0.5],
with_box_refine=True,
as_two_stage=True,
use_dn=True,
data_preprocessor=dict(
type='MultiBranchDataPreprocessor',
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1)),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=0,
norm_cfg=dict(type='FrozenBN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='backbone_pth/backbone.pth')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
# AlignDETR: Add conv bias.
bias=True,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict(
num_layers=6,
num_cp=0,
layer_cfg=dict(
self_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048, # 1024 for DeformDETR
ffn_drop=0.0))), # 0.1 for DeformDETR
decoder=dict(
num_layers=6,
return_intermediate=True,
bbox_roi_extractor = dict(
type='SingleRoIExtractor',
finest_scale=56,
roi_layer=dict(
type='RoIAlign', output_size=7, sampling_ratio=2),
out_channels=256,
featmap_strides=[8, 16, 32, 64]),
layer_cfg=dict(
merge_method='learnable_channel_aware',
initial_weights=[1,1,1],
merge_dropout=0.,
dy_conv_cfg=dict(
in_channels=256,
feat_channels=64,
out_channels=256,
input_feat_shape=7,
act_cfg=dict(type='ReLU', inplace=True),
norm_cfg=dict(type='LN')),
regional_ca_cfg=dict(
sample_num=5,
embed_dims=256,
num_heads=8,
use_key_pos=True,
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=0.0,
temperature=20),
attn_drop=0.,
proj_drop=0.,
dropout_layer=dict(type='Dropout', drop_prob=0.),
init_cfg=None,
batch_first=True,
norm_cfg=dict(type='LN'),
act_cfg = dict(type='ReLU', inplace=True),), # 0.1 for DeformDETR
self_attn_cfg=dict(embed_dims=256, num_heads=8,
dropout=0.0), # 0.1 for DeformDETR
cross_attn_cfg=dict(embed_dims=256, num_levels=4,
dropout=0.0), # 0.1 for DeformDETR
ffn_cfg=dict(
embed_dims=256,
feedforward_channels=2048,
ffn_drop=0.0)),
post_norm_cfg=None),
positional_encoding=dict(
num_feats=128,
normalize=True,
offset=-0.5,
temperature=10000),
bbox_head=dict(
type='HPRAlignDETRHead',
all_layers_num_gt_repeat=[2, 2, 2, 2, 2, 1, 2],
alpha=0.25,
gamma=2.0,
tau=1.5,
num_classes=80,
sync_cls_avg_factor=True,
loss_cls=dict(
type='CrossEntropyLoss', use_sigmoid=True,
loss_weight=1.0), # 2.0 in DeformDETR
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
dn_cfg=dict( # TODO: Move to model.train_cfg ?
label_noise_scale=0.5,
box_noise_scale=1.0, # 0.4 for DN-DETR
group_cfg=dict(dynamic=True, num_groups=None,
num_dn_queries=100)), # TODO: half num_dn_queries
# training and testing settings
train_cfg=dict(
assigner=dict(
type='MixedHungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=300)) # 100 for DeformDETR
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(
type='AdamW',
lr=0.0001, # 0.0002 for DeformDETR
weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})
) # custom_keys contains sampling_offsets and reference_points in DeformDETR # noqa
# learning policy
max_epochs = 12
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[11],
gamma=0.1)
]
vis_backends = [dict(type='LocalVisBackend'),
dict(type='TensorboardVisBackend')]
visualizer = dict(
type='DetLocalVisualizer',
vis_backends=vis_backends,
name='visualizer')
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)