diff --git a/README.md b/README.md
index 74b0ce0b56..8ed4acac1f 100644
--- a/README.md
+++ b/README.md
@@ -85,6 +85,7 @@ Supported backbones:
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
+- [x] [BEiT (ICLR'2022)](configs/beit)
Supported methods:
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 6adea211ff..acaa12e489 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -84,6 +84,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [Swin Transformer (ICCV'2021)](configs/swin)
- [x] [Twins (NeurIPS'2021)](configs/twins)
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
+- [x] [BEiT (ICLR'2022)](configs/beit)
已支持的算法:
diff --git a/configs/_base_/models/upernet_beit.py b/configs/_base_/models/upernet_beit.py
new file mode 100644
index 0000000000..9c5bfa3310
--- /dev/null
+++ b/configs/_base_/models/upernet_beit.py
@@ -0,0 +1,50 @@
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ pretrained=None,
+ backbone=dict(
+ type='BEiT',
+ img_size=(640, 640),
+ patch_size=16,
+ in_channels=3,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=(3, 5, 7, 11),
+ qv_bias=True,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_cfg=dict(type='LN', eps=1e-6),
+ act_cfg=dict(type='GELU'),
+ norm_eval=False,
+ init_values=0.1),
+ neck=dict(type='Feature2Pyramid', embed_dim=768, rescales=[4, 2, 1, 0.5]),
+ decode_head=dict(
+ type='UPerHead',
+ in_channels=[768, 768, 768, 768],
+ in_index=[0, 1, 2, 3],
+ pool_scales=(1, 2, 3, 6),
+ channels=768,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ auxiliary_head=dict(
+ type='FCNHead',
+ in_channels=768,
+ in_index=2,
+ channels=256,
+ num_convs=1,
+ concat_input=False,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/configs/beit/README.md b/configs/beit/README.md
new file mode 100644
index 0000000000..31bf285356
--- /dev/null
+++ b/configs/beit/README.md
@@ -0,0 +1,84 @@
+# BEiT
+
+[BEiT: BERT Pre-Training of Image Transformers](https://arxiv.org/abs/2106.08254)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+We introduce a self-supervised vision representation model BEiT, which stands for Bidirectional Encoder representation from Image Transformers. Following BERT developed in the natural language processing area, we propose a masked image modeling task to pretrain vision Transformers. Specifically, each image has two views in our pre-training, i.e, image patches (such as 16x16 pixels), and visual tokens (i.e., discrete tokens). We first "tokenize" the original image into visual tokens. Then we randomly mask some image patches and fed them into the backbone Transformer. The pre-training objective is to recover the original visual tokens based on the corrupted image patches. After pre-training BEiT, we directly fine-tune the model parameters on downstream tasks by appending task layers upon the pretrained encoder. Experimental results on image classification and semantic segmentation show that our model achieves competitive results with previous pre-training methods. For example, base-size BEiT achieves 83.2% top-1 accuracy on ImageNet-1K, significantly outperforming from-scratch DeiT training (81.8%) with the same setup. Moreover, large-size BEiT obtains 86.3% only using ImageNet-1K, even outperforming ViT-L with supervised pre-training on ImageNet-22K (85.2%). The code and pretrained models are available at [this https URL](https://github.com/microsoft/unilm/tree/master/beit).
+
+
+
+
+
+
+## Citation
+
+```bibtex
+@inproceedings{beit,
+ title={{BEiT}: {BERT} Pre-Training of Image Transformers},
+ author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
+ booktitle={International Conference on Learning Representations},
+ year={2022},
+ url={https://openreview.net/forum?id=p-BhZSz59o4}
+}
+```
+
+## Usage
+
+To use other repositories' pre-trained models, it is necessary to convert keys.
+
+We provide a script [`beit2mmseg.py`](../../tools/model_converters/beit2mmseg.py) in the tools directory to convert the key of models from [the official repo](https://github.com/microsoft/unilm/tree/master/beit/semantic_segmentation) to MMSegmentation style.
+
+```shell
+python tools/model_converters/beit2mmseg.py ${PRETRAIN_PATH} ${STORE_PATH}
+```
+
+E.g.
+
+```shell
+python tools/model_converters/beit2mmseg.py https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth pretrain/beit_base_patch16_224_pt22k_ft22k.pth
+```
+
+This script convert model from `PRETRAIN_PATH` and store the converted model in `STORE_PATH`.
+
+In our default setting, pretrained models could be defined below:
+
+ | pretrained models | original models |
+ | ------ | -------- |
+ |BEiT_base.pth | ['BEiT_base'](https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth) |
+ |BEiT_large.pth | ['BEiT_large'](https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth) |
+
+Verify the single-scale results of the model:
+
+```shell
+sh tools/dist_test.sh \
+configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py \
+upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
+```
+
+Since relative position embedding requires the input length and width to be equal, the sliding window is adopted for multi-scale inference. So we set min_size=640, that is, the shortest edge is 640. So the multi-scale inference of config is performed separately, instead of '--aug-test'. For multi-scale inference:
+
+```shell
+sh tools/dist_test.sh \
+configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py \
+upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth $GPUS --eval mIoU
+```
+
+## Results and models
+
+### ADE20K
+
+| Method | Backbone | Crop Size | pretrain | pretrain img size | Batch Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ------ | -------- | --------- | ---------- | ------- | -------- | --- | --- | -------------- | ----- | ------------: | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| UperNet | BEiT-B | 640x640 | ImageNet-22K | 224x224 | 16 | 160000 | 15.88 | 2.00 | 53.08 | 53.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k.log.json) |
+| UperNet | BEiT-L | 640x640 | ImageNet-22K | 224x224 | 8 | 320000 | 22.64 | 0.96 | 56.33 | 56.84 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.log.json) |
diff --git a/configs/beit/beit.yml b/configs/beit/beit.yml
new file mode 100644
index 0000000000..6f3cee3edd
--- /dev/null
+++ b/configs/beit/beit.yml
@@ -0,0 +1,45 @@
+Models:
+- Name: upernet_beit-base_8x2_640x640_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: BEiT-B
+ crop size: (640,640)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 500.0
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (640,640)
+ Training Memory (GB): 15.88
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 53.08
+ mIoU(ms+flip): 53.84
+ Config: configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-base_8x2_640x640_160k_ade20k/upernet_beit-base_8x2_640x640_160k_ade20k-eead221d.pth
+- Name: upernet_beit-large_fp16_8x1_640x640_160k_ade20k
+ In Collection: UperNet
+ Metadata:
+ backbone: BEiT-L
+ crop size: (640,640)
+ lr schd: 320000
+ inference time (ms/im):
+ - value: 1041.67
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP16
+ resolution: (640,640)
+ Training Memory (GB): 22.64
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 56.33
+ mIoU(ms+flip): 56.84
+ Config: configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k/upernet_beit-large_fp16_8x1_640x640_160k_ade20k-8fc0dd5d.pth
diff --git a/configs/beit/upernet_beit-base_640x640_160k_ade20k_ms.py b/configs/beit/upernet_beit-base_640x640_160k_ade20k_ms.py
new file mode 100644
index 0000000000..f764c92c11
--- /dev/null
+++ b/configs/beit/upernet_beit-base_640x640_160k_ade20k_ms.py
@@ -0,0 +1,24 @@
+_base_ = './upernet_beit-base_8x2_640x640_160k_ade20k.py'
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2560, 640),
+ img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True, min_size=640),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ val=dict(pipeline=test_pipeline),
+ test=dict(pipeline=test_pipeline),
+ samples_per_gpu=2)
diff --git a/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py b/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
new file mode 100644
index 0000000000..b36adc3c0d
--- /dev/null
+++ b/configs/beit/upernet_beit-base_8x2_640x640_160k_ade20k.py
@@ -0,0 +1,30 @@
+_base_ = [
+ '../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
+]
+
+model = dict(
+ pretrained='pretrain/beit_base_patch16_224_pt22k_ft22k.pth',
+ test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))
+
+optimizer = dict(
+ _delete_=True,
+ type='AdamW',
+ lr=3e-5,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ constructor='LayerDecayOptimizerConstructor',
+ paramwise_cfg=dict(num_layers=12, layer_decay_rate=0.9))
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=1500,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+# By default, models are trained on 8 GPUs with 2 images per GPU
+data = dict(samples_per_gpu=2)
diff --git a/configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py b/configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py
new file mode 100644
index 0000000000..fd4d9477d4
--- /dev/null
+++ b/configs/beit/upernet_beit-large_fp16_640x640_160k_ade20k_ms.py
@@ -0,0 +1,22 @@
+_base_ = './upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py'
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(2560, 640),
+ img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True, min_size=640),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline))
diff --git a/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py b/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
new file mode 100644
index 0000000000..e6247b7352
--- /dev/null
+++ b/configs/beit/upernet_beit-large_fp16_8x1_640x640_160k_ade20k.py
@@ -0,0 +1,47 @@
+_base_ = [
+ '../_base_/models/upernet_beit.py', '../_base_/datasets/ade20k_640x640.py',
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_320k.py'
+]
+
+model = dict(
+ pretrained='pretrain/beit_large_patch16_224_pt22k_ft22k.pth',
+ backbone=dict(
+ type='BEiT',
+ embed_dims=1024,
+ num_layers=24,
+ num_heads=16,
+ mlp_ratio=4,
+ qv_bias=True,
+ init_values=1e-6,
+ drop_path_rate=0.2,
+ out_indices=[7, 11, 15, 23]),
+ neck=dict(embed_dim=1024, rescales=[4, 2, 1, 0.5]),
+ decode_head=dict(
+ in_channels=[1024, 1024, 1024, 1024], num_classes=150, channels=1024),
+ auxiliary_head=dict(in_channels=1024, num_classes=150),
+ test_cfg=dict(mode='slide', crop_size=(640, 640), stride=(426, 426)))
+
+optimizer = dict(
+ _delete_=True,
+ type='AdamW',
+ lr=2e-5,
+ betas=(0.9, 0.999),
+ weight_decay=0.05,
+ constructor='LayerDecayOptimizerConstructor',
+ paramwise_cfg=dict(num_layers=24, layer_decay_rate=0.95))
+
+lr_config = dict(
+ _delete_=True,
+ policy='poly',
+ warmup='linear',
+ warmup_iters=3000,
+ warmup_ratio=1e-6,
+ power=1.0,
+ min_lr=0.0,
+ by_epoch=False)
+
+data = dict(samples_per_gpu=1)
+optimizer_config = dict(
+ type='GradientCumulativeFp16OptimizerHook', cumulative_iters=2)
+
+fp16 = dict()
diff --git a/mmseg/core/__init__.py b/mmseg/core/__init__.py
index 402278618e..c60b48c0c6 100644
--- a/mmseg/core/__init__.py
+++ b/mmseg/core/__init__.py
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .evaluation import * # noqa: F401, F403
+from .layer_decay_optimizer_constructor import \
+ LayerDecayOptimizerConstructor # noqa: F401
from .seg import * # noqa: F401, F403
from .utils import * # noqa: F401, F403
diff --git a/mmseg/core/layer_decay_optimizer_constructor.py b/mmseg/core/layer_decay_optimizer_constructor.py
new file mode 100644
index 0000000000..30a09ba08e
--- /dev/null
+++ b/mmseg/core/layer_decay_optimizer_constructor.py
@@ -0,0 +1,87 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from mmcv.runner import (OPTIMIZER_BUILDERS, DefaultOptimizerConstructor,
+ get_dist_info)
+
+from mmseg.utils import get_root_logger
+
+
+def get_num_layer_for_vit(var_name, num_max_layer):
+ """Get the layer id to set the different learning rates.
+
+ Args:
+ var_name (str): The key of the model.
+ num_max_layer (int): Maximum number of backbone layers.
+ Returns:
+ layer id (int): Returns the layer id of the key.
+ """
+
+ if var_name in ('backbone.cls_token', 'backbone.mask_token',
+ 'backbone.pos_embed'):
+ return 0
+ elif var_name.startswith('backbone.patch_embed'):
+ return 0
+ elif var_name.startswith('backbone.layers'):
+ layer_id = int(var_name.split('.')[2])
+ return layer_id + 1
+ else:
+ return num_max_layer - 1
+
+
+@OPTIMIZER_BUILDERS.register_module()
+class LayerDecayOptimizerConstructor(DefaultOptimizerConstructor):
+ """Different learning rates are set for different layers of backbone."""
+
+ def add_params(self, params, module):
+ """Add all parameters of module to the params list.
+
+ The parameters of the given module will be added to the list of param
+ groups, with specific rules defined by paramwise_cfg.
+ Args:
+ params (list[dict]): A list of param groups, it will be modified
+ in place.
+ module (nn.Module): The module to be added.
+ """
+ parameter_groups = {}
+ logger = get_root_logger()
+ logger.info(self.paramwise_cfg)
+ num_layers = self.paramwise_cfg.get('num_layers') + 2
+ layer_decay_rate = self.paramwise_cfg.get('layer_decay_rate')
+ logger.info(f'Build LayerDecayOptimizerConstructor '
+ f'{layer_decay_rate} - {num_layers}')
+ weight_decay = self.base_wd
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith('.bias') or name in (
+ 'pos_embed', 'cls_token'):
+ group_name = 'no_decay'
+ this_weight_decay = 0.
+ else:
+ group_name = 'decay'
+ this_weight_decay = weight_decay
+ layer_id = get_num_layer_for_vit(name, num_layers)
+ group_name = f'layer_{layer_id}_{group_name}'
+ if group_name not in parameter_groups:
+ scale = layer_decay_rate**(num_layers - layer_id - 1)
+ parameter_groups[group_name] = {
+ 'weight_decay': this_weight_decay,
+ 'params': [],
+ 'param_names': [],
+ 'lr_scale': scale,
+ 'group_name': group_name,
+ 'lr': scale * self.base_lr
+ }
+ parameter_groups[group_name]['params'].append(param)
+ parameter_groups[group_name]['param_names'].append(name)
+ rank, _ = get_dist_info()
+ if rank == 0:
+ to_display = {}
+ for key in parameter_groups:
+ to_display[key] = {
+ 'param_names': parameter_groups[key]['param_names'],
+ 'lr_scale': parameter_groups[key]['lr_scale'],
+ 'lr': parameter_groups[key]['lr'],
+ 'weight_decay': parameter_groups[key]['weight_decay']
+ }
+ logger.info(f'Param groups ={to_display}')
+ params.extend(parameter_groups.values())
diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py
index 434378e993..1ede4874da 100644
--- a/mmseg/models/backbones/__init__.py
+++ b/mmseg/models/backbones/__init__.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .beit import BEiT
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
@@ -24,5 +25,5 @@
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
- 'SVT', 'STDCNet', 'STDCContextPathNet'
+ 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT'
]
diff --git a/mmseg/models/backbones/beit.py b/mmseg/models/backbones/beit.py
new file mode 100644
index 0000000000..26be3156fe
--- /dev/null
+++ b/mmseg/models/backbones/beit.py
@@ -0,0 +1,532 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import build_norm_layer
+from mmcv.cnn.bricks.drop import build_dropout
+from mmcv.cnn.bricks.transformer import FFN
+from mmcv.cnn.utils.weight_init import (constant_init, kaiming_init,
+ trunc_normal_)
+from mmcv.runner import BaseModule, ModuleList, _load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.utils import _pair as to_2tuple
+
+from mmseg.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import PatchEmbed
+
+try:
+ from scipy import interpolate
+except ImportError:
+ interpolate = None
+
+
+class BEiTAttention(BaseModule):
+ """Window based multi-head self-attention (W-MSA) module with relative
+ position bias.
+
+ Args:
+ embed_dims (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (tuple[int]): The height and width of the window.
+ qv_bias (bool): If True, add a learnable bias to q, v.
+ Default: True.
+ qk_scale (float | None, optional): Override default qk scale of
+ head_dim ** -0.5 if set. Default: None.
+ attn_drop_rate (float): Dropout ratio of attention weight.
+ Default: 0.0
+ proj_drop_rate (float): Dropout ratio of output. Default: 0.
+ init_cfg (dict | None, optional): The Config for initialization.
+ Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ window_size,
+ qv_bias=True,
+ qk_scale=None,
+ attn_drop_rate=0.,
+ proj_drop_rate=0.,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ head_embed_dims = embed_dims // num_heads
+ self.scale = qk_scale or head_embed_dims**-0.5
+ if qv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(embed_dims))
+ self.v_bias = nn.Parameter(torch.zeros(embed_dims))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+
+ self.window_size = window_size
+ # cls to token & token 2 cls & cls to cls
+ self.num_relative_distance = (2 * window_size[0] -
+ 1) * (2 * window_size[1] - 1) + 3
+ # relative_position_bias_table shape is (2*Wh-1 * 2*Ww-1 + 3, nH)
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads))
+
+ # get pair-wise relative position index for
+ # each token inside the window
+ coords_h = torch.arange(window_size[0])
+ coords_w = torch.arange(window_size[1])
+ # coords shape is (2, Wh, Ww)
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
+ # coords_flatten shape is (2, Wh*Ww)
+ coords_flatten = torch.flatten(coords, 1)
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :])
+ # relative_coords shape is (Wh*Ww, Wh*Ww, 2)
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
+ # shift to start from 0
+ relative_coords[:, :, 0] += window_size[0] - 1
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(
+ size=(window_size[0] * window_size[1] + 1, ) * 2,
+ dtype=relative_coords.dtype)
+ # relative_position_index shape is (Wh*Ww, Wh*Ww)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1)
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
+ relative_position_index[0, 0] = self.num_relative_distance - 1
+
+ self.register_buffer('relative_position_index',
+ relative_position_index)
+ self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=False)
+ self.attn_drop = nn.Dropout(attn_drop_rate)
+ self.proj = nn.Linear(embed_dims, embed_dims)
+ self.proj_drop = nn.Dropout(proj_drop_rate)
+
+ def init_weights(self):
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
+
+ def forward(self, x):
+ """
+ Args:
+ x (tensor): input features with shape of (num_windows*B, N, C).
+ """
+ B, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ k_bias = torch.zeros_like(self.v_bias, requires_grad=False)
+ qkv_bias = torch.cat((self.q_bias, k_bias, self.v_bias))
+
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+ if self.relative_position_bias_table is not None:
+ Wh = self.window_size[0]
+ Ww = self.window_size[1]
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ Wh * Ww + 1, Wh * Ww + 1, -1)
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class TransformerEncoderLayer(BaseModule):
+ """Implements one encoder layer in Vision Transformer.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default: 0.0.
+ drop_path_rate (float): Stochastic depth rate. Default 0.0.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Default: 2.
+ qv_bias (bool): Enable bias for qv if True. Default: True
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN').
+ window_size (tuple[int], optional): The height and width of the window.
+ Default: None.
+ init_values (float, optional): Initialize the values of BEiTAttention
+ and FFN with learnable scaling. Default: None.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ num_fcs=2,
+ qv_bias=True,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='LN'),
+ window_size=None,
+ init_values=None):
+ super(TransformerEncoderLayer, self).__init__()
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, embed_dims, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.attn = BEiTAttention(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ window_size=window_size,
+ qv_bias=qv_bias,
+ qk_scale=None,
+ attn_drop_rate=attn_drop_rate,
+ proj_drop_rate=0.,
+ init_cfg=None)
+ self.ffn = FFN(
+ embed_dims=embed_dims,
+ feedforward_channels=feedforward_channels,
+ num_fcs=num_fcs,
+ ffn_drop=0.,
+ dropout_layer=None,
+ act_cfg=act_cfg,
+ add_identity=False)
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, embed_dims, postfix=2)
+ self.add_module(self.norm2_name, norm2)
+ # NOTE: drop path for stochastic depth, we shall see if
+ # this is better than dropout here
+ dropout_layer = dict(type='DropPath', drop_prob=drop_path_rate)
+ self.drop_path = build_dropout(
+ dropout_layer) if dropout_layer else nn.Identity()
+ self.gamma_1 = nn.Parameter(
+ init_values * torch.ones((embed_dims)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(
+ init_values * torch.ones((embed_dims)), requires_grad=True)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.gamma_2 * self.ffn(self.norm2(x)))
+ return x
+
+
+@BACKBONES.register_module()
+class BEiT(BaseModule):
+ """BERT Pre-Training of Image Transformers.
+
+ Args:
+ img_size (int | tuple): Input image size. Default: 224.
+ patch_size (int): The patch size. Default: 16.
+ in_channels (int): Number of input channels. Default: 3.
+ embed_dims (int): Embedding dimension. Default: 768.
+ num_layers (int): Depth of transformer. Default: 12.
+ num_heads (int): Number of attention heads. Default: 12.
+ mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
+ Default: 4.
+ out_indices (list | tuple | int): Output from which stages.
+ Default: -1.
+ qv_bias (bool): Enable bias for qv if True. Default: True.
+ attn_drop_rate (float): The drop out rate for attention layer.
+ Default 0.0
+ drop_path_rate (float): Stochastic depth rate. Default 0.0.
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='LN')
+ act_cfg (dict): The activation config for FFNs.
+ Default: dict(type='GELU').
+ patch_norm (bool): Whether to add a norm in PatchEmbed Block.
+ Default: False.
+ final_norm (bool): Whether to add a additional layer to normalize
+ final feature map. Default: False.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Default: 2.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only. Default: False.
+ pretrained (str, optional): Model pretrained path. Default: None.
+ init_values (float): Initialize the values of BEiTAttention and FFN
+ with learnable scaling.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=16,
+ in_channels=3,
+ embed_dims=768,
+ num_layers=12,
+ num_heads=12,
+ mlp_ratio=4,
+ out_indices=-1,
+ qv_bias=True,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_cfg=dict(type='LN'),
+ act_cfg=dict(type='GELU'),
+ patch_norm=False,
+ final_norm=False,
+ num_fcs=2,
+ norm_eval=False,
+ pretrained=None,
+ init_values=0.1,
+ init_cfg=None):
+ super(BEiT, self).__init__(init_cfg=init_cfg)
+ if isinstance(img_size, int):
+ img_size = to_2tuple(img_size)
+ elif isinstance(img_size, tuple):
+ if len(img_size) == 1:
+ img_size = to_2tuple(img_size[0])
+ assert len(img_size) == 2, \
+ f'The size of image should have length 1 or 2, ' \
+ f'but got {len(img_size)}'
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be set at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is not None:
+ raise TypeError('pretrained must be a str or None')
+
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.norm_eval = norm_eval
+ self.pretrained = pretrained
+
+ self.patch_embed = PatchEmbed(
+ in_channels=in_channels,
+ embed_dims=embed_dims,
+ conv_type='Conv2d',
+ kernel_size=patch_size,
+ stride=patch_size,
+ padding=0,
+ norm_cfg=norm_cfg if patch_norm else None,
+ init_cfg=None)
+
+ window_size = (img_size[0] // patch_size, img_size[1] // patch_size)
+ self.patch_shape = window_size
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
+
+ if isinstance(out_indices, int):
+ if out_indices == -1:
+ out_indices = num_layers - 1
+ self.out_indices = [out_indices]
+ elif isinstance(out_indices, list) or isinstance(out_indices, tuple):
+ self.out_indices = out_indices
+ else:
+ raise TypeError('out_indices must be type of int, list or tuple')
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, num_layers)]
+ self.layers = ModuleList()
+ for i in range(num_layers):
+ self.layers.append(
+ TransformerEncoderLayer(
+ embed_dims=embed_dims,
+ num_heads=num_heads,
+ feedforward_channels=mlp_ratio * embed_dims,
+ attn_drop_rate=attn_drop_rate,
+ drop_path_rate=dpr[i],
+ num_fcs=num_fcs,
+ qv_bias=qv_bias,
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg,
+ window_size=window_size,
+ init_values=init_values))
+
+ self.final_norm = final_norm
+ if final_norm:
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, embed_dims, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+
+ @property
+ def norm1(self):
+ return getattr(self, self.norm1_name)
+
+ def _geometric_sequence_interpolation(self, src_size, dst_size, sequence,
+ num):
+ """Get new sequence via geometric sequence interpolation.
+
+ Args:
+ src_size (int): Pos_embedding size in pre-trained model.
+ dst_size (int): Pos_embedding size in the current model.
+ sequence (tensor): The relative position bias of the pretrain
+ model after removing the extra tokens.
+ num (int): Number of attention heads.
+ Returns:
+ new_sequence (tensor): Geometric sequence interpolate the
+ pre-trained relative position bias to the size of
+ the current model.
+ """
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r**n) / (1.0 - r)
+
+ # Here is a binary function.
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+ # The position of each interpolated point is determined
+ # by the ratio obtained by dichotomy.
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q**(i + 1)
+ r_ids = [-_ for _ in reversed(dis)]
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+ # Interpolation functions are being executed and called.
+ new_sequence = []
+ for i in range(num):
+ z = sequence[:, i].view(src_size, src_size).float().numpy()
+ f = interpolate.interp2d(x, y, z, kind='cubic')
+ new_sequence.append(
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(sequence))
+ new_sequence = torch.cat(new_sequence, dim=-1)
+ return new_sequence
+
+ def resize_rel_pos_embed(self, checkpoint):
+ """Resize relative pos_embed weights.
+
+ This function is modified from
+ https://github.com/microsoft/unilm/blob/master/beit/semantic_segmentation/mmcv_custom/checkpoint.py. # noqa: E501
+ Copyright (c) Microsoft Corporation
+ Licensed under the MIT License
+
+ Args:
+ checkpoint (dict): Key and value of the pretrain model.
+ Returns:
+ state_dict (dict): Interpolate the relative pos_embed weights
+ in the pre-train model to the current model size.
+ """
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ else:
+ state_dict = checkpoint
+
+ all_keys = list(state_dict.keys())
+ for key in all_keys:
+ if 'relative_position_index' in key:
+ state_dict.pop(key)
+ # In order to keep the center of pos_bias as consistent as
+ # possible after interpolation, and vice versa in the edge
+ # area, the geometric sequence interpolation method is adopted.
+ if 'relative_position_bias_table' in key:
+ rel_pos_bias = state_dict[key]
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
+ dst_num_pos, _ = self.state_dict()[key].size()
+ dst_patch_shape = self.patch_shape
+ if dst_patch_shape[0] != dst_patch_shape[1]:
+ raise NotImplementedError()
+ # Count the number of extra tokens.
+ num_extra_tokens = dst_num_pos - (
+ dst_patch_shape[0] * 2 - 1) * (
+ dst_patch_shape[1] * 2 - 1)
+ src_size = int((src_num_pos - num_extra_tokens)**0.5)
+ dst_size = int((dst_num_pos - num_extra_tokens)**0.5)
+ if src_size != dst_size:
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
+ new_rel_pos_bias = self._geometric_sequence_interpolation(
+ src_size, dst_size, rel_pos_bias, num_attn_heads)
+ new_rel_pos_bias = torch.cat(
+ (new_rel_pos_bias, extra_tokens), dim=0)
+ state_dict[key] = new_rel_pos_bias
+
+ return state_dict
+
+ def init_weights(self):
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ self.apply(_init_weights)
+
+ if (isinstance(self.init_cfg, dict)
+ and self.init_cfg.get('type') == 'Pretrained'):
+ logger = get_root_logger()
+ checkpoint = _load_checkpoint(
+ self.init_cfg['checkpoint'], logger=logger, map_location='cpu')
+ state_dict = self.resize_rel_pos_embed(checkpoint)
+ self.load_state_dict(state_dict, False)
+ elif self.init_cfg is not None:
+ super(BEiT, self).init_weights()
+ else:
+ # We only implement the 'jax_impl' initialization implemented at
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py#L353 # noqa: E501
+ # Copyright 2019 Ross Wightman
+ # Licensed under the Apache License, Version 2.0 (the "License")
+ trunc_normal_(self.cls_token, std=.02)
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ if 'ffn' in n:
+ nn.init.normal_(m.bias, mean=0., std=1e-6)
+ else:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ kaiming_init(m, mode='fan_in', bias=0.)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm, nn.LayerNorm)):
+ constant_init(m, val=1.0, bias=0.)
+
+ def forward(self, inputs):
+ B = inputs.shape[0]
+
+ x, hw_shape = self.patch_embed(inputs)
+
+ # stole cls_tokens impl from Phil Wang, thanks
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ outs = []
+ for i, layer in enumerate(self.layers):
+ x = layer(x)
+ if i == len(self.layers) - 1:
+ if self.final_norm:
+ x = self.norm1(x)
+ if i in self.out_indices:
+ # Remove class token and reshape token for decoder head
+ out = x[:, 1:]
+ B, _, C = out.shape
+ out = out.reshape(B, hw_shape[0], hw_shape[1],
+ C).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(BEiT, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.LayerNorm):
+ m.eval()
diff --git a/mmseg/models/necks/__init__.py b/mmseg/models/necks/__init__.py
index aba73f165b..ff03186a92 100644
--- a/mmseg/models/necks/__init__.py
+++ b/mmseg/models/necks/__init__.py
@@ -1,8 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .featurepyramid import Feature2Pyramid
from .fpn import FPN
from .ic_neck import ICNeck
from .jpu import JPU
from .mla_neck import MLANeck
from .multilevel_neck import MultiLevelNeck
-__all__ = ['FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU']
+__all__ = [
+ 'FPN', 'MultiLevelNeck', 'MLANeck', 'ICNeck', 'JPU', 'Feature2Pyramid'
+]
diff --git a/mmseg/models/necks/featurepyramid.py b/mmseg/models/necks/featurepyramid.py
new file mode 100644
index 0000000000..82a00ceb1c
--- /dev/null
+++ b/mmseg/models/necks/featurepyramid.py
@@ -0,0 +1,67 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import build_norm_layer
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class Feature2Pyramid(nn.Module):
+ """Feature2Pyramid.
+
+ A neck structure connect ViT backbone and decoder_heads.
+
+ Args:
+ embed_dims (int): Embedding dimension.
+ rescales (list[float]): Different sampling multiples were
+ used to obtain pyramid features. Default: [4, 2, 1, 0.5].
+ norm_cfg (dict): Config dict for normalization layer.
+ Default: dict(type='SyncBN', requires_grad=True).
+ """
+
+ def __init__(self,
+ embed_dim,
+ rescales=[4, 2, 1, 0.5],
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
+ super(Feature2Pyramid, self).__init__()
+ self.rescales = rescales
+ self.upsample_4x = None
+ for k in self.rescales:
+ if k == 4:
+ self.upsample_4x = nn.Sequential(
+ nn.ConvTranspose2d(
+ embed_dim, embed_dim, kernel_size=2, stride=2),
+ build_norm_layer(norm_cfg, embed_dim)[1],
+ nn.GELU(),
+ nn.ConvTranspose2d(
+ embed_dim, embed_dim, kernel_size=2, stride=2),
+ )
+ elif k == 2:
+ self.upsample_2x = nn.Sequential(
+ nn.ConvTranspose2d(
+ embed_dim, embed_dim, kernel_size=2, stride=2))
+ elif k == 1:
+ self.identity = nn.Identity()
+ elif k == 0.5:
+ self.downsample_2x = nn.MaxPool2d(kernel_size=2, stride=2)
+ elif k == 0.25:
+ self.downsample_4x = nn.MaxPool2d(kernel_size=4, stride=4)
+ else:
+ raise KeyError(f'invalid {k} for feature2pyramid')
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.rescales)
+ outputs = []
+ if self.upsample_4x is not None:
+ ops = [
+ self.upsample_4x, self.upsample_2x, self.identity,
+ self.downsample_2x
+ ]
+ else:
+ ops = [
+ self.upsample_2x, self.identity, self.downsample_2x,
+ self.downsample_4x
+ ]
+ for i in range(len(inputs)):
+ outputs.append(ops[i](inputs[i]))
+ return tuple(outputs)
diff --git a/model-index.yml b/model-index.yml
index 235ad7f6e7..d8e9516bf4 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -1,6 +1,7 @@
Import:
- configs/ann/ann.yml
- configs/apcnet/apcnet.yml
+- configs/beit/beit.yml
- configs/bisenetv1/bisenetv1.yml
- configs/bisenetv2/bisenetv2.yml
- configs/ccnet/ccnet.yml
diff --git a/tests/test_core/test_layer_decay_optimizer_constructor.py b/tests/test_core/test_layer_decay_optimizer_constructor.py
new file mode 100644
index 0000000000..f595d31331
--- /dev/null
+++ b/tests/test_core/test_layer_decay_optimizer_constructor.py
@@ -0,0 +1,70 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import torch.nn as nn
+
+from mmseg.core.layer_decay_optimizer_constructor import \
+ LayerDecayOptimizerConstructor
+
+layer_wise_gt_lst = [{
+ 'weight_decay': 0.0,
+ 'lr_scale': 16
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 8
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 8
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 4
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 4
+}, {
+ 'weight_decay': 0.05,
+ 'lr_scale': 2
+}, {
+ 'weight_decay': 0.0,
+ 'lr_scale': 2
+}]
+
+
+class BEiTExampleModel(nn.Module):
+
+ def __init__(self, depth):
+ super().__init__()
+ self.backbone = nn.ModuleList()
+
+ # add some variables to meet unit test coverate rate
+ self.backbone.cls_token = nn.Parameter(torch.ones(1))
+ self.backbone.patch_embed = nn.Parameter(torch.ones(1))
+ self.backbone.layers = nn.ModuleList()
+ for _ in range(depth):
+ layer = nn.Conv2d(3, 3, 1)
+ self.backbone.layers.append(layer)
+
+
+def check_beit_adamw_optimizer(optimizer, gt_lst):
+ assert isinstance(optimizer, torch.optim.AdamW)
+ assert optimizer.defaults['lr'] == 1
+ assert optimizer.defaults['weight_decay'] == 0.05
+ param_groups = optimizer.param_groups
+ # 1 layer (cls_token and patch_embed) + 3 layers * 2 (w, b) = 7 layers
+ assert len(param_groups) == 7
+ for i, param_dict in enumerate(param_groups):
+ assert param_dict['weight_decay'] == gt_lst[i]['weight_decay']
+ assert param_dict['lr_scale'] == gt_lst[i]['lr_scale']
+ assert param_dict['lr_scale'] == param_dict['lr']
+
+
+def test_beit_layer_decay_optimizer_constructor():
+
+ # paramwise_cfg with ConvNeXtExampleModel
+ model = BEiTExampleModel(depth=3)
+ optimizer_cfg = dict(
+ type='AdamW', lr=1, betas=(0.9, 0.999), weight_decay=0.05)
+ paramwise_cfg = dict(num_layers=3, layer_decay_rate=2)
+ optim_constructor = LayerDecayOptimizerConstructor(optimizer_cfg,
+ paramwise_cfg)
+ optimizer = optim_constructor(model)
+ check_beit_adamw_optimizer(optimizer, layer_wise_gt_lst)
diff --git a/tests/test_models/test_backbones/test_beit.py b/tests/test_models/test_backbones/test_beit.py
new file mode 100644
index 0000000000..cf3960894d
--- /dev/null
+++ b/tests/test_models/test_backbones/test_beit.py
@@ -0,0 +1,182 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+
+from mmseg.models.backbones.beit import BEiT
+from .utils import check_norm_state
+
+
+def test_beit_backbone():
+ with pytest.raises(TypeError):
+ # pretrained must be a string path
+ model = BEiT()
+ model.init_weights(pretrained=0)
+
+ with pytest.raises(TypeError):
+ # img_size must be int or tuple
+ model = BEiT(img_size=512.0)
+
+ with pytest.raises(TypeError):
+ # out_indices must be int ,list or tuple
+ model = BEiT(out_indices=1.)
+
+ with pytest.raises(AssertionError):
+ # The length of img_size tuple must be lower than 3.
+ BEiT(img_size=(224, 224, 224))
+
+ with pytest.raises(TypeError):
+ # Pretrained must be None or Str.
+ BEiT(pretrained=123)
+
+ # Test img_size isinstance tuple
+ imgs = torch.randn(1, 3, 224, 224)
+ model = BEiT(img_size=(224, ))
+ model.init_weights()
+ model(imgs)
+
+ # Test img_size isinstance tuple
+ imgs = torch.randn(1, 3, 224, 224)
+ model = BEiT(img_size=(224, 224))
+ model(imgs)
+
+ # Test norm_eval = True
+ model = BEiT(norm_eval=True)
+ model.train()
+
+ # Test BEiT backbone with input size of 224 and patch size of 16
+ model = BEiT()
+ model.init_weights()
+ model.train()
+
+ # Test qv_bias
+ model = BEiT(qv_bias=False)
+ model.train()
+
+ # Test out_indices = list
+ model = BEiT(out_indices=[2, 4, 8, 12])
+ model.train()
+
+ assert check_norm_state(model.modules(), True)
+
+ # Test image size = (224, 224)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+ # Test BEiT backbone with input size of 256 and patch size of 16
+ model = BEiT(img_size=(256, 256))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 256, 256)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 16, 16)
+
+ # Test BEiT backbone with input size of 32 and patch size of 16
+ model = BEiT(img_size=(32, 32))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 32, 32)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 2, 2)
+
+ # Test unbalanced size input image
+ model = BEiT(img_size=(112, 224))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 112, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 7, 14)
+
+ # Test irregular input image
+ model = BEiT(img_size=(234, 345))
+ model.init_weights()
+ model.train()
+ imgs = torch.randn(1, 3, 234, 345)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 21)
+
+ # Test init_values=0
+ model = BEiT(init_values=0)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+ # Test final norm
+ model = BEiT(final_norm=True)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+ # Test patch norm
+ model = BEiT(patch_norm=True)
+ imgs = torch.randn(1, 3, 224, 224)
+ feat = model(imgs)
+ assert feat[-1].shape == (1, 768, 14, 14)
+
+
+def test_beit_init():
+ path = 'PATH_THAT_DO_NOT_EXIST'
+ # Test all combinations of pretrained and init_cfg
+ # pretrained=None, init_cfg=None
+ model = BEiT(pretrained=None, init_cfg=None)
+ assert model.init_cfg is None
+ model.init_weights()
+
+ # pretrained=None
+ # init_cfg loads pretrain from an non-existent file
+ model = BEiT(
+ pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
+ assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
+ # Test loading a checkpoint from an non-existent file
+ with pytest.raises(OSError):
+ model.init_weights()
+
+ # test resize_rel_pos_embed
+ value = torch.randn(732, 16)
+ ckpt = {
+ 'state_dict': {
+ 'layers.0.attn.relative_position_index': 0,
+ 'layers.0.attn.relative_position_bias_table': value
+ }
+ }
+ model = BEiT(img_size=(512, 512))
+ with pytest.raises(AttributeError):
+ model.resize_rel_pos_embed(ckpt)
+
+ # pretrained=None
+ # init_cfg=123, whose type is unsupported
+ model = BEiT(pretrained=None, init_cfg=123)
+ with pytest.raises(TypeError):
+ model.init_weights()
+
+ # pretrained loads pretrain from an non-existent file
+ # init_cfg=None
+ model = BEiT(pretrained=path, init_cfg=None)
+ assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
+ # Test loading a checkpoint from an non-existent file
+ with pytest.raises(OSError):
+ model.init_weights()
+
+ # pretrained loads pretrain from an non-existent file
+ # init_cfg loads pretrain from an non-existent file
+ with pytest.raises(AssertionError):
+ model = BEiT(
+ pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
+ with pytest.raises(AssertionError):
+ model = BEiT(pretrained=path, init_cfg=123)
+
+ # pretrain=123, whose type is unsupported
+ # init_cfg=None
+ with pytest.raises(TypeError):
+ model = BEiT(pretrained=123, init_cfg=None)
+
+ # pretrain=123, whose type is unsupported
+ # init_cfg loads pretrain from an non-existent file
+ with pytest.raises(AssertionError):
+ model = BEiT(
+ pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
+
+ # pretrain=123, whose type is unsupported
+ # init_cfg=123, whose type is unsupported
+ with pytest.raises(AssertionError):
+ model = BEiT(pretrained=123, init_cfg=123)
diff --git a/tests/test_models/test_necks/test_feature2pyramid.py b/tests/test_models/test_necks/test_feature2pyramid.py
new file mode 100644
index 0000000000..44fd02c489
--- /dev/null
+++ b/tests/test_models/test_necks/test_feature2pyramid.py
@@ -0,0 +1,38 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import pytest
+import torch
+
+from mmseg.models import Feature2Pyramid
+
+
+def test_feature2pyramid():
+ # test
+ rescales = [4, 2, 1, 0.5]
+ embed_dim = 64
+ inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
+
+ fpn = Feature2Pyramid(
+ embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
+ outputs = fpn(inputs)
+ assert outputs[0].shape == torch.Size([1, 64, 128, 128])
+ assert outputs[1].shape == torch.Size([1, 64, 64, 64])
+ assert outputs[2].shape == torch.Size([1, 64, 32, 32])
+ assert outputs[3].shape == torch.Size([1, 64, 16, 16])
+
+ # test rescales = [2, 1, 0.5, 0.25]
+ rescales = [2, 1, 0.5, 0.25]
+ inputs = [torch.randn(1, embed_dim, 32, 32) for i in range(len(rescales))]
+
+ fpn = Feature2Pyramid(
+ embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
+ outputs = fpn(inputs)
+ assert outputs[0].shape == torch.Size([1, 64, 64, 64])
+ assert outputs[1].shape == torch.Size([1, 64, 32, 32])
+ assert outputs[2].shape == torch.Size([1, 64, 16, 16])
+ assert outputs[3].shape == torch.Size([1, 64, 8, 8])
+
+ # test rescales = [4, 2, 0.25, 0]
+ rescales = [4, 2, 0.25, 0]
+ with pytest.raises(KeyError):
+ fpn = Feature2Pyramid(
+ embed_dim, rescales, norm_cfg=dict(type='BN', requires_grad=True))
diff --git a/tools/model_converters/beit2mmseg.py b/tools/model_converters/beit2mmseg.py
new file mode 100644
index 0000000000..d23cfdb0b3
--- /dev/null
+++ b/tools/model_converters/beit2mmseg.py
@@ -0,0 +1,56 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import argparse
+import os.path as osp
+from collections import OrderedDict
+
+import mmcv
+import torch
+from mmcv.runner import CheckpointLoader
+
+
+def convert_beit(ckpt):
+ new_ckpt = OrderedDict()
+
+ for k, v in ckpt.items():
+ if k.startswith('patch_embed'):
+ new_key = k.replace('patch_embed.proj', 'patch_embed.projection')
+ new_ckpt[new_key] = v
+ if k.startswith('blocks'):
+ new_key = k.replace('blocks', 'layers')
+ if 'norm' in new_key:
+ new_key = new_key.replace('norm', 'ln')
+ elif 'mlp.fc1' in new_key:
+ new_key = new_key.replace('mlp.fc1', 'ffn.layers.0.0')
+ elif 'mlp.fc2' in new_key:
+ new_key = new_key.replace('mlp.fc2', 'ffn.layers.1')
+ new_ckpt[new_key] = v
+ else:
+ new_key = k
+ new_ckpt[new_key] = v
+
+ return new_ckpt
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Convert keys in official pretrained beit models to'
+ 'MMSegmentation style.')
+ parser.add_argument('src', help='src model path or url')
+ # The dst path must be a full path of the new checkpoint.
+ parser.add_argument('dst', help='save path')
+ args = parser.parse_args()
+
+ checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ weight = convert_beit(state_dict)
+ mmcv.mkdir_or_exist(osp.dirname(args.dst))
+ torch.save(weight, args.dst)
+
+
+if __name__ == '__main__':
+ main()