Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add BEiT backbone #1404

Merged
merged 47 commits into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0d05d63
[Feature] Add BEiT backbone
linfangjian01 Mar 21, 2022
d226d41
fix
linfangjian01 Mar 21, 2022
e6c03bc
fix
linfangjian01 Mar 21, 2022
9e3f5b6
fix
linfangjian01 Mar 21, 2022
ba98cc3
fix
linfangjian01 Mar 21, 2022
a208802
Merge branch 'open-mmlab:master' into addbeit
linfangjian01 Mar 22, 2022
c0462b6
add readme
linfangjian01 Mar 22, 2022
c2c03e5
fix
linfangjian01 Mar 23, 2022
a4fff29
fix
linfangjian01 Mar 23, 2022
bdf5f77
fix
linfangjian01 Mar 23, 2022
ac7c52a
fix
linfangjian01 Mar 24, 2022
720285e
fix
linfangjian01 Mar 25, 2022
22d864e
add link
linfangjian01 Mar 25, 2022
0864be3
fix memory
linfangjian01 Mar 25, 2022
5b1b7b7
fix
linfangjian01 Mar 25, 2022
e55c3b1
fix
linfangjian01 Mar 25, 2022
ca488f4
fix
linfangjian01 Mar 25, 2022
b9cb639
fix
linfangjian01 Mar 25, 2022
468feb6
fix
linfangjian01 Mar 25, 2022
2dac77a
fix
linfangjian01 Mar 25, 2022
2158425
fix
linfangjian01 Mar 26, 2022
28c15ed
fix
linfangjian01 Mar 26, 2022
574b66f
fix
linfangjian01 Mar 26, 2022
0744f63
fix
linfangjian01 Mar 26, 2022
56a9d00
fix
linfangjian01 Mar 27, 2022
1cfee52
fix
linfangjian01 Mar 27, 2022
ba9b840
fix
linfangjian01 Mar 28, 2022
8230e4d
fix
linfangjian01 Mar 28, 2022
ef1a0e6
fix
linfangjian01 Mar 28, 2022
95151e8
fix
linfangjian01 Mar 28, 2022
decf9d2
fix
linfangjian01 Mar 28, 2022
4aafb7e
fix test_beit.py
linfangjian01 Mar 28, 2022
4ca514d
fix
linfangjian01 Mar 28, 2022
34339fa
fix
linfangjian01 Mar 28, 2022
f7dc33e
fix
linfangjian01 Mar 28, 2022
2e5a973
fix
linfangjian01 Mar 29, 2022
55d7ef2
Merge branch 'open-mmlab:master' into addbeit
linfangjian01 Mar 29, 2022
7ca0f7e
fix
linfangjian01 Mar 29, 2022
19b5e28
fix
linfangjian01 Mar 29, 2022
783dcfb
fix
linfangjian01 Mar 29, 2022
b556b69
fix
linfangjian01 Mar 29, 2022
aab4063
fix
linfangjian01 Mar 29, 2022
f161d0d
fix
linfangjian01 Mar 29, 2022
c48c8f0
fix
linfangjian01 Mar 29, 2022
9fa14a3
fix
linfangjian01 Mar 29, 2022
7bb7dd9
fix
linfangjian01 Mar 29, 2022
d080d84
fix
linfangjian01 Mar 30, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Supported backbones:
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [BEiT (ICLR'2022)](configs/beit)

Supported methods:

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [BEiT (ICLR'2022)](configs/beit)

已支持的算法:

Expand Down
50 changes: 50 additions & 0 deletions configs/_base_/models/upernet_beit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='BEiT',
img_size=(640, 640),
patch_size=16,
in_channels=3,
embed_dims=768,
num_layers=12,
num_heads=12,
mlp_ratio=4,
out_indices=(3, 5, 7, 11),
qv_bias=True,
attn_drop_rate=0.0,
drop_path_rate=0.1,
norm_cfg=dict(type='LN', eps=1e-6),
act_cfg=dict(type='GELU'),
norm_eval=False,
init_values=0.1),
neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
decode_head=dict(
type='UPerHead',
in_channels=[768, 768, 768, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=768,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=768,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=150,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
84 changes: 84 additions & 0 deletions configs/beit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# BEiT

[BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)

## Introduction

<!-- [BACKBONE] -->

<a href="https://github.com/microsoft/unilm/tree/master/beit">Official Repo</a>

<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.23.0/mmseg/models/backbones/beit.py#1404">Code Snippet</a>

## Abstract

<!-- [ABSTRACT] -->

We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%). The code and pretrained models are available at [this https URL](https://github.com/microsoft/unilm/tree/master/beit).

linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/93248678/160155758-781c9a45-b1d7-4530-9015-88eca6645006.png" width="70%"/>
</div>

## Citation

```bibtex
@inproceedings{beit,
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=p-BhZSz59o4}
}
```

## Usage

To use other repositories' pre-trained models, it is necessary to convert keys.

We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/microsoft/unilm/tree/master/beit/semantic_segmentation) to MMSegmentation style.

```shell
python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
```

E.g.

```shell
python tools/model_converters/beit2mmseg.py https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth pretrain/beit_base_patch16_224_pt22k_ft22k.pth
```

This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.

linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
In our default setting, pretrained models could be defined below:

| pretrained models | original models |
| ------ | -------- |
|BEiT_base.pth | ['BEiT_base'](https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth) |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would provide mmseg-style pretrained models for better usage experience.

|BEiT_large.pth | ['BEiT_large'](https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth) |

Verify the single-scale results of the model:
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved

```shell
sh tools/dist_test.sh \
configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py \
upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
```

Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=640, that is, the shortest edge is 640. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:

```shell
sh tools/dist_test.sh \
configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py \
upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
```

## Results and models

### ADE20K

| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| UperNet | BEiT-B | 640x640 | ImageNet-22K | 224x224 | 16 | 160000 | 15.88 | 2.00 | 53.08 | 53.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k.log.json) |
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
| UperNet | BEiT-L | 640x640 | ImageNet-22K | 224x224 | 8 | 320000 | 22.64 | 0.96 | 56.33 | 56.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.log.json) |
45 changes: 45 additions & 0 deletions configs/beit/beit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
Models:
- Name: upernet_beit-base_8x2_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: BEiT-B
crop size: (640,640)
lr schd: 160000
inference time (ms/im):
- value: 500.0
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (640,640)
Training Memory (GB): 15.88
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 53.08
mIoU(ms+flip): 53.84
Config: configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
- Name: upernet_beit-large_fp16_8x1_640x640_160k_ade20k
In Collection: UperNet
Metadata:
backbone: BEiT-L
crop size: (640,640)
lr schd: 320000
linfangjian01 marked this conversation as resolved.
Show resolved Hide resolved
inference time (ms/im):
- value: 1041.67
hardware: V100
backend: PyTorch
batch size: 1
mode: FP16
resolution: (640,640)
Training Memory (GB): 22.64
Results:
- Task: Semantic Segmentation
Dataset: ADE20K
Metrics:
mIoU: 56.33
mIoU(ms+flip): 56.84
Config: configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth
24 changes: 24 additions & 0 deletions configs/beit/upernet_beit-base_640x640_160k_ade20k_ms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = './upernet_beit-base_8x2_640x640_160k_ade20k.py'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True, min_size=640),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline),
samples_per_gpu=2)
30 changes: 30 additions & 0 deletions configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
_base_ = [
'../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]

model = dict(
pretrained='pretrain/beit_base_patch16_224_pt22k_ft22k.pth',
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))

optimizer = dict(
_delete_=True,
type='AdamW',
lr=3e-5,
betas=(0.9, 0.999),
weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=1500,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

# By default, models are trained on 8 GPUs with 2 images per GPU
data = dict(samples_per_gpu=2)
22 changes: 22 additions & 0 deletions configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
_base_ = './upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2560, 640),
img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True, min_size=640),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline))
47 changes: 47 additions & 0 deletions configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
_base_ = [
'../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_320k.py'
]

model = dict(
pretrained='pretrain/beit_large_patch16_224_pt22k_ft22k.pth',
backbone=dict(
type='BEiT',
embed_dims=1024,
num_layers=24,
num_heads=16,
mlp_ratio=4,
qv_bias=True,
init_values=1e-6,
drop_path_rate=0.2,
out_indices=[7, 11, 15, 23]),
neck=dict(embed_dim=1024, rescales=[4, 2, 1, 0.5]),
decode_head=dict(
in_channels=[1024, 1024, 1024, 1024], num_classes=150, channels=1024),
auxiliary_head=dict(in_channels=1024, num_classes=150),
test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))

optimizer = dict(
_delete_=True,
type='AdamW',
lr=2e-5,
betas=(0.9, 0.999),
weight_decay=0.05,
constructor='LayerDecayOptimizerConstructor',
paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95))

lr_config = dict(
_delete_=True,
policy='poly',
warmup='linear',
warmup_iters=3000,
warmup_ratio=1e-6,
power=1.0,
min_lr=0.0,
by_epoch=False)

data = dict(samples_per_gpu=1)
optimizer_config = dict(
type='GradientCumulativeFp16OptimizerHook', cumulative_iters=2)

fp16 = dict()
2 changes: 2 additions & 0 deletions mmseg/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
from .layer_decay_optimizer_constructor import \
LayerDecayOptimizerConstructor # noqa: F401
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
Loading