forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support EMANet (open-mmlab#34)
* add emanet * fixed bug and typos * add emanet config * fixed padding * fixed identity * rename * rename * add concat_input * fallback to update last * Fixed concat * update EMANet * Add tests * remove self-implement norm Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com>
- Loading branch information
Showing
10 changed files
with
282 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# model settings | ||
norm_cfg = dict(type='SyncBN', requires_grad=True) | ||
model = dict( | ||
type='EncoderDecoder', | ||
pretrained='open-mmlab://resnet50_v1c', | ||
backbone=dict( | ||
type='ResNetV1c', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
dilations=(1, 1, 2, 4), | ||
strides=(1, 2, 1, 1), | ||
norm_cfg=norm_cfg, | ||
norm_eval=False, | ||
style='pytorch', | ||
contract_dilation=True), | ||
decode_head=dict( | ||
type='EMAHead', | ||
in_channels=2048, | ||
in_index=3, | ||
channels=256, | ||
ema_channels=512, | ||
num_bases=64, | ||
num_stages=3, | ||
momentum=0.1, | ||
dropout_ratio=0.1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), | ||
auxiliary_head=dict( | ||
type='FCNHead', | ||
in_channels=1024, | ||
in_index=2, | ||
channels=256, | ||
num_convs=1, | ||
concat_input=False, | ||
dropout_ratio=0.1, | ||
num_classes=19, | ||
norm_cfg=norm_cfg, | ||
align_corners=False, | ||
loss_decode=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4))) | ||
# model training and testing settings | ||
train_cfg = dict() | ||
test_cfg = dict(mode='whole') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Expectation-Maximization Attention Networks for Semantic Segmentation | ||
|
||
## Introduction | ||
``` | ||
@inproceedings{li2019expectation, | ||
title={Expectation-maximization attention networks for semantic segmentation}, | ||
author={Li, Xia and Zhong, Zhisheng and Wu, Jianlong and Yang, Yibo and Lin, Zhouchen and Liu, Hong}, | ||
booktitle={Proceedings of the IEEE International Conference on Computer Vision}, | ||
pages={9167--9176}, | ||
year={2019} | ||
} | ||
``` | ||
|
||
## Results and models | ||
|
||
### Cityscapes | ||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | download | | ||
|--------|----------|-----------|--------:|---------:|----------------|------:|---------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| | ||
| EMANet | R-50-D8 | 512x1024 | 80000 | 5.4 | 4.58 | 77.59 | 79.44 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes_20200901_100301-c43fcef1.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_512x1024_80k_cityscapes/emanet_r50-d8_512x1024_80k_cityscapes-20200901_100301.log.json) | | ||
| EMANet | R-101-D8 | 512x1024 | 80000 | 6.2 | 2.87 | 79.10 | 81.21 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes_20200901_100301-2d970745.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_512x1024_80k_cityscapes/emanet_r101-d8_512x1024_80k_cityscapes-20200901_100301.log.json) | | ||
| EMANet | R-50-D8 | 769x769 | 80000 | 8.9 | 1.97 | 79.33 | 80.49 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes_20200901_100301-16f8de52.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r50-d8_769x769_80k_cityscapes/emanet_r50-d8_769x769_80k_cityscapes-20200901_100301.log.json) | | ||
| EMANet | R-101-D8 | 769x769 | 80000 | 10.1 | 1.22 | 79.62 | 81.00 | [model](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes_20200901_100301-47a324ce.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmsegmentation/v0.5/emanet/emanet_r101-d8_769x769_80k_cityscapes/emanet_r101-d8_769x769_80k_cityscapes-20200901_100301.log.json) | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_base_ = './emanet_r50-d8_512x1024_80k_cityscapes.py' | ||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
_base_ = './emanet_r50-d8_769x769_80k_cityscapes.py' | ||
model = dict(pretrained='open-mmlab://resnet101_v1c', backbone=dict(depth=101)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
_base_ = [ | ||
'../_base_/models/emanet_r50-d8.py', '../_base_/datasets/cityscapes.py', | ||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_80k.py' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = [ | ||
'../_base_/models/emanet_r50-d8.py', | ||
'../_base_/datasets/cityscapes_769x769.py', '../_base_/default_runtime.py', | ||
'../_base_/schedules/schedule_80k.py' | ||
] | ||
model = dict( | ||
decode_head=dict(align_corners=True), | ||
auxiliary_head=dict(align_corners=True)) | ||
test_cfg = dict(mode='slide', crop_size=(769, 769), stride=(513, 513)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
import math | ||
|
||
import torch | ||
import torch.distributed as dist | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from mmcv.cnn import ConvModule | ||
|
||
from ..builder import HEADS | ||
from .decode_head import BaseDecodeHead | ||
|
||
|
||
def reduce_mean(tensor): | ||
"""Reduce mean when distributed training.""" | ||
if not (dist.is_available() and dist.is_initialized()): | ||
return tensor | ||
tensor = tensor.clone() | ||
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) | ||
return tensor | ||
|
||
|
||
class EMAModule(nn.Module): | ||
"""Expectation Maximization Attention Module used in EMANet. | ||
Args: | ||
channels (int): Channels of the whole module. | ||
num_bases (int): Number of bases. | ||
num_stages (int): Number of the EM iterations. | ||
""" | ||
|
||
def __init__(self, channels, num_bases, num_stages, momentum): | ||
super(EMAModule, self).__init__() | ||
assert num_stages >= 1, 'num_stages must be at least 1!' | ||
self.num_bases = num_bases | ||
self.num_stages = num_stages | ||
self.momentum = momentum | ||
|
||
bases = torch.zeros(1, channels, self.num_bases) | ||
bases.normal_(0, math.sqrt(2. / self.num_bases)) | ||
# [1, channels, num_bases] | ||
bases = F.normalize(bases, dim=1, p=2) | ||
self.register_buffer('bases', bases) | ||
|
||
def forward(self, feats): | ||
"""Forward function.""" | ||
batch_size, channels, height, width = feats.size() | ||
# [batch_size, channels, height*width] | ||
feats = feats.view(batch_size, channels, height * width) | ||
# [batch_size, channels, num_bases] | ||
bases = self.bases.repeat(batch_size, 1, 1) | ||
|
||
with torch.no_grad(): | ||
for i in range(self.num_stages): | ||
# [batch_size, height*width, num_bases] | ||
attention = torch.einsum('bcn,bck->bnk', feats, bases) | ||
attention = F.softmax(attention, dim=2) | ||
# l1 norm | ||
attention_normed = F.normalize(attention, dim=1, p=1) | ||
# [batch_size, channels, num_bases] | ||
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed) | ||
# l2 norm | ||
bases = F.normalize(bases, dim=1, p=2) | ||
|
||
feats_recon = torch.einsum('bck,bnk->bcn', bases, attention) | ||
feats_recon = feats_recon.view(batch_size, channels, height, width) | ||
|
||
if self.training: | ||
bases = bases.mean(dim=0, keepdim=True) | ||
bases = reduce_mean(bases) | ||
# l2 norm | ||
bases = F.normalize(bases, dim=1, p=2) | ||
self.bases = (1 - | ||
self.momentum) * self.bases + self.momentum * bases | ||
|
||
return feats_recon | ||
|
||
|
||
@HEADS.register_module() | ||
class EMAHead(BaseDecodeHead): | ||
"""Expectation Maximization Attention Networks for Semantic Segmentation. | ||
This head is the implementation of `EMANet | ||
<https://arxiv.org/abs/1907.13426>`_. | ||
Args: | ||
ema_channels (int): EMA module channels | ||
num_bases (int): Number of bases. | ||
num_stages (int): Number of the EM iterations. | ||
concat_input (bool): Whether concat the input and output of convs | ||
before classification layer. Default: True | ||
momentum (float): Momentum to update the base. Default: 0.1. | ||
""" | ||
|
||
def __init__(self, | ||
ema_channels, | ||
num_bases, | ||
num_stages, | ||
concat_input=True, | ||
momentum=0.1, | ||
**kwargs): | ||
super(EMAHead, self).__init__(**kwargs) | ||
self.ema_channels = ema_channels | ||
self.num_bases = num_bases | ||
self.num_stages = num_stages | ||
self.concat_input = concat_input | ||
self.momentum = momentum | ||
self.ema_module = EMAModule(self.ema_channels, self.num_bases, | ||
self.num_stages, self.momentum) | ||
|
||
self.ema_in_conv = ConvModule( | ||
self.in_channels, | ||
self.ema_channels, | ||
3, | ||
padding=1, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg) | ||
# project (0, inf) -> (-inf, inf) | ||
self.ema_mid_conv = ConvModule( | ||
self.ema_channels, | ||
self.ema_channels, | ||
1, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=None, | ||
act_cfg=None) | ||
for param in self.ema_mid_conv.parameters(): | ||
param.requires_grad = False | ||
|
||
self.ema_out_conv = ConvModule( | ||
self.ema_channels, | ||
self.ema_channels, | ||
1, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=None) | ||
self.bottleneck = ConvModule( | ||
self.ema_channels, | ||
self.channels, | ||
3, | ||
padding=1, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg) | ||
if self.concat_input: | ||
self.conv_cat = ConvModule( | ||
self.in_channels + self.channels, | ||
self.channels, | ||
kernel_size=3, | ||
padding=1, | ||
conv_cfg=self.conv_cfg, | ||
norm_cfg=self.norm_cfg, | ||
act_cfg=self.act_cfg) | ||
|
||
def forward(self, inputs): | ||
"""Forward function.""" | ||
x = self._transform_inputs(inputs) | ||
feats = self.ema_in_conv(x) | ||
identity = feats | ||
feats = self.ema_mid_conv(feats) | ||
recon = self.ema_module(feats) | ||
recon = F.relu(recon, inplace=True) | ||
recon = self.ema_out_conv(recon) | ||
output = F.relu(identity + recon, inplace=True) | ||
output = self.bottleneck(output) | ||
if self.concat_input: | ||
output = self.conv_cat(torch.cat([x, output], dim=1)) | ||
output = self.cls_seg(output) | ||
return output |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters