diff --git a/README.md b/README.md
index 69f751b9e5..a17557cb81 100644
--- a/README.md
+++ b/README.md
@@ -117,6 +117,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
+- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
diff --git a/README_zh-CN.md b/README_zh-CN.md
index 709e6ef195..e09b515ab5 100644
--- a/README_zh-CN.md
+++ b/README_zh-CN.md
@@ -98,6 +98,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [ConvNeXt (CVPR'2022)](configs/convnext)
- [x] [MAE (CVPR'2022)](configs/mae)
- [x] [PoolFormer (CVPR'2022)](configs/poolformer)
+- [x] [SegNeXt (NeurIPS'2022)](configs/segnext)
diff --git a/configs/segnext/README.md b/configs/segnext/README.md
new file mode 100644
index 0000000000..315f4e23e8
--- /dev/null
+++ b/configs/segnext/README.md
@@ -0,0 +1,63 @@
+# SegNeXt
+
+> [SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation](https://arxiv.org/abs/2209.08575)
+
+## Introduction
+
+
+
+Official Repo
+
+Code Snippet
+
+## Abstract
+
+
+
+We present SegNeXt, a simple convolutional network architecture for semantic segmentation. Recent transformer-based models have dominated the field of semantic segmentation due to the efficiency of self-attention in encoding spatial information. In this paper, we show that convolutional attention is a more efficient and effective way to encode contextual information than the self-attention mechanism in transformers. By re-examining the characteristics owned by successful segmentation models, we discover several key components leading to the performance improvement of segmentation models. This motivates us to design a novel convolutional attention network that uses cheap convolutional operations. Without bells and whistles, our SegNeXt significantly improves the performance of previous state-of-the-art methods on popular benchmarks, including ADE20K, Cityscapes, COCO-Stuff, Pascal VOC, Pascal Context, and iSAID. Notably, SegNeXt outperforms EfficientNet-L2 w/ NAS-FPN and achieves 90.6% mIoU on the Pascal VOC 2012 test leaderboard using only 1/10 parameters of it. On average, SegNeXt achieves about 2.0% mIoU improvements compared to the state-of-the-art methods on the ADE20K datasets with the same or fewer computations. Code is available at [this https URL](https://github.com/uyzhang/JSeg) (Jittor) and [this https URL](https://github.com/Visual-Attention-Network/SegNeXt) (Pytorch).
+
+
+
+
+
+
+
+## Results and models
+
+### ADE20K
+
+| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
+| ------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244.log.json) |
+| SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014.log.json) |
+| SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053.log.json) |
+| SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055.log.json) |
+
+Note:
+
+- When we integrated SegNeXt into MMSegmentation, we modified some layers' names to make them more precise and concise without changing the model architecture. Therefore, the keys of pre-trained weights are different from the [original weights](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/), but don't worry about these changes. we have converted them and uploaded the checkpoints, you might find URL of pre-trained checkpoints in config files and can use them directly for training.
+
+- The total batch size is 16. We trained for SegNeXt with a single GPU as the performance degrades significantly when using`SyncBN` (mainly in `OverlapPatchEmbed` modules of `MSCAN`) of PyTorch 1.9.
+
+- There will be subtle differences when model testing as Non-negative Matrix Factorization (NMF) in `LightHamHead` will be initialized randomly. To control this randomness, please set the random seed when model testing. You can modify [`./tools/test.py`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/tools/test.py) like:
+
+```python
+def main():
+ from mmengine.runner import seg_random_seed
+ random_seed = xxx # set random seed recorded in training log
+ set_random_seed(random_seed, deterministic=False)
+ ...
+```
+
+- This model performance is sensitive to the seed values used, please refer to the log file for the specific settings of the seed. If you choose a different seed, the results might differ from the table results. Take SegNeXt Large for example, its results range from 49.60 to 51.0.
+
+## Citation
+
+```bibtex
+@article{guo2022segnext,
+ title={SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation},
+ author={Guo, Meng-Hao and Lu, Cheng-Ze and Hou, Qibin and Liu, Zhengning and Cheng, Ming-Ming and Hu, Shi-Min},
+ journal={arXiv preprint arXiv:2209.08575},
+ year={2022}
+}
+```
diff --git a/configs/segnext/segnext.yml b/configs/segnext/segnext.yml
new file mode 100644
index 0000000000..3bcdea8403
--- /dev/null
+++ b/configs/segnext/segnext.yml
@@ -0,0 +1,103 @@
+Collections:
+- Name: SegNeXt
+ Metadata:
+ Training Data:
+ - ADE20K
+ Paper:
+ URL: https://arxiv.org/abs/2209.08575
+ Title: 'SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation'
+ README: configs/segnext/README.md
+ Code:
+ URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/mscan.py#L328
+ Version: dev-1.x
+ Converted From:
+ Code: https://github.com/visual-attention-network/segnext
+Models:
+- Name: segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512
+ In Collection: SegNeXt
+ Metadata:
+ backbone: MSCAN-T
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 19.09
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 17.88
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 41.5
+ mIoU(ms+flip): 42.59
+ Config: configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth
+- Name: segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512
+ In Collection: SegNeXt
+ Metadata:
+ backbone: MSCAN-S
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 23.66
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 21.47
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 44.16
+ mIoU(ms+flip): 45.81
+ Config: configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth
+- Name: segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512
+ In Collection: SegNeXt
+ Metadata:
+ backbone: MSCAN-B
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 28.45
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 31.03
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 48.03
+ mIoU(ms+flip): 49.68
+ Config: configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth
+- Name: segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512
+ In Collection: SegNeXt
+ Metadata:
+ backbone: MSCAN-L
+ crop size: (512,512)
+ lr schd: 160000
+ inference time (ms/im):
+ - value: 43.65
+ hardware: V100
+ backend: PyTorch
+ batch size: 1
+ mode: FP32
+ resolution: (512,512)
+ Training Memory (GB): 43.32
+ Results:
+ - Task: Semantic Segmentation
+ Dataset: ADE20K
+ Metrics:
+ mIoU: 50.99
+ mIoU(ms+flip): 52.1
+ Config: configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py
+ Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth
diff --git a/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py
new file mode 100644
index 0000000000..000f448483
--- /dev/null
+++ b/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py
@@ -0,0 +1,28 @@
+_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
+
+# model settings
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth' # noqa
+ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ embed_dims=[64, 128, 320, 512],
+ depths=[3, 3, 12, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
+ drop_path_rate=0.1,
+ norm_cfg=dict(type='BN', requires_grad=True)),
+ decode_head=dict(
+ type='LightHamHead',
+ in_channels=[128, 320, 512],
+ in_index=[1, 2, 3],
+ channels=512,
+ ham_channels=512,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=ham_norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py
new file mode 100644
index 0000000000..212d0a8557
--- /dev/null
+++ b/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py
@@ -0,0 +1,27 @@
+_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
+# model settings
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_l_20230227-cef260d4.pth' # noqa
+ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ embed_dims=[64, 128, 320, 512],
+ depths=[3, 5, 27, 3],
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
+ drop_path_rate=0.3,
+ norm_cfg=dict(type='BN', requires_grad=True)),
+ decode_head=dict(
+ type='LightHamHead',
+ in_channels=[128, 320, 512],
+ in_index=[1, 2, 3],
+ channels=1024,
+ ham_channels=1024,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=ham_norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py
new file mode 100644
index 0000000000..9a90779a60
--- /dev/null
+++ b/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py
@@ -0,0 +1,27 @@
+_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py'
+# model settings
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth' # noqa
+ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ backbone=dict(
+ embed_dims=[64, 128, 320, 512],
+ depths=[2, 2, 4, 2],
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
+ norm_cfg=dict(type='BN', requires_grad=True)),
+ decode_head=dict(
+ type='LightHamHead',
+ in_channels=[128, 320, 512],
+ in_index=[1, 2, 3],
+ channels=256,
+ ham_channels=256,
+ ham_kwargs=dict(MD_R=16),
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=ham_norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
diff --git a/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py
new file mode 100644
index 0000000000..c8d6da85ff
--- /dev/null
+++ b/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py
@@ -0,0 +1,84 @@
+_base_ = [
+ '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py',
+ '../_base_/datasets/ade20k.py'
+]
+# model settings
+checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth' # noqa
+ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
+crop_size = (512, 512)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255,
+ size=(512, 512),
+ test_cfg=dict(size_divisor=32))
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ pretrained=None,
+ backbone=dict(
+ type='MSCAN',
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
+ embed_dims=[32, 64, 160, 256],
+ mlp_ratios=[8, 8, 4, 4],
+ drop_rate=0.0,
+ drop_path_rate=0.1,
+ depths=[3, 3, 5, 2],
+ attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
+ attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='BN', requires_grad=True)),
+ decode_head=dict(
+ type='LightHamHead',
+ in_channels=[64, 160, 256],
+ in_index=[1, 2, 3],
+ channels=256,
+ ham_channels=256,
+ dropout_ratio=0.1,
+ num_classes=150,
+ norm_cfg=ham_norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ ham_kwargs=dict(
+ MD_S=1,
+ MD_R=16,
+ train_steps=6,
+ eval_steps=7,
+ inv_t=100,
+ rand_init=True)),
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+# dataset settings
+train_dataloader = dict(batch_size=16)
+
+# optimizer
+optim_wrapper = dict(
+ _delete_=True,
+ type='OptimWrapper',
+ optimizer=dict(
+ type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
+ paramwise_cfg=dict(
+ custom_keys={
+ 'pos_block': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.),
+ 'head': dict(lr_mult=10.)
+ }))
+
+param_scheduler = [
+ dict(
+ type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
+ dict(
+ type='PolyLR',
+ power=1.0,
+ begin=1500,
+ end=160000,
+ eta_min=0.0,
+ by_epoch=False,
+ )
+]
diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py
index 909b54f3ec..e3107306ea 100644
--- a/mmseg/models/backbones/__init__.py
+++ b/mmseg/models/backbones/__init__.py
@@ -11,6 +11,7 @@
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
+from .mscan import MSCAN
from .pidnet import PIDNet
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
@@ -27,5 +28,5 @@
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
- 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet'
+ 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
]
diff --git a/mmseg/models/backbones/mscan.py b/mmseg/models/backbones/mscan.py
new file mode 100644
index 0000000000..7150cb7a1c
--- /dev/null
+++ b/mmseg/models/backbones/mscan.py
@@ -0,0 +1,467 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Originally from https://github.com/visual-attention-network/segnext
+# Licensed under the Apache License, Version 2.0 (the "License")
+import math
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import build_activation_layer, build_norm_layer
+from mmcv.cnn.bricks import DropPath
+from mmengine.model import BaseModule
+from mmengine.model.weight_init import (constant_init, normal_init,
+ trunc_normal_init)
+
+from mmseg.registry import MODELS
+
+
+class Mlp(BaseModule):
+ """Multi Layer Perceptron (MLP) Module.
+
+ Args:
+ in_features (int): The dimension of input features.
+ hidden_features (int): The dimension of hidden features.
+ Defaults: None.
+ out_features (int): The dimension of output features.
+ Defaults: None.
+ act_cfg (dict): Config dict for activation layer in block.
+ Default: dict(type='GELU').
+ drop (float): The number of dropout rate in MLP block.
+ Defaults: 0.0.
+ """
+
+ def __init__(self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_cfg=dict(type='GELU'),
+ drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.dwconv = nn.Conv2d(
+ hidden_features,
+ hidden_features,
+ 3,
+ 1,
+ 1,
+ bias=True,
+ groups=hidden_features)
+ self.act = build_activation_layer(act_cfg)
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ """Forward function."""
+
+ x = self.fc1(x)
+
+ x = self.dwconv(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+
+ return x
+
+
+class StemConv(BaseModule):
+ """Stem Block at the beginning of Semantic Branch.
+
+ Args:
+ in_channels (int): The dimension of input channels.
+ out_channels (int): The dimension of output channels.
+ act_cfg (dict): Config dict for activation layer in block.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults: dict(type='SyncBN', requires_grad=True).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
+ super().__init__()
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels,
+ out_channels // 2,
+ kernel_size=(3, 3),
+ stride=(2, 2),
+ padding=(1, 1)),
+ build_norm_layer(norm_cfg, out_channels // 2)[1],
+ build_activation_layer(act_cfg),
+ nn.Conv2d(
+ out_channels // 2,
+ out_channels,
+ kernel_size=(3, 3),
+ stride=(2, 2),
+ padding=(1, 1)),
+ build_norm_layer(norm_cfg, out_channels)[1],
+ )
+
+ def forward(self, x):
+ """Forward function."""
+
+ x = self.proj(x)
+ _, _, H, W = x.size()
+ x = x.flatten(2).transpose(1, 2)
+ return x, H, W
+
+
+class MSCAAttention(BaseModule):
+ """Attention Module in Multi-Scale Convolutional Attention Module (MSCA).
+
+ Args:
+ channels (int): The dimension of channels.
+ kernel_sizes (list): The size of attention
+ kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
+ paddings (list): The number of
+ corresponding padding value in attention module.
+ Defaults: [2, [0, 3], [0, 5], [0, 10]].
+ """
+
+ def __init__(self,
+ channels,
+ kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
+ paddings=[2, [0, 3], [0, 5], [0, 10]]):
+ super().__init__()
+ self.conv0 = nn.Conv2d(
+ channels,
+ channels,
+ kernel_size=kernel_sizes[0],
+ padding=paddings[0],
+ groups=channels)
+ for i, (kernel_size,
+ padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])):
+ kernel_size_ = [kernel_size, kernel_size[::-1]]
+ padding_ = [padding, padding[::-1]]
+ conv_name = [f'conv{i}_1', f'conv{i}_2']
+ for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_,
+ conv_name):
+ self.add_module(
+ i_conv,
+ nn.Conv2d(
+ channels,
+ channels,
+ tuple(i_kernel),
+ padding=i_pad,
+ groups=channels))
+ self.conv3 = nn.Conv2d(channels, channels, 1)
+
+ def forward(self, x):
+ """Forward function."""
+
+ u = x.clone()
+
+ attn = self.conv0(x)
+
+ # Multi-Scale Feature extraction
+ attn_0 = self.conv0_1(attn)
+ attn_0 = self.conv0_2(attn_0)
+
+ attn_1 = self.conv1_1(attn)
+ attn_1 = self.conv1_2(attn_1)
+
+ attn_2 = self.conv2_1(attn)
+ attn_2 = self.conv2_2(attn_2)
+
+ attn = attn + attn_0 + attn_1 + attn_2
+ # Channel Mixing
+ attn = self.conv3(attn)
+
+ # Convolutional Attention
+ x = attn * u
+
+ return x
+
+
+class MSCASpatialAttention(BaseModule):
+ """Spatial Attention Module in Multi-Scale Convolutional Attention Module
+ (MSCA).
+
+ Args:
+ in_channels (int): The dimension of channels.
+ attention_kernel_sizes (list): The size of attention
+ kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
+ attention_kernel_paddings (list): The number of
+ corresponding padding value in attention module.
+ Defaults: [2, [0, 3], [0, 5], [0, 10]].
+ act_cfg (dict): Config dict for activation layer in block.
+ Default: dict(type='GELU').
+ """
+
+ def __init__(self,
+ in_channels,
+ attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
+ attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
+ act_cfg=dict(type='GELU')):
+ super().__init__()
+ self.proj_1 = nn.Conv2d(in_channels, in_channels, 1)
+ self.activation = build_activation_layer(act_cfg)
+ self.spatial_gating_unit = MSCAAttention(in_channels,
+ attention_kernel_sizes,
+ attention_kernel_paddings)
+ self.proj_2 = nn.Conv2d(in_channels, in_channels, 1)
+
+ def forward(self, x):
+ """Forward function."""
+
+ shorcut = x.clone()
+ x = self.proj_1(x)
+ x = self.activation(x)
+ x = self.spatial_gating_unit(x)
+ x = self.proj_2(x)
+ x = x + shorcut
+ return x
+
+
+class MSCABlock(BaseModule):
+ """Basic Multi-Scale Convolutional Attention Block. It leverage the large-
+ kernel attention (LKA) mechanism to build both channel and spatial
+ attention. In each branch, it uses two depth-wise strip convolutions to
+ approximate standard depth-wise convolutions with large kernels. The kernel
+ size for each branch is set to 7, 11, and 21, respectively.
+
+ Args:
+ channels (int): The dimension of channels.
+ attention_kernel_sizes (list): The size of attention
+ kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]].
+ attention_kernel_paddings (list): The number of
+ corresponding padding value in attention module.
+ Defaults: [2, [0, 3], [0, 5], [0, 10]].
+ mlp_ratio (float): The ratio of multiple input dimension to
+ calculate hidden feature in MLP layer. Defaults: 4.0.
+ drop (float): The number of dropout rate in MLP block.
+ Defaults: 0.0.
+ drop_path (float): The ratio of drop paths.
+ Defaults: 0.0.
+ act_cfg (dict): Config dict for activation layer in block.
+ Default: dict(type='GELU').
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults: dict(type='SyncBN', requires_grad=True).
+ """
+
+ def __init__(self,
+ channels,
+ attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
+ attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
+ mlp_ratio=4.,
+ drop=0.,
+ drop_path=0.,
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
+ super().__init__()
+ self.norm1 = build_norm_layer(norm_cfg, channels)[1]
+ self.attn = MSCASpatialAttention(channels, attention_kernel_sizes,
+ attention_kernel_paddings, act_cfg)
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = build_norm_layer(norm_cfg, channels)[1]
+ mlp_hidden_channels = int(channels * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=channels,
+ hidden_features=mlp_hidden_channels,
+ act_cfg=act_cfg,
+ drop=drop)
+ layer_scale_init_value = 1e-2
+ self.layer_scale_1 = nn.Parameter(
+ layer_scale_init_value * torch.ones(channels), requires_grad=True)
+ self.layer_scale_2 = nn.Parameter(
+ layer_scale_init_value * torch.ones(channels), requires_grad=True)
+
+ def forward(self, x, H, W):
+ """Forward function."""
+
+ B, N, C = x.shape
+ x = x.permute(0, 2, 1).view(B, C, H, W)
+ x = x + self.drop_path(
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) *
+ self.attn(self.norm1(x)))
+ x = x + self.drop_path(
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) *
+ self.mlp(self.norm2(x)))
+ x = x.view(B, C, N).permute(0, 2, 1)
+ return x
+
+
+class OverlapPatchEmbed(BaseModule):
+ """Image to Patch Embedding.
+
+ Args:
+ patch_size (int): The patch size.
+ Defaults: 7.
+ stride (int): Stride of the convolutional layer.
+ Default: 4.
+ in_channels (int): The number of input channels.
+ Defaults: 3.
+ embed_dims (int): The dimensions of embedding.
+ Defaults: 768.
+ norm_cfg (dict): Config dict for normalization layer.
+ Defaults: dict(type='SyncBN', requires_grad=True).
+ """
+
+ def __init__(self,
+ patch_size=7,
+ stride=4,
+ in_channels=3,
+ embed_dim=768,
+ norm_cfg=dict(type='SyncBN', requires_grad=True)):
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ in_channels,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=patch_size // 2)
+ self.norm = build_norm_layer(norm_cfg, embed_dim)[1]
+
+ def forward(self, x):
+ """Forward function."""
+
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = self.norm(x)
+
+ x = x.flatten(2).transpose(1, 2)
+
+ return x, H, W
+
+
+@MODELS.register_module()
+class MSCAN(BaseModule):
+ """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone.
+
+ This backbone is the implementation of `SegNeXt: Rethinking
+ Convolutional Attention Design for Semantic
+ Segmentation `_.
+ Inspiration from https://github.com/visual-attention-network/segnext.
+
+ Args:
+ in_channels (int): The number of input channels. Defaults: 3.
+ embed_dims (list[int]): Embedding dimension.
+ Defaults: [64, 128, 256, 512].
+ mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim.
+ Defaults: [4, 4, 4, 4].
+ drop_rate (float): Dropout rate. Defaults: 0.
+ drop_path_rate (float): Stochastic depth rate. Defaults: 0.
+ depths (list[int]): Depths of each Swin Transformer stage.
+ Default: [3, 4, 6, 3].
+ num_stages (int): MSCAN stages. Default: 4.
+ attention_kernel_sizes (list): Size of attention kernel in
+ Attention Module (Figure 2(b) of original paper).
+ Defaults: [5, [1, 7], [1, 11], [1, 21]].
+ attention_kernel_paddings (list): Size of attention paddings
+ in Attention Module (Figure 2(b) of original paper).
+ Defaults: [2, [0, 3], [0, 5], [0, 10]].
+ norm_cfg (dict): Config of norm layers.
+ Defaults: dict(type='SyncBN', requires_grad=True).
+ pretrained (str, optional): model pretrained path.
+ Default: None.
+ init_cfg (dict or list[dict], optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels=3,
+ embed_dims=[64, 128, 256, 512],
+ mlp_ratios=[4, 4, 4, 4],
+ drop_rate=0.,
+ drop_path_rate=0.,
+ depths=[3, 4, 6, 3],
+ num_stages=4,
+ attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]],
+ attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]],
+ act_cfg=dict(type='GELU'),
+ norm_cfg=dict(type='SyncBN', requires_grad=True),
+ pretrained=None,
+ init_cfg=None):
+ super().__init__(init_cfg=init_cfg)
+
+ assert not (init_cfg and pretrained), \
+ 'init_cfg and pretrained cannot be set at the same time'
+ if isinstance(pretrained, str):
+ warnings.warn('DeprecationWarning: pretrained is deprecated, '
+ 'please use "init_cfg" instead')
+ self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
+ elif pretrained is not None:
+ raise TypeError('pretrained must be a str or None')
+
+ self.depths = depths
+ self.num_stages = num_stages
+
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+ cur = 0
+
+ for i in range(num_stages):
+ if i == 0:
+ patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg)
+ else:
+ patch_embed = OverlapPatchEmbed(
+ patch_size=7 if i == 0 else 3,
+ stride=4 if i == 0 else 2,
+ in_channels=in_channels if i == 0 else embed_dims[i - 1],
+ embed_dim=embed_dims[i],
+ norm_cfg=norm_cfg)
+
+ block = nn.ModuleList([
+ MSCABlock(
+ channels=embed_dims[i],
+ attention_kernel_sizes=attention_kernel_sizes,
+ attention_kernel_paddings=attention_kernel_paddings,
+ mlp_ratio=mlp_ratios[i],
+ drop=drop_rate,
+ drop_path=dpr[cur + j],
+ act_cfg=act_cfg,
+ norm_cfg=norm_cfg) for j in range(depths[i])
+ ])
+ norm = nn.LayerNorm(embed_dims[i])
+ cur += depths[i]
+
+ setattr(self, f'patch_embed{i + 1}', patch_embed)
+ setattr(self, f'block{i + 1}', block)
+ setattr(self, f'norm{i + 1}', norm)
+
+ def init_weights(self):
+ """Initialize modules of MSCAN."""
+
+ print('init cfg', self.init_cfg)
+ if self.init_cfg is None:
+ for m in self.modules():
+ if isinstance(m, nn.Linear):
+ trunc_normal_init(m, std=.02, bias=0.)
+ elif isinstance(m, nn.LayerNorm):
+ constant_init(m, val=1.0, bias=0.)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[
+ 1] * m.out_channels
+ fan_out //= m.groups
+ normal_init(
+ m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0)
+ else:
+ super().init_weights()
+
+ def forward(self, x):
+ """Forward function."""
+
+ B = x.shape[0]
+ outs = []
+
+ for i in range(self.num_stages):
+ patch_embed = getattr(self, f'patch_embed{i + 1}')
+ block = getattr(self, f'block{i + 1}')
+ norm = getattr(self, f'norm{i + 1}')
+ x, H, W = patch_embed(x)
+ for blk in block:
+ x = blk(x, H, W)
+ x = norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ return outs
diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py
index e6eeafc248..18235456bc 100644
--- a/mmseg/models/decode_heads/__init__.py
+++ b/mmseg/models/decode_heads/__init__.py
@@ -12,6 +12,7 @@
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
+from .ham_head import LightHamHead
from .isa_head import ISAHead
from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator
from .lraspp_head import LRASPPHead
@@ -40,5 +41,5 @@
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
- 'PIDHead'
+ 'LightHamHead', 'PIDHead'
]
diff --git a/mmseg/models/decode_heads/ham_head.py b/mmseg/models/decode_heads/ham_head.py
new file mode 100644
index 0000000000..d80025f77d
--- /dev/null
+++ b/mmseg/models/decode_heads/ham_head.py
@@ -0,0 +1,257 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# Originally from https://github.com/visual-attention-network/segnext
+# Licensed under the Apache License, Version 2.0 (the "License")
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+
+from mmseg.registry import MODELS
+from ..utils import resize
+from .decode_head import BaseDecodeHead
+
+
+class Matrix_Decomposition_2D_Base(nn.Module):
+ """Base class of 2D Matrix Decomposition.
+
+ Args:
+ MD_S (int): The number of spatial coefficient in
+ Matrix Decomposition, it may be used for calculation
+ of the number of latent dimension D in Matrix
+ Decomposition. Defaults: 1.
+ MD_R (int): The number of latent dimension R in
+ Matrix Decomposition. Defaults: 64.
+ train_steps (int): The number of iteration steps in
+ Multiplicative Update (MU) rule to solve Non-negative
+ Matrix Factorization (NMF) in training. Defaults: 6.
+ eval_steps (int): The number of iteration steps in
+ Multiplicative Update (MU) rule to solve Non-negative
+ Matrix Factorization (NMF) in evaluation. Defaults: 7.
+ inv_t (int): Inverted multiple number to make coefficient
+ smaller in softmax. Defaults: 100.
+ rand_init (bool): Whether to initialize randomly.
+ Defaults: True.
+ """
+
+ def __init__(self,
+ MD_S=1,
+ MD_R=64,
+ train_steps=6,
+ eval_steps=7,
+ inv_t=100,
+ rand_init=True):
+ super().__init__()
+
+ self.S = MD_S
+ self.R = MD_R
+
+ self.train_steps = train_steps
+ self.eval_steps = eval_steps
+
+ self.inv_t = inv_t
+
+ self.rand_init = rand_init
+
+ def _build_bases(self, B, S, D, R, cuda=False):
+ raise NotImplementedError
+
+ def local_step(self, x, bases, coef):
+ raise NotImplementedError
+
+ def local_inference(self, x, bases):
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ coef = torch.bmm(x.transpose(1, 2), bases)
+ coef = F.softmax(self.inv_t * coef, dim=-1)
+
+ steps = self.train_steps if self.training else self.eval_steps
+ for _ in range(steps):
+ bases, coef = self.local_step(x, bases, coef)
+
+ return bases, coef
+
+ def compute_coef(self, x, bases, coef):
+ raise NotImplementedError
+
+ def forward(self, x, return_bases=False):
+ """Forward Function."""
+ B, C, H, W = x.shape
+
+ # (B, C, H, W) -> (B * S, D, N)
+ D = C // self.S
+ N = H * W
+ x = x.view(B * self.S, D, N)
+ cuda = 'cuda' in str(x.device)
+ if not self.rand_init and not hasattr(self, 'bases'):
+ bases = self._build_bases(1, self.S, D, self.R, cuda=cuda)
+ self.register_buffer('bases', bases)
+
+ # (S, D, R) -> (B * S, D, R)
+ if self.rand_init:
+ bases = self._build_bases(B, self.S, D, self.R, cuda=cuda)
+ else:
+ bases = self.bases.repeat(B, 1, 1)
+
+ bases, coef = self.local_inference(x, bases)
+
+ # (B * S, N, R)
+ coef = self.compute_coef(x, bases, coef)
+
+ # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
+ x = torch.bmm(bases, coef.transpose(1, 2))
+
+ # (B * S, D, N) -> (B, C, H, W)
+ x = x.view(B, C, H, W)
+
+ return x
+
+
+class NMF2D(Matrix_Decomposition_2D_Base):
+ """Non-negative Matrix Factorization (NMF) module.
+
+ It is inherited from ``Matrix_Decomposition_2D_Base`` module.
+ """
+
+ def __init__(self, args=dict()):
+ super().__init__(**args)
+
+ self.inv_t = 1
+
+ def _build_bases(self, B, S, D, R, cuda=False):
+ """Build bases in initialization."""
+ if cuda:
+ bases = torch.rand((B * S, D, R)).cuda()
+ else:
+ bases = torch.rand((B * S, D, R))
+
+ bases = F.normalize(bases, dim=1)
+
+ return bases
+
+ def local_step(self, x, bases, coef):
+ """Local step in iteration to renew bases and coefficient."""
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ numerator = torch.bmm(x.transpose(1, 2), bases)
+ # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
+ # Multiplicative Update
+ coef = coef * numerator / (denominator + 1e-6)
+
+ # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
+ numerator = torch.bmm(x, coef)
+ # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
+ denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
+ # Multiplicative Update
+ bases = bases * numerator / (denominator + 1e-6)
+
+ return bases, coef
+
+ def compute_coef(self, x, bases, coef):
+ """Compute coefficient."""
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ numerator = torch.bmm(x.transpose(1, 2), bases)
+ # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
+ # multiplication update
+ coef = coef * numerator / (denominator + 1e-6)
+
+ return coef
+
+
+class Hamburger(nn.Module):
+ """Hamburger Module. It consists of one slice of "ham" (matrix
+ decomposition) and two slices of "bread" (linear transformation).
+
+ Args:
+ ham_channels (int): Input and output channels of feature.
+ ham_kwargs (dict): Config of matrix decomposition module.
+ norm_cfg (dict | None): Config of norm layers.
+ """
+
+ def __init__(self,
+ ham_channels=512,
+ ham_kwargs=dict(),
+ norm_cfg=None,
+ **kwargs):
+ super().__init__()
+
+ self.ham_in = ConvModule(
+ ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None)
+
+ self.ham = NMF2D(ham_kwargs)
+
+ self.ham_out = ConvModule(
+ ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
+
+ def forward(self, x):
+ enjoy = self.ham_in(x)
+ enjoy = F.relu(enjoy, inplace=True)
+ enjoy = self.ham(enjoy)
+ enjoy = self.ham_out(enjoy)
+ ham = F.relu(x + enjoy, inplace=True)
+
+ return ham
+
+
+@MODELS.register_module()
+class LightHamHead(BaseDecodeHead):
+ """SegNeXt decode head.
+
+ This decode head is the implementation of `SegNeXt: Rethinking
+ Convolutional Attention Design for Semantic
+ Segmentation `_.
+ Inspiration from https://github.com/visual-attention-network/segnext.
+
+ Specifically, LightHamHead is inspired by HamNet from
+ `Is Attention Better Than Matrix Decomposition?
+ `.
+
+ Args:
+ ham_channels (int): input channels for Hamburger.
+ Defaults: 512.
+ ham_kwargs (int): kwagrs for Ham. Defaults: dict().
+ """
+
+ def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs):
+ super().__init__(input_transform='multiple_select', **kwargs)
+ self.ham_channels = ham_channels
+
+ self.squeeze = ConvModule(
+ sum(self.in_channels),
+ self.ham_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs)
+
+ self.align = ConvModule(
+ self.ham_channels,
+ self.channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+
+ def forward(self, inputs):
+ """Forward function."""
+ inputs = self._transform_inputs(inputs)
+
+ inputs = [
+ resize(
+ level,
+ size=inputs[0].shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners) for level in inputs
+ ]
+
+ inputs = torch.cat(inputs, dim=1)
+ # apply a conv block to squeeze feature map
+ x = self.squeeze(inputs)
+ # apply hamburger module
+ x = self.hamburger(x)
+
+ # apply a conv block to align feature map
+ output = self.align(x)
+ output = self.cls_seg(output)
+ return output
diff --git a/model-index.yml b/model-index.yml
index be7210e120..130031a303 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -39,6 +39,7 @@ Import:
- configs/resnest/resnest.yml
- configs/segformer/segformer.yml
- configs/segmenter/segmenter.yml
+- configs/segnext/segnext.yml
- configs/sem_fpn/sem_fpn.yml
- configs/setr/setr.yml
- configs/stdc/stdc.yml
diff --git a/tests/test_models/test_backbones/test_mscan.py b/tests/test_models/test_backbones/test_mscan.py
new file mode 100644
index 0000000000..84dfb8e450
--- /dev/null
+++ b/tests/test_models/test_backbones/test_mscan.py
@@ -0,0 +1,69 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmseg.models.backbones import MSCAN
+from mmseg.models.backbones.mscan import (MSCAAttention, MSCASpatialAttention,
+ OverlapPatchEmbed, StemConv)
+
+
+def test_mscan_backbone():
+ # Test MSCAN Standard Forward
+ model = MSCAN(
+ embed_dims=[8, 16, 32, 64],
+ norm_cfg=dict(type='BN', requires_grad=True))
+ model.init_weights()
+ model.train()
+ batch_size = 2
+ imgs = torch.randn(batch_size, 3, 64, 128)
+ feat = model(imgs)
+
+ assert len(feat) == 4
+ # output for segment Head
+ assert feat[0].shape == torch.Size([batch_size, 8, 16, 32])
+ assert feat[1].shape == torch.Size([batch_size, 16, 8, 16])
+ assert feat[2].shape == torch.Size([batch_size, 32, 4, 8])
+ assert feat[3].shape == torch.Size([batch_size, 64, 2, 4])
+
+ # Test input with rare shape
+ batch_size = 2
+ imgs = torch.randn(batch_size, 3, 95, 27)
+ feat = model(imgs)
+ assert len(feat) == 4
+
+
+def test_mscan_overlap_patch_embed_module():
+ x_overlap_patch_embed = OverlapPatchEmbed(
+ norm_cfg=dict(type='BN', requires_grad=True))
+ assert x_overlap_patch_embed.proj.in_channels == 3
+ assert x_overlap_patch_embed.norm.weight.shape == torch.Size([768])
+ x = torch.randn(2, 3, 16, 32)
+ x_out, H, W = x_overlap_patch_embed(x)
+ assert x_out.shape == torch.Size([2, 32, 768])
+
+
+def test_mscan_spatial_attention_module():
+ x_spatial_attention = MSCASpatialAttention(8)
+ assert x_spatial_attention.proj_1.kernel_size == (1, 1)
+ assert x_spatial_attention.proj_2.stride == (1, 1)
+ x = torch.randn(2, 8, 16, 32)
+ x_out = x_spatial_attention(x)
+ assert x_out.shape == torch.Size([2, 8, 16, 32])
+
+
+def test_mscan_attention_module():
+ x_attention = MSCAAttention(8)
+ assert x_attention.conv0.weight.shape[0] == 8
+ assert x_attention.conv3.kernel_size == (1, 1)
+ x = torch.randn(2, 8, 16, 32)
+ x_out = x_attention(x)
+ assert x_out.shape == torch.Size([2, 8, 16, 32])
+
+
+def test_mscan_stem_module():
+ x_stem = StemConv(8, 8, norm_cfg=dict(type='BN', requires_grad=True))
+ assert x_stem.proj[0].weight.shape[0] == 4
+ assert x_stem.proj[-1].weight.shape[0] == 8
+ x = torch.randn(2, 8, 16, 32)
+ x_out, H, W = x_stem(x)
+ assert x_out.shape == torch.Size([2, 32, 8])
+ assert (H, W) == (4, 8)
diff --git a/tests/test_models/test_heads/test_ham_head.py b/tests/test_models/test_heads/test_ham_head.py
new file mode 100644
index 0000000000..f802d2d8db
--- /dev/null
+++ b/tests/test_models/test_heads/test_ham_head.py
@@ -0,0 +1,44 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+from mmseg.models.decode_heads import LightHamHead
+from .utils import _conv_has_norm, to_cuda
+
+ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True)
+
+
+def test_ham_head():
+
+ # test without sync_bn
+ head = LightHamHead(
+ in_channels=[16, 32, 64],
+ in_index=[1, 2, 3],
+ channels=64,
+ ham_channels=64,
+ dropout_ratio=0.1,
+ num_classes=19,
+ norm_cfg=ham_norm_cfg,
+ align_corners=False,
+ loss_decode=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ ham_kwargs=dict(
+ MD_S=1,
+ MD_R=64,
+ train_steps=6,
+ eval_steps=7,
+ inv_t=100,
+ rand_init=True))
+ assert not _conv_has_norm(head, sync_bn=False)
+
+ inputs = [
+ torch.randn(1, 8, 32, 32),
+ torch.randn(1, 16, 16, 16),
+ torch.randn(1, 32, 8, 8),
+ torch.randn(1, 64, 4, 4)
+ ]
+ if torch.cuda.is_available():
+ head, inputs = to_cuda(head, inputs)
+ assert head.in_channels == [16, 32, 64]
+ assert head.hamburger.ham_in.in_channels == 64
+ outputs = head(inputs)
+ assert outputs.shape == (1, head.num_classes, 16, 16)