Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Official implementation of SETR #531

Merged
merged 88 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
88 commits
Select commit Hold shift + click to select a range
da9573b
Adjust vision transformer backbone architectures;
Apr 28, 2021
f9c8420
Merge Master
Apr 28, 2021
2f580e9
Fix some parameters loss bug;
Apr 28, 2021
e1d59cd
* Store intermediate token features and impose no processes on them;
Apr 29, 2021
e4a11e7
Fix some doc error
Apr 29, 2021
d5644c1
Add a arg for VisionTransformer backbone to control if input class to…
Apr 29, 2021
70cde52
Add stochastic depth decay rule for DropPath;
Apr 29, 2021
97af059
* Fix output bug when input_cls_token=False;
Apr 29, 2021
16e21ca
Re-implement of SETR
Apr 29, 2021
53e7e80
* Modify some docs of heads of SETR;
Apr 29, 2021
995a728
* Modify some arg of setr heads;
Apr 30, 2021
7ff896d
Merge branch 'master' into setr
May 2, 2021
b3d5258
Merge Master
May 2, 2021
8d18d86
* Add 768x768 cityscapes dataset config;
May 2, 2021
2c93446
* Fix the low code coverage of unit test about heads of setr;
May 3, 2021
efe6913
* Add pascal context dataset & ade20k dataset config;
May 3, 2021
4b2fd5e
Modify folder structure.
May 3, 2021
0ae504d
add setr
CuttlefishXuan May 3, 2021
1377131
modify vit
CuttlefishXuan May 3, 2021
d47e10c
Fix the test_cfg arg position;
May 5, 2021
27d1479
Fix some learning schedule bug;
May 5, 2021
f70b315
optimize setr code
May 5, 2021
be0d2fb
Add arg: final_reshape to control if converting output feature inform…
May 6, 2021
49130ce
Fix the default value of final_reshape;
May 6, 2021
b83992d
Merge branch 'vit_final_reshape' into setr
May 6, 2021
f7052f9
Modify arg: final_reshape to arg: out_shape;
May 6, 2021
c5858f5
Fix some unit test bug;
May 6, 2021
c8ea16a
Merge branch 'vit_final_reshape' into setr
May 6, 2021
0599c71
Add MLA neck;
May 6, 2021
7eda50b
Merge pr #526
May 6, 2021
0d92194
Remove some rebundant files.
May 6, 2021
7040840
* Fix the code style bug;
May 6, 2021
851366a
Ignoring CityscapesCoarseDataset and MapillaryDataset.
May 6, 2021
c424a35
Fix the activation function loss bug;
May 7, 2021
a40ed61
Fix the img_size bug of SETR_PUP_ADE20K
May 8, 2021
ad2ca50
Merge Master
May 8, 2021
8629d60
Merge Master
May 10, 2021
4f24d57
Merge branch 'setr_official' of github.com:sennnnn/mmsegmentation int…
May 10, 2021
5ab2b35
* Fix the lint bug of transformers.py;
May 10, 2021
a70621c
Convert vit of setr out shape from NLC to NCHW.
May 10, 2021
4161634
* Modify Resize action of data pipeline;
May 10, 2021
a5f8c1f
Remove arg: find_unused_parameters which is False by default.
May 11, 2021
0636d9e
Error auxiliary head of PUP deit
May 11, 2021
125c1ee
Remove the minimal restrict of slide inference.
May 11, 2021
760d0c5
Modify doc string of Resize
May 11, 2021
45f3df3
Seperate this part of code to a new PR #544
May 11, 2021
16c5fe5
* Remove some rebundant codes;
May 11, 2021
81410a1
Fix the tuple in_channels of mla_deit.
May 11, 2021
1d92146
Merge branch 'master' into setr_official
May 17, 2021
cdc6d30
Modify code style
May 17, 2021
3d39112
Modify implementation of SETR Heads
May 17, 2021
9b80384
non-square input support for setr heads
May 17, 2021
03eb097
Modify config argument for above commits
May 17, 2021
cb50538
Remove norm_layer argument of SETRMLAHead
May 17, 2021
2115d66
Add mla_align_corners for MLAModule interpolate
May 17, 2021
9ec4c9e
[Refactor]Refactor of SETRMLAHead
May 20, 2021
104ff0c
[Refactor]MLA Neck
May 20, 2021
d2b0107
Fix config bug
May 20, 2021
949cb65
[Refactor]SETR Naive Head and SETR PUP Head
May 21, 2021
3962975
[Fix]Fix the lack of arg: act_cfg and arg: norm_cfg
May 21, 2021
3d1bef5
Merge branch 'master' into setr_official
May 21, 2021
90b9b3d
Fix config error
May 21, 2021
5a3b376
Refactor of SETR MLA, Naive, PUP heads.
May 24, 2021
8f7c141
Modify some attribute name of SETR Heads.
May 24, 2021
8bfc651
Merge Master
Jun 18, 2021
c9c4284
Modify setr configs to adapt new vit code.
Jun 18, 2021
f45ca2d
Fix trunc_normal_ bug
Jun 18, 2021
e8fd36b
Parameters init adjustment.
Jun 18, 2021
1090fb3
Remove redundant doc string of SETRUPHead
Jun 18, 2021
39c6070
Fix pretrained bug
Jun 18, 2021
07ffcd2
[Fix] Fix vit init bug
Jun 18, 2021
4c22dff
Add some vit unit tests
Jun 18, 2021
bc0bcdd
Modify module import
Jun 18, 2021
e33a1c5
Remove norm from PatchEmbed
Jun 19, 2021
16f1fab
Fix pretrain weights bug
Jun 19, 2021
9acbf5a
Modify pretrained judge
Jun 19, 2021
1f4b5b8
Update vit init
Jun 19, 2021
bb294e0
Fix some gradient backward bugs.
Jun 19, 2021
bf71f60
Add some unit tests to improve code cov
Jun 19, 2021
37e65b3
Merge branch 'vit_init_refactor' into setr_official
Jun 19, 2021
a19b5d8
Fix init_weights of setr up head
Jun 19, 2021
a8f609f
Merge master
Jun 21, 2021
8409843
Add DropPath in FFN
Jun 21, 2021
07e38fa
Finish benchmark of SETR
Jun 22, 2021
399e053
Remove DropPath implementation and use DropPath from mmcv.
Jun 22, 2021
5ff95f1
Modify out_indices arg
Jun 22, 2021
23f734a
Fix out_indices bug.
Jun 23, 2021
82d4455
Remove cityscapes base dataset config.
Jun 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions configs/_base_/datasets/cityscapes_768x768.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
_base_ = './cityscapes.py'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may also benchmark the default cityscapes.py config.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may first benchmark 2-4 configs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already add

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (768, 768)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)),
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2049, 1025),
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))
90 changes: 90 additions & 0 deletions configs/_base_/models/setr_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),
patch_size=16,
in_channels=3,
embed_dim=1024,
depth=24,
num_heads=16,
out_indices=(5, 11, 17, 23),
drop_rate=0.1,
norm_cfg=backbone_norm_cfg,
out_shape='NCHW',
with_cls_token=False,
interpolate_mode='bilinear',
),
neck=dict(
type='MLA',
in_channels=[1024, 1024, 1024, 1024],
out_channels=256,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU'),
),
decode_head=dict(
type='SETRMLAHead',
in_channels=(1024, 1024, 1024, 1024),
channels=512,
in_index=(0, 1, 2, 3),
img_size=(768, 768),
mla_channels=256,
mlahead_channels=128,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=0,
img_size=(768, 768),
mla_channels=256,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=1,
img_size=(768, 768),
mla_channels=256,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=2,
img_size=(768, 768),
mla_channels=256,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=3,
img_size=(768, 768),
mla_channels=256,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
],
train_cfg=dict(),
test_cfg=dict(mode='whole'))
90 changes: 90 additions & 0 deletions configs/_base_/models/setr_naive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),
patch_size=16,
in_channels=3,
embed_dim=1024,
depth=24,
num_heads=16,
out_indices=(9, 14, 19, 23),
drop_rate=0.1,
norm_cfg=backbone_norm_cfg,
out_shape='NCHW',
with_cls_token=True,
interpolate_mode='bilinear',
),
decode_head=dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=3,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=1,
conv3x3_conv1x1=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=0,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=1,
conv3x3_conv1x1=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=1,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=1,
conv3x3_conv1x1=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=2,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=1,
conv3x3_conv1x1=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
],
train_cfg=dict(),
test_cfg=dict(mode='whole'))
106 changes: 106 additions & 0 deletions configs/_base_/models/setr_pup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# model settings
backbone_norm_cfg = dict(type='LN', eps=1e-6, requires_grad=True)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=\
'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(768, 768),
patch_size=16,
in_channels=3,
embed_dim=1024,
depth=24,
num_heads=16,
out_indices=(9, 14, 19, 23),
drop_rate=0.1,
norm_cfg=backbone_norm_cfg,
out_shape='NCHW',
with_cls_token=True,
interpolate_mode='bilinear',
),
decode_head=dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=3,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=4,
up_mode='bilinear',
num_up_layer=4,
conv3x3_conv1x1=True,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=0,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=2,
conv3x3_conv1x1=True,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=1,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=2,
conv3x3_conv1x1=True,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=2,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=2,
conv3x3_conv1x1=True,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRUPHead',
in_channels=1024,
channels=512,
in_index=3,
img_size=(768, 768),
embed_dim=1024,
num_classes=19,
norm_cfg=norm_cfg,
num_convs=2,
up_mode='bilinear',
num_up_layer=2,
conv3x3_conv1x1=True,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
],
train_cfg=dict(),
test_cfg=dict(mode='whole'))
3 changes: 3 additions & 0 deletions configs/setr/setr_mla_480x480_80k_pascal_context_bs_16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = ['./setr_mla_480x480_80k_pascal_context_bs_8.py']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add this in future PR.


data = dict(samples_per_gpu=2)
62 changes: 62 additions & 0 deletions configs/setr/setr_mla_480x480_80k_pascal_context_bs_8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
_base_ = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may add this in future PR.

'../_base_/models/setr_mla.py', '../_base_/datasets/pascal_context.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
backbone=dict(img_size=(480, 480), drop_rate=0),
decode_head=dict(img_size=(480, 480), num_classes=60),
auxiliary_head=[
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=0,
img_size=(480, 480),
mla_channels=256,
num_classes=60,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=1,
img_size=(480, 480),
mla_channels=256,
num_classes=60,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=2,
img_size=(480, 480),
mla_channels=256,
num_classes=60,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
dict(
type='SETRMLAAUXHead',
in_channels=256,
channels=512,
in_index=3,
img_size=(480, 480),
mla_channels=256,
num_classes=19,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))
],
test_cfg=dict(mode='slide', crop_size=(480, 480), stride=(320, 320)))

optimizer = dict(
lr=0.001,
weight_decay=0.0,
paramwise_cfg=dict(custom_keys={'head': dict(lr_mult=10.)}))

data = dict(samples_per_gpu=1)
Loading