From 9f4c1c1bc281df23028042cdb675dd64509f8a3a Mon Sep 17 00:00:00 2001 From: ETTR123 <740580207@qq.com> Date: Sun, 12 Dec 2021 17:18:10 +0800 Subject: [PATCH 01/12] add enet --- configs/enet/README.md | 13 + ...net_cityscapes_1024x512_adam_0.002_80k.yml | 47 ++ paddleseg/models/__init__.py | 1 + paddleseg/models/enet.py | 631 ++++++++++++++++++ 4 files changed, 692 insertions(+) create mode 100644 configs/enet/README.md create mode 100644 configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml create mode 100644 paddleseg/models/enet.py diff --git a/configs/enet/README.md b/configs/enet/README.md new file mode 100644 index 0000000000..b04ba2fa40 --- /dev/null +++ b/configs/enet/README.md @@ -0,0 +1,13 @@ +# ENet: A Deep Neural Network Architecture forReal-Time Semantic Segmentation + +## Reference +> Abhishek Chaurasia, Sangpil Kim, Eugenio Culurciello. "ENet: A Deep Neural Network Architecture for +Real-Time Semantic Segmentation." arXiv preprint arXiv:1606.02147(2016). + +## Performance + +### Cityscapes + +| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|ENet|-|1024x512|80000|58.3%|-|-|[]| diff --git a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml new file mode 100644 index 0000000000..d0eedbe509 --- /dev/null +++ b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml @@ -0,0 +1,47 @@ +iters: 80000 +batch_size: 8 + +train_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: ResizeStepScaling + min_scale_factor: 0.5 + max_scale_factor: 2.0 + scale_step_size: 0.25 + - type: RandomPaddingCrop + crop_size: [1024, 512] + - type: RandomHorizontalFlip + - type: RandomDistort + brightness_range: 0.4 + contrast_range: 0.4 + saturation_range: 0.4 + - type: Normalize + mode: train + +val_dataset: + type: Cityscapes + dataset_root: data/cityscapes + transforms: + - type: Normalize + mode: val + +model: + type: ENet + num_classes: 19 + pretrained: Null + +optimizer: + type: adam + weight_decay: 0.0002 + +lr_scheduler: + type: PolynomialDecay + learning_rate: 0.001 + end_lr: 0 + power: 0.9 + +loss: + types: + - type: CrossEntropyLoss + coef: [1] diff --git a/paddleseg/models/__init__.py b/paddleseg/models/__init__.py index e7b186dc17..08ea770c18 100644 --- a/paddleseg/models/__init__.py +++ b/paddleseg/models/__init__.py @@ -48,3 +48,4 @@ from .segnet import SegNet from .hrnet_contrast import HRNetW48Contrast from .espnet import ESPNetV2 +from .enet import ENet diff --git a/paddleseg/models/enet.py b/paddleseg/models/enet.py new file mode 100644 index 0000000000..7528f93cb1 --- /dev/null +++ b/paddleseg/models/enet.py @@ -0,0 +1,631 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddleseg import utils +from paddleseg.models import layers +from paddleseg.cvlibs import manager, param_init + +__all__ = ['ENet'] + + +@manager.MODELS.add_component +class ENet(nn.Layer): + """ + The ENet implementation based on PaddlePaddle. + The original article refers to + Adam Paszke, Abhishek Chaurasia, Sangpil Kim, Eugenio Culurciello, et al."ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation" + (https://arxiv.org/abs/1606.02147). + + Args: + num_classes (int): The unique number of target classes. + pretrained (str, optional): The path or url of pretrained model. Default: None. + encoder_relu (bool, optional): When ``True`` ReLU is used as the activation + function; otherwise, PReLU is used. Default: False. + decoder_relu (bool, optional): When ``True`` ReLU is used as the activation + function; otherwise, PReLU is used. Default: True. + """ + def __init__(self, + num_classes, + pretrained=None, + encoder_relu=False, + decoder_relu=True): + super(ENet, self).__init__() + + self.numclasses = num_classes + self.initial_block = InitialBlock(3, 16, relu=encoder_relu) + + # Stage 1 - Encoder + self.downsample1_0 = DownsamplingBottleneck(16, + 64, + return_indices=True, + dropout_prob=0.01, + relu=encoder_relu) + self.regular1_1 = RegularBottleneck(64, + padding=1, + dropout_prob=0.01, + relu=encoder_relu) + self.regular1_2 = RegularBottleneck(64, + padding=1, + dropout_prob=0.01, + relu=encoder_relu) + self.regular1_3 = RegularBottleneck(64, + padding=1, + dropout_prob=0.01, + relu=encoder_relu) + self.regular1_4 = RegularBottleneck(64, + padding=1, + dropout_prob=0.01, + relu=encoder_relu) + + # Stage 2 - Encoder + self.downsample2_0 = DownsamplingBottleneck(64, + 128, + return_indices=True, + dropout_prob=0.1, + relu=encoder_relu) + self.regular2_1 = RegularBottleneck(128, + padding=1, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated2_2 = RegularBottleneck(128, + dilation=2, + padding=2, + dropout_prob=0.1, + relu=encoder_relu) + self.asymmetric2_3 = RegularBottleneck(128, + kernel_size=5, + padding=2, + asymmetric=True, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated2_4 = RegularBottleneck(128, + dilation=4, + padding=4, + dropout_prob=0.1, + relu=encoder_relu) + self.regular2_5 = RegularBottleneck(128, + padding=1, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated2_6 = RegularBottleneck(128, + dilation=8, + padding=8, + dropout_prob=0.1, + relu=encoder_relu) + self.asymmetric2_7 = RegularBottleneck(128, + kernel_size=5, + asymmetric=True, + padding=2, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated2_8 = RegularBottleneck(128, + dilation=16, + padding=16, + dropout_prob=0.1, + relu=encoder_relu) + + # Stage 3 - Encoder + self.regular3_0 = RegularBottleneck(128, + padding=1, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated3_1 = RegularBottleneck(128, + dilation=2, + padding=2, + dropout_prob=0.1, + relu=encoder_relu) + self.asymmetric3_2 = RegularBottleneck(128, + kernel_size=5, + padding=2, + asymmetric=True, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated3_3 = RegularBottleneck(128, + dilation=4, + padding=4, + dropout_prob=0.1, + relu=encoder_relu) + self.regular3_4 = RegularBottleneck(128, + padding=1, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated3_5 = RegularBottleneck(128, + dilation=8, + padding=8, + dropout_prob=0.1, + relu=encoder_relu) + self.asymmetric3_6 = RegularBottleneck(128, + kernel_size=5, + asymmetric=True, + padding=2, + dropout_prob=0.1, + relu=encoder_relu) + self.dilated3_7 = RegularBottleneck(128, + dilation=16, + padding=16, + dropout_prob=0.1, + relu=encoder_relu) + + # Stage 4 - Decoder + self.upsample4_0 = UpsamplingBottleneck(128, + 64, + dropout_prob=0.1, + relu=decoder_relu) + self.regular4_1 = RegularBottleneck(64, + padding=1, + dropout_prob=0.1, + relu=decoder_relu) + self.regular4_2 = RegularBottleneck(64, + padding=1, + dropout_prob=0.1, + relu=decoder_relu) + + # Stage 5 - Decoder + self.upsample5_0 = UpsamplingBottleneck(64, + 16, + dropout_prob=0.1, + relu=decoder_relu) + self.regular5_1 = RegularBottleneck(16, + padding=1, + dropout_prob=0.1, + relu=decoder_relu) + self.transposed_conv = nn.Conv2DTranspose(16, + num_classes, + kernel_size=3, + stride=2, + padding=1, + bias_attr=False) + + self.pretrained = pretrained + self.init_weight() + + def forward(self, x): + + input_size = x.shape + x = self.initial_block(x) + + # Stage 1 - Encoder + stage1_input_size = x.shape + x, max_indices1_0 = self.downsample1_0(x) + x = self.regular1_1(x) + x = self.regular1_2(x) + x = self.regular1_3(x) + x = self.regular1_4(x) + + # Stage 2 - Encoder + stage2_input_size = x.shape + x, max_indices2_0 = self.downsample2_0(x) + x = self.regular2_1(x) + x = self.dilated2_2(x) + x = self.asymmetric2_3(x) + x = self.dilated2_4(x) + x = self.regular2_5(x) + x = self.dilated2_6(x) + x = self.asymmetric2_7(x) + x = self.dilated2_8(x) + + # Stage 3 - Encoder + x = self.regular3_0(x) + x = self.dilated3_1(x) + x = self.asymmetric3_2(x) + x = self.dilated3_3(x) + x = self.regular3_4(x) + x = self.dilated3_5(x) + x = self.asymmetric3_6(x) + x = self.dilated3_7(x) + + # Stage 4 - Decoder + x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size) + x = self.regular4_1(x) + x = self.regular4_2(x) + + # Stage 5 - Decoder + x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size) + x = self.regular5_1(x) + x = self.transposed_conv(x, output_size=input_size[2:]) + return [x] + + def init_weight(self): + if self.pretrained is not None: + utils.load_pretrained_model(self, self.pretrained) + + +class InitialBlock(nn.Layer): + """ + The initial block is composed of two branches: + 1. a main branch which performs a regular convolution with stride 2; + 2. an extension branch which performs max-pooling. + Doing both operations in parallel and concatenating their results + allows for efficient downsampling and expansion. The main branch + outputs 13 feature maps while the extension branch outputs 3, for a + total of 16 feature maps after concatenation. + Args: + in_channels (int): the number of input channels. + out_channels (int): the number output channels. + kernel_size (int, optional): the kernel size of the filters used in + the convolution layer. Default: 3. + padding (int, optional): zero-padding added to both sides of the + input. Default: 0. + bias (bool, optional): Adds a learnable bias to the output if + ``True``. Default: False. + relu (bool, optional): When ``True`` ReLU is used as the activation + function; otherwise, PReLU is used. Default: True. + """ + def __init__(self, in_channels, out_channels, bias=False, relu=True): + super(InitialBlock, self).__init__() + + if relu: + activation = nn.ReLU + else: + activation = nn.PReLU + + self.main_branch = nn.Conv2D(in_channels, + out_channels - 3, + kernel_size=3, + stride=2, + padding=1, + bias_attr=bias) + + self.ext_branch = nn.MaxPool2D(3, stride=2, padding=1) + + self.batch_norm = layers.SyncBatchNorm(out_channels) + + self.out_activation = activation() + + def forward(self, x): + main = self.main_branch(x) + ext = self.ext_branch(x) + + out = paddle.concat((main, ext), 1) + + out = self.batch_norm(out) + + return self.out_activation(out) + + +class RegularBottleneck(nn.Layer): + """ + Regular bottlenecks are the main building block of ENet. + Main branch: + 1. Shortcut connection. + Extension branch: + 1. 1x1 convolution which decreases the number of channels by + ``internal_ratio``, also called a projection; + 2. regular, dilated or asymmetric convolution; + 3. 1x1 convolution which increases the number of channels back to + ``channels``, also called an expansion; + 4. dropout as a regularizer. + Args: + channels (int): the number of input and output channels. + internal_ratio (int, optional): a scale factor applied to + ``channels`` used to compute the number of + channels after the projection. eg. given ``channels`` equal to 128 and + internal_ratio equal to 2 the number of channels after the projection + is 64. Default: 4. + kernel_size (int, optional): the kernel size of the filters used in + the convolution layer described above in item 2 of the extension + branch. Default: 3. + padding (int, optional): zero-padding added to both sides of the + input. Default: 0. + dilation (int, optional): spacing between kernel elements for the + convolution described in item 2 of the extension branch. Default: 1. + asymmetric (bool, optional): flags if the convolution described in + item 2 of the extension branch is asymmetric or not. Default: False. + dropout_prob (float, optional): probability of an element to be + zeroed. Default: 0 (no dropout). + bias (bool, optional): Adds a learnable bias to the output if + ``True``. Default: False. + relu (bool, optional): When ``True`` ReLU is used as the activation + function; otherwise, PReLU is used. Default: True. + """ + def __init__(self, + channels, + internal_ratio=4, + kernel_size=3, + padding=0, + dilation=1, + asymmetric=False, + dropout_prob=0, + bias=False, + relu=True): + super(RegularBottleneck, self).__init__() + + if internal_ratio <= 1 or internal_ratio > channels: + raise RuntimeError( + "Value out of range. Expected value in the " + "interval [1, {0}], got internal_scale={1}.".format( + channels, internal_ratio)) + + internal_channels = channels // internal_ratio + + if relu: + activation = nn.ReLU + else: + activation = nn.PReLU + + self.ext_conv1 = nn.Sequential( + nn.Conv2D(channels, + internal_channels, + kernel_size=1, + stride=1, + bias_attr=bias), layers.SyncBatchNorm(internal_channels), + activation()) + + if asymmetric: + self.ext_conv2 = nn.Sequential( + nn.Conv2D(internal_channels, + internal_channels, + kernel_size=(kernel_size, 1), + stride=1, + padding=(padding, 0), + dilation=dilation, + bias_attr=bias), + layers.SyncBatchNorm(internal_channels), activation(), + nn.Conv2D(internal_channels, + internal_channels, + kernel_size=(1, kernel_size), + stride=1, + padding=(0, padding), + dilation=dilation, + bias_attr=bias), + layers.SyncBatchNorm(internal_channels), activation()) + else: + self.ext_conv2 = nn.Sequential( + nn.Conv2D(internal_channels, + internal_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + dilation=dilation, + bias_attr=bias), + layers.SyncBatchNorm(internal_channels), activation()) + + self.ext_conv3 = nn.Sequential( + nn.Conv2D(internal_channels, + channels, + kernel_size=1, + stride=1, + bias_attr=bias), layers.SyncBatchNorm(channels), + activation()) + + self.ext_regul = nn.Dropout2D(p=dropout_prob) + + self.out_activation = activation() + + def forward(self, x): + # Main branch shortcut + main = x + + ext = self.ext_conv1(x) + ext = self.ext_conv2(ext) + ext = self.ext_conv3(ext) + ext = self.ext_regul(ext) + + out = main + ext + + return self.out_activation(out) + + +class DownsamplingBottleneck(nn.Layer): + """ + Downsampling bottlenecks further downsample the feature map size. + Main branch: + 1. max pooling with stride 2; indices are saved to be used for + unpooling later. + Extension branch: + 1. 2x2 convolution with stride 2 that decreases the number of channels + by ``internal_ratio``, also called a projection; + 2. regular convolution (by default, 3x3); + 3. 1x1 convolution which increases the number of channels to + ``out_channels``, also called an expansion; + 4. dropout as a regularizer. + Args: + in_channels (int): the number of input channels. + out_channels (int): the number of output channels. + internal_ratio (int, optional): a scale factor applied to ``channels`` + used to compute the number of channels after the projection. eg. given + ``channels`` equal to 128 and internal_ratio equal to 2 the number of + channels after the projection is 64. Default: 4. + return_indices (bool, optional): if ``True``, will return the max + indices along with the outputs. Useful when unpooling later. + dropout_prob (float, optional): probability of an element to be + zeroed. Default: 0 (no dropout). + bias (bool, optional): Adds a learnable bias to the output if + ``True``. Default: False. + relu (bool, optional): When ``True`` ReLU is used as the activation + function; otherwise, PReLU is used. Default: True. + """ + def __init__(self, + in_channels, + out_channels, + internal_ratio=4, + return_indices=False, + dropout_prob=0, + bias=False, + relu=True): + super(DownsamplingBottleneck, self).__init__() + + self.return_indices = return_indices + + if internal_ratio <= 1 or internal_ratio > in_channels: + raise RuntimeError( + "Value out of range. Expected value in the " + "interval [1, {0}], got internal_scale={1}. ".format( + in_channels, internal_ratio)) + + internal_channels = in_channels // internal_ratio + + if relu: + activation = nn.ReLU + else: + activation = nn.PReLU + + self.main_max1 = nn.MaxPool2D(2, stride=2, return_mask=return_indices) + + self.ext_conv1 = nn.Sequential( + nn.Conv2D(in_channels, + internal_channels, + kernel_size=2, + stride=2, + bias_attr=bias), layers.SyncBatchNorm(internal_channels), + activation()) + + self.ext_conv2 = nn.Sequential( + nn.Conv2D(internal_channels, + internal_channels, + kernel_size=3, + stride=1, + padding=1, + bias_attr=bias), layers.SyncBatchNorm(internal_channels), + activation()) + + self.ext_conv3 = nn.Sequential( + nn.Conv2D(internal_channels, + out_channels, + kernel_size=1, + stride=1, + bias_attr=bias), layers.SyncBatchNorm(out_channels), + activation()) + + self.ext_regul = nn.Dropout2D(p=dropout_prob) + + self.out_activation = activation() + + def forward(self, x): + # Main branch shortcut + if self.return_indices: + main, max_indices = self.main_max1(x) + else: + main = self.main_max1(x) + + ext = self.ext_conv1(x) + ext = self.ext_conv2(ext) + ext = self.ext_conv3(ext) + ext = self.ext_regul(ext) + + # Main branch channel padding + n, ch_ext, h, w = ext.shape + ch_main = main.shape[1] + padding = paddle.zeros((n, ch_ext - ch_main, h, w)) + + main = paddle.concat((main, padding), 1) + + out = main + ext + + return self.out_activation(out), max_indices + + +class UpsamplingBottleneck(nn.Layer): + """ + The upsampling bottlenecks upsample the feature map resolution using max + pooling indices stored from the corresponding downsampling bottleneck. + Main branch: + 1. 1x1 convolution with stride 1 that decreases the number of channels by + ``internal_ratio``, also called a projection; + 2. max unpool layer using the max pool indices from the corresponding + downsampling max pool layer. + Extension branch: + 1. 1x1 convolution with stride 1 that decreases the number of channels by + ``internal_ratio``, also called a projection; + 2. transposed convolution (by default, 3x3); + 3. 1x1 convolution which increases the number of channels to + ``out_channels``, also called an expansion; + 4. dropout as a regularizer. + Args: + in_channels (int): the number of input channels. + out_channels (int): the number of output channels. + internal_ratio (int, optional): a scale factor applied to ``in_channels`` + used to compute the number of channels after the projection. eg. given + ``in_channels`` equal to 128 and ``internal_ratio`` equal to 2 the number + of channels after the projection is 64. Default: 4. + dropout_prob (float, optional): probability of an element to be zeroed. + Default: 0 (no dropout). + bias (bool, optional): Adds a learnable bias to the output if ``True``. + Default: False. + relu (bool, optional): When ``True`` ReLU is used as the activation + function; otherwise, PReLU is used. Default: True. + """ + def __init__(self, + in_channels, + out_channels, + internal_ratio=4, + dropout_prob=0, + bias=False, + relu=True): + super(UpsamplingBottleneck, self).__init__() + + if internal_ratio <= 1 or internal_ratio > in_channels: + raise RuntimeError( + "Value out of range. Expected value in the " + "interval [1, {0}], got internal_scale={1}. ".format( + in_channels, internal_ratio)) + + internal_channels = in_channels // internal_ratio + + if relu: + activation = nn.ReLU + else: + activation = nn.PReLU + + self.main_conv1 = nn.Sequential( + nn.Conv2D(in_channels, out_channels, kernel_size=1, bias_attr=bias), + layers.SyncBatchNorm(out_channels)) + + self.ext_conv1 = nn.Sequential( + nn.Conv2D(in_channels, + internal_channels, + kernel_size=1, + bias_attr=bias), layers.SyncBatchNorm(internal_channels), + activation()) + + self.ext_tconv1 = nn.Conv2DTranspose(internal_channels, + internal_channels, + kernel_size=2, + stride=2, + bias_attr=bias) + self.ext_tconv1_bnorm = layers.SyncBatchNorm(internal_channels) + self.ext_tconv1_activation = activation() + + self.ext_conv2 = nn.Sequential( + nn.Conv2D(internal_channels, + out_channels, + kernel_size=1, + bias_attr=bias), layers.SyncBatchNorm(out_channels)) + + self.ext_regul = nn.Dropout2D(p=dropout_prob) + + self.out_activation = activation() + + def forward(self, x, max_indices, output_size): + # Main branch shortcut + main = self.main_conv1(x) + main = F.max_unpool2d(main, + max_indices, + kernel_size=2, + output_size=output_size) + + ext = self.ext_conv1(x) + ext = self.ext_tconv1(ext, output_size=output_size[2:]) + ext = self.ext_tconv1_bnorm(ext) + ext = self.ext_tconv1_activation(ext) + ext = self.ext_conv2(ext) + ext = self.ext_regul(ext) + + out = main + ext + + return self.out_activation(out) From e55e463aaca290128b727b7e2318741476b13d94 Mon Sep 17 00:00:00 2001 From: ETTR123 <740580207@qq.com> Date: Tue, 14 Dec 2021 10:33:56 +0800 Subject: [PATCH 02/12] Update README.md --- configs/enet/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/enet/README.md b/configs/enet/README.md index b04ba2fa40..d13c517f06 100644 --- a/configs/enet/README.md +++ b/configs/enet/README.md @@ -8,6 +8,6 @@ Real-Time Semantic Segmentation." arXiv preprint arXiv:1606.02147(2016). ### Cityscapes -| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | +| Model | Backbone | Resolution | Training Iters | mIoU | Links | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| -|ENet|-|1024x512|80000|58.3%|-|-|[]| +|ENet|-|1024x512|80000|61.3%|[提取码:2fle](https://pan.baidu.com/share/init?surl=7k-Mq_BmZp0FaY1IDjAVoQ)| From da9d98484f2e3f0674f6155eecae5e26146698e9 Mon Sep 17 00:00:00 2001 From: ETTR123 <740580207@qq.com> Date: Tue, 14 Dec 2021 10:44:23 +0800 Subject: [PATCH 03/12] Update README.md --- configs/enet/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/enet/README.md b/configs/enet/README.md index d13c517f06..04649ec273 100644 --- a/configs/enet/README.md +++ b/configs/enet/README.md @@ -9,5 +9,5 @@ Real-Time Semantic Segmentation." arXiv preprint arXiv:1606.02147(2016). ### Cityscapes | Model | Backbone | Resolution | Training Iters | mIoU | Links | -|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|:-:|:-:|:-:|:-:|:-:|:-:| |ENet|-|1024x512|80000|61.3%|[提取码:2fle](https://pan.baidu.com/share/init?surl=7k-Mq_BmZp0FaY1IDjAVoQ)| From ebb49d4ff99e73706737f8af1f7e2bd98aa915fc Mon Sep 17 00:00:00 2001 From: ETTR123 <740580207@qq.com> Date: Tue, 14 Dec 2021 10:58:31 +0800 Subject: [PATCH 04/12] Update enet_cityscapes_1024x512_adam_0.002_80k.yml --- configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml | 5 ----- 1 file changed, 5 deletions(-) diff --git a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml index d0eedbe509..d2d6ebfe86 100644 --- a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml +++ b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml @@ -5,13 +5,8 @@ train_dataset: type: Cityscapes dataset_root: data/cityscapes transforms: - - type: ResizeStepScaling - min_scale_factor: 0.5 - max_scale_factor: 2.0 - scale_step_size: 0.25 - type: RandomPaddingCrop crop_size: [1024, 512] - - type: RandomHorizontalFlip - type: RandomDistort brightness_range: 0.4 contrast_range: 0.4 From f3330ce2a3a85cfca474b866c2efb5c0079d00d5 Mon Sep 17 00:00:00 2001 From: ETTR123 <740580207@qq.com> Date: Tue, 14 Dec 2021 14:47:22 +0800 Subject: [PATCH 05/12] update --- ...net_cityscapes_1024x512_adam_0.002_80k.yml | 21 ++----------------- paddleseg/models/enet.py | 19 +++++------------ 2 files changed, 7 insertions(+), 33 deletions(-) diff --git a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml index d2d6ebfe86..fc81e58797 100644 --- a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml +++ b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml @@ -1,4 +1,5 @@ -iters: 80000 +_base_: '../_base_/cityscapes.yml' + batch_size: 8 train_dataset: @@ -14,13 +15,6 @@ train_dataset: - type: Normalize mode: train -val_dataset: - type: Cityscapes - dataset_root: data/cityscapes - transforms: - - type: Normalize - mode: val - model: type: ENet num_classes: 19 @@ -29,14 +23,3 @@ model: optimizer: type: adam weight_decay: 0.0002 - -lr_scheduler: - type: PolynomialDecay - learning_rate: 0.001 - end_lr: 0 - power: 0.9 - -loss: - types: - - type: CrossEntropyLoss - coef: [1] diff --git a/paddleseg/models/enet.py b/paddleseg/models/enet.py index 7528f93cb1..3971a48e41 100644 --- a/paddleseg/models/enet.py +++ b/paddleseg/models/enet.py @@ -27,6 +27,7 @@ class ENet(nn.Layer): """ The ENet implementation based on PaddlePaddle. + The original article refers to Adam Paszke, Abhishek Chaurasia, Sangpil Kim, Eugenio Culurciello, et al."ENet: A Deep Neural Network Architecture for Real-Time Semantic Segmentation" (https://arxiv.org/abs/1606.02147). @@ -49,7 +50,6 @@ def __init__(self, self.numclasses = num_classes self.initial_block = InitialBlock(3, 16, relu=encoder_relu) - # Stage 1 - Encoder self.downsample1_0 = DownsamplingBottleneck(16, 64, return_indices=True, @@ -72,7 +72,6 @@ def __init__(self, dropout_prob=0.01, relu=encoder_relu) - # Stage 2 - Encoder self.downsample2_0 = DownsamplingBottleneck(64, 128, return_indices=True, @@ -119,7 +118,6 @@ def __init__(self, dropout_prob=0.1, relu=encoder_relu) - # Stage 3 - Encoder self.regular3_0 = RegularBottleneck(128, padding=1, dropout_prob=0.1, @@ -161,7 +159,6 @@ def __init__(self, dropout_prob=0.1, relu=encoder_relu) - # Stage 4 - Decoder self.upsample4_0 = UpsamplingBottleneck(128, 64, dropout_prob=0.1, @@ -175,7 +172,6 @@ def __init__(self, dropout_prob=0.1, relu=decoder_relu) - # Stage 5 - Decoder self.upsample5_0 = UpsamplingBottleneck(64, 16, dropout_prob=0.1, @@ -199,7 +195,6 @@ def forward(self, x): input_size = x.shape x = self.initial_block(x) - # Stage 1 - Encoder stage1_input_size = x.shape x, max_indices1_0 = self.downsample1_0(x) x = self.regular1_1(x) @@ -207,7 +202,6 @@ def forward(self, x): x = self.regular1_3(x) x = self.regular1_4(x) - # Stage 2 - Encoder stage2_input_size = x.shape x, max_indices2_0 = self.downsample2_0(x) x = self.regular2_1(x) @@ -219,7 +213,6 @@ def forward(self, x): x = self.asymmetric2_7(x) x = self.dilated2_8(x) - # Stage 3 - Encoder x = self.regular3_0(x) x = self.dilated3_1(x) x = self.asymmetric3_2(x) @@ -229,12 +222,10 @@ def forward(self, x): x = self.asymmetric3_6(x) x = self.dilated3_7(x) - # Stage 4 - Decoder x = self.upsample4_0(x, max_indices2_0, output_size=stage2_input_size) x = self.regular4_1(x) x = self.regular4_2(x) - # Stage 5 - Decoder x = self.upsample5_0(x, max_indices1_0, output_size=stage1_input_size) x = self.regular5_1(x) x = self.transposed_conv(x, output_size=input_size[2:]) @@ -254,6 +245,7 @@ class InitialBlock(nn.Layer): allows for efficient downsampling and expansion. The main branch outputs 13 feature maps while the extension branch outputs 3, for a total of 16 feature maps after concatenation. + Args: in_channels (int): the number of input channels. out_channels (int): the number output channels. @@ -310,6 +302,7 @@ class RegularBottleneck(nn.Layer): 3. 1x1 convolution which increases the number of channels back to ``channels``, also called an expansion; 4. dropout as a regularizer. + Args: channels (int): the number of input and output channels. internal_ratio (int, optional): a scale factor applied to @@ -408,7 +401,6 @@ def __init__(self, self.out_activation = activation() def forward(self, x): - # Main branch shortcut main = x ext = self.ext_conv1(x) @@ -434,6 +426,7 @@ class DownsamplingBottleneck(nn.Layer): 3. 1x1 convolution which increases the number of channels to ``out_channels``, also called an expansion; 4. dropout as a regularizer. + Args: in_channels (int): the number of input channels. out_channels (int): the number of output channels. @@ -507,7 +500,6 @@ def __init__(self, self.out_activation = activation() def forward(self, x): - # Main branch shortcut if self.return_indices: main, max_indices = self.main_max1(x) else: @@ -518,7 +510,6 @@ def forward(self, x): ext = self.ext_conv3(ext) ext = self.ext_regul(ext) - # Main branch channel padding n, ch_ext, h, w = ext.shape ch_main = main.shape[1] padding = paddle.zeros((n, ch_ext - ch_main, h, w)) @@ -546,6 +537,7 @@ class UpsamplingBottleneck(nn.Layer): 3. 1x1 convolution which increases the number of channels to ``out_channels``, also called an expansion; 4. dropout as a regularizer. + Args: in_channels (int): the number of input channels. out_channels (int): the number of output channels. @@ -612,7 +604,6 @@ def __init__(self, self.out_activation = activation() def forward(self, x, max_indices, output_size): - # Main branch shortcut main = self.main_conv1(x) main = F.max_unpool2d(main, max_indices, From f691133e24f206a5cab79a5f8cb345b08561c872 Mon Sep 17 00:00:00 2001 From: ETTR123 <56824848+ETTR123@users.noreply.github.com> Date: Thu, 13 Jan 2022 17:35:15 +0800 Subject: [PATCH 06/12] Update train.py --- paddleseg/core/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index d3ee4f7e5f..e7a1a35608 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -314,7 +314,7 @@ def train(model, batch_start = time.time() # Calculate flops. - if local_rank == 0: + if local_rank == 0 and model.__class__.__name__ != 'ENet': _, c, h, w = images.shape _ = paddle.flops( model, [1, c, h, w], From 00b3ada22fef5f34770f4eec31867fe348c8ec19 Mon Sep 17 00:00:00 2001 From: ETTR123 <56824848+ETTR123@users.noreply.github.com> Date: Sat, 15 Jan 2022 17:11:41 +0800 Subject: [PATCH 07/12] Update enet_cityscapes_1024x512_adam_0.002_80k.yml Cancel inheritance of optimizer parameters in ../_base_/cityscapes.yml. --- configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml index fc81e58797..d53b6ab7f9 100644 --- a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml +++ b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml @@ -21,5 +21,6 @@ model: pretrained: Null optimizer: + _inherited_: False type: adam weight_decay: 0.0002 From 7fdee74b89f6b82693481c3f0d80cf0d15c6b72e Mon Sep 17 00:00:00 2001 From: ETTR123 <56824848+ETTR123@users.noreply.github.com> Date: Tue, 18 Jan 2022 10:05:07 +0800 Subject: [PATCH 08/12] Update enet_cityscapes_1024x512_adam_0.002_80k.yml --- configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml index d53b6ab7f9..48fb6df9fd 100644 --- a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml +++ b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml @@ -24,3 +24,9 @@ optimizer: _inherited_: False type: adam weight_decay: 0.0002 + +lr_scheduler: + end_lr: 0 + learning_rate: 0.001 + power: 0.9 + type: PolynomialDecay From f519ceb1c42c4f9122bc529a7bd913435f13ad1c Mon Sep 17 00:00:00 2001 From: ETTR123 <56824848+ETTR123@users.noreply.github.com> Date: Wed, 19 Jan 2022 11:23:50 +0800 Subject: [PATCH 09/12] add Miou --- configs/enet/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/configs/enet/README.md b/configs/enet/README.md index 04649ec273..aa0c0b8259 100644 --- a/configs/enet/README.md +++ b/configs/enet/README.md @@ -8,6 +8,6 @@ Real-Time Semantic Segmentation." arXiv preprint arXiv:1606.02147(2016). ### Cityscapes -| Model | Backbone | Resolution | Training Iters | mIoU | Links | -|:-:|:-:|:-:|:-:|:-:|:-:| -|ENet|-|1024x512|80000|61.3%|[提取码:2fle](https://pan.baidu.com/share/init?surl=7k-Mq_BmZp0FaY1IDjAVoQ)| +| Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | +|:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| +|ENet|-|1024x512|80000|67.42%|68.11%|67.99%|[]()|| From 50d5fa720ed8dac27e7cecf02f924d5ac46eec8e Mon Sep 17 00:00:00 2001 From: shiyutang <34859558+shiyutang@users.noreply.github.com> Date: Thu, 20 Jan 2022 16:59:58 +0800 Subject: [PATCH 10/12] Update README.md --- configs/enet/README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/enet/README.md b/configs/enet/README.md index aa0c0b8259..e098df2a7a 100644 --- a/configs/enet/README.md +++ b/configs/enet/README.md @@ -10,4 +10,5 @@ Real-Time Semantic Segmentation." arXiv preprint arXiv:1606.02147(2016). | Model | Backbone | Resolution | Training Iters | mIoU | mIoU (flip) | mIoU (ms+flip) | Links | |:-:|:-:|:-:|:-:|:-:|:-:|:-:|:-:| -|ENet|-|1024x512|80000|67.42%|68.11%|67.99%|[]()|| +|ENet|-|1024x512|80000|67.42%|68.11%|67.99%|[model](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/bisenetv1_cityscapes_1024x512_160k/model.pdparams)\|[log](https://bj.bcebos.com/paddleseg/dygraph/cityscapes/bisenetv1_cityscapes_1024x512_160k/train.log)\|[vdl](https://paddlepaddle.org.cn/paddle/visualdl/service/app?id=5d57386cdfcdb6a6bcb5135af134a0f2)| +| From 97becbf88c5e1079f8e55125a62ebb552362776a Mon Sep 17 00:00:00 2001 From: shiyutang <34859558+shiyutang@users.noreply.github.com> Date: Thu, 20 Jan 2022 17:03:39 +0800 Subject: [PATCH 11/12] Update train.py --- paddleseg/core/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddleseg/core/train.py b/paddleseg/core/train.py index e7a1a35608..d3ee4f7e5f 100644 --- a/paddleseg/core/train.py +++ b/paddleseg/core/train.py @@ -314,7 +314,7 @@ def train(model, batch_start = time.time() # Calculate flops. - if local_rank == 0 and model.__class__.__name__ != 'ENet': + if local_rank == 0: _, c, h, w = images.shape _ = paddle.flops( model, [1, c, h, w], From 78acede389b0ee405681bc5b3792717f0fb4f98f Mon Sep 17 00:00:00 2001 From: shiyutang <34859558+shiyutang@users.noreply.github.com> Date: Thu, 20 Jan 2022 17:04:22 +0800 Subject: [PATCH 12/12] trigger --- configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml index 48fb6df9fd..7abfa2f89b 100644 --- a/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml +++ b/configs/enet/enet_cityscapes_1024x512_adam_0.002_80k.yml @@ -1,5 +1,4 @@ _base_: '../_base_/cityscapes.yml' - batch_size: 8 train_dataset: