diff --git a/README.md b/README.md index 533e39d740..17cd7a8edf 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,7 @@ Supported methods: - [x] [PointRend (CVPR'2020)](configs/point_rend) - [x] [CGNet (TIP'2020)](configs/cgnet) - [x] [SETR (CVPR'2021)](configs/setr) +- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2) - [x] [SegFormer (ArXiv'2021)](configs/segformer) Supported datasets: diff --git a/README_zh-CN.md b/README_zh-CN.md index 17bde8a6c0..0889ce11a4 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -93,6 +93,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O - [x] [PointRend (CVPR'2020)](configs/point_rend) - [x] [CGNet (TIP'2020)](configs/cgnet) - [x] [SETR (CVPR'2021)](configs/setr) +- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2) - [x] [SegFormer (ArXiv'2021)](configs/segformer) 已支持的数据集: diff --git a/configs/_base_/datasets/cityscapes_1024x1024.py b/configs/_base_/datasets/cityscapes_1024x1024.py new file mode 100644 index 0000000000..f98d929723 --- /dev/null +++ b/configs/_base_/datasets/cityscapes_1024x1024.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (1024, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/configs/_base_/models/bisenetv2.py b/configs/_base_/models/bisenetv2.py new file mode 100644 index 0000000000..f8fffeecad --- /dev/null +++ b/configs/_base_/models/bisenetv2.py @@ -0,0 +1,80 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='BiSeNetV2', + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + init_cfg=None, + align_corners=False), + decode_head=dict( + type='FCNHead', + in_channels=128, + in_index=0, + channels=1024, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=16, + channels=16, + num_convs=2, + num_classes=19, + in_index=1, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=32, + channels=64, + num_convs=2, + num_classes=19, + in_index=2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=64, + channels=256, + num_convs=2, + num_classes=19, + in_index=3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='FCNHead', + in_channels=128, + channels=1024, + num_convs=2, + num_classes=19, + in_index=4, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/bisenetv2/README.md b/configs/bisenetv2/README.md new file mode 100644 index 0000000000..48ecf06557 --- /dev/null +++ b/configs/bisenetv2/README.md @@ -0,0 +1,33 @@ +# Bisenet v2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation + +## Introduction + + + +```latex +@article{yu2021bisenet, + title={Bisenet v2: Bilateral network with guided aggregation for real-time semantic segmentation}, + author={Yu, Changqian and Gao, Changxin and Wang, Jingbo and Yu, Gang and Shen, Chunhua and Sang, Nong}, + journal={International Journal of Computer Vision}, + pages={1--18}, + year={2021}, + publisher={Springer} +} +``` + +## Results and models + +### Cityscapes + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| BiSeNetV2 | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | 31.77 | 73.21 | 75.74 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551.log.json) | +| BiSeNetV2 (OHEM) | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | - | 73.57 | 75.80 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947.log.json) | +| BiSeNetV2 (4x8) | BiSeNetV2 | 1024x1024 | 160000 | 15.05 | - | 75.76 | 77.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032.log.json) | +| BiSeNetV2 (FP16) | BiSeNetV2 | 1024x1024 | 160000 | 5.77 | 36.65 | 73.07 | 75.13 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942.log.json) | + +Note: + +- `OHEM` means Online Hard Example Mining (OHEM) is adopted in training. +- `FP16` means Mixed Precision (FP16) is adopted in training. +- `4x8` means 4 GPUs with 8 samples per GPU in training. diff --git a/configs/bisenetv2/bisenetv2.yml b/configs/bisenetv2/bisenetv2.yml new file mode 100644 index 0000000000..7c98dd23d1 --- /dev/null +++ b/configs/bisenetv2/bisenetv2.yml @@ -0,0 +1,80 @@ +Collections: +- Metadata: + Training Data: + - Cityscapes + Name: bisenetv2 +Models: +- Config: configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py + In Collection: bisenetv2 + Metadata: + backbone: BiSeNetV2 + crop size: (1024,1024) + inference time (ms/im): + - backend: PyTorch + batch size: 1 + hardware: V100 + mode: FP32 + resolution: (1024,1024) + value: 31.48 + lr schd: 160000 + memory (GB): 7.64 + Name: bisenetv2_fcn_4x4_1024x1024_160k_cityscapes + Results: + Dataset: Cityscapes + Metrics: + mIoU: 73.21 + mIoU(ms+flip): 75.74 + Task: Semantic Segmentation + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth +- Config: configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py + In Collection: bisenetv2 + Metadata: + backbone: BiSeNetV2 + crop size: (1024,1024) + lr schd: 160000 + memory (GB): 7.64 + Name: bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes + Results: + Dataset: Cityscapes + Metrics: + mIoU: 73.57 + mIoU(ms+flip): 75.8 + Task: Semantic Segmentation + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth +- Config: configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py + In Collection: bisenetv2 + Metadata: + backbone: BiSeNetV2 + crop size: (1024,1024) + lr schd: 160000 + memory (GB): 15.05 + Name: bisenetv2_fcn_4x8_1024x1024_160k_cityscapes + Results: + Dataset: Cityscapes + Metrics: + mIoU: 75.76 + mIoU(ms+flip): 77.79 + Task: Semantic Segmentation + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth +- Config: configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py + In Collection: bisenetv2 + Metadata: + backbone: BiSeNetV2 + crop size: (1024,1024) + inference time (ms/im): + - backend: PyTorch + batch size: 1 + hardware: V100 + mode: FP32 + resolution: (1024,1024) + value: 27.29 + lr schd: 160000 + memory (GB): 5.77 + Name: bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes + Results: + Dataset: Cityscapes + Metrics: + mIoU: 73.07 + mIoU(ms+flip): 75.13 + Task: Semantic Segmentation + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth diff --git a/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py b/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py new file mode 100644 index 0000000000..1248bd87aa --- /dev/null +++ b/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py @@ -0,0 +1,11 @@ +_base_ = [ + '../_base_/models/bisenetv2.py', + '../_base_/datasets/cityscapes_1024x1024.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +lr_config = dict(warmup='linear', warmup_iters=1000) +optimizer = dict(lr=0.05) +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, +) diff --git a/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py b/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py new file mode 100644 index 0000000000..babc2cd74a --- /dev/null +++ b/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py @@ -0,0 +1,11 @@ +_base_ = [ + '../_base_/models/bisenetv2.py', + '../_base_/datasets/cityscapes_1024x1024.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +lr_config = dict(warmup='linear', warmup_iters=1000) +optimizer = dict(lr=0.05) +data = dict( + samples_per_gpu=8, + workers_per_gpu=8, +) diff --git a/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py b/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py new file mode 100644 index 0000000000..0196214b7d --- /dev/null +++ b/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py @@ -0,0 +1,5 @@ +_base_ = './bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py' +# fp16 settings +optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.) +# fp16 placeholder +fp16 = dict() diff --git a/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py b/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py new file mode 100644 index 0000000000..f14e528130 --- /dev/null +++ b/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py @@ -0,0 +1,12 @@ +_base_ = [ + '../_base_/models/bisenetv2.py', + '../_base_/datasets/cityscapes_1024x1024.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +sampler = dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000) +lr_config = dict(warmup='linear', warmup_iters=1000) +optimizer = dict(lr=0.05) +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, +) diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 75ef2c3a86..24e2397235 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .bisenetv2 import BiSeNetV2 from .cgnet import CGNet from .fast_scnn import FastSCNN from .hrnet import HRNet @@ -15,5 +16,5 @@ __all__ = [ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', - 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer' + 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV2' ] diff --git a/mmseg/models/backbones/bisenetv2.py b/mmseg/models/backbones/bisenetv2.py new file mode 100644 index 0000000000..eb05e10d52 --- /dev/null +++ b/mmseg/models/backbones/bisenetv2.py @@ -0,0 +1,622 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, + build_activation_layer, build_norm_layer) +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import BACKBONES + + +class DetailBranch(BaseModule): + """Detail Branch with wide channels and shallow layers to capture low-level + details and generate high-resolution feature representation. + + Args: + detail_channels (Tuple[int]): Size of channel numbers of each stage + in Detail Branch, in paper it has 3 stages. + Default: (64, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Feature map of Detail Branch. + """ + + def __init__(self, + detail_channels=(64, 64, 128), + in_channels=3, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(DetailBranch, self).__init__(init_cfg=init_cfg) + detail_branch = [] + for i in range(len(detail_channels)): + if i == 0: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + else: + detail_branch.append( + nn.Sequential( + ConvModule( + in_channels=detail_channels[i - 1], + out_channels=detail_channels[i], + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=detail_channels[i], + out_channels=detail_channels[i], + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg))) + self.detail_branch = nn.ModuleList(detail_branch) + + def forward(self, x): + for stage in self.detail_branch: + x = stage(x) + return x + + +class StemBlock(BaseModule): + """Stem Block at the beginning of Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): First feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(StemBlock, self).__init__(init_cfg=init_cfg) + + self.conv_first = ConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.convs = nn.Sequential( + ConvModule( + in_channels=out_channels, + out_channels=out_channels // 2, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=out_channels // 2, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.pool = nn.MaxPool2d( + kernel_size=3, stride=2, padding=1, ceil_mode=False) + self.fuse_last = ConvModule( + in_channels=out_channels * 2, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + x = self.conv_first(x) + x_left = self.convs(x) + x_right = self.pool(x) + x = self.fuse_last(torch.cat([x_left, x_right], dim=1)) + return x + + +class GELayer(BaseModule): + """Gather-and-Expansion Layer. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + stride (int): Stride of GELayer. Default: 1 + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Intermidiate feature map in + Semantic Branch. + """ + + def __init__(self, + in_channels, + out_channels, + exp_ratio=6, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(GELayer, self).__init__(init_cfg=init_cfg) + mid_channel = in_channels * exp_ratio + self.conv1 = ConvModule( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + if stride == 1: + self.dwconv = nn.Sequential( + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + self.shortcut = None + else: + self.dwconv = nn.Sequential( + ConvModule( + in_channels=in_channels, + out_channels=mid_channel, + kernel_size=3, + stride=stride, + padding=1, + groups=in_channels, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + # ReLU in ConvModule not shown in paper + ConvModule( + in_channels=mid_channel, + out_channels=mid_channel, + kernel_size=3, + stride=1, + padding=1, + groups=mid_channel, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + self.shortcut = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=norm_cfg, + pw_act_cfg=None, + )) + + self.conv2 = nn.Sequential( + ConvModule( + in_channels=mid_channel, + out_channels=out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None, + )) + + self.act = build_activation_layer(act_cfg) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.dwconv(x) + x = self.conv2(x) + if self.shortcut is not None: + shortcut = self.shortcut(identity) + x = x + shortcut + else: + x = x + identity + x = self.act(x) + return x + + +class CEBlock(BaseModule): + """Context Embedding Block for large receptive filed in Semantic Branch. + + Args: + in_channels (int): Number of input channels. + Default: 3. + out_channels (int): Number of output channels. + Default: 16. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + x (torch.Tensor): Last feature map in Semantic Branch. + """ + + def __init__(self, + in_channels=3, + out_channels=16, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(CEBlock, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_channels = out_channels + self.gap = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + build_norm_layer(norm_cfg, self.in_channels)[1]) + self.conv_gap = ConvModule( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + # Note: in paper here is naive conv2d, no bn-relu + self.conv_last = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def forward(self, x): + identity = x + x = self.gap(x) + x = self.conv_gap(x) + x = identity + x + x = self.conv_last(x) + return x + + +class SemanticBranch(BaseModule): + """Semantic Branch which is lightweight with narrow channels and deep + layers to obtain high-level semantic context. + + Args: + semantic_channels(Tuple[int]): Size of channel numbers of + various stages in Semantic Branch. + Default: (16, 32, 64, 128). + in_channels (int): Number of channels of input image. Default: 3. + exp_ratio (int): Expansion ratio for middle channels. + Default: 6. + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + semantic_outs (List[torch.Tensor]): List of several feature maps + for auxiliary heads (Booster) and Bilateral + Guided Aggregation Layer. + """ + + def __init__(self, + semantic_channels=(16, 32, 64, 128), + in_channels=3, + exp_ratio=6, + init_cfg=None): + super(SemanticBranch, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.semantic_channels = semantic_channels + self.semantic_stages = [] + for i in range(len(semantic_channels)): + stage_name = f'stage{i + 1}' + self.semantic_stages.append(stage_name) + if i == 0: + self.add_module( + stage_name, + StemBlock(self.in_channels, semantic_channels[i])) + elif i == (len(semantic_channels) - 1): + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + else: + self.add_module( + stage_name, + nn.Sequential( + GELayer(semantic_channels[i - 1], semantic_channels[i], + exp_ratio, 2), + GELayer(semantic_channels[i], semantic_channels[i], + exp_ratio, 1))) + + self.add_module(f'stage{len(semantic_channels)}_CEBlock', + CEBlock(semantic_channels[-1], semantic_channels[-1])) + self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock') + + def forward(self, x): + semantic_outs = [] + for stage_name in self.semantic_stages: + semantic_stage = getattr(self, stage_name) + x = semantic_stage(x) + semantic_outs.append(x) + return semantic_outs + + +class BGALayer(BaseModule): + """Bilateral Guided Aggregation Layer to fuse the complementary information + from both Detail Branch and Semantic Branch. + + Args: + out_channels (int): Number of output channels. + Default: 128. + align_corners (bool): align_corners argument of F.interpolate. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + Returns: + output (torch.Tensor): Output feature map for Segment heads. + """ + + def __init__(self, + out_channels=128, + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(BGALayer, self).__init__(init_cfg=init_cfg) + self.out_channels = out_channels + self.align_corners = align_corners + self.detail_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.detail_down = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=2, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False)) + self.semantic_conv = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None)) + self.semantic_dwconv = nn.Sequential( + DepthwiseSeparableConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + dw_norm_cfg=norm_cfg, + dw_act_cfg=None, + pw_norm_cfg=None, + pw_act_cfg=None, + )) + self.conv = ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + inplace=True, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + ) + + def forward(self, x_d, x_s): + detail_dwconv = self.detail_dwconv(x_d) + detail_down = self.detail_down(x_d) + semantic_conv = self.semantic_conv(x_s) + semantic_dwconv = self.semantic_dwconv(x_s) + semantic_conv = resize( + input=semantic_conv, + size=detail_dwconv.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv) + fuse_2 = detail_down * torch.sigmoid(semantic_dwconv) + fuse_2 = resize( + input=fuse_2, + size=fuse_1.shape[2:], + mode='bilinear', + align_corners=self.align_corners) + output = self.conv(fuse_1 + fuse_2) + return output + + +@BACKBONES.register_module() +class BiSeNetV2(BaseModule): + """BiSeNetV2: Bilateral Network with Guided Aggregation for + Real-time Semantic Segmentation. + + This backbone is the implementation of + `BiSeNetV2 `_. + + Args: + in_channels (int): Number of channel of input image. Default: 3. + detail_channels (Tuple[int], optional): Channels of each stage + in Detail Branch. Default: (64, 64, 128). + semantic_channels (Tuple[int], optional): Channels of each stage + in Semantic Branch. Default: (16, 32, 64, 128). + See Table 1 and Figure 3 of paper for more details. + semantic_expansion_ratio (int, optional): The expansion factor + expanding channel number of middle channels in Semantic Branch. + Default: 6. + bga_channels (int, optional): Number of middle channels in + Bilateral Guided Aggregation Layer. Default: 128. + out_indices (Tuple[int] | int, optional): Output from which stages. + Default: (0, 1, 2, 3, 4). + align_corners (bool, optional): The align_corners argument of + resize operation in Bilateral Guided Aggregation Layer. + Default: False. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels=3, + detail_channels=(64, 64, 128), + semantic_channels=(16, 32, 64, 128), + semantic_expansion_ratio=6, + bga_channels=128, + out_indices=(0, 1, 2, 3, 4), + align_corners=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + init_cfg=None): + if init_cfg is None: + init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm']) + ] + super(BiSeNetV2, self).__init__(init_cfg=init_cfg) + self.in_channels = in_channels + self.out_indices = out_indices + self.detail_channels = detail_channels + self.semantic_channels = semantic_channels + self.semantic_expansion_ratio = semantic_expansion_ratio + self.bga_channels = bga_channels + self.align_corners = align_corners + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.detail = DetailBranch(self.detail_channels, self.in_channels) + self.semantic = SemanticBranch(self.semantic_channels, + self.in_channels, + self.semantic_expansion_ratio) + self.bga = BGALayer(self.bga_channels, self.align_corners) + + def forward(self, x): + # stole refactoring code from Coin Cheung, thanks + x_detail = self.detail(x) + x_semantic_lst = self.semantic(x) + x_head = self.bga(x_detail, x_semantic_lst[-1]) + outs = [x_head] + x_semantic_lst[:-1] + outs = [outs[i] for i in self.out_indices] + return tuple(outs) diff --git a/mmseg/models/backbones/swin.py b/mmseg/models/backbones/swin.py index 424c456cb3..7de1883678 100644 --- a/mmseg/models/backbones/swin.py +++ b/mmseg/models/backbones/swin.py @@ -25,6 +25,7 @@ class PatchMerging(BaseModule): This layer use nn.Unfold to group feature map by kernel_size, and use norm and linear layer to embed grouped feature map. + Args: in_channels (int): The num of input channels. out_channels (int): The num of output channels. diff --git a/model-index.yml b/model-index.yml index 21a1f0bdc9..1fa927ad92 100644 --- a/model-index.yml +++ b/model-index.yml @@ -1,6 +1,7 @@ Import: - configs/ann/ann.yml - configs/apcnet/apcnet.yml +- configs/bisenetv2/bisenetv2.yml - configs/ccnet/ccnet.yml - configs/cgnet/cgnet.yml - configs/danet/danet.yml diff --git a/tests/test_models/test_backbones/test_bisenetv2.py b/tests/test_models/test_backbones/test_bisenetv2.py new file mode 100644 index 0000000000..a1d1adc5f9 --- /dev/null +++ b/tests/test_models/test_backbones/test_bisenetv2.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmcv.cnn import ConvModule + +from mmseg.models.backbones import BiSeNetV2 +from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch, + SemanticBranch) + + +def test_bisenetv2_backbone(): + # Test BiSeNetV2 Standard Forward + model = BiSeNetV2() + model.init_weights() + model.train() + batch_size = 2 + imgs = torch.randn(batch_size, 3, 512, 1024) + feat = model(imgs) + + assert len(feat) == 5 + # output for segment Head + assert feat[0].shape == torch.Size([batch_size, 128, 64, 128]) + # for auxiliary head 1 + assert feat[1].shape == torch.Size([batch_size, 16, 128, 256]) + # for auxiliary head 2 + assert feat[2].shape == torch.Size([batch_size, 32, 64, 128]) + # for auxiliary head 3 + assert feat[3].shape == torch.Size([batch_size, 64, 32, 64]) + # for auxiliary head 4 + assert feat[4].shape == torch.Size([batch_size, 128, 16, 32]) + + # Test input with rare shape + batch_size = 2 + imgs = torch.randn(batch_size, 3, 527, 952) + feat = model(imgs) + assert len(feat) == 5 + + +def test_bisenetv2_DetailBranch(): + x = torch.randn(1, 3, 512, 1024) + detail_branch = DetailBranch(detail_channels=(64, 64, 128)) + assert isinstance(detail_branch.detail_branch[0][0], ConvModule) + x_out = detail_branch(x) + assert x_out.shape == torch.Size([1, 128, 64, 128]) + + +def test_bisenetv2_SemanticBranch(): + semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128)) + assert semantic_branch.stage1.pool.stride == 2 + + +def test_bisenetv2_BGALayer(): + x_a = torch.randn(1, 128, 64, 128) + x_b = torch.randn(1, 128, 16, 32) + bga = BGALayer() + assert isinstance(bga.conv, ConvModule) + x_out = bga(x_a, x_b) + assert x_out.shape == torch.Size([1, 128, 64, 128])