From d271d88e3cacce5dc836abd0701e45e0f1f06812 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Thu, 14 Oct 2021 00:04:59 +0800 Subject: [PATCH 01/13] first commit --- configs/_base_/models/erfnet.py | 28 ++ .../erfnet_4x4_1024x1024_160k_cityscapes.py | 4 + mmseg/models/backbones/__init__.py | 3 +- mmseg/models/backbones/erfnet.py | 345 ++++++++++++++++++ mmseg/models/decode_heads/__init__.py | 4 +- mmseg/models/decode_heads/erf_head.py | 33 ++ .../test_models/test_backbones/test_erfnet.py | 146 ++++++++ tests/test_models/test_heads/test_erf_head.py | 17 + 8 files changed, 578 insertions(+), 2 deletions(-) create mode 100644 configs/_base_/models/erfnet.py create mode 100644 configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py create mode 100644 mmseg/models/backbones/erfnet.py create mode 100644 mmseg/models/decode_heads/erf_head.py create mode 100644 tests/test_models/test_backbones/test_erfnet.py create mode 100644 tests/test_models/test_heads/test_erf_head.py diff --git a/configs/_base_/models/erfnet.py b/configs/_base_/models/erfnet.py new file mode 100644 index 0000000000..9b8de4e77a --- /dev/null +++ b/configs/_base_/models/erfnet.py @@ -0,0 +1,28 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='ERFNet', + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + init_cfg=None), + decode_head=dict( + type='ERFHead', + in_channels=16, + channels=19, + num_classes=19, + norm_cfg=norm_cfg, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py b/configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py new file mode 100644 index 0000000000..4f21accb9c --- /dev/null +++ b/configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py @@ -0,0 +1,4 @@ +_base_ = [ + '../_base_/models/erfnet.py', '../_base_/datasets/cityscapes_1024x1024.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py index 6d320323b8..a196418c00 100644 --- a/mmseg/models/backbones/__init__.py +++ b/mmseg/models/backbones/__init__.py @@ -2,6 +2,7 @@ from .bisenetv1 import BiSeNetV1 from .bisenetv2 import BiSeNetV2 from .cgnet import CGNet +from .erfnet import ERFNet from .fast_scnn import FastSCNN from .hrnet import HRNet from .icnet import ICNet @@ -19,5 +20,5 @@ 'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN', 'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3', 'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', - 'BiSeNetV1', 'BiSeNetV2', 'ICNet' + 'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'ERFNet' ] diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py new file mode 100644 index 0000000000..ad86ef0ec5 --- /dev/null +++ b/mmseg/models/backbones/erfnet.py @@ -0,0 +1,345 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmcv.runner import BaseModule + +from mmseg.ops import resize +from ..builder import BACKBONES + + +class DownsamplerBlock(BaseModule): + """Downsampler block of ERFNet. + + This module is a little different from basical ConvModule. Concatenation + of Conv and MaxPool will be used before Batch Norm. + + Args: + in_channels (int): Number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(DownsamplerBlock, self).__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = build_conv_layer( + self.conv_cfg, + in_channels, + out_channels - in_channels, + kernel_size=3, + stride=2, + padding=1, + bias=True) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + conv_out = self.conv(input) + pool_out = self.pool(input) + pool_out = resize( + input=pool_out, + size=conv_out.size()[2:], + mode='bilinear', + align_corners=False) + output = torch.cat([conv_out, pool_out], 1) + output = self.bn(output) + output = self.act(output) + return output + + +class non_bottleneck_1d(BaseModule): + """Non-bottleneck block of ERFNet. + + Args: + channels (int): Number of channels in Non-bottleneck block. + drop_rate (float): Probability of an element to be zeroed. + Default 0. + dilation (int): Dilation parameter for last two conv layers. + Default 1. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + channels, + drop_rate=0, + dilation=1, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(non_bottleneck_1d, self).__init__(init_cfg=init_cfg) + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.conv3x1_1 = build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=(1, 0), + bias=True) + + self.conv1x3_1 = build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=(0, 1), + bias=True) + + self.bn1 = build_norm_layer(self.norm_cfg, channels)[1] + + self.conv3x1_2 = build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=(1 * dilation, 0), + bias=True, + dilation=(dilation, 1)) + + self.conv1x3_2 = build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=(0, 1 * dilation), + bias=True, + dilation=(1, dilation)) + + self.bn2 = build_norm_layer(self.norm_cfg, channels)[1] + self.act = build_activation_layer(self.act_cfg) + self.dropout = nn.Dropout(p=drop_rate) + + def forward(self, input): + output = self.conv3x1_1(input) + output = self.act(output) + output = self.conv1x3_1(output) + output = self.bn1(output) + output = self.act(output) + + output = self.conv3x1_2(output) + output = self.act(output) + output = self.conv1x3_2(output) + output = self.bn2(output) + + output = self.dropout(output) + output = self.act(output + input) + return output + + +class UpsamplerBlock(BaseModule): + """Upsampler block of ERFNet. + + Args: + in_channels (int): Number of input channels. + out_channels (int): The number of output channels. + conv_cfg (dict | None): Config of conv layers. + Default: None. + norm_cfg (dict | None): Config of norm layers. + Default: dict(type='BN'). + act_cfg (dict): Config of activation layers. + Default: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None. + """ + + def __init__(self, + in_channels, + out_channels, + conv_cfg=None, + norm_cfg=dict(type='BN', eps=1e-3), + act_cfg=dict(type='ReLU'), + init_cfg=None): + super(UpsamplerBlock, self).__init__(init_cfg=init_cfg) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + bias=True) + self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] + self.act = build_activation_layer(self.act_cfg) + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + output = self.act(output) + return output + + +@BACKBONES.register_module() +class ERFNet(BaseModule): + """ERFNet backbone. + + This backbone is the implementation of `ERFNet: Efficient Residual + Factorized ConvNet for Real-time SemanticSegmentation + `_. + + Args: + in_channels (int): The number of channels of input + image. Default: 3. + enc_downsample_channels (Tuple[int]): Size of channel + numbers of various Downsampler block in encoder. + Default: (16, 64, 128). + enc_num_stages_non_bottleneck (Tuple[int]): Number of stages of + Non-bottleneck block in encoder. + Default: (5, 8). + enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each + stage of Non-bottleneck block of encoder. + Default: (1, 2, 4, 8, 16). + enc_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in encoder. + Default: (64, 128). + dec_upsample_channels (Tuple[int]): Size of channel numbers of + various Deconvolution block in decoder. + Default: (64, 16). + dec_num_stages_non_bottleneck (Tuple[int]): Number of stages of + Non-bottleneck block in decoder. + Default: (2, 2). + dec_non_bottleneck_channels (Tuple[int]): Size of channel + numbers of various Non-bottleneck block in decoder. + Default: (64, 16). + drop_rate (float): Probability of an element to be zeroed. + Default 0.1. + """ + + def __init__(self, + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + act_cfg=dict(type='ReLU'), + init_cfg=None): + + super(ERFNet, self).__init__(init_cfg=init_cfg) + assert len(enc_downsample_channels) \ + == len(dec_upsample_channels)+1, 'Number of downsample\ + block of encoder does not \ + match number of upsample block of decoder!' + assert len(enc_downsample_channels) \ + == len(enc_num_stages_non_bottleneck)+1, 'Number of \ + downsample block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(enc_downsample_channels) \ + == len(enc_non_bottleneck_channels)+1, 'Number of \ + downsample block of encoder does not match \ + number of channels of Non-bottleneck block of encoder!' + assert enc_num_stages_non_bottleneck[-1] \ + % len(enc_non_bottleneck_dilations) == 0, 'Number of \ + Non-bottleneck block of encoder does not match \ + number of Non-bottleneck block of encoder!' + assert len(dec_upsample_channels) \ + == len(dec_num_stages_non_bottleneck), 'Number of \ + upsample block of decoder does not match \ + number of Non-bottleneck block of decoder!' + assert len(dec_num_stages_non_bottleneck) \ + == len(dec_non_bottleneck_channels), 'Number of \ + Non-bottleneck block of decoder does not match \ + number of channels of Non-bottleneck block of decoder!' + + self.in_channels = in_channels + self.enc_downsample_channels = enc_downsample_channels + self.enc_num_stages_non_bottleneck = enc_num_stages_non_bottleneck + self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations + self.enc_non_bottleneck_channels = enc_non_bottleneck_channels + self.dec_upsample_channels = dec_upsample_channels + self.dec_num_stages_non_bottleneck = dec_num_stages_non_bottleneck + self.dec_non_bottleneck_channels = dec_non_bottleneck_channels + self.dropout_ratio = dropout_ratio + + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + + self.encoder.append( + DownsamplerBlock(self.in_channels, enc_downsample_channels[0])) + + for i in range(len(enc_downsample_channels) - 1): + self.encoder.append( + DownsamplerBlock(enc_downsample_channels[i], + enc_downsample_channels[i + 1])) + # Last part of encoder is some dilated non_bottleneck_1d blocks. + if i == len(enc_downsample_channels) - 2: + iteration_times = int(enc_num_stages_non_bottleneck[-1] / + len(enc_non_bottleneck_dilations)) + for j in range(iteration_times): + for k in range(len(enc_non_bottleneck_dilations)): + self.encoder.append( + non_bottleneck_1d(enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k])) + else: + for j in range(enc_num_stages_non_bottleneck[i]): + self.encoder.append( + non_bottleneck_1d(enc_downsample_channels[i + 1], + self.dropout_ratio)) + + for i in range(len(dec_upsample_channels)): + if i == 0: + self.decoder.append( + UpsamplerBlock(enc_downsample_channels[-1], + dec_non_bottleneck_channels[i])) + for j in range(dec_num_stages_non_bottleneck[i]): + self.decoder.append( + non_bottleneck_1d(dec_non_bottleneck_channels[i])) + else: + self.decoder.append( + UpsamplerBlock(dec_non_bottleneck_channels[i - 1], + dec_non_bottleneck_channels[i])) + for j in range(dec_num_stages_non_bottleneck[i]): + self.decoder.append( + non_bottleneck_1d(dec_non_bottleneck_channels[i])) + + def forward(self, x): + for enc in self.encoder: + x = enc(x) + for dec in self.decoder: + x = dec(x) + return x diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 14a2b2d6f1..050ce5ea71 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -9,6 +9,7 @@ from .dpt_head import DPTHead from .ema_head import EMAHead from .enc_head import EncHead +from .erf_head import ERFHead from .fcn_head import FCNHead from .fpn_head import FPNHead from .gc_head import GCHead @@ -31,5 +32,6 @@ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', - 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead' + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead', + 'ERFHead' ] diff --git a/mmseg/models/decode_heads/erf_head.py b/mmseg/models/decode_heads/erf_head.py new file mode 100644 index 0000000000..9af5d1fe42 --- /dev/null +++ b/mmseg/models/decode_heads/erf_head.py @@ -0,0 +1,33 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.nn as nn + +from ..builder import HEADS +from .decode_head import BaseDecodeHead + + +@HEADS.register_module() +class ERFHead(BaseDecodeHead): + """ERFNet backbone. + + This decoder head is the implementation of `ERFNet: Efficient + Residual Factorized ConvNet for Real-time SemanticSegmentation + `_. + + Actually it is one ConvTranspose2d operation. + """ + + def __init__(self, **kwargs): + super(ERFHead, self).__init__(**kwargs) + self.output_conv = nn.ConvTranspose2d( + in_channels=self.in_channels, + out_channels=self.channels, + kernel_size=2, + stride=2, + padding=0, + output_padding=0, + bias=True) + + def forward(self, inputs): + """Forward function.""" + output = self.output_conv(inputs) + return output diff --git a/tests/test_models/test_backbones/test_erfnet.py b/tests/test_models/test_backbones/test_erfnet.py new file mode 100644 index 0000000000..f5bc6f4346 --- /dev/null +++ b/tests/test_models/test_backbones/test_erfnet.py @@ -0,0 +1,146 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch + +from mmseg.models.backbones import ERFNet +from mmseg.models.backbones.erfnet import (DownsamplerBlock, UpsamplerBlock, + non_bottleneck_1d) + + +def test_erfnet_backbone(): + # Test ERFNet Standard Forward. + model = ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + ) + model.init_weights() + model.train() + batch_size = 2 + imgs = torch.randn(batch_size, 3, 256, 512) + output = model(imgs) + + # output for segment Head + assert output.shape == torch.Size([batch_size, 16, 128, 256]) + + # Test input with rare shape + batch_size = 2 + imgs = torch.randn(batch_size, 3, 527, 279) + output = model(imgs) + assert len(output) == batch_size + + with pytest.raises(AssertionError): + # Number of encoder downsample block and decoder upsample block. + ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(128, 64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + ) + with pytest.raises(AssertionError): + # Number of encoder downsample block and encoder Non-bottleneck block. + ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8, 10), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + ) + with pytest.raises(AssertionError): + # Number of encoder downsample block and + # channels of encoder Non-bottleneck block. + ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128, 256), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + ) + + with pytest.raises(AssertionError): + # Number of encoder Non-bottleneck block and number of its channels. + ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8, 3), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + ) + with pytest.raises(AssertionError): + # Number of decoder upsample block and decoder Non-bottleneck block. + ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2, 3), + dec_non_bottleneck_channels=(64, 16), + dropout_ratio=0.1, + ) + with pytest.raises(AssertionError): + # Number of decoder Non-bottleneck block and number of its channels. + ERFNet( + in_channels=3, + enc_downsample_channels=(16, 64, 128), + enc_num_stages_non_bottleneck=(5, 8), + enc_non_bottleneck_dilations=(2, 4, 8, 16), + enc_non_bottleneck_channels=(64, 128), + dec_upsample_channels=(64, 16), + dec_num_stages_non_bottleneck=(2, 2), + dec_non_bottleneck_channels=(64, 16, 8), + dropout_ratio=0.1, + ) + + +def test_erfnet_downsampler_block(): + x_db = DownsamplerBlock(16, 64) + assert x_db.conv.in_channels == 16 + assert x_db.conv.out_channels == 48 + assert len(x_db.bn.weight) == 64 + assert x_db.pool.kernel_size == 2 + assert x_db.pool.stride == 2 + + +def test_erfnet_non_bottleneck_1d(): + x_nb1d = non_bottleneck_1d(16, 0, 1) + assert x_nb1d.conv3x1_1.in_channels == 16 + assert x_nb1d.conv3x1_1.out_channels == 16 + assert x_nb1d.conv1x3_1.in_channels == 16 + assert x_nb1d.conv1x3_1.out_channels == 16 + assert x_nb1d.conv3x1_2.in_channels == 16 + assert x_nb1d.conv3x1_2.out_channels == 16 + assert x_nb1d.conv1x3_2.in_channels == 16 + assert x_nb1d.conv1x3_2.out_channels == 16 + assert x_nb1d.dropout.p == 0 + + +def test_erfnet_upsampler_block(): + x_ub = UpsamplerBlock(64, 16) + assert x_ub.conv.in_channels == 64 + assert x_ub.conv.out_channels == 16 + assert len(x_ub.bn.weight) == 16 diff --git a/tests/test_models/test_heads/test_erf_head.py b/tests/test_models/test_heads/test_erf_head.py new file mode 100644 index 0000000000..c35f4fb4d8 --- /dev/null +++ b/tests/test_models/test_heads/test_erf_head.py @@ -0,0 +1,17 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.models.decode_heads import ERFHead +from .utils import to_cuda + + +def test_erf_head(): + head = ERFHead(in_channels=16, channels=19, num_classes=19) + assert head.output_conv.in_channels == 16 + assert head.output_conv.out_channels == 19 + + inputs = torch.randn(1, 16, 45, 45).cuda() + if torch.cuda.is_available(): + head, inputs = to_cuda(head, inputs) + outputs = head(inputs) + assert outputs.shape == (1, head.num_classes, 90, 90) From 8512be7ec629205b011800e7885f09e751cf7f5d Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Thu, 14 Oct 2021 00:39:43 +0800 Subject: [PATCH 02/13] Fixing Unittest Error --- mmseg/models/backbones/erfnet.py | 2 +- mmseg/models/decode_heads/erf_head.py | 3 ++- tests/test_models/test_backbones/test_erfnet.py | 4 ++-- tests/test_models/test_heads/test_erf_head.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py index ad86ef0ec5..170f0416be 100644 --- a/mmseg/models/backbones/erfnet.py +++ b/mmseg/models/backbones/erfnet.py @@ -342,4 +342,4 @@ def forward(self, x): x = enc(x) for dec in self.decoder: x = dec(x) - return x + return [x] diff --git a/mmseg/models/decode_heads/erf_head.py b/mmseg/models/decode_heads/erf_head.py index 9af5d1fe42..f8fb46d9a9 100644 --- a/mmseg/models/decode_heads/erf_head.py +++ b/mmseg/models/decode_heads/erf_head.py @@ -29,5 +29,6 @@ def __init__(self, **kwargs): def forward(self, inputs): """Forward function.""" - output = self.output_conv(inputs) + x = self._transform_inputs(inputs) + output = self.output_conv(x) return output diff --git a/tests/test_models/test_backbones/test_erfnet.py b/tests/test_models/test_backbones/test_erfnet.py index f5bc6f4346..72b086c8ba 100644 --- a/tests/test_models/test_backbones/test_erfnet.py +++ b/tests/test_models/test_backbones/test_erfnet.py @@ -27,13 +27,13 @@ def test_erfnet_backbone(): output = model(imgs) # output for segment Head - assert output.shape == torch.Size([batch_size, 16, 128, 256]) + assert output[0].shape == torch.Size([batch_size, 16, 128, 256]) # Test input with rare shape batch_size = 2 imgs = torch.randn(batch_size, 3, 527, 279) output = model(imgs) - assert len(output) == batch_size + assert len(output[0]) == batch_size with pytest.raises(AssertionError): # Number of encoder downsample block and decoder upsample block. diff --git a/tests/test_models/test_heads/test_erf_head.py b/tests/test_models/test_heads/test_erf_head.py index c35f4fb4d8..6e44ab5b54 100644 --- a/tests/test_models/test_heads/test_erf_head.py +++ b/tests/test_models/test_heads/test_erf_head.py @@ -10,7 +10,7 @@ def test_erf_head(): assert head.output_conv.in_channels == 16 assert head.output_conv.out_channels == 19 - inputs = torch.randn(1, 16, 45, 45).cuda() + inputs = [torch.randn(1, 16, 45, 45)] if torch.cuda.is_available(): head, inputs = to_cuda(head, inputs) outputs = head(inputs) From 17fbaef2689d3328261d319c3714a02292e757c7 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Mon, 18 Oct 2021 22:47:25 +0800 Subject: [PATCH 03/13] first refactory of ERFNet --- mmseg/models/backbones/erfnet.py | 40 +++++++++---------- .../test_models/test_backbones/test_erfnet.py | 6 +-- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py index 170f0416be..ff4d03be7c 100644 --- a/mmseg/models/backbones/erfnet.py +++ b/mmseg/models/backbones/erfnet.py @@ -12,11 +12,11 @@ class DownsamplerBlock(BaseModule): """Downsampler block of ERFNet. This module is a little different from basical ConvModule. Concatenation - of Conv and MaxPool will be used before Batch Norm. + of Conv and MaxPool will be used before BatchNorm. Args: in_channels (int): Number of input channels. - out_channels (int): The number of output channels. + out_channels (int): Number of output channels. conv_cfg (dict | None): Config of conv layers. Default: None. norm_cfg (dict | None): Config of norm layers. @@ -45,8 +45,7 @@ def __init__(self, out_channels - in_channels, kernel_size=3, stride=2, - padding=1, - bias=True) + padding=1) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) self.bn = build_norm_layer(self.norm_cfg, out_channels)[1] self.act = build_activation_layer(self.act_cfg) @@ -65,14 +64,14 @@ def forward(self, input): return output -class non_bottleneck_1d(BaseModule): +class NonBottleneck1d(BaseModule): """Non-bottleneck block of ERFNet. Args: channels (int): Number of channels in Non-bottleneck block. drop_rate (float): Probability of an element to be zeroed. Default 0. - dilation (int): Dilation parameter for last two conv layers. + dilation (int): Dilation rate for last two conv layers. Default 1. conv_cfg (dict | None): Config of conv layers. Default: None. @@ -92,7 +91,7 @@ def __init__(self, norm_cfg=dict(type='BN', eps=1e-3), act_cfg=dict(type='ReLU'), init_cfg=None): - super(non_bottleneck_1d, self).__init__(init_cfg=init_cfg) + super(NonBottleneck1d, self).__init__(init_cfg=init_cfg) self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg @@ -133,7 +132,7 @@ def __init__(self, channels, kernel_size=(1, 3), stride=1, - padding=(0, 1 * dilation), + padding=(0, dilation), bias=True, dilation=(1, dilation)) @@ -163,7 +162,7 @@ class UpsamplerBlock(BaseModule): Args: in_channels (int): Number of input channels. - out_channels (int): The number of output channels. + out_channels (int): Number of output channels. conv_cfg (dict | None): Config of conv layers. Default: None. norm_cfg (dict | None): Config of norm layers. @@ -223,7 +222,7 @@ class ERFNet(BaseModule): Default: (5, 8). enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each stage of Non-bottleneck block of encoder. - Default: (1, 2, 4, 8, 16). + Default: (2, 4, 8, 16). enc_non_bottleneck_channels (Tuple[int]): Size of channel numbers of various Non-bottleneck block in encoder. Default: (64, 128). @@ -305,37 +304,34 @@ def __init__(self, self.encoder.append( DownsamplerBlock(enc_downsample_channels[i], enc_downsample_channels[i + 1])) - # Last part of encoder is some dilated non_bottleneck_1d blocks. + # Last part of encoder is some dilated NonBottleneck1d blocks. if i == len(enc_downsample_channels) - 2: iteration_times = int(enc_num_stages_non_bottleneck[-1] / len(enc_non_bottleneck_dilations)) for j in range(iteration_times): for k in range(len(enc_non_bottleneck_dilations)): self.encoder.append( - non_bottleneck_1d(enc_downsample_channels[-1], - self.dropout_ratio, - enc_non_bottleneck_dilations[k])) + NonBottleneck1d(enc_downsample_channels[-1], + self.dropout_ratio, + enc_non_bottleneck_dilations[k])) else: for j in range(enc_num_stages_non_bottleneck[i]): self.encoder.append( - non_bottleneck_1d(enc_downsample_channels[i + 1], - self.dropout_ratio)) + NonBottleneck1d(enc_downsample_channels[i + 1], + self.dropout_ratio)) for i in range(len(dec_upsample_channels)): if i == 0: self.decoder.append( UpsamplerBlock(enc_downsample_channels[-1], dec_non_bottleneck_channels[i])) - for j in range(dec_num_stages_non_bottleneck[i]): - self.decoder.append( - non_bottleneck_1d(dec_non_bottleneck_channels[i])) else: self.decoder.append( UpsamplerBlock(dec_non_bottleneck_channels[i - 1], dec_non_bottleneck_channels[i])) - for j in range(dec_num_stages_non_bottleneck[i]): - self.decoder.append( - non_bottleneck_1d(dec_non_bottleneck_channels[i])) + for j in range(dec_num_stages_non_bottleneck[i]): + self.decoder.append( + NonBottleneck1d(dec_non_bottleneck_channels[i])) def forward(self, x): for enc in self.encoder: diff --git a/tests/test_models/test_backbones/test_erfnet.py b/tests/test_models/test_backbones/test_erfnet.py index 72b086c8ba..60b392cc2c 100644 --- a/tests/test_models/test_backbones/test_erfnet.py +++ b/tests/test_models/test_backbones/test_erfnet.py @@ -3,8 +3,8 @@ import torch from mmseg.models.backbones import ERFNet -from mmseg.models.backbones.erfnet import (DownsamplerBlock, UpsamplerBlock, - non_bottleneck_1d) +from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d, + UpsamplerBlock) def test_erfnet_backbone(): @@ -127,7 +127,7 @@ def test_erfnet_downsampler_block(): def test_erfnet_non_bottleneck_1d(): - x_nb1d = non_bottleneck_1d(16, 0, 1) + x_nb1d = NonBottleneck1d(16, 0, 1) assert x_nb1d.conv3x1_1.in_channels == 16 assert x_nb1d.conv3x1_1.out_channels == 16 assert x_nb1d.conv1x3_1.in_channels == 16 From d196d426dd01150eb5ad169b4596648eb829561e Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Tue, 26 Oct 2021 03:08:32 +0800 Subject: [PATCH 04/13] Refactorying NonBottleneck1d Module --- mmseg/models/backbones/erfnet.py | 93 ++++++++----------- .../test_models/test_backbones/test_erfnet.py | 18 ++-- 2 files changed, 48 insertions(+), 63 deletions(-) diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py index ff4d03be7c..87ac64dd6b 100644 --- a/mmseg/models/backbones/erfnet.py +++ b/mmseg/models/backbones/erfnet.py @@ -96,63 +96,48 @@ def __init__(self, self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg - self.conv3x1_1 = build_conv_layer( - self.conv_cfg, - channels, - channels, - kernel_size=(3, 1), - stride=1, - padding=(1, 0), - bias=True) - - self.conv1x3_1 = build_conv_layer( - self.conv_cfg, - channels, - channels, - kernel_size=(1, 3), - stride=1, - padding=(0, 1), - bias=True) - - self.bn1 = build_norm_layer(self.norm_cfg, channels)[1] - - self.conv3x1_2 = build_conv_layer( - self.conv_cfg, - channels, - channels, - kernel_size=(3, 1), - stride=1, - padding=(1 * dilation, 0), - bias=True, - dilation=(dilation, 1)) - - self.conv1x3_2 = build_conv_layer( - self.conv_cfg, - channels, - channels, - kernel_size=(1, 3), - stride=1, - padding=(0, dilation), - bias=True, - dilation=(1, dilation)) - - self.bn2 = build_norm_layer(self.norm_cfg, channels)[1] self.act = build_activation_layer(self.act_cfg) - self.dropout = nn.Dropout(p=drop_rate) - - def forward(self, input): - output = self.conv3x1_1(input) - output = self.act(output) - output = self.conv1x3_1(output) - output = self.bn1(output) - output = self.act(output) - output = self.conv3x1_2(output) - output = self.act(output) - output = self.conv1x3_2(output) - output = self.bn2(output) + self.convs_layer = nn.ModuleList() + for conv_layer in range(2): + conv_first_padding = (1, 0) if conv_layer == 0 else (1 * dilation, + 0) + conv_first_dilation = 1 if conv_layer == 0 else (dilation, 1) + conv_second_padding = (0, 1) if conv_layer == 0 else (0, dilation) + conv_second_dilation = 1 if conv_layer == 0 else (1, dilation) + + self.convs_layer.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(3, 1), + stride=1, + padding=conv_first_padding, + bias=True, + dilation=conv_first_dilation)) + self.convs_layer.append(self.act) + self.convs_layer.append( + build_conv_layer( + self.conv_cfg, + channels, + channels, + kernel_size=(1, 3), + stride=1, + padding=conv_second_padding, + bias=True, + dilation=conv_second_dilation)) + self.convs_layer.append( + build_norm_layer(self.norm_cfg, channels)[1]) + if conv_layer == 0: + self.convs_layer.append(self.act) + else: + self.convs_layer.append(nn.Dropout(p=drop_rate)) - output = self.dropout(output) + def forward(self, input): + output = input + for op in self.convs_layer: + output = op(output) output = self.act(output + input) return output diff --git a/tests/test_models/test_backbones/test_erfnet.py b/tests/test_models/test_backbones/test_erfnet.py index 60b392cc2c..a963cd10f0 100644 --- a/tests/test_models/test_backbones/test_erfnet.py +++ b/tests/test_models/test_backbones/test_erfnet.py @@ -128,15 +128,15 @@ def test_erfnet_downsampler_block(): def test_erfnet_non_bottleneck_1d(): x_nb1d = NonBottleneck1d(16, 0, 1) - assert x_nb1d.conv3x1_1.in_channels == 16 - assert x_nb1d.conv3x1_1.out_channels == 16 - assert x_nb1d.conv1x3_1.in_channels == 16 - assert x_nb1d.conv1x3_1.out_channels == 16 - assert x_nb1d.conv3x1_2.in_channels == 16 - assert x_nb1d.conv3x1_2.out_channels == 16 - assert x_nb1d.conv1x3_2.in_channels == 16 - assert x_nb1d.conv1x3_2.out_channels == 16 - assert x_nb1d.dropout.p == 0 + assert x_nb1d.convs_layer[0].in_channels == 16 + assert x_nb1d.convs_layer[0].out_channels == 16 + assert x_nb1d.convs_layer[2].in_channels == 16 + assert x_nb1d.convs_layer[2].out_channels == 16 + assert x_nb1d.convs_layer[5].in_channels == 16 + assert x_nb1d.convs_layer[5].out_channels == 16 + assert x_nb1d.convs_layer[7].in_channels == 16 + assert x_nb1d.convs_layer[7].out_channels == 16 + assert x_nb1d.convs_layer[9].p == 0 def test_erfnet_upsampler_block(): From 02f4a269ec66df808037ee3b5205cf07c76b3455 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Mon, 8 Nov 2021 15:30:29 +0800 Subject: [PATCH 05/13] uploading models&logs --- .../models/{erfnet.py => fcn_erfnet.py} | 8 +++- configs/erfnet/README.md | 38 +++++++++++++++++++ configs/erfnet/erfnet.yml | 37 ++++++++++++++++++ .../erfnet_4x4_1024x1024_160k_cityscapes.py | 4 -- ...fcn_erfnet_4x4_512x1024_160k_cityscapes.py | 8 ++++ model-index.yml | 1 + 6 files changed, 90 insertions(+), 6 deletions(-) rename configs/_base_/models/{erfnet.py => fcn_erfnet.py} (85%) create mode 100644 configs/erfnet/README.md create mode 100644 configs/erfnet/erfnet.yml delete mode 100644 configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py create mode 100644 configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py diff --git a/configs/_base_/models/erfnet.py b/configs/_base_/models/fcn_erfnet.py similarity index 85% rename from configs/_base_/models/erfnet.py rename to configs/_base_/models/fcn_erfnet.py index 9b8de4e77a..a1367b3732 100644 --- a/configs/_base_/models/erfnet.py +++ b/configs/_base_/models/fcn_erfnet.py @@ -16,11 +16,15 @@ dropout_ratio=0.1, init_cfg=None), decode_head=dict( - type='ERFHead', + type='FCNHead', in_channels=16, - channels=19, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, + align_corners=False, loss_decode=dict( type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), # model training and testing settings diff --git a/configs/erfnet/README.md b/configs/erfnet/README.md new file mode 100644 index 0000000000..5966e02c2c --- /dev/null +++ b/configs/erfnet/README.md @@ -0,0 +1,38 @@ +# ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation + +## Introduction + + + +Official Repo + +Code Snippet + +
+ERFNet (T-ITS) + +```latex +@article{romera2017erfnet, + title={Erfnet: Efficient residual factorized convnet for real-time semantic segmentation}, + author={Romera, Eduardo and Alvarez, Jos{\'e} M and Bergasa, Luis M and Arroyo, Roberto}, + journal={IEEE Transactions on Intelligent Transportation Systems}, + volume={19}, + number={1}, + pages={263--272}, + year={2017}, + publisher={IEEE} +} +``` + +
+ +## Results and models + +### Cityscapes + +| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | +| --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| FCN | ERFNet | 512x1024 | 160000 | 16.40 | 2.16 | 71.4 | 72.96 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334-8f691334.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334.log.json) | +Note: + +- Last deconvolution layer in original paper is replaced by normal `FCN` decoder head and upsampling operation. diff --git a/configs/erfnet/erfnet.yml b/configs/erfnet/erfnet.yml new file mode 100644 index 0000000000..00f57cd6e3 --- /dev/null +++ b/configs/erfnet/erfnet.yml @@ -0,0 +1,37 @@ +Collections: +- Name: erfnet + Metadata: + Training Data: + - Cityscapes + Paper: + URL: http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17tits.pdf + Title: 'ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation' + README: configs/erfnet/README.md + Code: + URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/erfnet.py#L321 + Version: v0.20.0 + Converted From: + Code: https://github.com/Eromera/erfnet_pytorch +Models: +- Name: fcn_erfnet_4x4_512x1024_160k_cityscapes + In Collection: erfnet + Metadata: + backbone: ERFNet + crop size: (512,1024) + lr schd: 160000 + inference time (ms/im): + - value: 462.96 + hardware: V100 + backend: PyTorch + batch size: 1 + mode: FP32 + resolution: (512,1024) + memory (GB): 16.4 + Results: + - Task: Semantic Segmentation + Dataset: Cityscapes + Metrics: + mIoU: 71.4 + mIoU(ms+flip): 72.96 + Config: configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334-8f691334.pth diff --git a/configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py b/configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py deleted file mode 100644 index 4f21accb9c..0000000000 --- a/configs/erfnet/erfnet_4x4_1024x1024_160k_cityscapes.py +++ /dev/null @@ -1,4 +0,0 @@ -_base_ = [ - '../_base_/models/erfnet.py', '../_base_/datasets/cityscapes_1024x1024.py', - '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' -] diff --git a/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py b/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py new file mode 100644 index 0000000000..31f97c5180 --- /dev/null +++ b/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py @@ -0,0 +1,8 @@ +_base_ = [ + '../_base_/models/fcn_erfnet.py', '../_base_/datasets/cityscapes.py', + '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' +] +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +data = dict( + samples_per_gpu=4, + workers_per_gpu=4,) diff --git a/model-index.yml b/model-index.yml index 00da8d6a2a..31a14b6be7 100644 --- a/model-index.yml +++ b/model-index.yml @@ -13,6 +13,7 @@ Import: - configs/dpt/dpt.yml - configs/emanet/emanet.yml - configs/encnet/encnet.yml +- configs/erfnet/erfnet.yml - configs/fastfcn/fastfcn.yml - configs/fastscnn/fastscnn.yml - configs/fcn/fcn.yml From f08234251566643c1f18be22fdc9802b5ac35a7c Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Mon, 8 Nov 2021 15:32:59 +0800 Subject: [PATCH 06/13] uploading models&logs --- ...fcn_erfnet_4x4_512x1024_160k_cityscapes.py | 3 +- mmseg/models/decode_heads/__init__.py | 5 +-- mmseg/models/decode_heads/erf_head.py | 34 ------------------- tests/test_models/test_heads/test_erf_head.py | 17 ---------- 4 files changed, 3 insertions(+), 56 deletions(-) delete mode 100644 mmseg/models/decode_heads/erf_head.py delete mode 100644 tests/test_models/test_heads/test_erf_head.py diff --git a/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py b/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py index 31f97c5180..0f6e20d6a2 100644 --- a/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py +++ b/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py @@ -5,4 +5,5 @@ optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) data = dict( samples_per_gpu=4, - workers_per_gpu=4,) + workers_per_gpu=4, +) diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index 050ce5ea71..e18527ca71 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -9,7 +9,6 @@ from .dpt_head import DPTHead from .ema_head import EMAHead from .enc_head import EncHead -from .erf_head import ERFHead from .fcn_head import FCNHead from .fpn_head import FPNHead from .gc_head import GCHead @@ -32,6 +31,4 @@ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', - 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead', - 'ERFHead' -] + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead'] diff --git a/mmseg/models/decode_heads/erf_head.py b/mmseg/models/decode_heads/erf_head.py deleted file mode 100644 index f8fb46d9a9..0000000000 --- a/mmseg/models/decode_heads/erf_head.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.nn as nn - -from ..builder import HEADS -from .decode_head import BaseDecodeHead - - -@HEADS.register_module() -class ERFHead(BaseDecodeHead): - """ERFNet backbone. - - This decoder head is the implementation of `ERFNet: Efficient - Residual Factorized ConvNet for Real-time SemanticSegmentation - `_. - - Actually it is one ConvTranspose2d operation. - """ - - def __init__(self, **kwargs): - super(ERFHead, self).__init__(**kwargs) - self.output_conv = nn.ConvTranspose2d( - in_channels=self.in_channels, - out_channels=self.channels, - kernel_size=2, - stride=2, - padding=0, - output_padding=0, - bias=True) - - def forward(self, inputs): - """Forward function.""" - x = self._transform_inputs(inputs) - output = self.output_conv(x) - return output diff --git a/tests/test_models/test_heads/test_erf_head.py b/tests/test_models/test_heads/test_erf_head.py deleted file mode 100644 index 6e44ab5b54..0000000000 --- a/tests/test_models/test_heads/test_erf_head.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch - -from mmseg.models.decode_heads import ERFHead -from .utils import to_cuda - - -def test_erf_head(): - head = ERFHead(in_channels=16, channels=19, num_classes=19) - assert head.output_conv.in_channels == 16 - assert head.output_conv.out_channels == 19 - - inputs = [torch.randn(1, 16, 45, 45)] - if torch.cuda.is_available(): - head, inputs = to_cuda(head, inputs) - outputs = head(inputs) - assert outputs.shape == (1, head.num_classes, 90, 90) From cd7c8147808f44e2a5a3dac6a9691fdbfe87160e Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Tue, 9 Nov 2021 18:05:05 +0800 Subject: [PATCH 07/13] fix partial bugs & typos --- .../models/{fcn_erfnet.py => erfnet_fcn.py} | 0 configs/erfnet/README.md | 2 +- configs/erfnet/erfnet.yml | 6 ++--- ...rfnet_fcn_4x4_512x1024_160k_cityscapes.py} | 3 +-- mmseg/models/backbones/erfnet.py | 27 ++++++++++--------- mmseg/models/decode_heads/__init__.py | 3 ++- 6 files changed, 22 insertions(+), 19 deletions(-) rename configs/_base_/models/{fcn_erfnet.py => erfnet_fcn.py} (100%) rename configs/erfnet/{fcn_erfnet_4x4_512x1024_160k_cityscapes.py => erfnet_fcn_4x4_512x1024_160k_cityscapes.py} (53%) diff --git a/configs/_base_/models/fcn_erfnet.py b/configs/_base_/models/erfnet_fcn.py similarity index 100% rename from configs/_base_/models/fcn_erfnet.py rename to configs/_base_/models/erfnet_fcn.py diff --git a/configs/erfnet/README.md b/configs/erfnet/README.md index 5966e02c2c..ff68d1e782 100644 --- a/configs/erfnet/README.md +++ b/configs/erfnet/README.md @@ -35,4 +35,4 @@ | FCN | ERFNet | 512x1024 | 160000 | 16.40 | 2.16 | 71.4 | 72.96 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334-8f691334.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334.log.json) | Note: -- Last deconvolution layer in original paper is replaced by normal `FCN` decoder head and upsampling operation. +- Last deconvolution layer in the original paper is replaced by a naive `FCN` decoder head and a bilinear upsampling layer. diff --git a/configs/erfnet/erfnet.yml b/configs/erfnet/erfnet.yml index 00f57cd6e3..322d8ceb85 100644 --- a/configs/erfnet/erfnet.yml +++ b/configs/erfnet/erfnet.yml @@ -13,7 +13,7 @@ Collections: Converted From: Code: https://github.com/Eromera/erfnet_pytorch Models: -- Name: fcn_erfnet_4x4_512x1024_160k_cityscapes +- Name: '' In Collection: erfnet Metadata: backbone: ERFNet @@ -33,5 +33,5 @@ Models: Metrics: mIoU: 71.4 mIoU(ms+flip): 72.96 - Config: configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334-8f691334.pth + Config: '' + Weights: '' diff --git a/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py b/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py similarity index 53% rename from configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py rename to configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py index 0f6e20d6a2..8cb8e51492 100644 --- a/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py +++ b/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py @@ -1,8 +1,7 @@ _base_ = [ - '../_base_/models/fcn_erfnet.py', '../_base_/datasets/cityscapes.py', + '../_base_/models/erfnet_fcn.py', '../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' ] -optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) data = dict( samples_per_gpu=4, workers_per_gpu=4, diff --git a/mmseg/models/backbones/erfnet.py b/mmseg/models/backbones/erfnet.py index 87ac64dd6b..4f2427461b 100644 --- a/mmseg/models/backbones/erfnet.py +++ b/mmseg/models/backbones/erfnet.py @@ -11,8 +11,9 @@ class DownsamplerBlock(BaseModule): """Downsampler block of ERFNet. - This module is a little different from basical ConvModule. Concatenation - of Conv and MaxPool will be used before BatchNorm. + This module is a little different from basical ConvModule. + The features from Conv and MaxPool layers are + concatenated before BatchNorm. Args: in_channels (int): Number of input channels. @@ -73,6 +74,8 @@ class NonBottleneck1d(BaseModule): Default 0. dilation (int): Dilation rate for last two conv layers. Default 1. + num_conv_layer (int): Number of 3x1 and 1x3 convolution layers. + Default 2. conv_cfg (dict | None): Config of conv layers. Default: None. norm_cfg (dict | None): Config of norm layers. @@ -87,6 +90,7 @@ def __init__(self, channels, drop_rate=0, dilation=1, + num_conv_layer=2, conv_cfg=None, norm_cfg=dict(type='BN', eps=1e-3), act_cfg=dict(type='ReLU'), @@ -99,12 +103,11 @@ def __init__(self, self.act = build_activation_layer(self.act_cfg) self.convs_layer = nn.ModuleList() - for conv_layer in range(2): - conv_first_padding = (1, 0) if conv_layer == 0 else (1 * dilation, - 0) - conv_first_dilation = 1 if conv_layer == 0 else (dilation, 1) - conv_second_padding = (0, 1) if conv_layer == 0 else (0, dilation) - conv_second_dilation = 1 if conv_layer == 0 else (1, dilation) + for conv_layer in range(num_conv_layer): + first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0) + first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1) + second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation) + second_conv_dilation = 1 if conv_layer == 0 else (1, dilation) self.convs_layer.append( build_conv_layer( @@ -113,9 +116,9 @@ def __init__(self, channels, kernel_size=(3, 1), stride=1, - padding=conv_first_padding, + padding=first_conv_padding, bias=True, - dilation=conv_first_dilation)) + dilation=first_conv_dilation)) self.convs_layer.append(self.act) self.convs_layer.append( build_conv_layer( @@ -124,9 +127,9 @@ def __init__(self, channels, kernel_size=(1, 3), stride=1, - padding=conv_second_padding, + padding=second_conv_padding, bias=True, - dilation=conv_second_dilation)) + dilation=second_conv_dilation)) self.convs_layer.append( build_norm_layer(self.norm_cfg, channels)[1]) if conv_layer == 0: diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py index e18527ca71..14a2b2d6f1 100644 --- a/mmseg/models/decode_heads/__init__.py +++ b/mmseg/models/decode_heads/__init__.py @@ -31,4 +31,5 @@ 'UPerHead', 'DepthwiseSeparableASPPHead', 'ANNHead', 'DAHead', 'OCRHead', 'EncHead', 'DepthwiseSeparableFCNHead', 'FPNHead', 'EMAHead', 'DNLHead', 'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead', - 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead'] + 'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegformerHead', 'ISAHead' +] From d7c09d4363957e53882f376517ded849d5aadd60 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Fri, 26 Nov 2021 01:17:10 +0800 Subject: [PATCH 08/13] ERFNet --- configs/_base_/models/erfnet.py | 4 ++-- configs/erfnet/README.md | 16 ++++++++++++++-- configs/erfnet/erfnet.yml | 10 +++++----- .../small_erfnet_4x4_512x1024_160k_cityscapes.py | 2 -- 4 files changed, 21 insertions(+), 11 deletions(-) delete mode 100644 configs/erfnet/small_erfnet_4x4_512x1024_160k_cityscapes.py diff --git a/configs/_base_/models/erfnet.py b/configs/_base_/models/erfnet.py index e6f5ca5bae..df195420ef 100644 --- a/configs/_base_/models/erfnet.py +++ b/configs/_base_/models/erfnet.py @@ -19,8 +19,8 @@ type='DepthwiseSeparableFCNHead', in_channels=16, channels=128, - num_convs=2, - concat_input=True, + num_convs=1, + concat_input=False, num_classes=19, norm_cfg=norm_cfg, align_corners=False, diff --git a/configs/erfnet/README.md b/configs/erfnet/README.md index ff68d1e782..b58e13d0d7 100644 --- a/configs/erfnet/README.md +++ b/configs/erfnet/README.md @@ -8,6 +8,15 @@ Code Snippet +## Abstract + +Semantic segmentation is a challenging task that addresses most of the perception needs of intelligent vehicles (IVs) in an unified way. Deep neural networks excel at this task, as they can be trained end-to-end to accurately classify multiple object categories in an image at pixel level. However, a good tradeoff between high quality and computational resources is yet not present in the state-of-the-art semantic segmentation approaches, limiting their application in real vehicles. In this paper, we propose a deep architecture that is able to run in real time while providing accurate semantic segmentation. The core of our architecture is a novel layer that uses residual connections and factorized convolutions in order to remain efficient while retaining remarkable accuracy. Our approach is able to run at over 83 FPS in a single Titan X, and 7 FPS in a Jetson TX1 (embedded device). A comprehensive set of experiments on the publicly available Cityscapes data set demonstrates that our system achieves an accuracy that is similar to the state of the art, while being orders of magnitude faster to compute than other architectures that achieve top precision. The resulting tradeoff makes our model an ideal approach for scene understanding in IV applications. The code is publicly available at: https://github.com/Eromera/erfnet. + + +
+ +
+
ERFNet (T-ITS) @@ -32,7 +41,10 @@ | Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| FCN | ERFNet | 512x1024 | 160000 | 16.40 | 2.16 | 71.4 | 72.96 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334-8f691334.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334.log.json) | +| ERFNet | ERFNet | 512x1024 | 160000 | 6.10 | 15.17 | 70.8 | 72.09 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_4x4_512x1024_160k_cityscapes/erfnet_4x4_512x1024_160k_cityscapes_20211123_021608-80bcacef.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_4x4_512x1024_160k_cityscapes/erfnet_4x4_512x1024_160k_cityscapes_20211123_021608.log.json) | + Note: -- Last deconvolution layer in the original paper is replaced by a naive `FCN` decoder head and a bilinear upsampling layer. +- The model is trained from scratch. + +- Last deconvolution layer in the [original paper](https://github.com/Eromera/erfnet_pytorch/blob/master/train/erfnet.py#L123) is replaced by a naive `DepthwiseSeparableFCNHead` decoder head and a bilinear upsampling layer, found more effective and efficient. diff --git a/configs/erfnet/erfnet.yml b/configs/erfnet/erfnet.yml index a538fe6bef..06e8c6b27a 100644 --- a/configs/erfnet/erfnet.yml +++ b/configs/erfnet/erfnet.yml @@ -20,18 +20,18 @@ Models: crop size: (512,1024) lr schd: 160000 inference time (ms/im): - - value: 462.96 + - value: 65.92 hardware: V100 backend: PyTorch batch size: 1 mode: FP32 resolution: (512,1024) - memory (GB): 16.4 + memory (GB): 6.1 Results: - Task: Semantic Segmentation Dataset: Cityscapes Metrics: - mIoU: 71.4 - mIoU(ms+flip): 72.96 + mIoU: 70.8 + mIoU(ms+flip): 72.09 Config: configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/fcn_erfnet_4x4_512x1024_160k_cityscapes/fcn_erfnet_4x4_512x1024_160k_cityscapes_20211103_011334-8f691334.pth + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_4x4_512x1024_160k_cityscapes/erfnet_4x4_512x1024_160k_cityscapes_20211123_021608-80bcacef.pth diff --git a/configs/erfnet/small_erfnet_4x4_512x1024_160k_cityscapes.py b/configs/erfnet/small_erfnet_4x4_512x1024_160k_cityscapes.py deleted file mode 100644 index b000788c03..0000000000 --- a/configs/erfnet/small_erfnet_4x4_512x1024_160k_cityscapes.py +++ /dev/null @@ -1,2 +0,0 @@ -_base_ = './erfnet_4x4_512x1024_160k_cityscapes.py' -model = dict(decode_head=dict(num_convs=1, concat_input=False)) From f17ecc00e1d0b763a04631b0410884cfb08181f4 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Sat, 27 Nov 2021 04:44:46 +0800 Subject: [PATCH 09/13] add ERFNet with FCNHead --- configs/_base_/models/{erfnet.py => erfnet_fcn.py} | 3 ++- configs/erfnet/README.md | 4 ++-- configs/erfnet/erfnet.yml | 14 +++++++------- ... => erfnet_fcn_4x4_512x1024_160k_cityscapes.py} | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) rename configs/_base_/models/{erfnet.py => erfnet_fcn.py} (94%) rename configs/erfnet/{erfnet_4x4_512x1024_160k_cityscapes.py => erfnet_fcn_4x4_512x1024_160k_cityscapes.py} (66%) diff --git a/configs/_base_/models/erfnet.py b/configs/_base_/models/erfnet_fcn.py similarity index 94% rename from configs/_base_/models/erfnet.py rename to configs/_base_/models/erfnet_fcn.py index df195420ef..7f2e9bff8d 100644 --- a/configs/_base_/models/erfnet.py +++ b/configs/_base_/models/erfnet_fcn.py @@ -16,11 +16,12 @@ dropout_ratio=0.1, init_cfg=None), decode_head=dict( - type='DepthwiseSeparableFCNHead', + type='FCNHead', in_channels=16, channels=128, num_convs=1, concat_input=False, + dropout_ratio=0.1, num_classes=19, norm_cfg=norm_cfg, align_corners=False, diff --git a/configs/erfnet/README.md b/configs/erfnet/README.md index b58e13d0d7..06a1b957d1 100644 --- a/configs/erfnet/README.md +++ b/configs/erfnet/README.md @@ -41,10 +41,10 @@ Semantic segmentation is a challenging task that addresses most of the perceptio | Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| ERFNet | ERFNet | 512x1024 | 160000 | 6.10 | 15.17 | 70.8 | 72.09 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_4x4_512x1024_160k_cityscapes/erfnet_4x4_512x1024_160k_cityscapes_20211123_021608-80bcacef.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_4x4_512x1024_160k_cityscapes/erfnet_4x4_512x1024_160k_cityscapes_20211123_021608.log.json) | +| ERFNet | ERFNet | 512x1024 | 160000 | 6.04 | 15.26 | 71.08 | 72.6 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056.log.json) | Note: - The model is trained from scratch. -- Last deconvolution layer in the [original paper](https://github.com/Eromera/erfnet_pytorch/blob/master/train/erfnet.py#L123) is replaced by a naive `DepthwiseSeparableFCNHead` decoder head and a bilinear upsampling layer, found more effective and efficient. +- Last deconvolution layer in the [original paper](https://github.com/Eromera/erfnet_pytorch/blob/master/train/erfnet.py#L123) is replaced by a naive `FCNHead` decoder head and a bilinear upsampling layer, found more effective and efficient. diff --git a/configs/erfnet/erfnet.yml b/configs/erfnet/erfnet.yml index 06e8c6b27a..8dbda2a176 100644 --- a/configs/erfnet/erfnet.yml +++ b/configs/erfnet/erfnet.yml @@ -13,25 +13,25 @@ Collections: Converted From: Code: https://github.com/Eromera/erfnet_pytorch Models: -- Name: erfnet_4x4_512x1024_160k_cityscapes +- Name: erfnet_fcn_4x4_512x1024_160k_cityscapes In Collection: erfnet Metadata: backbone: ERFNet crop size: (512,1024) lr schd: 160000 inference time (ms/im): - - value: 65.92 + - value: 65.53 hardware: V100 backend: PyTorch batch size: 1 mode: FP32 resolution: (512,1024) - memory (GB): 6.1 + memory (GB): 6.04 Results: - Task: Semantic Segmentation Dataset: Cityscapes Metrics: - mIoU: 70.8 - mIoU(ms+flip): 72.09 - Config: configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py - Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_4x4_512x1024_160k_cityscapes/erfnet_4x4_512x1024_160k_cityscapes_20211123_021608-80bcacef.pth + mIoU: 71.08 + mIoU(ms+flip): 72.6 + Config: configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py + Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth diff --git a/configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py b/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py similarity index 66% rename from configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py rename to configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py index 5fed9fc87a..8cb8e51492 100644 --- a/configs/erfnet/erfnet_4x4_512x1024_160k_cityscapes.py +++ b/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py @@ -1,5 +1,5 @@ _base_ = [ - '../_base_/models/erfnet.py', '../_base_/datasets/cityscapes.py', + '../_base_/models/erfnet_fcn.py', '../_base_/datasets/cityscapes.py', '../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py' ] data = dict( From 84bf2d29e2c86b8ced0efa02f7cc1f26f966857e Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Sat, 27 Nov 2021 05:44:06 +0800 Subject: [PATCH 10/13] fix typos of ERFNet --- configs/erfnet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/erfnet/README.md b/configs/erfnet/README.md index 06a1b957d1..aacb37a6b6 100644 --- a/configs/erfnet/README.md +++ b/configs/erfnet/README.md @@ -41,7 +41,7 @@ Semantic segmentation is a challenging task that addresses most of the perceptio | Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download | | --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| ERFNet | ERFNet | 512x1024 | 160000 | 6.04 | 15.26 | 71.08 | 72.6 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056.log.json) | +| FCN | ERFNet | 512x1024 | 160000 | 6.04 | 15.26 | 71.08 | 72.6 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056.log.json) | Note: From 57ea3b733f479c8ca1e7012eab820241cd01f07b Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Sun, 28 Nov 2021 16:53:40 +0800 Subject: [PATCH 11/13] add name on README.md cover --- README.md | 1 + README_zh-CN.md | 1 + 2 files changed, 2 insertions(+) diff --git a/README.md b/README.md index 8f4d8573c6..698fd3a178 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ Supported backbones: Supported methods: - [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn) +- [x] [ERFNet (T'ITS'2017)](configs/erfnet) - [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) diff --git a/README_zh-CN.md b/README_zh-CN.md index fa48aff05c..bfc55e8181 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -69,6 +69,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O 已支持的算法: - [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn) +- [x] [ERFNet (T'ITS'2017)](configs/erfnet) - [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) From 4b027433797e0a5184c997973e0521367379747e Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Tue, 30 Nov 2021 20:59:09 +0800 Subject: [PATCH 12/13] chane name to T-ITS'2017 --- README.md | 2 +- README_zh-CN.md | 2 +- configs/erfnet/README.md | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 698fd3a178..63e717bc5f 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,7 @@ Supported backbones: Supported methods: - [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn) -- [x] [ERFNet (T'ITS'2017)](configs/erfnet) +- [x] [ERFNet (T-ITS'2017)](configs/erfnet) - [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) diff --git a/README_zh-CN.md b/README_zh-CN.md index bfc55e8181..1a9ea5e2e6 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -69,7 +69,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O 已支持的算法: - [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn) -- [x] [ERFNet (T'ITS'2017)](configs/erfnet) +- [x] [ERFNet (T-ITS'2017)](configs/erfnet) - [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet) - [x] [PSPNet (CVPR'2017)](configs/pspnet) - [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3) diff --git a/configs/erfnet/README.md b/configs/erfnet/README.md index aacb37a6b6..6d7477c37b 100644 --- a/configs/erfnet/README.md +++ b/configs/erfnet/README.md @@ -18,7 +18,7 @@ Semantic segmentation is a challenging task that addresses most of the perceptio
-ERFNet (T-ITS) +ERFNet (T-ITS'2017) ```latex @article{romera2017erfnet, From efee39b20fca40bb9e47bd17b6c86f728e6bed23 Mon Sep 17 00:00:00 2001 From: MengzhangLI Date: Wed, 1 Dec 2021 16:57:00 +0800 Subject: [PATCH 13/13] fix lint error --- configs/erfnet/erfnet.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/erfnet/erfnet.yml b/configs/erfnet/erfnet.yml index 8dbda2a176..f0d8fb7bbf 100644 --- a/configs/erfnet/erfnet.yml +++ b/configs/erfnet/erfnet.yml @@ -26,7 +26,7 @@ Models: batch size: 1 mode: FP32 resolution: (512,1024) - memory (GB): 6.04 + Training Memory (GB): 6.04 Results: - Task: Semantic Segmentation Dataset: Cityscapes