Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Real-time model ERFNet #960

Merged
merged 16 commits into from
Dec 2, 2021
Merged
28 changes: 28 additions & 0 deletions configs/_base_/models/erfnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='ERFNet',
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_num_stages_non_bottleneck=(5, 8),
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_num_stages_non_bottleneck=(2, 2),
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
init_cfg=None),
decode_head=dict(
type='ERFHead',
in_channels=16,
channels=19,
num_classes=19,
norm_cfg=norm_cfg,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
4 changes: 4 additions & 0 deletions configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = [
'../_base_/models/erfnet.py', '../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
from .icnet import ICNet
Expand All @@ -19,5 +20,5 @@
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet'
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'ERFNet'
]
326 changes: 326 additions & 0 deletions mmseg/models/backbones/erfnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,326 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule

from mmseg.ops import resize
from ..builder import BACKBONES


class DownsamplerBlock(BaseModule):
"""Downsampler block of ERFNet.

This module is a little different from basical ConvModule. Concatenation
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
of Conv and MaxPool will be used before BatchNorm.
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved

Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""

def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(DownsamplerBlock, self).__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg

self.conv = build_conv_layer(
self.conv_cfg,
in_channels,
out_channels - in_channels,
kernel_size=3,
stride=2,
padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
self.act = build_activation_layer(self.act_cfg)

def forward(self, input):
conv_out = self.conv(input)
pool_out = self.pool(input)
pool_out = resize(
input=pool_out,
size=conv_out.size()[2:],
mode='bilinear',
align_corners=False)
output = torch.cat([conv_out, pool_out], 1)
output = self.bn(output)
output = self.act(output)
return output


class NonBottleneck1d(BaseModule):
"""Non-bottleneck block of ERFNet.

Args:
channels (int): Number of channels in Non-bottleneck block.
drop_rate (float): Probability of an element to be zeroed.
Default 0.
dilation (int): Dilation rate for last two conv layers.
Default 1.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""

def __init__(self,
channels,
drop_rate=0,
dilation=1,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(NonBottleneck1d, self).__init__(init_cfg=init_cfg)

self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.act = build_activation_layer(self.act_cfg)

self.convs_layer = nn.ModuleList()
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
for conv_layer in range(2):
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
conv_first_padding = (1, 0) if conv_layer == 0 else (1 * dilation,
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
0)
conv_first_dilation = 1 if conv_layer == 0 else (dilation, 1)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
conv_second_padding = (0, 1) if conv_layer == 0 else (0, dilation)
conv_second_dilation = 1 if conv_layer == 0 else (1, dilation)
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved

self.convs_layer.append(
build_conv_layer(
self.conv_cfg,
channels,
channels,
kernel_size=(3, 1),
stride=1,
padding=conv_first_padding,
bias=True,
dilation=conv_first_dilation))
self.convs_layer.append(self.act)
self.convs_layer.append(
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
build_conv_layer(
self.conv_cfg,
channels,
channels,
kernel_size=(1, 3),
stride=1,
padding=conv_second_padding,
bias=True,
dilation=conv_second_dilation))
self.convs_layer.append(
build_norm_layer(self.norm_cfg, channels)[1])
if conv_layer == 0:
self.convs_layer.append(self.act)
else:
self.convs_layer.append(nn.Dropout(p=drop_rate))

def forward(self, input):
output = input
for op in self.convs_layer:
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
output = op(output)
output = self.act(output + input)
return output


class UpsamplerBlock(BaseModule):
"""Upsampler block of ERFNet.

Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""

def __init__(self,
in_channels,
out_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(UpsamplerBlock, self).__init__(init_cfg=init_cfg)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg

self.conv = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
output_padding=1,
bias=True)
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
self.act = build_activation_layer(self.act_cfg)

def forward(self, input):
output = self.conv(input)
output = self.bn(output)
output = self.act(output)
return output


@BACKBONES.register_module()
class ERFNet(BaseModule):
"""ERFNet backbone.

This backbone is the implementation of `ERFNet: Efficient Residual
Factorized ConvNet for Real-time SemanticSegmentation
<https://ieeexplore.ieee.org/document/8063438>`_.

Args:
in_channels (int): The number of channels of input
image. Default: 3.
enc_downsample_channels (Tuple[int]): Size of channel
numbers of various Downsampler block in encoder.
Default: (16, 64, 128).
enc_num_stages_non_bottleneck (Tuple[int]): Number of stages of
Non-bottleneck block in encoder.
Default: (5, 8).
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
stage of Non-bottleneck block of encoder.
Default: (2, 4, 8, 16).
enc_non_bottleneck_channels (Tuple[int]): Size of channel
numbers of various Non-bottleneck block in encoder.
Default: (64, 128).
dec_upsample_channels (Tuple[int]): Size of channel numbers of
various Deconvolution block in decoder.
Default: (64, 16).
dec_num_stages_non_bottleneck (Tuple[int]): Number of stages of
Non-bottleneck block in decoder.
Default: (2, 2).
dec_non_bottleneck_channels (Tuple[int]): Size of channel
numbers of various Non-bottleneck block in decoder.
Default: (64, 16).
drop_rate (float): Probability of an element to be zeroed.
Default 0.1.
"""

def __init__(self,
in_channels=3,
enc_downsample_channels=(16, 64, 128),
enc_num_stages_non_bottleneck=(5, 8),
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
enc_non_bottleneck_dilations=(2, 4, 8, 16),
enc_non_bottleneck_channels=(64, 128),
dec_upsample_channels=(64, 16),
dec_num_stages_non_bottleneck=(2, 2),
MengzhangLI marked this conversation as resolved.
Show resolved Hide resolved
dec_non_bottleneck_channels=(64, 16),
dropout_ratio=0.1,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='ReLU'),
init_cfg=None):

super(ERFNet, self).__init__(init_cfg=init_cfg)
assert len(enc_downsample_channels) \
== len(dec_upsample_channels)+1, 'Number of downsample\
block of encoder does not \
match number of upsample block of decoder!'
assert len(enc_downsample_channels) \
== len(enc_num_stages_non_bottleneck)+1, 'Number of \
downsample block of encoder does not match \
number of Non-bottleneck block of encoder!'
assert len(enc_downsample_channels) \
== len(enc_non_bottleneck_channels)+1, 'Number of \
downsample block of encoder does not match \
number of channels of Non-bottleneck block of encoder!'
assert enc_num_stages_non_bottleneck[-1] \
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
Non-bottleneck block of encoder does not match \
number of Non-bottleneck block of encoder!'
assert len(dec_upsample_channels) \
== len(dec_num_stages_non_bottleneck), 'Number of \
upsample block of decoder does not match \
number of Non-bottleneck block of decoder!'
assert len(dec_num_stages_non_bottleneck) \
== len(dec_non_bottleneck_channels), 'Number of \
Non-bottleneck block of decoder does not match \
number of channels of Non-bottleneck block of decoder!'

self.in_channels = in_channels
self.enc_downsample_channels = enc_downsample_channels
self.enc_num_stages_non_bottleneck = enc_num_stages_non_bottleneck
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
self.dec_upsample_channels = dec_upsample_channels
self.dec_num_stages_non_bottleneck = dec_num_stages_non_bottleneck
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
self.dropout_ratio = dropout_ratio

self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()

self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg

self.encoder.append(
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))

for i in range(len(enc_downsample_channels) - 1):
self.encoder.append(
DownsamplerBlock(enc_downsample_channels[i],
enc_downsample_channels[i + 1]))
# Last part of encoder is some dilated NonBottleneck1d blocks.
if i == len(enc_downsample_channels) - 2:
iteration_times = int(enc_num_stages_non_bottleneck[-1] /
len(enc_non_bottleneck_dilations))
for j in range(iteration_times):
for k in range(len(enc_non_bottleneck_dilations)):
self.encoder.append(
NonBottleneck1d(enc_downsample_channels[-1],
self.dropout_ratio,
enc_non_bottleneck_dilations[k]))
else:
for j in range(enc_num_stages_non_bottleneck[i]):
self.encoder.append(
NonBottleneck1d(enc_downsample_channels[i + 1],
self.dropout_ratio))

for i in range(len(dec_upsample_channels)):
if i == 0:
self.decoder.append(
UpsamplerBlock(enc_downsample_channels[-1],
dec_non_bottleneck_channels[i]))
else:
self.decoder.append(
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
dec_non_bottleneck_channels[i]))
for j in range(dec_num_stages_non_bottleneck[i]):
self.decoder.append(
NonBottleneck1d(dec_non_bottleneck_channels[i]))

def forward(self, x):
for enc in self.encoder:
x = enc(x)
for dec in self.decoder:
x = dec(x)
return [x]
4 changes: 3 additions & 1 deletion mmseg/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .dpt_head import DPTHead
from .ema_head import EMAHead
from .enc_head import EncHead
from .erf_head import ERFHead
from .fcn_head import FCNHead
from .fpn_head import FPNHead
from .gc_head import GCHead
Expand All @@ -31,5 +32,6 @@
'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead',
'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead',
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead'
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead',
'ERFHead'
]
Loading