diff --git a/README.md b/README.md index 69f751b9e5..a17557cb81 100644 --- a/README.md +++ b/README.md @@ -117,6 +117,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md). - [x] [ConvNeXt (CVPR'2022)](configs/convnext) - [x] [MAE (CVPR'2022)](configs/mae) - [x] [PoolFormer (CVPR'2022)](configs/poolformer) +- [x] [SegNeXt (NeurIPS'2022)](configs/segnext) diff --git a/README_zh-CN.md b/README_zh-CN.md index 709e6ef195..e09b515ab5 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -98,6 +98,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [ConvNeXt (CVPR'2022)](configs/convnext) - [x] [MAE (CVPR'2022)](configs/mae) - [x] [PoolFormer (CVPR'2022)](configs/poolformer) +- [x] [SegNeXt (NeurIPS'2022)](configs/segnext) diff --git a/configs/segnext/README.md b/configs/segnext/README.md new file mode 100644 index 0000000000..315f4e23e8 --- /dev/null +++ b/configs/segnext/README.md @@ -0,0 +1,63 @@ +# SegNeXt + +> [SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation](https://arxiv.org/abs/2209.08575) + +## Introduction + + + +Official Repo + +Code Snippet + +## Abstract + + + +We present SegNeXt, a simple convolutional network architecture for semantic segmentation. Recent transformer-based models have dominated the field of semantic segmentation due to the efficiency of self-attention in encoding spatial information. In this paper, we show that convolutional attention is a more efficient and effective way to encode contextual information than the self-attention mechanism in transformers. By re-examining the characteristics owned by successful segmentation models, we discover several key components leading to the performance improvement of segmentation models. This motivates us to design a novel convolutional attention network that uses cheap convolutional operations. Without bells and whistles, our SegNeXt significantly improves the performance of previous state-of-the-art methods on popular benchmarks, including ADE20K, Cityscapes, COCO-Stuff, Pascal VOC, Pascal Context, and iSAID. Notably, SegNeXt outperforms EfficientNet-L2 w/ NAS-FPN and achieves 90.6% mIoU on the Pascal VOC 2012 test leaderboard using only 1/10 parameters of it. On average, SegNeXt achieves about 2.0% mIoU improvements compared to the state-of-the-art methods on the ADE20K datasets with the same or fewer computations. Code is available at [this https URL](https://github.com/uyzhang/JSeg) (Jittor) and [this https URL](https://github.com/Visual-Attention-Network/SegNeXt) (Pytorch). + + + +
+ +
+ +## Results and models + +### ADE20K + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------- | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| SegNeXt | MSCAN-T | 512x512 | 160000 | 17.88 | 52.38 | 41.50 | 42.59 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244.log.json) | +| SegNeXt | MSCAN-S | 512x512 | 160000 | 21.47 | 42.27 | 44.16 | 45.81 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014.log.json) | +| SegNeXt | MSCAN-B | 512x512 | 160000 | 31.03 | 35.15 | 48.03 | 49.68 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053.log.json) | +| SegNeXt | MSCAN-L | 512x512 | 160000 | 43.32 | 22.91 | 50.99 | 52.10 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055.log.json) | + +Note: + +- When we integrated SegNeXt into MMSegmentation, we modified some layers' names to make them more precise and concise without changing the model architecture. Therefore, the keys of pre-trained weights are different from the [original weights](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/), but don't worry about these changes. we have converted them and uploaded the checkpoints, you might find URL of pre-trained checkpoints in config files and can use them directly for training. + +- The total batch size is 16. We trained for SegNeXt with a single GPU as the performance degrades significantly when using`SyncBN` (mainly in `OverlapPatchEmbed` modules of `MSCAN`) of PyTorch 1.9. + +- There will be subtle differences when model testing as Non-negative Matrix Factorization (NMF) in `LightHamHead` will be initialized randomly. To control this randomness, please set the random seed when model testing. You can modify [`./tools/test.py`](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/tools/test.py) like: + +```python +def main(): + from mmengine.runner import seg_random_seed + random_seed = xxx # set random seed recorded in training log + set_random_seed(random_seed, deterministic=False) + ... +``` + +- This model performance is sensitive to the seed values used, please refer to the log file for the specific settings of the seed. If you choose a different seed, the results might differ from the table results. Take SegNeXt Large for example, its results range from 49.60 to 51.0. + +## Citation + +```bibtex +@article{guo2022segnext, + title={SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation}, + author={Guo, Meng-Hao and Lu, Cheng-Ze and Hou, Qibin and Liu, Zhengning and Cheng, Ming-Ming and Hu, Shi-Min}, + journal={arXiv preprint arXiv:2209.08575}, + year={2022} +} +``` diff --git a/configs/segnext/segnext.yml b/configs/segnext/segnext.yml new file mode 100644 index 0000000000..3bcdea8403 --- /dev/null +++ b/configs/segnext/segnext.yml @@ -0,0 +1,103 @@ +Collections: +- Name: SegNeXt + Metadata: + Training Data: + - ADE20K + Paper: + URL: https://arxiv.org/abs/2209.08575 + Title: 'SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation' + README: configs/segnext/README.md + Code: + URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/mscan.py#L328 + Version: dev-1.x + Converted From: + Code: https://github.com/visual-attention-network/segnext +Models: +- Name: segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512 + In Collection: SegNeXt + Metadata: + backbone: MSCAN-T + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 19.09 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 17.88 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 41.5 + mIoU(ms+flip): 42.59 + Config: configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k/segnext_mscan-t_1x16_512x512_adamw_160k_ade20k_20230210_140244-05bd8466.pth +- Name: segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512 + In Collection: SegNeXt + Metadata: + backbone: MSCAN-S + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 23.66 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 21.47 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 44.16 + mIoU(ms+flip): 45.81 + Config: configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k/segnext_mscan-s_1x16_512x512_adamw_160k_ade20k_20230214_113014-43013668.pth +- Name: segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512 + In Collection: SegNeXt + Metadata: + backbone: MSCAN-B + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 28.45 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 31.03 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 48.03 + mIoU(ms+flip): 49.68 + Config: configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k/segnext_mscan-b_1x16_512x512_adamw_160k_ade20k_20230209_172053-b6f6c70c.pth +- Name: segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512 + In Collection: SegNeXt + Metadata: + backbone: MSCAN-L + crop size: (512,512) + lr schd: 160000 + inference time (ms/im): + - value: 43.65 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,512) + Training Memory (GB): 43.32 + Results: + - Task: Semantic Segmentation + Dataset: ADE20K + Metrics: + mIoU: 50.99 + mIoU(ms+flip): 52.1 + Config: configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/segnext/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k/segnext_mscan-l_1x16_512x512_adamw_160k_ade20k_20230209_172055-19b14b63.pth diff --git a/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py new file mode 100644 index 0000000000..000f448483 --- /dev/null +++ b/configs/segnext/segnext_mscan-b_1xb16-adamw-160k_ade20k-512x512.py @@ -0,0 +1,28 @@ +_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py' + +# model settings +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth' # noqa +ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + embed_dims=[64, 128, 320, 512], + depths=[3, 3, 12, 3], + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + drop_path_rate=0.1, + norm_cfg=dict(type='BN', requires_grad=True)), + decode_head=dict( + type='LightHamHead', + in_channels=[128, 320, 512], + in_index=[1, 2, 3], + channels=512, + ham_channels=512, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=ham_norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py new file mode 100644 index 0000000000..212d0a8557 --- /dev/null +++ b/configs/segnext/segnext_mscan-l_1xb16-adamw-160k_ade20k-512x512.py @@ -0,0 +1,27 @@ +_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py' +# model settings +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_l_20230227-cef260d4.pth' # noqa +ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + embed_dims=[64, 128, 320, 512], + depths=[3, 5, 27, 3], + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + drop_path_rate=0.3, + norm_cfg=dict(type='BN', requires_grad=True)), + decode_head=dict( + type='LightHamHead', + in_channels=[128, 320, 512], + in_index=[1, 2, 3], + channels=1024, + ham_channels=1024, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=ham_norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py new file mode 100644 index 0000000000..9a90779a60 --- /dev/null +++ b/configs/segnext/segnext_mscan-s_1xb16-adamw-160k_ade20k-512x512.py @@ -0,0 +1,27 @@ +_base_ = './segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py' +# model settings +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth' # noqa +ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + embed_dims=[64, 128, 320, 512], + depths=[2, 2, 4, 2], + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + norm_cfg=dict(type='BN', requires_grad=True)), + decode_head=dict( + type='LightHamHead', + in_channels=[128, 320, 512], + in_index=[1, 2, 3], + channels=256, + ham_channels=256, + ham_kwargs=dict(MD_R=16), + dropout_ratio=0.1, + num_classes=150, + norm_cfg=ham_norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py b/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py new file mode 100644 index 0000000000..c8d6da85ff --- /dev/null +++ b/configs/segnext/segnext_mscan-t_1xb16-adamw-160k_ade20k-512x512.py @@ -0,0 +1,84 @@ +_base_ = [ + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py', + '../_base_/datasets/ade20k.py' +] +# model settings +checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth' # noqa +ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) +crop_size = (512, 512) +data_preprocessor = dict( + type='SegDataPreProcessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_val=0, + seg_pad_val=255, + size=(512, 512), + test_cfg=dict(size_divisor=32)) +model = dict( + type='EncoderDecoder', + data_preprocessor=data_preprocessor, + pretrained=None, + backbone=dict( + type='MSCAN', + init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file), + embed_dims=[32, 64, 160, 256], + mlp_ratios=[8, 8, 4, 4], + drop_rate=0.0, + drop_path_rate=0.1, + depths=[3, 3, 5, 2], + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='BN', requires_grad=True)), + decode_head=dict( + type='LightHamHead', + in_channels=[64, 160, 256], + in_index=[1, 2, 3], + channels=256, + ham_channels=256, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=ham_norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + ham_kwargs=dict( + MD_S=1, + MD_R=16, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) + +# dataset settings +train_dataloader = dict(batch_size=16) + +# optimizer +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict( + type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01), + paramwise_cfg=dict( + custom_keys={ + 'pos_block': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.), + 'head': dict(lr_mult=10.) + })) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), + dict( + type='PolyLR', + power=1.0, + begin=1500, + end=160000, + eta_min=0.0, + by_epoch=False, + ) +] diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 909b54f3ec..e3107306ea 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -11,6 +11,7 @@ from .mit import MixVisionTransformer from .mobilenet_v2 import MobileNetV2 from .mobilenet_v3 import MobileNetV3 +from .mscan import MSCAN from .pidnet import PIDNet from .resnest import ResNeSt from .resnet import ResNet, ResNetV1c, ResNetV1d @@ -27,5 +28,5 @@ 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT', - 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet' + 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN' ] diff --git a/mmseg/models/backbones/mscan.py b/mmseg/models/backbones/mscan.py new file mode 100644 index 0000000000..7150cb7a1c --- /dev/null +++ b/mmseg/models/backbones/mscan.py @@ -0,0 +1,467 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_norm_layer +from mmcv.cnn.bricks import DropPath +from mmengine.model import BaseModule +from mmengine.model.weight_init import (constant_init, normal_init, + trunc_normal_init) + +from mmseg.registry import MODELS + + +class Mlp(BaseModule): + """Multi Layer Perceptron (MLP) Module. + + Args: + in_features (int): The dimension of input features. + hidden_features (int): The dimension of hidden features. + Defaults: None. + out_features (int): The dimension of output features. + Defaults: None. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + """ + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_cfg=dict(type='GELU'), + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = nn.Conv2d( + hidden_features, + hidden_features, + 3, + 1, + 1, + bias=True, + groups=hidden_features) + self.act = build_activation_layer(act_cfg) + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward function.""" + + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): The dimension of input channels. + out_channels (int): The dimension of output channels. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + in_channels, + out_channels, + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels // 2, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels // 2)[1], + build_activation_layer(act_cfg), + nn.Conv2d( + out_channels // 2, + out_channels, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1)), + build_norm_layer(norm_cfg, out_channels)[1], + ) + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class MSCAAttention(BaseModule): + """Attention Module in Multi-Scale Convolutional Attention Module (MSCA). + + Args: + channels (int): The dimension of channels. + kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + """ + + def __init__(self, + channels, + kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + paddings=[2, [0, 3], [0, 5], [0, 10]]): + super().__init__() + self.conv0 = nn.Conv2d( + channels, + channels, + kernel_size=kernel_sizes[0], + padding=paddings[0], + groups=channels) + for i, (kernel_size, + padding) in enumerate(zip(kernel_sizes[1:], paddings[1:])): + kernel_size_ = [kernel_size, kernel_size[::-1]] + padding_ = [padding, padding[::-1]] + conv_name = [f'conv{i}_1', f'conv{i}_2'] + for i_kernel, i_pad, i_conv in zip(kernel_size_, padding_, + conv_name): + self.add_module( + i_conv, + nn.Conv2d( + channels, + channels, + tuple(i_kernel), + padding=i_pad, + groups=channels)) + self.conv3 = nn.Conv2d(channels, channels, 1) + + def forward(self, x): + """Forward function.""" + + u = x.clone() + + attn = self.conv0(x) + + # Multi-Scale Feature extraction + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + + attn = attn + attn_0 + attn_1 + attn_2 + # Channel Mixing + attn = self.conv3(attn) + + # Convolutional Attention + x = attn * u + + return x + + +class MSCASpatialAttention(BaseModule): + """Spatial Attention Module in Multi-Scale Convolutional Attention Module + (MSCA). + + Args: + in_channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + """ + + def __init__(self, + in_channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU')): + super().__init__() + self.proj_1 = nn.Conv2d(in_channels, in_channels, 1) + self.activation = build_activation_layer(act_cfg) + self.spatial_gating_unit = MSCAAttention(in_channels, + attention_kernel_sizes, + attention_kernel_paddings) + self.proj_2 = nn.Conv2d(in_channels, in_channels, 1) + + def forward(self, x): + """Forward function.""" + + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class MSCABlock(BaseModule): + """Basic Multi-Scale Convolutional Attention Block. It leverage the large- + kernel attention (LKA) mechanism to build both channel and spatial + attention. In each branch, it uses two depth-wise strip convolutions to + approximate standard depth-wise convolutions with large kernels. The kernel + size for each branch is set to 7, 11, and 21, respectively. + + Args: + channels (int): The dimension of channels. + attention_kernel_sizes (list): The size of attention + kernel. Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): The number of + corresponding padding value in attention module. + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + mlp_ratio (float): The ratio of multiple input dimension to + calculate hidden feature in MLP layer. Defaults: 4.0. + drop (float): The number of dropout rate in MLP block. + Defaults: 0.0. + drop_path (float): The ratio of drop paths. + Defaults: 0.0. + act_cfg (dict): Config dict for activation layer in block. + Default: dict(type='GELU'). + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + channels, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + mlp_ratio=4., + drop=0., + drop_path=0., + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + self.norm1 = build_norm_layer(norm_cfg, channels)[1] + self.attn = MSCASpatialAttention(channels, attention_kernel_sizes, + attention_kernel_paddings, act_cfg) + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = build_norm_layer(norm_cfg, channels)[1] + mlp_hidden_channels = int(channels * mlp_ratio) + self.mlp = Mlp( + in_features=channels, + hidden_features=mlp_hidden_channels, + act_cfg=act_cfg, + drop=drop) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones(channels), requires_grad=True) + + def forward(self, x, H, W): + """Forward function.""" + + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * + self.attn(self.norm1(x))) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * + self.mlp(self.norm2(x))) + x = x.view(B, C, N).permute(0, 2, 1) + return x + + +class OverlapPatchEmbed(BaseModule): + """Image to Patch Embedding. + + Args: + patch_size (int): The patch size. + Defaults: 7. + stride (int): Stride of the convolutional layer. + Default: 4. + in_channels (int): The number of input channels. + Defaults: 3. + embed_dims (int): The dimensions of embedding. + Defaults: 768. + norm_cfg (dict): Config dict for normalization layer. + Defaults: dict(type='SyncBN', requires_grad=True). + """ + + def __init__(self, + patch_size=7, + stride=4, + in_channels=3, + embed_dim=768, + norm_cfg=dict(type='SyncBN', requires_grad=True)): + super().__init__() + + self.proj = nn.Conv2d( + in_channels, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=patch_size // 2) + self.norm = build_norm_layer(norm_cfg, embed_dim)[1] + + def forward(self, x): + """Forward function.""" + + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + + x = x.flatten(2).transpose(1, 2) + + return x, H, W + + +@MODELS.register_module() +class MSCAN(BaseModule): + """SegNeXt Multi-Scale Convolutional Attention Network (MCSAN) backbone. + + This backbone is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Args: + in_channels (int): The number of input channels. Defaults: 3. + embed_dims (list[int]): Embedding dimension. + Defaults: [64, 128, 256, 512]. + mlp_ratios (list[int]): Ratio of mlp hidden dim to embedding dim. + Defaults: [4, 4, 4, 4]. + drop_rate (float): Dropout rate. Defaults: 0. + drop_path_rate (float): Stochastic depth rate. Defaults: 0. + depths (list[int]): Depths of each Swin Transformer stage. + Default: [3, 4, 6, 3]. + num_stages (int): MSCAN stages. Default: 4. + attention_kernel_sizes (list): Size of attention kernel in + Attention Module (Figure 2(b) of original paper). + Defaults: [5, [1, 7], [1, 11], [1, 21]]. + attention_kernel_paddings (list): Size of attention paddings + in Attention Module (Figure 2(b) of original paper). + Defaults: [2, [0, 3], [0, 5], [0, 10]]. + norm_cfg (dict): Config of norm layers. + Defaults: dict(type='SyncBN', requires_grad=True). + pretrained (str, optional): model pretrained path. + Default: None. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + embed_dims=[64, 128, 256, 512], + mlp_ratios=[4, 4, 4, 4], + drop_rate=0., + drop_path_rate=0., + depths=[3, 4, 6, 3], + num_stages=4, + attention_kernel_sizes=[5, [1, 7], [1, 11], [1, 21]], + attention_kernel_paddings=[2, [0, 3], [0, 5], [0, 10]], + act_cfg=dict(type='GELU'), + norm_cfg=dict(type='SyncBN', requires_grad=True), + pretrained=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be set at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is not None: + raise TypeError('pretrained must be a str or None') + + self.depths = depths + self.num_stages = num_stages + + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + cur = 0 + + for i in range(num_stages): + if i == 0: + patch_embed = StemConv(3, embed_dims[0], norm_cfg=norm_cfg) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_channels=in_channels if i == 0 else embed_dims[i - 1], + embed_dim=embed_dims[i], + norm_cfg=norm_cfg) + + block = nn.ModuleList([ + MSCABlock( + channels=embed_dims[i], + attention_kernel_sizes=attention_kernel_sizes, + attention_kernel_paddings=attention_kernel_paddings, + mlp_ratio=mlp_ratios[i], + drop=drop_rate, + drop_path=dpr[cur + j], + act_cfg=act_cfg, + norm_cfg=norm_cfg) for j in range(depths[i]) + ]) + norm = nn.LayerNorm(embed_dims[i]) + cur += depths[i] + + setattr(self, f'patch_embed{i + 1}', patch_embed) + setattr(self, f'block{i + 1}', block) + setattr(self, f'norm{i + 1}', norm) + + def init_weights(self): + """Initialize modules of MSCAN.""" + + print('init cfg', self.init_cfg) + if self.init_cfg is None: + for m in self.modules(): + if isinstance(m, nn.Linear): + trunc_normal_init(m, std=.02, bias=0.) + elif isinstance(m, nn.LayerNorm): + constant_init(m, val=1.0, bias=0.) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[ + 1] * m.out_channels + fan_out //= m.groups + normal_init( + m, mean=0, std=math.sqrt(2.0 / fan_out), bias=0) + else: + super().init_weights() + + def forward(self, x): + """Forward function.""" + + B = x.shape[0] + outs = [] + + for i in range(self.num_stages): + patch_embed = getattr(self, f'patch_embed{i + 1}') + block = getattr(self, f'block{i + 1}') + norm = getattr(self, f'norm{i + 1}') + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return outs diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index e6eeafc248..18235456bc 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -12,6 +12,7 @@ from .fcn_head import FCNHead from .fpn_head import FPNHead from .gc_head import GCHead +from .ham_head import LightHamHead from .isa_head import ISAHead from .knet_head import IterativeDecodeHead, KernelUpdateHead, KernelUpdator from .lraspp_head import LRASPPHead @@ -40,5 +41,5 @@ 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead', 'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead', 'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead', - 'PIDHead' + 'LightHamHead', 'PIDHead' ] diff --git a/mmseg/models/decode_heads/ham_head.py b/mmseg/models/decode_heads/ham_head.py new file mode 100644 index 0000000000..d80025f77d --- /dev/null +++ b/mmseg/models/decode_heads/ham_head.py @@ -0,0 +1,257 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Originally from https://github.com/visual-attention-network/segnext +# Licensed under the Apache License, Version 2.0 (the "License") +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from mmseg.registry import MODELS +from ..utils import resize +from .decode_head import BaseDecodeHead + + +class Matrix_Decomposition_2D_Base(nn.Module): + """Base class of 2D Matrix Decomposition. + + Args: + MD_S (int): The number of spatial coefficient in + Matrix Decomposition, it may be used for calculation + of the number of latent dimension D in Matrix + Decomposition. Defaults: 1. + MD_R (int): The number of latent dimension R in + Matrix Decomposition. Defaults: 64. + train_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in training. Defaults: 6. + eval_steps (int): The number of iteration steps in + Multiplicative Update (MU) rule to solve Non-negative + Matrix Factorization (NMF) in evaluation. Defaults: 7. + inv_t (int): Inverted multiple number to make coefficient + smaller in softmax. Defaults: 100. + rand_init (bool): Whether to initialize randomly. + Defaults: True. + """ + + def __init__(self, + MD_S=1, + MD_R=64, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True): + super().__init__() + + self.S = MD_S + self.R = MD_R + + self.train_steps = train_steps + self.eval_steps = eval_steps + + self.inv_t = inv_t + + self.rand_init = rand_init + + def _build_bases(self, B, S, D, R, cuda=False): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + """Forward Function.""" + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + cuda = 'cuda' in str(x.device) + if not self.rand_init and not hasattr(self, 'bases'): + bases = self._build_bases(1, self.S, D, self.R, cuda=cuda) + self.register_buffer('bases', bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R, cuda=cuda) + else: + bases = self.bases.repeat(B, 1, 1) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + + return x + + +class NMF2D(Matrix_Decomposition_2D_Base): + """Non-negative Matrix Factorization (NMF) module. + + It is inherited from ``Matrix_Decomposition_2D_Base`` module. + """ + + def __init__(self, args=dict()): + super().__init__(**args) + + self.inv_t = 1 + + def _build_bases(self, B, S, D, R, cuda=False): + """Build bases in initialization.""" + if cuda: + bases = torch.rand((B * S, D, R)).cuda() + else: + bases = torch.rand((B * S, D, R)) + + bases = F.normalize(bases, dim=1) + + return bases + + def local_step(self, x, bases, coef): + """Local step in iteration to renew bases and coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + """Compute coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class Hamburger(nn.Module): + """Hamburger Module. It consists of one slice of "ham" (matrix + decomposition) and two slices of "bread" (linear transformation). + + Args: + ham_channels (int): Input and output channels of feature. + ham_kwargs (dict): Config of matrix decomposition module. + norm_cfg (dict | None): Config of norm layers. + """ + + def __init__(self, + ham_channels=512, + ham_kwargs=dict(), + norm_cfg=None, + **kwargs): + super().__init__() + + self.ham_in = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=None, act_cfg=None) + + self.ham = NMF2D(ham_kwargs) + + self.ham_out = ConvModule( + ham_channels, ham_channels, 1, norm_cfg=norm_cfg, act_cfg=None) + + def forward(self, x): + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=True) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=True) + + return ham + + +@MODELS.register_module() +class LightHamHead(BaseDecodeHead): + """SegNeXt decode head. + + This decode head is the implementation of `SegNeXt: Rethinking + Convolutional Attention Design for Semantic + Segmentation `_. + Inspiration from https://github.com/visual-attention-network/segnext. + + Specifically, LightHamHead is inspired by HamNet from + `Is Attention Better Than Matrix Decomposition? + `. + + Args: + ham_channels (int): input channels for Hamburger. + Defaults: 512. + ham_kwargs (int): kwagrs for Ham. Defaults: dict(). + """ + + def __init__(self, ham_channels=512, ham_kwargs=dict(), **kwargs): + super().__init__(input_transform='multiple_select', **kwargs) + self.ham_channels = ham_channels + + self.squeeze = ConvModule( + sum(self.in_channels), + self.ham_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.hamburger = Hamburger(ham_channels, ham_kwargs, **kwargs) + + self.align = ConvModule( + self.ham_channels, + self.channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + inputs = self._transform_inputs(inputs) + + inputs = [ + resize( + level, + size=inputs[0].shape[2:], + mode='bilinear', + align_corners=self.align_corners) for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + # apply a conv block to squeeze feature map + x = self.squeeze(inputs) + # apply hamburger module + x = self.hamburger(x) + + # apply a conv block to align feature map + output = self.align(x) + output = self.cls_seg(output) + return output diff --git a/model-index.yml b/model-index.yml index be7210e120..130031a303 100644 --- a/model-index.yml +++ b/model-index.yml @@ -39,6 +39,7 @@ Import: - configs/resnest/resnest.yml - configs/segformer/segformer.yml - configs/segmenter/segmenter.yml +- configs/segnext/segnext.yml - configs/sem_fpn/sem_fpn.yml - configs/setr/setr.yml - configs/stdc/stdc.yml diff --git a/tests/test_models/test_backbones/test_mscan.py b/tests/test_models/test_backbones/test_mscan.py new file mode 100644 index 0000000000..84dfb8e450 --- /dev/null +++ b/tests/test_models/test_backbones/test_mscan.py @@ -0,0 +1,69 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.models.backbones import MSCAN +from mmseg.models.backbones.mscan import (MSCAAttention, MSCASpatialAttention, + OverlapPatchEmbed, StemConv) + + +def test_mscan_backbone(): + # Test MSCAN Standard Forward + model = MSCAN( + embed_dims=[8, 16, 32, 64], + norm_cfg=dict(type='BN', requires_grad=True)) + model.init_weights() + model.train() + batch_size = 2 + imgs = torch.randn(batch_size, 3, 64, 128) + feat = model(imgs) + + assert len(feat) == 4 + # output for segment Head + assert feat[0].shape == torch.Size([batch_size, 8, 16, 32]) + assert feat[1].shape == torch.Size([batch_size, 16, 8, 16]) + assert feat[2].shape == torch.Size([batch_size, 32, 4, 8]) + assert feat[3].shape == torch.Size([batch_size, 64, 2, 4]) + + # Test input with rare shape + batch_size = 2 + imgs = torch.randn(batch_size, 3, 95, 27) + feat = model(imgs) + assert len(feat) == 4 + + +def test_mscan_overlap_patch_embed_module(): + x_overlap_patch_embed = OverlapPatchEmbed( + norm_cfg=dict(type='BN', requires_grad=True)) + assert x_overlap_patch_embed.proj.in_channels == 3 + assert x_overlap_patch_embed.norm.weight.shape == torch.Size([768]) + x = torch.randn(2, 3, 16, 32) + x_out, H, W = x_overlap_patch_embed(x) + assert x_out.shape == torch.Size([2, 32, 768]) + + +def test_mscan_spatial_attention_module(): + x_spatial_attention = MSCASpatialAttention(8) + assert x_spatial_attention.proj_1.kernel_size == (1, 1) + assert x_spatial_attention.proj_2.stride == (1, 1) + x = torch.randn(2, 8, 16, 32) + x_out = x_spatial_attention(x) + assert x_out.shape == torch.Size([2, 8, 16, 32]) + + +def test_mscan_attention_module(): + x_attention = MSCAAttention(8) + assert x_attention.conv0.weight.shape[0] == 8 + assert x_attention.conv3.kernel_size == (1, 1) + x = torch.randn(2, 8, 16, 32) + x_out = x_attention(x) + assert x_out.shape == torch.Size([2, 8, 16, 32]) + + +def test_mscan_stem_module(): + x_stem = StemConv(8, 8, norm_cfg=dict(type='BN', requires_grad=True)) + assert x_stem.proj[0].weight.shape[0] == 4 + assert x_stem.proj[-1].weight.shape[0] == 8 + x = torch.randn(2, 8, 16, 32) + x_out, H, W = x_stem(x) + assert x_out.shape == torch.Size([2, 32, 8]) + assert (H, W) == (4, 8) diff --git a/tests/test_models/test_heads/test_ham_head.py b/tests/test_models/test_heads/test_ham_head.py new file mode 100644 index 0000000000..f802d2d8db --- /dev/null +++ b/tests/test_models/test_heads/test_ham_head.py @@ -0,0 +1,44 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.models.decode_heads import LightHamHead +from .utils import _conv_has_norm, to_cuda + +ham_norm_cfg = dict(type='GN', num_groups=32, requires_grad=True) + + +def test_ham_head(): + + # test without sync_bn + head = LightHamHead( + in_channels=[16, 32, 64], + in_index=[1, 2, 3], + channels=64, + ham_channels=64, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=ham_norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + ham_kwargs=dict( + MD_S=1, + MD_R=64, + train_steps=6, + eval_steps=7, + inv_t=100, + rand_init=True)) + assert not _conv_has_norm(head, sync_bn=False) + + inputs = [ + torch.randn(1, 8, 32, 32), + torch.randn(1, 16, 16, 16), + torch.randn(1, 32, 8, 8), + torch.randn(1, 64, 4, 4) + ] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + assert head.in_channels == [16, 32, 64] + assert head.hamburger.ham_in.in_channels == 64 + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 16, 16)