diff --git a/configs/_base_/models/emanet_r50-d8.py b/configs/_base_/models/emanet_r50-d8.py new file mode 100644 index 0000000000..326a25137a --- /dev/null +++ b/configs/_base_/models/emanet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EMAHead', + in_channels=2048, + in_index=3, + channels=256, + ema_channels=512, + num_bases=64, + num_stages=3, + momentum=0.1, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))) +# model training and testing settings +train_cfg = dict() +test_cfg = dict(mode='whole') diff --git a/configs/emanet/README.md b/configs/emanet/README.md new file mode 100644 index 0000000000..25ef985332 --- /dev/null +++ b/configs/emanet/README.md @@ -0,0 +1,22 @@ +# Expectation-Maximization Attention Networks for Semantic Segmentation + +## Introduction +``` +@inproceedings{li2019expectation, + title={Expectation-maximization attention networks for semantic segmentation}, + author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong}, + booktitle={Proceedings of the IEEE International Conference on Computer Vision}, + pages={9167--9176}, + year={2019} +} +``` + +## Results and models + +### Cityscapes +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | +|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) | +| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) | +| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) | +| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) | diff --git a/configs/emanet/emanet_r101-d8_512x1024_80k_cityscapes.py b/configs/emanet/emanet_r101-d8_512x1024_80k_cityscapes.py new file mode 100644 index 0000000000..58f28b43f5 --- /dev/null +++ b/configs/emanet/emanet_r101-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/emanet/emanet_r101-d8_769x769_80k_cityscapes.py b/configs/emanet/emanet_r101-d8_769x769_80k_cityscapes.py new file mode 100644 index 0000000000..c5dbf20b0f --- /dev/null +++ b/configs/emanet/emanet_r101-d8_769x769_80k_cityscapes.py @@ -0,0 +1,2 @@ +_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py' +model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) diff --git a/configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py b/configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py new file mode 100644 index 0000000000..73b7788bf9 --- /dev/null +++ b/configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py @@ -0,0 +1,4 @@ +_base_ = [ + '../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' +] diff --git a/configs/emanet/emanet_r50-d8_769x769_80k_cityscapes.py b/configs/emanet/emanet_r50-d8_769x769_80k_cityscapes.py new file mode 100644 index 0000000000..0fd9beea03 --- /dev/null +++ b/configs/emanet/emanet_r50-d8_769x769_80k_cityscapes.py @@ -0,0 +1,9 @@ +_base_ = [ + '../_base_/models/emanet_r50-d8.py', + '../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py', + '../_base_/schedules/schedule_80k.py' +] +model = dict( + decode_head=dict(align_corners=True), + auxiliary_head=dict(align_corners=True)) +test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513)) diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 5828034018..a4a06d2af8 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -2,6 +2,7 @@ from .aspp_head import ASPPHead from .cc_head import CCHead from .da_head import DAHead +from .ema_head import EMAHead from .enc_head import EncHead from .fcn_head import FCNHead from .fpn_head import FPNHead @@ -17,5 +18,5 @@ __all__ = [ 'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead', 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', - 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead' + 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead' ] diff --git a/mmseg/models/decode_heads/ema_head.py b/mmseg/models/decode_heads/ema_head.py new file mode 100644 index 0000000000..619d757046 --- /dev/null +++ b/mmseg/models/decode_heads/ema_head.py @@ -0,0 +1,168 @@ +import math + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import ConvModule + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +def reduce_mean(tensor): + """Reduce mean when distributed training.""" + if not (dist.is_available() and dist.is_initialized()): + return tensor + tensor = tensor.clone() + dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) + return tensor + + +class EMAModule(nn.Module): + """Expectation Maximization Attention Module used in EMANet. + + Args: + channels (int): Channels of the whole module. + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + """ + + def __init__(self, channels, num_bases, num_stages, momentum): + super(EMAModule, self).__init__() + assert num_stages >= 1, 'num_stages must be at least 1!' + self.num_bases = num_bases + self.num_stages = num_stages + self.momentum = momentum + + bases = torch.zeros(1, channels, self.num_bases) + bases.normal_(0, math.sqrt(2. / self.num_bases)) + # [1, channels, num_bases] + bases = F.normalize(bases, dim=1, p=2) + self.register_buffer('bases', bases) + + def forward(self, feats): + """Forward function.""" + batch_size, channels, height, width = feats.size() + # [batch_size, channels, height*width] + feats = feats.view(batch_size, channels, height * width) + # [batch_size, channels, num_bases] + bases = self.bases.repeat(batch_size, 1, 1) + + with torch.no_grad(): + for i in range(self.num_stages): + # [batch_size, height*width, num_bases] + attention = torch.einsum('bcn,bck->bnk', feats, bases) + attention = F.softmax(attention, dim=2) + # l1 norm + attention_normed = F.normalize(attention, dim=1, p=1) + # [batch_size, channels, num_bases] + bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + + feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) + feats_recon = feats_recon.view(batch_size, channels, height, width) + + if self.training: + bases = bases.mean(dim=0, keepdim=True) + bases = reduce_mean(bases) + # l2 norm + bases = F.normalize(bases, dim=1, p=2) + self.bases = (1 - + self.momentum) * self.bases + self.momentum * bases + + return feats_recon + + +@HEADS.register_module() +class EMAHead(BaseDecodeHead): + """Expectation Maximization Attention Networks for Semantic Segmentation. + + This head is the implementation of `EMANet + `_. + + Args: + ema_channels (int): EMA module channels + num_bases (int): Number of bases. + num_stages (int): Number of the EM iterations. + concat_input (bool): Whether concat the input and output of convs + before classification layer. Default: True + momentum (float): Momentum to update the base. Default: 0.1. + """ + + def __init__(self, + ema_channels, + num_bases, + num_stages, + concat_input=True, + momentum=0.1, + **kwargs): + super(EMAHead, self).__init__(**kwargs) + self.ema_channels = ema_channels + self.num_bases = num_bases + self.num_stages = num_stages + self.concat_input = concat_input + self.momentum = momentum + self.ema_module = EMAModule(self.ema_channels, self.num_bases, + self.num_stages, self.momentum) + + self.ema_in_conv = ConvModule( + self.in_channels, + self.ema_channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + # project (0, inf) -> (-inf, inf) + self.ema_mid_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=None, + act_cfg=None) + for param in self.ema_mid_conv.parameters(): + param.requires_grad = False + + self.ema_out_conv = ConvModule( + self.ema_channels, + self.ema_channels, + 1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=None) + self.bottleneck = ConvModule( + self.ema_channels, + self.channels, + 3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + if self.concat_input: + self.conv_cat = ConvModule( + self.in_channels + self.channels, + self.channels, + kernel_size=3, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, inputs): + """Forward function.""" + x = self._transform_inputs(inputs) + feats = self.ema_in_conv(x) + identity = feats + feats = self.ema_mid_conv(feats) + recon = self.ema_module(feats) + recon = F.relu(recon, inplace=True) + recon = self.ema_out_conv(recon) + output = F.relu(identity + recon, inplace=True) + output = self.bottleneck(output) + if self.concat_input: + output = self.conv_cat(torch.cat([x, output], dim=1)) + output = self.cls_seg(output) + return output diff --git a/tests/test_models/test_forward.py b/tests/test_models/test_forward.py index ca2f7cb277..afca42c285 100644 --- a/tests/test_models/test_forward.py +++ b/tests/test_models/test_forward.py @@ -162,6 +162,11 @@ def test_mobilenet_v2_forward(): 'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py') +def test_emanet_forward(): + _test_encoder_decoder_forward( + 'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py') + + def get_world_size(process_group): return 1 diff --git a/tests/test_models/test_heads.py b/tests/test_models/test_heads.py index 3ac6bb0aa2..8a36d1ffc7 100644 --- a/tests/test_models/test_heads.py +++ b/tests/test_models/test_heads.py @@ -6,9 +6,9 @@ from mmcv.utils.parrots_wrapper import SyncBatchNorm from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead, - DepthwiseSeparableASPPHead, EncHead, - FCNHead, GCHead, NLHead, OCRHead, - PSAHead, PSPHead, UPerHead) + DepthwiseSeparableASPPHead, EMAHead, + EncHead, FCNHead, GCHead, NLHead, + OCRHead, PSAHead, PSPHead, UPerHead) from mmseg.models.decode_heads.decode_head import BaseDecodeHead @@ -539,3 +539,21 @@ def test_dw_aspp_head(): assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24) outputs = head(inputs) assert outputs.shape == (1, head.num_classes, 45, 45) + + +def test_emanet_head(): + head = EMAHead( + in_channels=32, + ema_channels=24, + channels=16, + num_stages=3, + num_bases=16, + num_classes=19) + for param in head.ema_mid_conv.parameters(): + assert not param.requires_grad + assert hasattr(head, 'ema_module') + inputs = [torch.randn(1, 32, 45, 45)] + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 45, 45)