-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeformable-detr_r50_16xb2-100e_aodraw_srgb.py
114 lines (110 loc) · 3.69 KB
/
deformable-detr_r50_16xb2-100e_aodraw_srgb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
_base_ = [
'../_base_/datasets/aodraw_detection_srgb.py', '../_base_/default_runtime.py'
]
model = dict(
type='DeformableDETR',
num_queries=300,
num_feature_levels=4,
with_box_refine=False,
as_two_stage=False,
data_preprocessor=dict(
type='DetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=1),
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch',
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='ChannelMapper',
in_channels=[512, 1024, 2048],
kernel_size=1,
out_channels=256,
act_cfg=None,
norm_cfg=dict(type='GN', num_groups=32),
num_outs=4),
encoder=dict( # DeformableDetrTransformerEncoder
num_layers=6,
layer_cfg=dict( # DeformableDetrTransformerEncoderLayer
self_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1))),
decoder=dict( # DeformableDetrTransformerDecoder
num_layers=6,
return_intermediate=True,
layer_cfg=dict( # DeformableDetrTransformerDecoderLayer
self_attn_cfg=dict( # MultiheadAttention
embed_dims=256,
num_heads=8,
dropout=0.1,
batch_first=True),
cross_attn_cfg=dict( # MultiScaleDeformableAttention
embed_dims=256,
batch_first=True),
ffn_cfg=dict(
embed_dims=256, feedforward_channels=1024, ffn_drop=0.1)),
post_norm_cfg=None),
positional_encoding=dict(num_feats=128, normalize=True, offset=-0.5),
bbox_head=dict(
type='DeformableDETRHead',
num_classes=62,
sync_cls_avg_factor=True,
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=2.0),
loss_bbox=dict(type='L1Loss', loss_weight=5.0),
loss_iou=dict(type='GIoULoss', loss_weight=2.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(
type='HungarianAssigner',
match_costs=[
dict(type='FocalLossCost', weight=2.0),
dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'),
dict(type='IoUCost', iou_mode='giou', weight=2.0)
])),
test_cfg=dict(max_per_img=100))
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0001),
clip_grad=dict(max_norm=0.1, norm_type=2),
paramwise_cfg=dict(
custom_keys={
'backbone': dict(lr_mult=0.1),
'sampling_offsets': dict(lr_mult=0.1),
'reference_points': dict(lr_mult=0.1)
}))
# learning policy
max_epochs = 25
train_cfg = dict(
type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=max_epochs,
by_epoch=True,
milestones=[20],
gamma=0.1)
]
train_dataloader = dict(batch_size=2)
# NOTE: `auto_scale_lr` is for automatically scaling LR,
# USER SHOULD NOT CHANGE ITS VALUES.
# base_batch_size = (8 GPUs) x (2 samples per GPU)
auto_scale_lr = dict(base_batch_size=16)