Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add configs for vit backbone plus decode_heads #520

Merged
merged 39 commits into from
Jul 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1c66609
add config
Apr 26, 2021
a439a40
add cityscapes config
Apr 26, 2021
b7aa7ee
add default value to docstring
Apr 26, 2021
4e8383c
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
Apr 26, 2021
e7d6243
fix lint
Apr 26, 2021
0ce8958
add deit-s and deit-b
Apr 27, 2021
6b40465
add readme
Apr 28, 2021
6d0ab21
add eps at norm_cfg
Apr 29, 2021
c9c6596
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
May 6, 2021
f7b8c18
add drop_path_rate experiment
May 6, 2021
d605af6
add deit case at init_weight
May 6, 2021
425cac7
add upernet result
May 11, 2021
dd6856e
update result and add upernet 160k config
May 12, 2021
bd86b64
update upernet result and fix settings
May 17, 2021
3f8db05
Update iters number
May 17, 2021
4d77761
update result and delete some configs
May 18, 2021
656c167
fix import error
May 18, 2021
7d13836
fix drop_path_rate
May 18, 2021
8931d7f
update result and restore config
May 19, 2021
2219cf7
update benchmark result
May 21, 2021
7f1866e
remove cityscapes exp
May 22, 2021
845f8f5
merge upstream master
May 22, 2021
69cb384
remove neck
May 24, 2021
cebbf6f
neck exp
May 26, 2021
4c0e952
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
Jun 21, 2021
a90ed4a
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
Jun 21, 2021
ebc0531
add more configs
Jun 21, 2021
e7a6637
Merge branch 'vit_plus_heads' of https://github.com/xiexinch/mmsegmen…
Jun 21, 2021
319f56e
fix init error
Jun 21, 2021
c4aa7b6
fix ffn setting
Jun 21, 2021
8457e67
update result
Jun 22, 2021
b0f9e29
update results
Jun 23, 2021
74424aa
update result
Jun 25, 2021
8c11f62
update results and fill table
Jun 28, 2021
637356d
Merge branch 'master' of https://github.com/open-mmlab/mmsegmentation…
Jun 28, 2021
31ba78d
delete or rename configs
Jun 28, 2021
53ae35c
fix link delimiter
Jun 28, 2021
451ae11
rename configs and fix link
Jun 29, 2021
8db5cdd
rename neck to mln
Jul 1, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions configs/_base_/models/upernet_vit-b16_ln_mln.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', # noqa
backbone=dict(
type='VisionTransformer',
img_size=(512, 512),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(2, 5, 8, 11),
qkv_bias=True,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
with_cls_token=True,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
out_shape='NCHW',
interpolate_mode='bicubic'),
neck=dict(
type='MultiLevelNeck',
in_channels=[768, 768, 768, 768],
out_channels=768,
scales=[4, 2, 1, 0.5]),
decode_head=dict(
type='UPerHead',
in_channels=[768, 768, 768, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=768,
in_index=3,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole')) # yapf: disable
32 changes: 32 additions & 0 deletions configs/vit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Vision Transformer

## Introduction

<!-- [ALGORITHM] -->

```latex
@article{dosoViTskiy2020,
title={An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
author={DosoViTskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and Uszkoreit, Jakob and Houlsby, Neil},
journal={arXiv preprint arXiv:2010.11929},
year={2020}
}
```

## Results and models

### ADE20K

| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| UPerNet | ViT-B + MLN | 512x512 | 80000 | 9.20 | 6.94 | 47.71 | 49.51 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/upernet_vit-b16_mln_512x512_80k_ade20k-0403cee1.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_80k_ade20k/20210624_130547.log.json) |
| UPerNet | ViT-B + MLN | 512x512 | 160000 | 9.20 | 7.58 | 46.75 | 48.46 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_160k_ade20k/upernet_vit-b16_mln_512x512_160k_ade20k-852fa768.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_mln_512x512_160k_ade20k/20210623_192432.log.json) |
| UPerNet | ViT-B + LN + MLN | 512x512 | 160000 | 9.21 | 6.82 | 47.73 | 49.95 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k/upernet_vit-b16_ln_mln_512x512_160k_ade20k-f444c077.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k/20210621_172828.log.json) |
| UPerNet | DeiT-S | 512x512 | 80000 | 4.68 | 29.85 | 42.96 | 43.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_512x512_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_80k_ade20k/upernet_deit-s16_512x512_80k_ade20k-afc93ec2.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_80k_ade20k/20210624_095228.log.json) |
| UPerNet | DeiT-S | 512x512 | 160000 | 4.68 | 29.19 | 42.87 | 43.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_160k_ade20k/upernet_deit-s16_512x512_160k_ade20k-5110d916.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_512x512_160k_ade20k/20210621_160903.log.json) |
| UPerNet | DeiT-S + MLN | 512x512 | 160000 | 5.69 | 11.18 | 43.82 | 45.07 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_mln_512x512_160k_ade20k/upernet_deit-s16_mln_512x512_160k_ade20k-fb9a5dfb.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_mln_512x512_160k_ade20k/20210621_161021.log.json) |
| UPerNet | DeiT-S + LN + MLN | 512x512 | 160000 | 5.69 | 12.39 | 43.52 | 45.01 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k/upernet_deit-s16_ln_mln_512x512_160k_ade20k-c0cd652f.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k/20210621_161021.log.json) |
| UPerNet | DeiT-B | 512x512 | 80000 | 7.75 | 9.69 | 45.24 | 46.73 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_512x512_80k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_80k_ade20k/upernet_deit-b16_512x512_80k_ade20k-1e090789.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_80k_ade20k/20210624_130529.log.json) |
| UPerNet | DeiT-B | 512x512 | 160000 | 7.75 | 10.39 | 45.36 | 47.16 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_160k_ade20k/upernet_deit-b16_512x512_160k_ade20k-828705d7.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_512x512_160k_ade20k/20210621_180100.log.json) |
| UPerNet | DeiT-B + MLN | 512x512 | 160000 | 9.21 | 7.78 | 45.46 | 47.16 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_mln_512x512_160k_ade20k/upernet_deit-b16_mln_512x512_160k_ade20k-4e1450f3.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_mln_512x512_160k_ade20k/20210621_191949.log.json) |
| UPerNet | DeiT-B + LN + MLN | 512x512 | 160000 | 9.21 | 7.75 | 45.37 | 47.23 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k/upernet_deit-b16_ln_mln_512x512_160k_ade20k-8a959c14.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k/20210623_153535.log.json) |
6 changes: 6 additions & 0 deletions configs/vit/upernet_deit-b16_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable
6 changes: 6 additions & 0 deletions configs/vit/upernet_deit-b16_512x512_80k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),
neck=None) # yapf: disable
5 changes: 5 additions & 0 deletions configs/vit/upernet_deit-b16_ln_mln_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1, final_norm=True)) # yapf: disable
5 changes: 5 additions & 0 deletions configs/vit/upernet_deit-b16_mln_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', # noqa
backbone=dict(drop_path_rate=0.1),) # yapf: disable
8 changes: 8 additions & 0 deletions configs/vit/upernet_deit-s16_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
8 changes: 8 additions & 0 deletions configs/vit/upernet_deit-s16_512x512_80k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_80k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=None,
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
12 changes: 12 additions & 0 deletions configs/vit/upernet_deit-s16_ln_mln_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(
num_heads=6,
embed_dims=384,
drop_path_rate=0.1,
final_norm=True),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
8 changes: 8 additions & 0 deletions configs/vit/upernet_deit-s16_mln_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = './upernet_vit-b16_mln_512x512_160k_ade20k.py'

model = dict(
pretrained='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', # noqa
backbone=dict(num_heads=6, embed_dims=384, drop_path_rate=0.1),
decode_head=dict(num_classes=150, in_channels=[384, 384, 384, 384]),
neck=dict(in_channels=[384, 384, 384, 384], out_channels=384),
auxiliary_head=dict(num_classes=150, in_channels=384)) # yapf: disable
38 changes: 38 additions & 0 deletions configs/vit/upernet_vit-b16_ln_mln_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
_base_ = [
'../_base_/models/upernet_vit-b16_ln_mln.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]

model = dict(
backbone=dict(drop_path_rate=0.1, final_norm=True),
decode_head=dict(num_classes=150),
auxiliary_head=dict(num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
36 changes: 36 additions & 0 deletions configs/vit/upernet_vit-b16_mln_512x512_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
_base_ = [
'../_base_/models/upernet_vit-b16_ln_mln.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_160k.py'
]

model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
36 changes: 36 additions & 0 deletions configs/vit/upernet_vit-b16_mln_512x512_80k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
_base_ = [
'../_base_/models/upernet_vit-b16_ln_mln.py',
'../_base_/datasets/ade20k.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]

model = dict(
decode_head=dict(num_classes=150), auxiliary_head=dict(num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optimizer = dict(
_delete_=True,
type='AdamW',
lr=0.00006,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'pos_embed': dict(decay_mult=0.),
'cls_token': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
11 changes: 9 additions & 2 deletions mmseg/models/necks/multilevel_neck.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmcv.cnn import ConvModule, xavier_init

from ..builder import NECKS

Expand All @@ -13,7 +13,8 @@ class MultiLevelNeck(nn.Module):
Args:
in_channels (List[int]): Number of input channels per scale.
out_channels (int): Number of output channels (used at each scale).
scales (List[int]): Scale factors for each input feature map.
scales (List[float]): Scale factors for each input feature map.
Default: [0.5, 1, 2, 4]
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: None.
Expand Down Expand Up @@ -52,6 +53,12 @@ def __init__(self,
norm_cfg=norm_cfg,
act_cfg=act_cfg))

# default init_weights for conv(msra) and norm in ConvModule
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
xavier_init(m, distribution='uniform')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use xavier_init for Conv2d?
kaiming_init for ConvModule is used in MMCV.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followed FPN

def forward(self, inputs):
assert len(inputs) == len(self.in_channels)
inputs = [
Expand Down
3 changes: 3 additions & 0 deletions tests/test_models/test_necks/test_multilevel_neck.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

def test_multilevel_neck():

# Test init_weights
MultiLevelNeck([266], 256).init_weights()

# Test multi feature maps
in_channels = [256, 512, 1024, 2048]
inputs = [torch.randn(1, c, 14, 14) for i, c in enumerate(in_channels)]
Expand Down