Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support EMANet #34

Merged
merged 14 commits into from
Sep 7, 2020
47 changes: 47 additions & 0 deletions configs/_base_/models/emanet_r50-d8.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='EMAHead',
in_channels=2048,
in_index=3,
channels=256,
ema_channels=512,
num_bases=64,
num_stages=3,
momentum=0.1,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
# model training and testing settings
train_cfg = dict()
test_cfg = dict(mode='whole')
22 changes: 22 additions & 0 deletions configs/emanet/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Expectation-Maximization Attention Networks for Semantic Segmentation

## Introduction
```
@inproceedings{li2019expectation,
title={Expectation-maximization attention networks for semantic segmentation},
author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong},
booktitle={Proceedings of the IEEE International Conference on Computer Vision},
pages={9167--9176},
year={2019}
}
```

## Results and models

### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download |
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) |
| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) |
2 changes: 2 additions & 0 deletions configs/emanet/emanet_r101-d8_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
2 changes: 2 additions & 0 deletions configs/emanet/emanet_r101-d8_769x769_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py'
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101))
4 changes: 4 additions & 0 deletions configs/emanet/emanet_r50-d8_512x1024_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py'
]
9 changes: 9 additions & 0 deletions configs/emanet/emanet_r50-d8_769x769_80k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/emanet_r50-d8.py',
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py',
'../_base_/schedules/schedule_80k.py'
]
model = dict(
decode_head=dict(align_corners=True),
auxiliary_head=dict(align_corners=True))
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513))
3 changes: 2 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
Expand All @@ -17,5 +18,5 @@
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead'
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead'
]
168 changes: 168 additions & 0 deletions mmseg/models/decode_heads/ema_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import math

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule

from ..builder import HEADS
from .decode_head import BaseDecodeHead


def reduce_mean(tensor):
"""Reduce mean when distributed training."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor


class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.

Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
"""

def __init__(self, channels, num_bases, num_stages, momentum):
super(EMAModule, self).__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum

bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, channels, num_bases]
bases = F.normalize(bases, dim=1, p=2)
self.register_buffer('bases', bases)

def forward(self, feats):
"""Forward function."""
batch_size, channels, height, width = feats.size()
# [batch_size, channels, height*width]
feats = feats.view(batch_size, channels, height * width)
# [batch_size, channels, num_bases]
bases = self.bases.repeat(batch_size, 1, 1)

with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
# l1 norm
attention_normed = F.normalize(attention, dim=1, p=1)
# [batch_size, channels, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)

feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, channels, height, width)

if self.training:
bases = bases.mean(dim=0, keepdim=True)
bases = reduce_mean(bases)
# l2 norm
bases = F.normalize(bases, dim=1, p=2)
self.bases = (1 -
self.momentum) * self.bases + self.momentum * bases

return feats_recon


@HEADS.register_module()
class EMAHead(BaseDecodeHead):
"""Expectation Maximization Attention Networks for Semantic Segmentation.

This head is the implementation of `EMANet
<https://arxiv.org/abs/1907.13426>`_.

Args:
ema_channels (int): EMA module channels
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
concat_input (bool): Whether concat the input and output of convs
before classification layer. Default: True
momentum (float): Momentum to update the base. Default: 0.1.
"""

def __init__(self,
ema_channels,
num_bases,
num_stages,
concat_input=True,
momentum=0.1,
**kwargs):
super(EMAHead, self).__init__(**kwargs)
self.ema_channels = ema_channels
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
self.num_bases = num_bases
self.num_stages = num_stages
self.concat_input = concat_input
self.momentum = momentum
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum)

self.ema_in_conv = ConvModule(
self.in_channels,
self.ema_channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
# project (0, inf) -> (-inf, inf)
self.ema_mid_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=None)
for param in self.ema_mid_conv.parameters():
param.requires_grad = False

self.ema_out_conv = ConvModule(
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.bottleneck = ConvModule(
self.ema_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, inputs):
"""Forward function."""
x = self._transform_inputs(inputs)
feats = self.ema_in_conv(x)
identity = feats
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = self.bottleneck(output)
if self.concat_input:
output = self.conv_cat(torch.cat([x, output], dim=1))
output = self.cls_seg(output)
return output
5 changes: 5 additions & 0 deletions tests/test_models/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@ def test_mobilenet_v2_forward():
'mobilenet_v2/pspnet_m-v2-d8_512x1024_80k_cityscapes.py')


def test_emanet_forward():
_test_encoder_decoder_forward(
'emanet/emanet_r50-d8_512x1024_80k_cityscapes.py')


def get_world_size(process_group):

return 1
Expand Down
24 changes: 21 additions & 3 deletions tests/test_models/test_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from mmcv.utils.parrots_wrapper import SyncBatchNorm

from mmseg.models.decode_heads import (ANNHead, ASPPHead, CCHead, DAHead,
DepthwiseSeparableASPPHead, EncHead,
FCNHead, GCHead, NLHead, OCRHead,
PSAHead, PSPHead, UPerHead)
DepthwiseSeparableASPPHead, EMAHead,
EncHead, FCNHead, GCHead, NLHead,
OCRHead, PSAHead, PSPHead, UPerHead)
from mmseg.models.decode_heads.decode_head import BaseDecodeHead


Expand Down Expand Up @@ -539,3 +539,21 @@ def test_dw_aspp_head():
assert head.aspp_modules[2].depthwise_conv.dilation == (24, 24)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)


def test_emanet_head():
head = EMAHead(
in_channels=32,
ema_channels=24,
channels=16,
num_stages=3,
num_bases=16,
num_classes=19)
for param in head.ema_mid_conv.parameters():
assert not param.requires_grad
assert hasattr(head, 'ema_module')
inputs = [torch.randn(1, 32, 45, 45)]
if torch.cuda.is_available():
head, inputs = to_cuda(head, inputs)
outputs = head(inputs)
assert outputs.shape == (1, head.num_classes, 45, 45)