Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support EMANet #34

Merged
merged 14 commits into from
Sep 7, 2020
4 changes: 2 additions & 2 deletions configs/_base_/models/emanet_r50-d8.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
type='EMAHead',
in_channels=2048,
in_index=3,
channels=512,
channels=256,
ema_channels=512,
num_bases=64,
num_stages=3,
momentum=0.1,
epsilon=1e-6,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
Expand Down
2 changes: 1 addition & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@
__all__ = [
'FCNHead', 'PSPHead', 'ASPPHead', 'PSAHead', 'NLHead', 'GCHead', 'CCHead',
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'EMANet'
'EncHead', 'EMAHead'
]
100 changes: 67 additions & 33 deletions mmseg/models/decode_heads/ema_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,62 @@
from .decode_head import BaseDecodeHead


def reduce_mean(tensor):
"""Reduce mean when distributed training."""
if not (dist.is_available() and dist.is_initialized()):
return tensor
tensor = tensor.clone()
dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
return tensor


class EMAModule(nn.Module):
"""Expectation Maximization Attention Module used in EMANet.

Args:
channels (int): Channels of the whole module.
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
epsilon (float): A small value for computation stability
epsilon (float): A small value for computation stability.
Default: 1e-6.
"""

def __init__(self, channels, num_bases, num_stages, momentum, epsilon):
def __init__(self,
channels,
num_bases,
num_stages,
momentum,
epsilon=1e-6):
super(EMAModule, self).__init__()
assert num_stages >= 1, 'num_stages must be at least 1!'
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
self.epsilon = epsilon

bases = torch.Tensor(1, channels, self.num_bases)
bases = torch.zeros(1, channels, self.num_bases)
bases.normal_(0, math.sqrt(2. / self.num_bases))
# [1, num_classes, num_bases]
bases = self._l2norm(bases, dim=1)
self.register_buffer('bases', bases)

def _l2norm(self, input, dim):
"""Normlize the inp tensor with l2-norm.
"""Normalize the inp tensor with l2-norm.

Returns a tensor where each sub-tensor of input along the given dim is
normalized such that the 2-norm of the sub-tensor is equal to 1.

Args:
inp (tensor): The input tensor.
input (tensor): The input tensor.
dim (int): The dimension to slice over to get the ssub-tensors.
"""
return input / (self.epsilon + input.norm(dim=dim, keepdim=True))

def _l1norm(self, input, dim):
"""Normlize the inp tensor with l1-norm.
"""Normalize the inp tensor with l1-norm.

Args:
inp (tensor): The input tensor.
input (tensor): The input tensor.
dim (int): The dimension to slice over to get the ssub-tensors.
"""
return input / (self.epsilon + input.sum(dim=dim, keepdim=True))
Expand All @@ -66,21 +81,23 @@ def forward(self, feats):
with torch.no_grad():
for i in range(self.num_stages):
# [batch_size, height*width, num_bases]
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
attention_normed = self._l1norm(attention, dim=1)
# [batch_size, num_classes, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
bases = self._l2norm(bases, dim=1)
attention = torch.einsum('bcn,bck->bnk', feats, bases)
attention = F.softmax(attention, dim=2)
attention_normed = self._l1norm(attention, dim=1)
# [batch_size, num_classes, num_bases]
bases = torch.einsum('bcn,bnk->bck', feats, attention_normed)
bases = self._l2norm(bases, dim=1)

feats_recon = torch.einsum('bck,bnk->bcn', bases, attention)
feats_recon = feats_recon.view(batch_size, num_classes, height, width)

if self.training:
base = dist.all_reduce(bases)
base = self._l2norm(base, dim=1)
self.base = (1 - self.momentum) * self.base + self.momentum * base

bases = bases.mean(dim=0, keepdim=True)
bases = reduce_mean(bases)
bases = self._l2norm(bases, dim=1)
self.bases = (1 -
self.momentum) * self.bases + self.momentum * bases

return feats_recon


Expand All @@ -94,50 +111,67 @@ class EMAHead(BaseDecodeHead):
Args:
num_bases (int): Number of bases.
num_stages (int): Number of the EM iterations.
epsilon (float): A small value for computation stability
momentum (float): Momentum to update the base. Default: 0.1.
epsilon (float): A small value for computation stability.
Default: 1e-6.
"""

def __init__(self, num_bases, num_stages, momentum, epsilon, **kwargs):
def __init__(self,
ema_channels,
num_bases,
num_stages,
momentum=0.1,
epsilon=1e-6,
**kwargs):
super(EMAHead, self).__init__(**kwargs)
self.ema_channels = ema_channels
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
self.num_bases = num_bases
self.num_stages = num_stages
self.momentum = momentum
self.epsilon = epsilon
self.ema_module = EMAModule(
self.channels,
self.num_bases,
self.num_stages,
self.momentum,
self.epsilon)
self.ema_module = EMAModule(self.ema_channels, self.num_bases,
self.num_stages, self.momentum,
self.epsilon)

self.ema_in_conv = ConvModule(
self.in_channels,
self.channels,
self.ema_channels,
3,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.ema_mid_conv = ConvModule(
self.channels,
self.channels,
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=None,
act_cfg=None)
self.ema_out_conv = ConvModule(
self.channels,
self.channels,
self.ema_channels,
self.ema_channels,
1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.bottleneck = ConvModule(
self.ema_channels,
self.channels,
3,
padding=1,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, inputs):
"""Forward function."""
feats = self.ema_in_conv(inputs)
identity = feats
x = self._transform_inputs(inputs)
feats = self.ema_in_conv(x)
feats = self.ema_mid_conv(feats)
recon = self.ema_module(feats)
recon = F.relu(recon, inplace=True)
recon = self.ema_out_conv(recon)
output = F.relu(identity + recon, inplace=True)
output = F.relu(feats + recon, inplace=True)
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
output = self.bottleneck(output)
output = self.cls_seg(output)
return output