From 91a026c7dd275f06c9c6e6dd6af9889ebe55294d Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 25 Oct 2019 23:28:29 +0300 Subject: [PATCH 01/79] Global refactoring of encoders & decoders --- pytorch_toolbelt/modules/__init__.py | 6 +- pytorch_toolbelt/modules/backbone/hrnet.py | 563 +++++++++++ .../modules/backbone/mobilenet.py | 3 +- .../modules/backbone/mobilenetv3.py | 1 - pytorch_toolbelt/modules/backbone/senet.py | 3 +- .../modules/backbone/wider_resnet.py | 3 +- pytorch_toolbelt/modules/coord_conv.py | 4 +- pytorch_toolbelt/modules/decoders/__init__.py | 8 + pytorch_toolbelt/modules/decoders/deeplab.py | 62 ++ pytorch_toolbelt/modules/decoders/fpn_cat.py | 100 ++ pytorch_toolbelt/modules/decoders/fpn_sum.py | 211 ++++ pytorch_toolbelt/modules/decoders/hrnet.py | 26 + .../modules/decoders/pyramid_pooling.py | 64 ++ .../modules/decoders/unet_decoder.py | 156 +++ pytorch_toolbelt/modules/decoders/upernet.py | 119 +++ pytorch_toolbelt/modules/encoders.py | 905 ------------------ pytorch_toolbelt/modules/encoders/__init__.py | 14 + pytorch_toolbelt/modules/encoders/common.py | 49 + pytorch_toolbelt/modules/encoders/densenet.py | 133 +++ .../modules/encoders/efficientnet.py | 146 +++ pytorch_toolbelt/modules/encoders/hrnet.py | 32 + .../modules/encoders/inception.py | 39 + .../modules/encoders/mobilenet.py | 88 ++ pytorch_toolbelt/modules/encoders/resnet.py | 109 +++ pytorch_toolbelt/modules/encoders/seresnet.py | 113 +++ .../modules/encoders/squeezenet.py | 57 ++ .../modules/encoders/wide_resnet.py | 179 ++++ pytorch_toolbelt/modules/fpn.py | 11 +- pytorch_toolbelt/modules/hypercolumn.py | 3 - pytorch_toolbelt/modules/pooling.py | 3 +- 30 files changed, 2288 insertions(+), 922 deletions(-) create mode 100644 pytorch_toolbelt/modules/backbone/hrnet.py create mode 100644 pytorch_toolbelt/modules/decoders/__init__.py create mode 100644 pytorch_toolbelt/modules/decoders/deeplab.py create mode 100644 pytorch_toolbelt/modules/decoders/fpn_cat.py create mode 100644 pytorch_toolbelt/modules/decoders/fpn_sum.py create mode 100644 pytorch_toolbelt/modules/decoders/hrnet.py create mode 100644 pytorch_toolbelt/modules/decoders/pyramid_pooling.py create mode 100644 pytorch_toolbelt/modules/decoders/unet_decoder.py create mode 100644 pytorch_toolbelt/modules/decoders/upernet.py delete mode 100644 pytorch_toolbelt/modules/encoders.py create mode 100644 pytorch_toolbelt/modules/encoders/__init__.py create mode 100644 pytorch_toolbelt/modules/encoders/common.py create mode 100644 pytorch_toolbelt/modules/encoders/densenet.py create mode 100644 pytorch_toolbelt/modules/encoders/efficientnet.py create mode 100644 pytorch_toolbelt/modules/encoders/hrnet.py create mode 100644 pytorch_toolbelt/modules/encoders/inception.py create mode 100644 pytorch_toolbelt/modules/encoders/mobilenet.py create mode 100644 pytorch_toolbelt/modules/encoders/resnet.py create mode 100644 pytorch_toolbelt/modules/encoders/seresnet.py create mode 100644 pytorch_toolbelt/modules/encoders/squeezenet.py create mode 100644 pytorch_toolbelt/modules/encoders/wide_resnet.py diff --git a/pytorch_toolbelt/modules/__init__.py b/pytorch_toolbelt/modules/__init__.py index 53647f62a..bc41cd69e 100644 --- a/pytorch_toolbelt/modules/__init__.py +++ b/pytorch_toolbelt/modules/__init__.py @@ -1,8 +1,8 @@ from __future__ import absolute_import from .abn import * -from .identity import * from .dsconv import * -from .scse import * -from .hypercolumn import * from .fpn import * +from .hypercolumn import * +from .identity import * +from .scse import * diff --git a/pytorch_toolbelt/modules/backbone/hrnet.py b/pytorch_toolbelt/modules/backbone/hrnet.py new file mode 100644 index 000000000..b94d233b5 --- /dev/null +++ b/pytorch_toolbelt/modules/backbone/hrnet.py @@ -0,0 +1,563 @@ +""" +This HRNet implementation is modified from the following repository: +https://github.com/HRNet/HRNet-Semantic-Segmentation +""" + +import os +import sys +from urllib.request import urlretrieve + +import torch +import torch.nn as nn +import torch.nn.functional as F + +model_urls = { + "hrnetv2_48": "http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth" +} + +HRNETV2_BN_MOMENTUM = 0.1 + + +def hrnet_conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, + bias=False + ) + + +class HRNetBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(HRNetBasicBlock, self).__init__() + self.conv1 = hrnet_conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = hrnet_conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HRNetBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(HRNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, planes, kernel_size=3, stride=stride, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False + ) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=HRNETV2_BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__( + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True, + ): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels + ) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels + ) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches( + self, num_branches, blocks, num_blocks, num_inchannels, + num_channels + ): + if num_branches != len(num_blocks): + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( + num_branches, len(num_blocks) + ) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( + num_branches, len(num_channels) + ) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( + num_branches, len(num_inchannels) + ) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if ( + stride != 1 + or self.num_inchannels[branch_index] + != num_channels[branch_index] * block.expansion + ): + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=HRNETV2_BN_MOMENTUM, + ), + ) + + layers = [] + layers.append( + block( + self.num_inchannels[branch_index], + num_channels[branch_index], + stride, + downsample, + ) + ) + self.num_inchannels[branch_index] = num_channels[ + branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index]) + ) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False, + ), + nn.BatchNorm2d( + num_inchannels[i], momentum=HRNETV2_BN_MOMENTUM + ), + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.BatchNorm2d( + num_outchannels_conv3x3, + momentum=HRNETV2_BN_MOMENTUM, + ), + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False, + ), + nn.BatchNorm2d( + num_outchannels_conv3x3, + momentum=HRNETV2_BN_MOMENTUM, + ), + nn.ReLU(inplace=True), + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode="bilinear", + align_corners=False, + ) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HRNetV2(nn.Module): + def __init__(self, width=48, **kwargs): + super(HRNetV2, self).__init__() + blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} + + extra = { + "STAGE2": { + "NUM_MODULES": 1, + "NUM_BRANCHES": 2, + "BLOCK": "BASIC", + "NUM_BLOCKS": (4, 4), + "NUM_CHANNELS": (width, width * 2), + "FUSE_METHOD": "SUM", + }, + "STAGE3": { + "NUM_MODULES": 4, + "NUM_BRANCHES": 3, + "BLOCK": "BASIC", + "NUM_BLOCKS": (4, 4, 4), + "NUM_CHANNELS": (width, width * 2, width * 4), + "FUSE_METHOD": "SUM", + }, + "STAGE4": { + "NUM_MODULES": 3, + "NUM_BRANCHES": 4, + "BLOCK": "BASIC", + "NUM_BLOCKS": (4, 4, 4, 4), + "NUM_CHANNELS": (width, width * 2, width * 4, width * 8), + "FUSE_METHOD": "SUM", + }, + "FINAL_CONV_KERNEL": 1, + } + + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.layer1 = self._make_layer(HRNetBottleneck, 64, 64, 4) + + self.stage2_cfg = extra["STAGE2"] + num_channels = self.stage2_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage2_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels + ) + + self.stage3_cfg = extra["STAGE3"] + num_channels = self.stage3_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage3_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels + ) + + self.stage4_cfg = extra["STAGE4"] + num_channels = self.stage4_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage4_cfg["BLOCK"]] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True + ) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False, + ), + nn.BatchNorm2d( + num_channels_cur_layer[i], + momentum=HRNETV2_BN_MOMENTUM + ), + nn.ReLU(inplace=True), + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = ( + num_channels_cur_layer[i] + if j == i - num_branches_pre + else inchannels + ) + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, + bias=False), + nn.BatchNorm2d(outchannels, + momentum=HRNETV2_BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(planes * block.expansion, + momentum=HRNETV2_BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + + blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} + + num_modules = layer_config["NUM_MODULES"] + num_branches = layer_config["NUM_BRANCHES"] + num_blocks = layer_config["NUM_BLOCKS"] + num_channels = layer_config["NUM_CHANNELS"] + block = blocks_dict[layer_config["BLOCK"]] + fuse_method = layer_config["FUSE_METHOD"] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x, return_feature_maps=False): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg["NUM_BRANCHES"]): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg["NUM_BRANCHES"]): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg["NUM_BRANCHES"]): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=False + ) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=False + ) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=False + ) + + x = torch.cat([x[0], x1, x2, x3], 1) + + # x = self.last_layer(x) + return [x] + + +def hrnetv2(pretrained=False, **kwargs): + model = HRNetV2(**kwargs) + if pretrained: + + def load_url(url, model_dir="./pretrained", map_location=None): + if not os.path.exists(model_dir): + os.makedirs(model_dir) + filename = url.split("/")[-1] + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + sys.stderr.write( + 'Downloading: "{}" to {}\n'.format(url, cached_file)) + urlretrieve(url, cached_file) + return torch.load(cached_file, map_location=map_location) + + model.load_state_dict(load_url(model_urls["hrnetv2_48"]), strict=False) + + return model diff --git a/pytorch_toolbelt/modules/backbone/mobilenet.py b/pytorch_toolbelt/modules/backbone/mobilenet.py index cd260ef44..4110e3292 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenet.py +++ b/pytorch_toolbelt/modules/backbone/mobilenet.py @@ -1,8 +1,9 @@ from __future__ import absolute_import -import torch.nn as nn import math +import torch.nn as nn + from ..activations import get_activation_module diff --git a/pytorch_toolbelt/modules/backbone/mobilenetv3.py b/pytorch_toolbelt/modules/backbone/mobilenetv3.py index 6de1037eb..45f8292a0 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenetv3.py +++ b/pytorch_toolbelt/modules/backbone/mobilenetv3.py @@ -5,7 +5,6 @@ import torch import torch.nn as nn import torch.nn.functional as F - # from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D from pytorch_toolbelt.modules.activations import HardSwish, HardSigmoid from pytorch_toolbelt.modules.identity import Identity diff --git a/pytorch_toolbelt/modules/backbone/senet.py b/pytorch_toolbelt/modules/backbone/senet.py index 460dcf3c6..6643bcb36 100644 --- a/pytorch_toolbelt/modules/backbone/senet.py +++ b/pytorch_toolbelt/modules/backbone/senet.py @@ -3,8 +3,9 @@ https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py """ from __future__ import print_function, division, absolute_import -from collections import OrderedDict + import math +from collections import OrderedDict import torch.nn as nn from torch.utils import model_zoo diff --git a/pytorch_toolbelt/modules/backbone/wider_resnet.py b/pytorch_toolbelt/modules/backbone/wider_resnet.py index 1ce3138a6..445c366d6 100644 --- a/pytorch_toolbelt/modules/backbone/wider_resnet.py +++ b/pytorch_toolbelt/modules/backbone/wider_resnet.py @@ -2,11 +2,10 @@ from functools import partial import torch -from torch import nn - from pytorch_toolbelt.modules.abn import ABN from pytorch_toolbelt.modules.pooling import GlobalAvgPool2d from pytorch_toolbelt.utils.torch_utils import count_parameters +from torch import nn class IdentityResidualBlock(nn.Module): diff --git a/pytorch_toolbelt/modules/coord_conv.py b/pytorch_toolbelt/modules/coord_conv.py index ff395b66b..ac70c41bf 100644 --- a/pytorch_toolbelt/modules/coord_conv.py +++ b/pytorch_toolbelt/modules/coord_conv.py @@ -1,5 +1,5 @@ -"""Implementation of the CoordConv modules from https://arxiv.org/abs/1807.03247 - +""" +Implementation of the CoordConv modules from https://arxiv.org/abs/1807.03247 """ import torch diff --git a/pytorch_toolbelt/modules/decoders/__init__.py b/pytorch_toolbelt/modules/decoders/__init__.py new file mode 100644 index 000000000..9ed669c49 --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/__init__.py @@ -0,0 +1,8 @@ +from __future__ import absolute_import + +from .fpn_sum import * +from .fpn_cat import * +from .deeplab import * +from .upernet import * +from .pyramid_pooling import * +from .unet_decoder import * \ No newline at end of file diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py new file mode 100644 index 000000000..490c3d8ee --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -0,0 +1,62 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = ["DeeplabV3Decoder"] + + +class DeeplabV3Decoder(nn.Module): + def __init__( + self, + high_level_features: int, + low_level_features: int, + num_classes: int, + dropout=0.5, + ): + super(DeeplabV3Decoder, self).__init__() + + self.conv1 = nn.Conv2d(low_level_features, 48, 1, bias=False) + self.bn1 = nn.BatchNorm2d(48) + self.relu = nn.ReLU(inplace=True) + + self.last_conv = nn.Sequential( + nn.Conv2d( + high_level_features + 48, + 256, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Dropout(dropout * 0.2), # 5 times smaller dropout rate + nn.Conv2d(256, num_classes, kernel_size=1, stride=1), + ) + self.reset_parameters() + + def forward(self, x, low_level_feat): + low_level_feat = self.conv1(low_level_feat) + low_level_feat = self.bn1(low_level_feat) + low_level_feat = self.relu(low_level_feat) + + x = F.interpolate( + x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True + ) + x = torch.cat((x, low_level_feat), dim=1) + x = self.last_conv(x) + + return x + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py new file mode 100644 index 000000000..77e9e06aa --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -0,0 +1,100 @@ +from typing import List, Tuple + +import torch +from pytorch_toolbelt.modules import ABN +from pytorch_toolbelt.modules.decoders import DecoderModule +from pytorch_toolbelt.utils.torch_utils import count_parameters +from torch import nn, Tensor + +from torch.nn import functional as F +from pytorch_toolbelt.modules.decoders import ( + FPNDecoder, + FPNBottleneckBlock, + FPNPredictionBlock, +) +from pytorch_toolbelt.modules.fpn import FPNFuse, UpsampleAdd + +__all__ = ["FPNCatDecoder"] + +from ..modules import DoubleConvBNRelu + + +class FPNCatDecoder(DecoderModule): + """ + + """ + + def __init__( + self, + feature_maps: List[int], + num_classes: int, + fpn_channels=128, + dropout=0.0, + abn_block=ABN, + ): + super().__init__() + + self.fpn = FPNDecoder( + feature_maps, + upsample_add_block=UpsampleAdd, + prediction_block=DoubleConvBNRelu, + fpn_features=fpn_channels, + prediction_features=fpn_channels, + ) + + self.fuse = FPNFuse() + self.dropout = nn.Dropout2d(dropout, inplace=True) + + self.dsv = nn.ModuleList( + [ + nn.Conv2d(fpn_features, num_classes, kernel_size=1) + for fpn_features in [fpn_channels] * len(feature_maps) + ] + ) + + features = sum(self.fpn.output_filters) + + self.final_block = nn.Sequential( + nn.Conv2d(features, features // 2, kernel_size=1), + nn.BatchNorm2d(features // 2), + nn.Conv2d( + features // 2, features // 4, kernel_size=3, padding=1, bias=True + ), + nn.LeakyReLU(inplace=True), + nn.BatchNorm2d(features // 4), + nn.Conv2d( + features // 4, features // 4, kernel_size=3, padding=1, bias=False + ), + nn.LeakyReLU(inplace=True), + nn.Conv2d(features // 4, num_classes, kernel_size=1, bias=True), + ) + + def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + fpn_maps = self.fpn(feature_maps) + + fused = self.fuse(fpn_maps) + fused = self.dropout(fused) + + dsv_masks = [] + for dsv_block, fpn in zip(self.dsv, fpn_maps): + dsv = dsv_block(fpn) + dsv_masks.append(dsv) + + x = self.final_block(fused) + return x, dsv_masks + + +@torch.no_grad() +def test_fpn_cat(): + channels = [256, 512, 1024, 2048] + sizes = [64, 32, 16, 8] + + net = FPNCatDecoder(channels, 5).eval() + + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] + output, dsv_masks = net(input) + + print(output.size(), output.mean(), output.std()) + for dsv in dsv_masks: + print(dsv.size(), dsv.mean(), dsv.std()) + print(count_parameters(net)) diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py new file mode 100644 index 000000000..f08817c33 --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -0,0 +1,211 @@ +from itertools import repeat +from typing import List, Tuple + +import torch +from pytorch_toolbelt.modules import Identity, ABN +from pytorch_toolbelt.modules.decoders import DecoderModule +from pytorch_toolbelt.utils.torch_utils import count_parameters + + +from torch import Tensor, nn +import torch.nn.functional as F + +__all__ = ["FPNSumDecoder", "FPNSumTransitionBlock", "FPNSumCenterBlock"] + + +class FPNSumCenterBlock(nn.Module): + def __init__( + self, + encoder_features: int, + decoder_features: int, + num_classes: int, + abn_block=ABN, + dropout=0.0, + ): + super().__init__() + self.bottleneck = nn.Conv2d( + encoder_features, encoder_features // 2, kernel_size=1 + ) + + self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) + self.proj2 = nn.Conv2d( + encoder_features // 2, encoder_features // 8, kernel_size=1 + ) + + self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4) + self.proj4 = nn.Conv2d( + encoder_features // 2, encoder_features // 8, kernel_size=1 + ) + + self.pool8 = nn.AvgPool2d(kernel_size=8, stride=8) + self.proj8 = nn.Conv2d( + encoder_features // 2, encoder_features // 8, kernel_size=1 + ) + + self.blend = nn.Conv2d( + encoder_features // 2 + 3 * encoder_features // 8, + decoder_features, + kernel_size=1, + ) + self.dropout = nn.Dropout2d(dropout, inplace=True) + + self.conv1 = nn.Conv2d( + decoder_features, decoder_features, kernel_size=3, padding=1, bias=False + ) + self.abn1 = abn_block(decoder_features) + + self.dsv = nn.Conv2d(decoder_features, num_classes, kernel_size=1) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + x = self.bottleneck(x) + + p2 = self.proj2(self.pool2(x)) + p4 = self.proj2(self.pool4(x)) + p8 = self.proj2(self.pool8(x)) + + x_size = x.size()[2:] + x = torch.cat( + [ + x, + F.interpolate(p2, size=x_size, mode="bilinear", align_corners=True), + F.interpolate(p4, size=x_size, mode="bilinear", align_corners=True), + F.interpolate(p8, size=x_size, mode="bilinear", align_corners=True), + ], + dim=1, + ) + + x = self.blend(x) + x = self.dropout(x) + + x = self.conv1(x) + x = self.abn1(x) + + dsv = self.dsv(x) + + return x, dsv + + +class FPNSumTransitionBlock(nn.Module): + def __init__( + self, + encoder_features: int, + decoder_features: int, + output_features: int, + num_classes: int, + abn_block=ABN, + dropout=0.0, + ): + super().__init__() + self.skip = nn.Conv2d(encoder_features, decoder_features, kernel_size=1) + if decoder_features == output_features: + self.reduction = Identity() + else: + self.reduction = nn.Conv2d(decoder_features, output_features, kernel_size=1) + + self.dropout = nn.Dropout2d(dropout, inplace=True) + self.conv1 = nn.Conv2d( + output_features, output_features, kernel_size=3, padding=1, bias=False + ) + self.abn1 = abn_block(output_features) + + self.dsv = nn.Conv2d(output_features, num_classes, kernel_size=1) + + def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Tuple[Tensor, Tensor]: + """ + + :param decoder_fm: + :param encoder_fm: + :return: + """ + decoder_fm = F.interpolate( + decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=True + ) + + encoder_fm = self.skip(encoder_fm) + x = decoder_fm + encoder_fm + + x = self.reduction(x) + x = self.dropout(x) + + x = self.conv1(x) + x = self.abn1(x) + + dsv = self.dsv(x) + + return x, dsv + + +class FPNSumDecoder(DecoderModule): + """ + + """ + + def __init__( + self, + feature_maps: List[int], + num_classes: int, + fpn_channels=256, + dropout=0.0, + abn_block=ABN, + ): + super().__init__() + + self.center = FPNSumCenterBlock( + feature_maps[-1], + fpn_channels, + num_classes=num_classes, + dropout=dropout, + abn_block=abn_block, + ) + + self.fpn_modules = nn.ModuleList( + [ + FPNSumTransitionBlock( + encoder_fm, + decoder_fm, + decoder_fm, + num_classes=num_classes, + dropout=dropout, + abn_block=abn_block, + ) + for decoder_fm, encoder_fm in zip( + repeat(fpn_channels), reversed(feature_maps[:-1]) + ) + ] + ) + + self.final_block = nn.Sequential( + nn.Conv2d(fpn_channels, num_classes, kernel_size=1) + ) + + def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor]: + last_feature_map = feature_maps[-1] + feature_maps = reversed(feature_maps[:-1]) + + dsv_masks = [] + x, dsv = self.center(last_feature_map) + + dsv_masks.append(dsv) + + for transition_unit, encoder_fm in zip(self.fpn_modules, feature_maps): + x, dsv = transition_unit(x, encoder_fm) + dsv_masks.append(dsv) + + x = self.final_block(x) + return x, dsv_masks + + +@torch.no_grad() +def test_fpn_sum(): + channels = [256, 512, 1024, 2048] + sizes = [64, 32, 16, 8] + + net = FPNSumDecoder(channels, 5).eval() + + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] + output, dsv_masks = net(input) + + print(output.size(), output.mean(), output.std()) + for dsv in dsv_masks: + print(dsv.size(), dsv.mean(), dsv.std()) + print(count_parameters(net)) diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py new file mode 100644 index 000000000..a582274f3 --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -0,0 +1,26 @@ + +class HRNetDecoder(DecoderModule): + def __init__(self, features: int, num_classes: int, dropout=0.): + super().__init__() + + self.last_layer = nn.Sequential( + nn.Conv2d( + in_channels=features, + out_channels=features, + kernel_size=1, + stride=1, + padding=0), + nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), + nn.ReLU(inplace=True), + nn.Dropout(dropout), + nn.Conv2d( + in_channels=features, + out_channels=num_classes, + kernel_size=3, + stride=1, + padding=1) + ) + + def forward(self, features): + return self.last_layer(features[-1]) + diff --git a/pytorch_toolbelt/modules/decoders/pyramid_pooling.py b/pytorch_toolbelt/modules/decoders/pyramid_pooling.py new file mode 100644 index 000000000..c15d421cf --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/pyramid_pooling.py @@ -0,0 +1,64 @@ +from typing import List + +import torch +import torch.nn +import torch.nn.functional as F +from torch import nn + + +class PPMDecoder(nn.Module): + """ + https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/42b7567a43b1dab568e2bbfcbc8872778fbda92a/models/models.py + """ + + def __init__( + self, + feature_maps: List[int], + num_classes=150, + channels=512, + pool_scales=(1, 2, 3, 6), + ): + super(PPMDecoder, self).__init__() + + fc_dim = feature_maps[-1] + self.ppm = [] + for scale in pool_scales: + self.ppm.append( + nn.Sequential( + nn.AdaptiveAvgPool2d(scale), + nn.Conv2d(fc_dim, channels, kernel_size=1, bias=False), + nn.BatchNorm2d(channels), + nn.ReLU(inplace=True), + ) + ) + self.ppm = nn.ModuleList(self.ppm) + + self.conv_last = nn.Sequential( + nn.Conv2d( + fc_dim + len(pool_scales) * channels, + channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.BatchNorm2d(channels), + nn.ReLU(inplace=True), + nn.Dropout2d(0.1), + nn.Conv2d(channels, num_classes, kernel_size=1), + ) + + def forward(self, feature_maps: List[torch.Tensor]): + last_fm = feature_maps[-1] + + input_size = last_fm.size() + ppm_out = [last_fm] + for pool_scale in self.ppm: + input_pooled = pool_scale(last_fm) + input_pooled = F.interpolate( + input_pooled, size=input_size[2:], mode="bilinear", align_corners=False + ) + ppm_out.append(input_pooled) + ppm_out = torch.cat(ppm_out, dim=1) + + x = self.conv_last(ppm_out) + return x diff --git a/pytorch_toolbelt/modules/decoders/unet_decoder.py b/pytorch_toolbelt/modules/decoders/unet_decoder.py new file mode 100644 index 000000000..5eb611b4f --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/unet_decoder.py @@ -0,0 +1,156 @@ +from typing import List + +import torch +import torch.nn.functional as F +from pytorch_toolbelt.modules import ABN +from pytorch_toolbelt.modules.decoders import DecoderModule +from pytorch_toolbelt.modules.encoders import SEResnet101Encoder +from pytorch_toolbelt.utils.torch_utils import count_parameters +from torch import nn + +__all__ = ["UNetDecoderV2", "UnetCentralBlockV2", "UnetDecoderBlockV2"] + + +class UnetCentralBlockV2(nn.Module): + def __init__(self, in_dec_filters, out_filters, mask_channels, abn_block=ABN): + super().__init__() + self.bottleneck = nn.Conv2d(in_dec_filters, out_filters, kernel_size=1) + + self.conv1 = nn.Conv2d( + out_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False + ) + self.abn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d( + out_filters, out_filters, kernel_size=3, padding=1, bias=False + ) + self.abn2 = abn_block(out_filters) + self.dsv = nn.Conv2d(out_filters, mask_channels, kernel_size=1) + + def forward(self, x): + x = self.bottleneck(x) + + x = self.conv1(x) + x = self.abn1(x) + x = self.conv2(x) + x = self.abn2(x) + + dsv = self.dsv(x) + + return x, dsv + + +class UnetDecoderBlockV2(nn.Module): + """ + """ + + def __init__( + self, + in_dec_filters: int, + in_enc_filters: int, + out_filters: int, + mask_channels: int, + abn_block=ABN, + pre_dropout_rate=0.0, + post_dropout_rate=0.0, + ): + super(UnetDecoderBlockV2, self).__init__() + + self.bottleneck = nn.Conv2d( + in_dec_filters + in_enc_filters, out_filters, kernel_size=1 + ) + + self.conv1 = nn.Conv2d( + out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False + ) + self.abn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d( + out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False + ) + self.abn2 = abn_block(out_filters) + + self.pre_drop = nn.Dropout2d(pre_dropout_rate, inplace=True) + + self.post_drop = nn.Dropout2d(post_dropout_rate, inplace=True) + + self.dsv = nn.Conv2d(out_filters, mask_channels, kernel_size=1) + + def forward(self, x, enc): + lat_size = enc.size()[2:] + x = F.interpolate(x, size=lat_size, mode="bilinear", align_corners=True) + + x = torch.cat([x, enc], 1) + x = self.bottleneck(x) + x = self.pre_drop(x) + + x = self.conv1(x) + x = self.abn1(x) + + x = self.conv2(x) + x = self.abn2(x) + + x = self.post_drop(x) + + dsv = self.dsv(x) + return x, dsv + + +class UNetDecoderV2(DecoderModule): + def __init__(self, features: List[int], decoder_features: int, mask_channels: int): + super().__init__() + + if not isinstance(decoder_features, list): + decoder_features = [ + decoder_features * (2 ** i) for i in range(len(features)) + ] + + blocks = [] + for block_index, in_enc_features in enumerate(features[:-1]): + blocks.append( + UnetDecoderBlockV2( + decoder_features[block_index + 1], + in_enc_features, + decoder_features[block_index], + mask_channels, + ) + ) + + self.center = UnetCentralBlockV2(features[-1], decoder_features[-1], mask_channels) + self.blocks = nn.ModuleList(blocks) + self.output_filters = decoder_features + + def forward(self, feature_maps): + + output, dsv = self.center(feature_maps[-1]) + decoder_outputs = [output] + dsv_list = [dsv] + + for decoder_block, encoder_output in zip( + reversed(self.blocks), reversed(feature_maps[:-1]) + ): + output, dsv = decoder_block(output, encoder_output) + decoder_outputs.append(output) + dsv_list.append(dsv) + + dsv_list = list(reversed(dsv_list)) + decoder_outputs = list(reversed(decoder_outputs)) + + return decoder_outputs, dsv_list + + +@torch.no_grad() +def test_unetv2(): + encoder = SEResnet101Encoder().cuda().eval() + decoder = ( + UNetDecoderV2(encoder.output_filters, [128, 192, 256, 512], 5).cuda().eval() + ) + + print(count_parameters(encoder)) + print(count_parameters(decoder)) + print(decoder) + + x = torch.rand((1, 3, 256, 512)).cuda() + fm = encoder(x) + fm2 = decoder(fm) + + for fm, dsv in fm2: + print(fm.size(), dsv.size()) diff --git a/pytorch_toolbelt/modules/decoders/upernet.py b/pytorch_toolbelt/modules/decoders/upernet.py new file mode 100644 index 000000000..1c23c1c73 --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/upernet.py @@ -0,0 +1,119 @@ +from typing import List + +import torch +import torch.nn +import torch.nn.functional as F +from torch import nn + + +def conv3x3_bn_relu(in_planes, out_planes, stride=1): + "3x3 convolution + BN + relu" + return nn.Sequential( + nn.Conv2d( + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False + ), + nn.BatchNorm2d(out_planes), + nn.ReLU(inplace=True), + ) + + +class UPerNet(nn.Module): + def __init__( + self, + output_filters: List[int], + num_classes=150, + pool_scales=(1, 2, 3, 6), + fpn_dim=256, + ): + super(UPerNet, self).__init__() + + last_fm_dim = output_filters[-1] + + # PPM Module + self.ppm_pooling = [] + self.ppm_conv = [] + + for scale in pool_scales: + self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) + self.ppm_conv.append( + nn.Sequential( + nn.Conv2d(last_fm_dim, 512, kernel_size=1, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ) + ) + self.ppm_pooling = nn.ModuleList(self.ppm_pooling) + self.ppm_conv = nn.ModuleList(self.ppm_conv) + self.ppm_last_conv = conv3x3_bn_relu( + last_fm_dim + len(pool_scales) * 512, fpn_dim, 1 + ) + + # FPN Module + self.fpn_in = [] + for fpn_inplane in output_filters[:-1]: # skip the top layer + self.fpn_in.append( + nn.Sequential( + nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), + nn.BatchNorm2d(fpn_dim), + nn.ReLU(inplace=True), + ) + ) + self.fpn_in = nn.ModuleList(self.fpn_in) + + self.fpn_out = [] + for i in range(len(output_filters) - 1): # skip the top layer + self.fpn_out.append(nn.Sequential(conv3x3_bn_relu(fpn_dim, fpn_dim, 1))) + self.fpn_out = nn.ModuleList(self.fpn_out) + + self.conv_last = nn.Sequential( + conv3x3_bn_relu(len(output_filters) * fpn_dim, fpn_dim, 1), + nn.Conv2d(fpn_dim, num_classes, kernel_size=1), + ) + + def forward(self, feature_maps): + last_fm = feature_maps[-1] + + input_size = last_fm.size() + ppm_out = [last_fm] + for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): + ppm_out.append( + pool_conv( + F.interpolate( + pool_scale(last_fm), + (input_size[2], input_size[3]), + mode="bilinear", + align_corners=False, + ) + ) + ) + ppm_out = torch.cat(ppm_out, 1) + f = self.ppm_last_conv(ppm_out) + + fpn_feature_list = [f] + for i in reversed(range(len(feature_maps) - 1)): + conv_x = feature_maps[i] + conv_x = self.fpn_in[i](conv_x) # lateral branch + + f = F.interpolate( + f, size=conv_x.size()[2:], mode="bilinear", align_corners=False + ) # top-down branch + f = conv_x + f + + fpn_feature_list.append(self.fpn_out[i](f)) + + fpn_feature_list.reverse() # [P2 - P5] + output_size = fpn_feature_list[0].size()[2:] + fusion_list = [fpn_feature_list[0]] + for i in range(1, len(fpn_feature_list)): + fusion_list.append( + F.interpolate( + fpn_feature_list[i], + output_size, + mode="bilinear", + align_corners=False, + ) + ) + + fusion_out = torch.cat(fusion_list, 1) + x = self.conv_last(fusion_out) + return x diff --git a/pytorch_toolbelt/modules/encoders.py b/pytorch_toolbelt/modules/encoders.py deleted file mode 100644 index 2750c26fb..000000000 --- a/pytorch_toolbelt/modules/encoders.py +++ /dev/null @@ -1,905 +0,0 @@ -"""Wrappers for different backbones for models that follows Encoder-Decoder architecture. - -Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. -""" - -from collections import OrderedDict -from typing import List - -from torch import nn -from torchvision.models import ( - resnet50, - resnet34, - resnet18, - resnet101, - resnet152, - squeezenet1_1, - densenet121, - densenet161, - densenet169, - densenet201, - DenseNet, -) - -from pytorch_toolbelt.modules.abn import ABN -from pytorch_toolbelt.modules.backbone.efficient_net import ( - efficient_net_b0, - efficient_net_b6, - efficient_net_b1, - efficient_net_b2, - efficient_net_b3, - efficient_net_b4, - efficient_net_b5, - efficient_net_b7, -) -from pytorch_toolbelt.modules.backbone.inceptionv4 import InceptionV4, \ - inceptionv4 -from pytorch_toolbelt.modules.backbone.mobilenetv3 import MobileNetV3 -from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, \ - WiderResNetA2 -from .backbone.mobilenet import MobileNetV2 -from .backbone.senet import ( - SENet, - se_resnext50_32x4d, - se_resnext101_32x4d, - se_resnet50, - se_resnet101, - se_resnet152, - senet154, -) - -__all__ = [ - "EncoderModule", - "ResnetEncoder", - "SEResnetEncoder", - "Resnet18Encoder", - "Resnet34Encoder", - "Resnet50Encoder", - "Resnet101Encoder", - "Resnet152Encoder", - "SEResNeXt50Encoder", - "SEResnet101Encoder", - "SEResNeXt101Encoder", - "SEResnet152Encoder", - "SENet154Encoder", - "MobilenetV2Encoder", - "MobilenetV3Encoder", - "SqueezenetEncoder", - "WiderResnetEncoder", - "WiderResnet16Encoder", - "WiderResnet20Encoder", - "WiderResnet38Encoder", - "WiderResnetA2Encoder", - "WiderResnet16A2Encoder", - "WiderResnet38A2Encoder", - "WiderResnet20A2Encoder", - "DenseNetEncoder", - "DenseNet121Encoder", - "DenseNet169Encoder", - "DenseNet201Encoder", - "EfficientNetEncoder", - "EfficientNetB0Encoder", - "EfficientNetB1Encoder", - "EfficientNetB2Encoder", - "EfficientNetB3Encoder", - "EfficientNetB4Encoder", - "EfficientNetB5Encoder", - "EfficientNetB6Encoder", - "EfficientNetB7Encoder", - "InceptionV4Encoder" -] - - -def _take(elements, indexes): - return list([elements[i] for i in indexes]) - - -class EncoderModule(nn.Module): - def __init__(self, channels: List[int], strides: List[int], - layers: List[int]): - super().__init__() - assert len(channels) == len(strides) - - self._layers = layers - - self._output_strides = _take(strides, layers) - self._output_filters = _take(channels, layers) - - def forward(self, x): - input = x - output_features = [] - for layer in self.encoder_layers: - output = layer(input) - output_features.append(output) - input = output - # Return only features that were requested - return _take(output_features, self._layers) - - @property - def output_strides(self) -> List[int]: - return self._output_strides - - @property - def output_filters(self) -> List[int]: - return self._output_filters - - @property - def encoder_layers(self): - raise NotImplementedError - - def set_trainable(self, trainable): - for param in self.parameters(): - param.requires_grad = bool(trainable) - - -class ResnetEncoder(EncoderModule): - def __init__(self, resnet, filters, strides, layers=None): - if layers is None: - layers = [1, 2, 3, 4] - super().__init__(filters, strides, layers) - - self.layer0 = nn.Sequential( - OrderedDict( - [("conv1", resnet.conv1), ("bn1", resnet.bn1), - ("relu", resnet.relu)] - ) - ) - self.maxpool = resnet.maxpool - - self.layer1 = resnet.layer1 - self.layer2 = resnet.layer2 - self.layer3 = resnet.layer3 - self.layer4 = resnet.layer4 - - @property - def encoder_layers(self): - return [self.layer0, self.layer1, self.layer2, self.layer3, - self.layer4] - - def forward(self, x): - input = x - output_features = [] - for layer in self.encoder_layers: - output = layer(input) - output_features.append(output) - - if layer == self.layer0: - # Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2 - output = self.maxpool(output) - input = output - - # Return only features that were requested - return _take(output_features, self._layers) - - -class Resnet18Encoder(ResnetEncoder): - def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet18(pretrained=pretrained), - [64, 64, 128, 256, 512], - [2, 4, 8, 16, 32], - layers, - ) - - -class Resnet34Encoder(ResnetEncoder): - def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet34(pretrained=pretrained), - [64, 64, 128, 256, 512], - [2, 4, 8, 16, 32], - layers, - ) - - -class Resnet50Encoder(ResnetEncoder): - def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet50(pretrained=pretrained), - [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], - layers, - ) - - -class Resnet101Encoder(ResnetEncoder): - def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet101(pretrained=pretrained), - [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], - layers, - ) - - -class Resnet152Encoder(ResnetEncoder): - def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet152(pretrained=pretrained), - [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], - layers, - ) - - -class SEResnetEncoder(EncoderModule): - """ - The only difference from vanilla ResNet is that it has 'layer0' module - """ - - def __init__(self, seresnet: SENet, channels, strides, layers=None): - if layers is None: - layers = [1, 2, 3, 4] - super().__init__(channels, strides, layers) - - self.maxpool = seresnet.layer0.pool - del seresnet.layer0.pool - - self.layer0 = seresnet.layer0 - self.layer1 = seresnet.layer1 - self.layer2 = seresnet.layer2 - self.layer3 = seresnet.layer3 - self.layer4 = seresnet.layer4 - - self._output_strides = _take(strides, layers) - self._output_filters = _take(channels, layers) - - @property - def encoder_layers(self): - return [self.layer0, self.layer1, self.layer2, self.layer3, - self.layer4] - - @property - def output_strides(self): - return self._output_strides - - @property - def output_filters(self): - return self._output_filters - - def forward(self, x): - input = x - output_features = [] - for layer in self.encoder_layers: - output = layer(input) - output_features.append(output) - - if layer == self.layer0: - # Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2 - output = self.maxpool(output) - input = output - - # Return only features that were requested - return _take(output_features, self._layers) - - -class SEResnet50Encoder(SEResnetEncoder): - def __init__(self, pretrained=True, layers=None): - encoder = se_resnet50(pretrained="imagenet" if pretrained else None) - super().__init__(encoder, [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], layers) - - -class SEResnet101Encoder(SEResnetEncoder): - def __init__(self, pretrained=True, layers=None): - encoder = se_resnet101(pretrained="imagenet" if pretrained else None) - super().__init__(encoder, [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], layers) - - -class SEResnet152Encoder(SEResnetEncoder): - def __init__(self, pretrained=True, layers=None): - encoder = se_resnet152(pretrained="imagenet" if pretrained else None) - super().__init__(encoder, [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], layers) - - -class SENet154Encoder(SEResnetEncoder): - def __init__(self, pretrained=True, layers=None): - encoder = senet154(pretrained="imagenet" if pretrained else None) - super().__init__(encoder, [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], layers) - - -class SEResNeXt50Encoder(SEResnetEncoder): - def __init__(self, pretrained=True, layers=None): - encoder = se_resnext50_32x4d( - pretrained="imagenet" if pretrained else None) - super().__init__(encoder, [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], layers) - - -class SEResNeXt101Encoder(SEResnetEncoder): - def __init__(self, pretrained=True, layers=None): - encoder = se_resnext101_32x4d( - pretrained="imagenet" if pretrained else None) - super().__init__(encoder, [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], layers) - - -class SqueezenetEncoder(EncoderModule): - def __init__(self, pretrained=True, layers=[1, 2, 3]): - super().__init__([64, 128, 256, 512], [4, 8, 16, 16], layers) - squeezenet = squeezenet1_1(pretrained=pretrained) - - # nn.Conv2d(3, 64, kernel_size=3, stride=2), - # nn.ReLU(inplace=True), - # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), - self.layer0 = nn.Sequential( - squeezenet.features[0], - squeezenet.features[1], - # squeezenet.features[2], - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - ) - - # Fire(64, 16, 64, 64), - # Fire(128, 16, 64, 64), - # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), - self.layer1 = nn.Sequential( - squeezenet.features[3], - squeezenet.features[4], - # squeezenet.features[5], - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - ) - - # Fire(128, 32, 128, 128), - # Fire(256, 32, 128, 128), - # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), - self.layer2 = nn.Sequential( - squeezenet.features[6], - squeezenet.features[7], - # squeezenet.features[8], - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), - ) - - # Fire(256, 48, 192, 192), - # Fire(384, 48, 192, 192), - # Fire(384, 64, 256, 256), - # Fire(512, 64, 256, 256), - self.layer3 = nn.Sequential( - squeezenet.features[9], - squeezenet.features[10], - squeezenet.features[11], - squeezenet.features[12], - ) - - @property - def encoder_layers(self): - return [self.layer0, self.layer1, self.layer2, self.layer3] - - -class MobilenetV2Encoder(EncoderModule): - def __init__(self, layers=[2, 3, 5, 7], activation="relu6"): - super().__init__( - [32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], - layers - ) - encoder = MobileNetV2(activation=activation) - - self.layer0 = encoder.layer0 - self.layer1 = encoder.layer1 - self.layer2 = encoder.layer2 - self.layer3 = encoder.layer3 - self.layer4 = encoder.layer4 - self.layer5 = encoder.layer5 - self.layer6 = encoder.layer6 - self.layer7 = encoder.layer7 - - @property - def encoder_layers(self): - return [ - self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4, - self.layer5, - self.layer6, - self.layer7, - ] - - -class MobilenetV3Encoder(EncoderModule): - def __init__( - self, input_channels=3, small=False, drop_prob=0.0, - layers=[1, 2, 3, 4] - ): - super().__init__( - [24, 24, 40, 96, 96] if small else [16, 40, 80, 160, 160], - [4, 8, 16, 32, 32], - layers, - ) - encoder = MobileNetV3( - in_channels=input_channels, small=small, drop_prob=drop_prob - ) - - self.conv1 = encoder.conv1 - self.bn1 = encoder.bn1 - self.act1 = encoder.act1 - - self.layer0 = encoder.layer0 - self.layer1 = encoder.layer1 - self.layer2 = encoder.layer2 - self.layer3 = encoder.layer3 - self.layer4 = encoder.layer4 - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.act1(x) - - output_features = [] - - x = self.layer0(x) - output_features.append(x) - - x = self.layer1(x) - output_features.append(x) - - x = self.layer2(x) - output_features.append(x) - - x = self.layer3(x) - output_features.append(x) - - x = self.layer4(x) - output_features.append(x) - - # Return only features that were requested - return _take(output_features, self._layers) - - @property - def encoder_layers(self): - return [self.layer0, self.layer1, self.layer2, self.layer3, - self.layer4] - - -class WiderResnetEncoder(EncoderModule): - def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): - super().__init__( - [64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], - layers - ) - - encoder = WiderResNet(structure, classes=0, norm_act=norm_act) - self.layer0 = encoder.mod1 - self.layer1 = encoder.mod2 - self.layer2 = encoder.mod3 - self.layer3 = encoder.mod4 - self.layer4 = encoder.mod5 - self.layer5 = encoder.mod6 - self.layer6 = encoder.mod7 - - self.pool2 = encoder.pool2 - self.pool3 = encoder.pool3 - self.pool4 = encoder.pool4 - self.pool5 = encoder.pool5 - self.pool6 = encoder.pool6 - - @property - def encoder_layers(self): - return [ - self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4, - self.layer5, - self.layer6, - ] - - def forward(self, input): - output_features = [] - - x = self.layer0(input) - output_features.append(x) - - x = self.layer1(self.pool2(x)) - output_features.append(x) - - x = self.layer2(self.pool3(x)) - output_features.append(x) - - x = self.layer3(self.pool4(x)) - output_features.append(x) - - x = self.layer4(self.pool5(x)) - output_features.append(x) - - x = self.layer5(self.pool6(x)) - output_features.append(x) - - x = self.layer6(x) - output_features.append(x) - - # Return only features that were requested - return _take(output_features, self._layers) - - -class WiderResnet16Encoder(WiderResnetEncoder): - def __init__(self, layers=None): - if layers is None: - layers = [2, 3, 4, 5, 6] - super().__init__(structure=[1, 1, 1, 1, 1, 1], layers=layers) - - -class WiderResnet20Encoder(WiderResnetEncoder): - def __init__(self, layers=None): - if layers is None: - layers = [2, 3, 4, 5, 6] - super().__init__(structure=[1, 1, 1, 3, 1, 1], layers=layers) - - -class WiderResnet38Encoder(WiderResnetEncoder): - def __init__(self, layers=None): - if layers is None: - layers = [2, 3, 4, 5, 6] - super().__init__(structure=[3, 3, 6, 3, 1, 1], layers=layers) - - -class WiderResnetA2Encoder(EncoderModule): - def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): - super().__init__( - [64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], - layers - ) - - encoder = WiderResNetA2(structure=structure, classes=0, - norm_act=norm_act) - self.layer0 = encoder.mod1 - self.layer1 = encoder.mod2 - self.layer2 = encoder.mod3 - self.layer3 = encoder.mod4 - self.layer4 = encoder.mod5 - self.layer5 = encoder.mod6 - self.layer6 = encoder.mod7 - - self.pool2 = encoder.pool2 - self.pool3 = encoder.pool3 - - @property - def encoder_layers(self): - return [ - self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4, - self.layer5, - self.layer6, - ] - - def forward(self, input): - output_features = [] - - out = self.layer0(input) - output_features.append(out) - - out = self.layer1(self.pool2(out)) - output_features.append(out) - - out = self.layer2(self.pool3(out)) - output_features.append(out) - - out = self.layer3(out) - output_features.append(out) - - out = self.layer4(out) - output_features.append(out) - - out = self.layer5(out) - output_features.append(out) - - out = self.layer6(out) - output_features.append(out) - - # Return only features that were requested - return _take(output_features, self._layers) - - -class WiderResnet16A2Encoder(WiderResnetA2Encoder): - def __init__(self, layers=None): - if layers is None: - layers = [2, 3, 4, 5, 6] - super().__init__(structure=[1, 1, 1, 1, 1, 1], layers=layers) - - -class WiderResnet20A2Encoder(WiderResnetA2Encoder): - def __init__(self, layers=None): - if layers is None: - layers = [2, 3, 4, 5, 6] - super().__init__(structure=[1, 1, 1, 3, 1, 1], layers=layers) - - -class WiderResnet38A2Encoder(WiderResnetA2Encoder): - def __init__(self, layers=None): - if layers is None: - layers = [2, 3, 4, 5, 6] - super().__init__(structure=[3, 3, 6, 3, 1, 1], layers=layers) - - -class DenseNetEncoder(EncoderModule): - def __init__( - self, - densenet: DenseNet, - strides: List[int], - channels: List[int], - layers: List[int], - first_avg_pool=False, - ): - if layers is None: - layers = [1, 2, 3, 4] - - super().__init__(channels, strides, layers) - - def except_pool(block: nn.Module): - del block.pool - return block - - self.layer0 = nn.Sequential( - densenet.features.conv0, densenet.features.norm0, - densenet.features.relu0 - ) - - self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) - self.pool0 = self.avg_pool if first_avg_pool else densenet.features.pool0 - - self.layer1 = nn.Sequential( - densenet.features.denseblock1, - except_pool(densenet.features.transition1) - ) - - self.layer2 = nn.Sequential( - densenet.features.denseblock2, - except_pool(densenet.features.transition2) - ) - - self.layer3 = nn.Sequential( - densenet.features.denseblock3, - except_pool(densenet.features.transition3) - ) - - self.layer4 = nn.Sequential(densenet.features.denseblock4) - - self._output_strides = _take(strides, layers) - self._output_filters = _take(channels, layers) - - @property - def encoder_layers(self): - return [self.layer0, self.layer1, self.layer2, self.layer3, - self.layer4] - - @property - def output_strides(self): - return self._output_strides - - @property - def output_filters(self): - return self._output_filters - - def forward(self, x): - input = x - output_features = [] - for layer in self.encoder_layers: - output = layer(input) - output_features.append(output) - - if layer == self.layer0: - # Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2 - output = self.pool0(output) - else: - output = self.avg_pool(output) - - input = output - - # Return only features that were requested - return _take(output_features, self._layers) - - -class DenseNet121Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, - first_avg_pool=False - ): - densenet = densenet121(pretrained=pretrained, - memory_efficient=memory_efficient) - strides = [2, 4, 8, 16, 32] - channels = [64, 128, 256, 512, 1024] - super().__init__(densenet, strides, channels, layers, first_avg_pool) - - -class DenseNet161Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, - first_avg_pool=False - ): - densenet = densenet161(pretrained=pretrained, - memory_efficient=memory_efficient) - strides = [2, 4, 8, 16, 32] - channels = [96, 192, 384, 1056, 2208] - super().__init__(densenet, strides, channels, layers, first_avg_pool) - - -class DenseNet169Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, - first_avg_pool=False - ): - densenet = densenet169(pretrained=pretrained, - memory_efficient=memory_efficient) - strides = [2, 4, 8, 16, 32] - channels = [64, 128, 256, 640, 1664] - super().__init__(densenet, strides, channels, layers, first_avg_pool) - - -class DenseNet201Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, - first_avg_pool=False - ): - densenet = densenet201(pretrained=pretrained, - memory_efficient=memory_efficient) - strides = [2, 4, 8, 16, 32] - channels = [64, 128, 256, 896, 1920] - super().__init__(densenet, strides, channels, layers, first_avg_pool) - - -class EfficientNetEncoder(EncoderModule): - def __init__(self, efficientnet, filters, strides, layers): - if layers is None: - layers = [1, 2, 4, 6] - - super().__init__(filters, strides, layers) - - self.stem = efficientnet.stem - - self.block0 = efficientnet.block0 - self.block1 = efficientnet.block1 - self.block2 = efficientnet.block2 - self.block3 = efficientnet.block3 - self.block4 = efficientnet.block4 - self.block5 = efficientnet.block5 - self.block6 = efficientnet.block6 - - @property - def encoder_layers(self): - return [ - self.block0, - self.block1, - self.block2, - self.block3, - self.block4, - self.block5, - self.block6, - ] - - def forward(self, x): - input = self.stem(x) - - output_features = [] - for layer in self.encoder_layers: - output = layer(input) - output_features.append(output) - input = output - - # Return only features that were requested - return _take(output_features, self._layers) - - -class EfficientNetB0Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b0(num_classes=1, **kwargs), - [16, 24, 40, 80, 112, 192, 320], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB1Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b1(num_classes=1, **kwargs), - [16, 24, 40, 80, 112, 192, 320], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB2Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b2(num_classes=1, **kwargs), - [16, 24, 48, 88, 120, 208, 352], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB3Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b3(num_classes=1, **kwargs), - [24, 32, 48, 96, 136, 232, 384], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB4Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b4(num_classes=1, **kwargs), - [24, 32, 56, 112, 160, 272, 448], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB5Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b5(num_classes=1, **kwargs), - [24, 40, 64, 128, 176, 304, 512], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB6Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b6(num_classes=1, **kwargs), - [32, 40, 72, 144, 200, 344, 576], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class EfficientNetB7Encoder(EfficientNetEncoder): - def __init__(self, layers=None, **kwargs): - super().__init__( - efficient_net_b7(num_classes=1, **kwargs), - [32, 48, 80, 160, 224, 384, 640], - [2, 4, 8, 16, 16, 32, 32], - layers, - ) - - -class InceptionV4Encoder(EncoderModule): - def __init__(self, pretrained=True, layers=None, **kwargs): - backbone = inceptionv4(pretrained="imagenet" if pretrained else None) - channels = [64, 192, 384, 1024, 1536] - strides = [2, 4, 8, 16, 32] # Note output strides are approximate - if layers is None: - layers = [1, 2, 3, 4] - features = backbone.features - super().__init__(channels, strides, layers) - - self.layer0 = features[0:3] - self.layer1 = features[3:5] - self.layer2 = features[5:10] - self.layer3 = features[10:18] - self.layer4 = features[18:22] - - self._output_strides = _take(strides, layers) - self._output_filters = _take(channels, layers) - - def forward(self, x): - input = x - output_features = [] - for layer in self.encoder_layers: - output = layer(input) - output_features.append(output) - input = output - - # Return only features that were requested - return _take(output_features, self._layers) - - @property - def encoder_layers(self): - return [self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4] diff --git a/pytorch_toolbelt/modules/encoders/__init__.py b/pytorch_toolbelt/modules/encoders/__init__.py new file mode 100644 index 000000000..9b4f975e3 --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/__init__.py @@ -0,0 +1,14 @@ +"""Wrappers for different backbones for models that follows Encoder-Decoder architecture. + +Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. +""" +from .common import * +from .densenet import * +from .efficientnet import * +from .hrnet import * +from .inception import * +from .mobilenet import * +from .resnet import * +from .seresnet import * +from .squeezenet import * +from .wide_resnet import * diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py new file mode 100644 index 000000000..7f35c20a5 --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -0,0 +1,49 @@ +"""Wrappers for different backbones for models that follows Encoder-Decoder architecture. + +Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. +""" + +from typing import List + +from torch import nn + + +def _take(elements, indexes): + return list([elements[i] for i in indexes]) + + +class EncoderModule(nn.Module): + def __init__(self, channels: List[int], strides: List[int], layers: List[int]): + super().__init__() + assert len(channels) == len(strides) + + self._layers = layers + + self._output_strides = _take(strides, layers) + self._output_filters = _take(channels, layers) + + def forward(self, x): + input = x + output_features = [] + for layer in self.encoder_layers: + output = layer(input) + output_features.append(output) + input = output + # Return only features that were requested + return _take(output_features, self._layers) + + @property + def output_strides(self) -> List[int]: + return self._output_strides + + @property + def output_filters(self) -> List[int]: + return self._output_filters + + @property + def encoder_layers(self): + raise NotImplementedError + + def set_trainable(self, trainable): + for param in self.parameters(): + param.requires_grad = bool(trainable) diff --git a/pytorch_toolbelt/modules/encoders/densenet.py b/pytorch_toolbelt/modules/encoders/densenet.py new file mode 100644 index 000000000..537ea0539 --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/densenet.py @@ -0,0 +1,133 @@ +from typing import List + +from torch import nn +from torchvision.models import ( + densenet121, + densenet161, + densenet169, + densenet201, + DenseNet, +) + +from .common import EncoderModule, _take + +__all__ = [ + "DenseNetEncoder", + "DenseNet121Encoder", + "DenseNet169Encoder", + "DenseNet161Encoder", + "DenseNet201Encoder", +] + + +class DenseNetEncoder(EncoderModule): + def __init__( + self, + densenet: DenseNet, + strides: List[int], + channels: List[int], + layers: List[int], + first_avg_pool=False, + ): + if layers is None: + layers = [1, 2, 3, 4] + + super().__init__(channels, strides, layers) + + def except_pool(block: nn.Module): + del block.pool + return block + + self.layer0 = nn.Sequential( + densenet.features.conv0, densenet.features.norm0, densenet.features.relu0 + ) + + self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) + self.pool0 = self.avg_pool if first_avg_pool else densenet.features.pool0 + + self.layer1 = nn.Sequential( + densenet.features.denseblock1, except_pool(densenet.features.transition1) + ) + + self.layer2 = nn.Sequential( + densenet.features.denseblock2, except_pool(densenet.features.transition2) + ) + + self.layer3 = nn.Sequential( + densenet.features.denseblock3, except_pool(densenet.features.transition3) + ) + + self.layer4 = nn.Sequential(densenet.features.denseblock4) + + self._output_strides = _take(strides, layers) + self._output_filters = _take(channels, layers) + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] + + @property + def output_strides(self): + return self._output_strides + + @property + def output_filters(self): + return self._output_filters + + def forward(self, x): + input = x + output_features = [] + for layer in self.encoder_layers: + output = layer(input) + output_features.append(output) + + if layer == self.layer0: + # Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2 + output = self.pool0(output) + else: + output = self.avg_pool(output) + + input = output + + # Return only features that were requested + return _take(output_features, self._layers) + + +class DenseNet121Encoder(DenseNetEncoder): + def __init__( + self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False + ): + densenet = densenet121(pretrained=pretrained, memory_efficient=memory_efficient) + strides = [2, 4, 8, 16, 32] + channels = [64, 128, 256, 512, 1024] + super().__init__(densenet, strides, channels, layers, first_avg_pool) + + +class DenseNet161Encoder(DenseNetEncoder): + def __init__( + self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False + ): + densenet = densenet161(pretrained=pretrained, memory_efficient=memory_efficient) + strides = [2, 4, 8, 16, 32] + channels = [96, 192, 384, 1056, 2208] + super().__init__(densenet, strides, channels, layers, first_avg_pool) + + +class DenseNet169Encoder(DenseNetEncoder): + def __init__( + self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False + ): + densenet = densenet169(pretrained=pretrained, memory_efficient=memory_efficient) + strides = [2, 4, 8, 16, 32] + channels = [64, 128, 256, 640, 1664] + super().__init__(densenet, strides, channels, layers, first_avg_pool) + + +class DenseNet201Encoder(DenseNetEncoder): + def __init__( + self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False + ): + densenet = densenet201(pretrained=pretrained, memory_efficient=memory_efficient) + strides = [2, 4, 8, 16, 32] + channels = [64, 128, 256, 896, 1920] + super().__init__(densenet, strides, channels, layers, first_avg_pool) diff --git a/pytorch_toolbelt/modules/encoders/efficientnet.py b/pytorch_toolbelt/modules/encoders/efficientnet.py new file mode 100644 index 000000000..42a886c8f --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/efficientnet.py @@ -0,0 +1,146 @@ +from pytorch_toolbelt.modules.backbone.efficient_net import ( + efficient_net_b0, + efficient_net_b6, + efficient_net_b1, + efficient_net_b2, + efficient_net_b3, + efficient_net_b4, + efficient_net_b5, + efficient_net_b7, +) + +from .common import EncoderModule, _take + +__all__ = [ + "EfficientNetEncoder", + "EfficientNetB0Encoder", + "EfficientNetB1Encoder", + "EfficientNetB2Encoder", + "EfficientNetB3Encoder", + "EfficientNetB4Encoder", + "EfficientNetB5Encoder", + "EfficientNetB6Encoder", + "EfficientNetB7Encoder", +] + + +class EfficientNetEncoder(EncoderModule): + def __init__(self, efficientnet, filters, strides, layers): + if layers is None: + layers = [1, 2, 4, 6] + + super().__init__(filters, strides, layers) + + self.stem = efficientnet.stem + + self.block0 = efficientnet.block0 + self.block1 = efficientnet.block1 + self.block2 = efficientnet.block2 + self.block3 = efficientnet.block3 + self.block4 = efficientnet.block4 + self.block5 = efficientnet.block5 + self.block6 = efficientnet.block6 + + @property + def encoder_layers(self): + return [ + self.block0, + self.block1, + self.block2, + self.block3, + self.block4, + self.block5, + self.block6, + ] + + def forward(self, x): + input = self.stem(x) + + output_features = [] + for layer in self.encoder_layers: + output = layer(input) + output_features.append(output) + input = output + + # Return only features that were requested + return _take(output_features, self._layers) + + +class EfficientNetB0Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b0(num_classes=1, **kwargs), + [16, 24, 40, 80, 112, 192, 320], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB1Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b1(num_classes=1, **kwargs), + [16, 24, 40, 80, 112, 192, 320], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB2Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b2(num_classes=1, **kwargs), + [16, 24, 48, 88, 120, 208, 352], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB3Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b3(num_classes=1, **kwargs), + [24, 32, 48, 96, 136, 232, 384], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB4Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b4(num_classes=1, **kwargs), + [24, 32, 56, 112, 160, 272, 448], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB5Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b5(num_classes=1, **kwargs), + [24, 40, 64, 128, 176, 304, 512], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB6Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b6(num_classes=1, **kwargs), + [32, 40, 72, 144, 200, 344, 576], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) + + +class EfficientNetB7Encoder(EfficientNetEncoder): + def __init__(self, layers=None, **kwargs): + super().__init__( + efficient_net_b7(num_classes=1, **kwargs), + [32, 48, 80, 160, 224, 384, 640], + [2, 4, 8, 16, 16, 32, 32], + layers, + ) diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py new file mode 100644 index 000000000..a639bcb93 --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -0,0 +1,32 @@ +from pytorch_toolbelt.modules.backbone.hrnet import hrnetv2 + +from .common import EncoderModule + +__all__ = ["HRNetV2Encoder48", "HRNetV2Encoder18", "HRNetV2Encoder34"] + + +class HRNetV2Encoder18(EncoderModule): + def __init__(self, pretrained=False): + super().__init__([144 + 72 + 36 + 18], [4], [0]) + self.hrnet = hrnetv2(width=18, pretrained=False) + + def forward(self, x): + return self.hrnet(x) + + +class HRNetV2Encoder34(EncoderModule): + def __init__(self, pretrained=False): + super().__init__([34 * 8 + 34 * 4 + 34 * 2 + 34], [4], [0]) + self.hrnet = hrnetv2(width=34, pretrained=False) + + def forward(self, x): + return self.hrnet(x) + + +class HRNetV2Encoder48(EncoderModule): + def __init__(self, pretrained=False): + super().__init__([720], [4], [0]) + self.hrnet = hrnetv2(pretrained=False) + + def forward(self, x): + return self.hrnet(x) diff --git a/pytorch_toolbelt/modules/encoders/inception.py b/pytorch_toolbelt/modules/encoders/inception.py new file mode 100644 index 000000000..d6c91122c --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/inception.py @@ -0,0 +1,39 @@ +from .common import EncoderModule, _take +from ..backbone.inceptionv4 import inceptionv4 + +__all__ = ["InceptionV4Encoder"] + + +class InceptionV4Encoder(EncoderModule): + def __init__(self, pretrained=True, layers=None, **kwargs): + backbone = inceptionv4(pretrained="imagenet" if pretrained else None) + channels = [64, 192, 384, 1024, 1536] + strides = [2, 4, 8, 16, 32] # Note output strides are approximate + if layers is None: + layers = [1, 2, 3, 4] + features = backbone.features + super().__init__(channels, strides, layers) + + self.layer0 = features[0:3] + self.layer1 = features[3:5] + self.layer2 = features[5:10] + self.layer3 = features[10:18] + self.layer4 = features[18:22] + + self._output_strides = _take(strides, layers) + self._output_filters = _take(channels, layers) + + def forward(self, x): + input = x + output_features = [] + for layer in self.encoder_layers: + output = layer(input) + output_features.append(output) + input = output + + # Return only features that were requested + return _take(output_features, self._layers) + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] diff --git a/pytorch_toolbelt/modules/encoders/mobilenet.py b/pytorch_toolbelt/modules/encoders/mobilenet.py new file mode 100644 index 000000000..b78eb0bbe --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/mobilenet.py @@ -0,0 +1,88 @@ +from .common import EncoderModule, _take +from ..backbone.mobilenet import MobileNetV2 +from ..backbone.mobilenetv3 import MobileNetV3 + +__all__ = ["MobilenetV2Encoder", "MobilenetV3Encoder"] + + +class MobilenetV2Encoder(EncoderModule): + def __init__(self, layers=[2, 3, 5, 7], activation="relu6"): + super().__init__( + [32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], layers + ) + encoder = MobileNetV2(activation=activation) + + self.layer0 = encoder.layer0 + self.layer1 = encoder.layer1 + self.layer2 = encoder.layer2 + self.layer3 = encoder.layer3 + self.layer4 = encoder.layer4 + self.layer5 = encoder.layer5 + self.layer6 = encoder.layer6 + self.layer7 = encoder.layer7 + + @property + def encoder_layers(self): + return [ + self.layer0, + self.layer1, + self.layer2, + self.layer3, + self.layer4, + self.layer5, + self.layer6, + self.layer7, + ] + + +class MobilenetV3Encoder(EncoderModule): + def __init__( + self, input_channels=3, small=False, drop_prob=0.0, layers=[1, 2, 3, 4] + ): + super().__init__( + [24, 24, 40, 96, 96] if small else [16, 40, 80, 160, 160], + [4, 8, 16, 32, 32], + layers, + ) + encoder = MobileNetV3( + in_channels=input_channels, small=small, drop_prob=drop_prob + ) + + self.conv1 = encoder.conv1 + self.bn1 = encoder.bn1 + self.act1 = encoder.act1 + + self.layer0 = encoder.layer0 + self.layer1 = encoder.layer1 + self.layer2 = encoder.layer2 + self.layer3 = encoder.layer3 + self.layer4 = encoder.layer4 + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.act1(x) + + output_features = [] + + x = self.layer0(x) + output_features.append(x) + + x = self.layer1(x) + output_features.append(x) + + x = self.layer2(x) + output_features.append(x) + + x = self.layer3(x) + output_features.append(x) + + x = self.layer4(x) + output_features.append(x) + + # Return only features that were requested + return _take(output_features, self._layers) + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] diff --git a/pytorch_toolbelt/modules/encoders/resnet.py b/pytorch_toolbelt/modules/encoders/resnet.py new file mode 100644 index 000000000..cacbf9d14 --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/resnet.py @@ -0,0 +1,109 @@ +"""Wrappers for different backbones for models that follows Encoder-Decoder architecture. + +Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. +""" + +from collections import OrderedDict + +from torch import nn +from torchvision.models import resnet50, resnet34, resnet18, resnet101, \ + resnet152 + +from .common import EncoderModule, _take + +__all__ = [ + "ResnetEncoder", + "Resnet18Encoder", + "Resnet34Encoder", + "Resnet50Encoder", + "Resnet101Encoder", + "Resnet152Encoder", +] + + +class ResnetEncoder(EncoderModule): + def __init__(self, resnet, filters, strides, layers=None): + if layers is None: + layers = [1, 2, 3, 4] + super().__init__(filters, strides, layers) + + self.layer0 = nn.Sequential( + OrderedDict( + [("conv1", resnet.conv1), ("bn1", resnet.bn1), ("relu", resnet.relu)] + ) + ) + self.maxpool = resnet.maxpool + + self.layer1 = resnet.layer1 + self.layer2 = resnet.layer2 + self.layer3 = resnet.layer3 + self.layer4 = resnet.layer4 + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] + + def forward(self, x): + input = x + output_features = [] + for layer in self.encoder_layers: + output = layer(input) + output_features.append(output) + + if layer == self.layer0: + # Fist maxpool operator is not a part of layer0 because we want that layer0 output to have stride of 2 + output = self.maxpool(output) + input = output + + # Return only features that were requested + return _take(output_features, self._layers) + + +class Resnet18Encoder(ResnetEncoder): + def __init__(self, pretrained=True, layers=None): + super().__init__( + resnet18(pretrained=pretrained), + [64, 64, 128, 256, 512], + [2, 4, 8, 16, 32], + layers, + ) + + +class Resnet34Encoder(ResnetEncoder): + def __init__(self, pretrained=True, layers=None): + super().__init__( + resnet34(pretrained=pretrained), + [64, 64, 128, 256, 512], + [2, 4, 8, 16, 32], + layers, + ) + + +class Resnet50Encoder(ResnetEncoder): + def __init__(self, pretrained=True, layers=None): + super().__init__( + resnet50(pretrained=pretrained), + [64, 256, 512, 1024, 2048], + [2, 4, 8, 16, 32], + layers, + ) + + +class Resnet101Encoder(ResnetEncoder): + def __init__(self, pretrained=True, layers=None): + super().__init__( + resnet101(pretrained=pretrained), + [64, 256, 512, 1024, 2048], + [2, 4, 8, 16, 32], + layers, + ) + + +class Resnet152Encoder(ResnetEncoder): + def __init__(self, pretrained=True, layers=None): + super().__init__( + resnet152(pretrained=pretrained), + [64, 256, 512, 1024, 2048], + [2, 4, 8, 16, 32], + layers, + ) diff --git a/pytorch_toolbelt/modules/encoders/seresnet.py b/pytorch_toolbelt/modules/encoders/seresnet.py new file mode 100644 index 000000000..b18dec903 --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/seresnet.py @@ -0,0 +1,113 @@ +"""Wrappers for different backbones for models that follows Encoder-Decoder architecture. + +Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. +""" + +from .common import EncoderModule, _take + +from ..backbone.senet import ( + SENet, + se_resnext50_32x4d, + se_resnext101_32x4d, + se_resnet50, + se_resnet101, + se_resnet152, + senet154, +) + +__all__ = [ + "SEResnetEncoder", + "SEResnet50Encoder", + "SEResnet101Encoder", + "SEResnet152Encoder", + "SEResNeXt50Encoder", + "SEResNeXt101Encoder", + "SENet154Encoder", +] + + +class SEResnetEncoder(EncoderModule): + """ + The only difference from vanilla ResNet is that it has 'layer0' module + """ + + def __init__(self, seresnet: SENet, channels, strides, layers=None): + if layers is None: + layers = [1, 2, 3, 4] + super().__init__(channels, strides, layers) + + self.maxpool = seresnet.layer0.pool + del seresnet.layer0.pool + + self.layer0 = seresnet.layer0 + self.layer1 = seresnet.layer1 + self.layer2 = seresnet.layer2 + self.layer3 = seresnet.layer3 + self.layer4 = seresnet.layer4 + + self._output_strides = _take(strides, layers) + self._output_filters = _take(channels, layers) + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] + + @property + def output_strides(self): + return self._output_strides + + @property + def output_filters(self): + return self._output_filters + + def forward(self, x): + input = x + output_features = [] + for layer in self.encoder_layers: + output = layer(input) + output_features.append(output) + + if layer == self.layer0: + # Fist maxpool operator is not a part of layer0 + # because we want that layer0 output to have stride of 2 + output = self.maxpool(output) + input = output + + # Return only features that were requested + return _take(output_features, self._layers) + + +class SEResnet50Encoder(SEResnetEncoder): + def __init__(self, pretrained=True, layers=None): + encoder = se_resnet50(pretrained="imagenet" if pretrained else None) + super().__init__(encoder, [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) + + +class SEResnet101Encoder(SEResnetEncoder): + def __init__(self, pretrained=True, layers=None): + encoder = se_resnet101(pretrained="imagenet" if pretrained else None) + super().__init__(encoder, [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) + + +class SEResnet152Encoder(SEResnetEncoder): + def __init__(self, pretrained=True, layers=None): + encoder = se_resnet152(pretrained="imagenet" if pretrained else None) + super().__init__(encoder, [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) + + +class SENet154Encoder(SEResnetEncoder): + def __init__(self, pretrained=True, layers=None): + encoder = senet154(pretrained="imagenet" if pretrained else None) + super().__init__(encoder, [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) + + +class SEResNeXt50Encoder(SEResnetEncoder): + def __init__(self, pretrained=True, layers=None): + encoder = se_resnext50_32x4d(pretrained="imagenet" if pretrained else None) + super().__init__(encoder, [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) + + +class SEResNeXt101Encoder(SEResnetEncoder): + def __init__(self, pretrained=True, layers=None): + encoder = se_resnext101_32x4d(pretrained="imagenet" if pretrained else None) + super().__init__(encoder, [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) diff --git a/pytorch_toolbelt/modules/encoders/squeezenet.py b/pytorch_toolbelt/modules/encoders/squeezenet.py new file mode 100644 index 000000000..9b66e5baf --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/squeezenet.py @@ -0,0 +1,57 @@ +from torch import nn +from torchvision.models import squeezenet1_1 + +from .common import EncoderModule + +__all__ = ["SqueezenetEncoder"] + + +class SqueezenetEncoder(EncoderModule): + def __init__(self, pretrained=True, layers=[1, 2, 3]): + super().__init__([64, 128, 256, 512], [4, 8, 16, 16], layers) + squeezenet = squeezenet1_1(pretrained=pretrained) + + # nn.Conv2d(3, 64, kernel_size=3, stride=2), + # nn.ReLU(inplace=True), + # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + self.layer0 = nn.Sequential( + squeezenet.features[0], + squeezenet.features[1], + # squeezenet.features[2], + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + # Fire(64, 16, 64, 64), + # Fire(128, 16, 64, 64), + # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + self.layer1 = nn.Sequential( + squeezenet.features[3], + squeezenet.features[4], + # squeezenet.features[5], + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + # Fire(128, 32, 128, 128), + # Fire(256, 32, 128, 128), + # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), + self.layer2 = nn.Sequential( + squeezenet.features[6], + squeezenet.features[7], + # squeezenet.features[8], + nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + ) + + # Fire(256, 48, 192, 192), + # Fire(384, 48, 192, 192), + # Fire(384, 64, 256, 256), + # Fire(512, 64, 256, 256), + self.layer3 = nn.Sequential( + squeezenet.features[9], + squeezenet.features[10], + squeezenet.features[11], + squeezenet.features[12], + ) + + @property + def encoder_layers(self): + return [self.layer0, self.layer1, self.layer2, self.layer3] diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py new file mode 100644 index 000000000..84579df7e --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -0,0 +1,179 @@ +from typing import List + +from pytorch_toolbelt.modules.abn import ABN +from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, \ + WiderResNetA2 + +from .common import EncoderModule, _take + +__all__ = [ + "WiderResnetEncoder", + "WiderResnet16A2Encoder", + "WiderResnet16Encoder", + "WiderResnet20Encoder", + "WiderResnet38A2Encoder", + "WiderResnet38Encoder", + "WiderResnetA2Encoder", + "WiderResnet20A2Encoder", +] + + +class WiderResnetEncoder(EncoderModule): + def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): + super().__init__( + [64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers + ) + + encoder = WiderResNet(structure, classes=0, norm_act=norm_act) + self.layer0 = encoder.mod1 + self.layer1 = encoder.mod2 + self.layer2 = encoder.mod3 + self.layer3 = encoder.mod4 + self.layer4 = encoder.mod5 + self.layer5 = encoder.mod6 + self.layer6 = encoder.mod7 + + self.pool2 = encoder.pool2 + self.pool3 = encoder.pool3 + self.pool4 = encoder.pool4 + self.pool5 = encoder.pool5 + self.pool6 = encoder.pool6 + + @property + def encoder_layers(self): + return [ + self.layer0, + self.layer1, + self.layer2, + self.layer3, + self.layer4, + self.layer5, + self.layer6, + ] + + def forward(self, input): + output_features = [] + + x = self.layer0(input) + output_features.append(x) + + x = self.layer1(self.pool2(x)) + output_features.append(x) + + x = self.layer2(self.pool3(x)) + output_features.append(x) + + x = self.layer3(self.pool4(x)) + output_features.append(x) + + x = self.layer4(self.pool5(x)) + output_features.append(x) + + x = self.layer5(self.pool6(x)) + output_features.append(x) + + x = self.layer6(x) + output_features.append(x) + + # Return only features that were requested + return _take(output_features, self._layers) + + +class WiderResnet16Encoder(WiderResnetEncoder): + def __init__(self, layers=None): + if layers is None: + layers = [2, 3, 4, 5, 6] + super().__init__(structure=[1, 1, 1, 1, 1, 1], layers=layers) + + +class WiderResnet20Encoder(WiderResnetEncoder): + def __init__(self, layers=None): + if layers is None: + layers = [2, 3, 4, 5, 6] + super().__init__(structure=[1, 1, 1, 3, 1, 1], layers=layers) + + +class WiderResnet38Encoder(WiderResnetEncoder): + def __init__(self, layers=None): + if layers is None: + layers = [2, 3, 4, 5, 6] + super().__init__(structure=[3, 3, 6, 3, 1, 1], layers=layers) + + +class WiderResnetA2Encoder(EncoderModule): + def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): + super().__init__( + [64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers + ) + + encoder = WiderResNetA2(structure=structure, classes=0, norm_act=norm_act) + self.layer0 = encoder.mod1 + self.layer1 = encoder.mod2 + self.layer2 = encoder.mod3 + self.layer3 = encoder.mod4 + self.layer4 = encoder.mod5 + self.layer5 = encoder.mod6 + self.layer6 = encoder.mod7 + + self.pool2 = encoder.pool2 + self.pool3 = encoder.pool3 + + @property + def encoder_layers(self): + return [ + self.layer0, + self.layer1, + self.layer2, + self.layer3, + self.layer4, + self.layer5, + self.layer6, + ] + + def forward(self, input): + output_features = [] + + out = self.layer0(input) + output_features.append(out) + + out = self.layer1(self.pool2(out)) + output_features.append(out) + + out = self.layer2(self.pool3(out)) + output_features.append(out) + + out = self.layer3(out) + output_features.append(out) + + out = self.layer4(out) + output_features.append(out) + + out = self.layer5(out) + output_features.append(out) + + out = self.layer6(out) + output_features.append(out) + + # Return only features that were requested + return _take(output_features, self._layers) + + +class WiderResnet16A2Encoder(WiderResnetA2Encoder): + def __init__(self, layers=None): + if layers is None: + layers = [2, 3, 4, 5, 6] + super().__init__(structure=[1, 1, 1, 1, 1, 1], layers=layers) + + +class WiderResnet20A2Encoder(WiderResnetA2Encoder): + def __init__(self, layers=None): + if layers is None: + layers = [2, 3, 4, 5, 6] + super().__init__(structure=[1, 1, 1, 3, 1, 1], layers=layers) + + +class WiderResnet38A2Encoder(WiderResnetA2Encoder): + def __init__(self, layers=None): + if layers is None: + layers = [2, 3, 4, 5, 6] + super().__init__(structure=[3, 3, 6, 3, 1, 1], layers=layers) diff --git a/pytorch_toolbelt/modules/fpn.py b/pytorch_toolbelt/modules/fpn.py index 11f214d0b..9ccec94a1 100644 --- a/pytorch_toolbelt/modules/fpn.py +++ b/pytorch_toolbelt/modules/fpn.py @@ -1,6 +1,6 @@ from __future__ import absolute_import -import torch +import torch from torch import nn from torch.nn import functional as F @@ -176,10 +176,15 @@ def forward(self, features): class HFF(nn.Module): """ - Hierarchical feature fusion - + Hierarchical feature fusion module. https://arxiv.org/pdf/1811.11431.pdf https://arxiv.org/pdf/1803.06815.pdf + + What it does is easily explained in code: + feature_map_N - feature_map of the smallest resolution + feature_map_0 - feature_map of the highest resolution + + >>> feature_map = feature_map_0 + up(feature_map[1] + up(feature_map[2] + up(feature_map[3] + ...)))) """ def __init__( diff --git a/pytorch_toolbelt/modules/hypercolumn.py b/pytorch_toolbelt/modules/hypercolumn.py index 1efd74f3d..15e9a62f5 100644 --- a/pytorch_toolbelt/modules/hypercolumn.py +++ b/pytorch_toolbelt/modules/hypercolumn.py @@ -3,9 +3,6 @@ Original paper: https://arxiv.org/abs/1411.5752 """ -import torch -from torch import nn -from torch.nn import functional as F from .fpn import FPNFuse __all__ = ["HyperColumn"] diff --git a/pytorch_toolbelt/modules/pooling.py b/pytorch_toolbelt/modules/pooling.py index bcaf67be0..ec845e569 100644 --- a/pytorch_toolbelt/modules/pooling.py +++ b/pytorch_toolbelt/modules/pooling.py @@ -42,7 +42,8 @@ def forward(self, x): class GWAP(nn.Module): """ - Global Weighted Average Pooling from paper "Global Weighted Average Pooling Bridges Pixel-level Localization and Image-level Classification" + Global Weighted Average Pooling from paper "Global Weighted Average + Pooling Bridges Pixel-level Localization and Image-level Classification" """ def __init__(self, features): From 902b4821c7d5872c5e26437c4989be8ffcdc025a Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 26 Oct 2019 21:36:00 +0300 Subject: [PATCH 02/79] Refactor encoders & decoders --- pytorch_toolbelt/__init__.py | 2 +- pytorch_toolbelt/modules/abn.py | 3 +- .../modules/backbone/efficient_net.py | 24 ++- pytorch_toolbelt/modules/backbone/hrnet.py | 124 +++++------- .../modules/backbone/inceptionv4.py | 35 ++-- .../modules/backbone/mobilenetv3.py | 1 + pytorch_toolbelt/modules/coord_conv.py | 13 +- pytorch_toolbelt/modules/decoders/__init__.py | 12 +- pytorch_toolbelt/modules/decoders/common.py | 15 ++ pytorch_toolbelt/modules/decoders/deeplab.py | 35 ++-- .../modules/{decoders.py => decoders/fpn.py} | 61 +----- pytorch_toolbelt/modules/decoders/fpn_cat.py | 7 - pytorch_toolbelt/modules/decoders/fpn_sum.py | 8 +- pytorch_toolbelt/modules/decoders/hrnet.py | 16 +- .../modules/decoders/pyramid_pooling.py | 6 +- pytorch_toolbelt/modules/decoders/unet.py | 183 ++++++++++++++++++ .../decoders/{unet_decoder.py => unet_v2.py} | 30 +-- pytorch_toolbelt/modules/dropblock.py | 2 + pytorch_toolbelt/modules/encoders/__init__.py | 1 + pytorch_toolbelt/modules/encoders/resnet.py | 3 +- pytorch_toolbelt/modules/encoders/unet.py | 66 +++++++ .../modules/encoders/wide_resnet.py | 3 +- pytorch_toolbelt/modules/unet.py | 131 ------------- 23 files changed, 417 insertions(+), 364 deletions(-) create mode 100644 pytorch_toolbelt/modules/decoders/common.py rename pytorch_toolbelt/modules/{decoders.py => decoders/fpn.py} (58%) create mode 100644 pytorch_toolbelt/modules/decoders/unet.py rename pytorch_toolbelt/modules/decoders/{unet_decoder.py => unet_v2.py} (83%) create mode 100644 pytorch_toolbelt/modules/encoders/unet.py delete mode 100644 pytorch_toolbelt/modules/unet.py diff --git a/pytorch_toolbelt/__init__.py b/pytorch_toolbelt/__init__.py index 55e84cc4a..5f013188d 100644 --- a/pytorch_toolbelt/__init__.py +++ b/pytorch_toolbelt/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import -__version__ = "0.2.2" +__version__ = "0.2.2-alpha" diff --git a/pytorch_toolbelt/modules/abn.py b/pytorch_toolbelt/modules/abn.py index 110c0376f..1183460be 100644 --- a/pytorch_toolbelt/modules/abn.py +++ b/pytorch_toolbelt/modules/abn.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as functional -from pytorch_toolbelt.modules.activations import ( + +from .activations import ( ACT_LEAKY_RELU, ACT_NONE, ACT_HARD_SIGMOID, diff --git a/pytorch_toolbelt/modules/backbone/efficient_net.py b/pytorch_toolbelt/modules/backbone/efficient_net.py index 360d73517..c5f3768ae 100644 --- a/pytorch_toolbelt/modules/backbone/efficient_net.py +++ b/pytorch_toolbelt/modules/backbone/efficient_net.py @@ -163,17 +163,23 @@ def __init__(self, block_args: EfficientNetBlockArgs, abn_block: ABN, abn_params def reset_parameters(self): if hasattr(self, "expand_conv"): - kaiming_normal_(self.expand_conv.weight, - a=self.abn0.slope, - nonlinearity=self.abn0.activation) + kaiming_normal_( + self.expand_conv.weight, + a=self.abn0.slope, + nonlinearity=self.abn0.activation, + ) - kaiming_normal_(self.depthwise_conv.weight, - a=self.abn1.slope, - nonlinearity=self.abn1.activation) + kaiming_normal_( + self.depthwise_conv.weight, + a=self.abn1.slope, + nonlinearity=self.abn1.activation, + ) - kaiming_normal_(self.project_conv.weight, - a=self.abn1.slope, - nonlinearity=self.abn2.activation) + kaiming_normal_( + self.project_conv.weight, + a=self.abn1.slope, + nonlinearity=self.abn2.activation, + ) def forward(self, inputs, drop_connect_rate=None): """ diff --git a/pytorch_toolbelt/modules/backbone/hrnet.py b/pytorch_toolbelt/modules/backbone/hrnet.py index b94d233b5..90a3bb536 100644 --- a/pytorch_toolbelt/modules/backbone/hrnet.py +++ b/pytorch_toolbelt/modules/backbone/hrnet.py @@ -5,15 +5,13 @@ import os import sys +from collections import OrderedDict from urllib.request import urlretrieve import torch import torch.nn as nn import torch.nn.functional as F -model_urls = { - "hrnetv2_48": "http://sceneparsing.csail.mit.edu/model/pretrained_resnet/hrnetv2_w48-imagenet.pth" -} HRNETV2_BN_MOMENTUM = 0.1 @@ -21,8 +19,7 @@ def hrnet_conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, - bias=False + in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False ) @@ -72,8 +69,7 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): self.conv3 = nn.Conv2d( planes, planes * self.expansion, kernel_size=1, bias=False ) - self.bn3 = nn.BatchNorm2d(planes * self.expansion, - momentum=HRNETV2_BN_MOMENTUM) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=HRNETV2_BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride @@ -103,14 +99,14 @@ def forward(self, x): class HighResolutionModule(nn.Module): def __init__( - self, - num_branches, - blocks, - num_blocks, - num_inchannels, - num_channels, - fuse_method, - multi_scale_output=True, + self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True, ): super(HighResolutionModule, self).__init__() self._check_branches( @@ -130,8 +126,7 @@ def __init__( self.relu = nn.ReLU(inplace=True) def _check_branches( - self, num_branches, blocks, num_blocks, num_inchannels, - num_channels + self, num_branches, blocks, num_blocks, num_inchannels, num_channels ): if num_branches != len(num_blocks): error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( @@ -151,13 +146,12 @@ def _check_branches( ) raise ValueError(error_msg) - def _make_one_branch(self, branch_index, block, num_blocks, num_channels, - stride=1): + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None if ( - stride != 1 - or self.num_inchannels[branch_index] - != num_channels[branch_index] * block.expansion + stride != 1 + or self.num_inchannels[branch_index] + != num_channels[branch_index] * block.expansion ): downsample = nn.Sequential( nn.Conv2d( @@ -182,12 +176,10 @@ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, downsample, ) ) - self.num_inchannels[branch_index] = num_channels[ - branch_index] * block.expansion + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): layers.append( - block(self.num_inchannels[branch_index], - num_channels[branch_index]) + block(self.num_inchannels[branch_index], num_channels[branch_index]) ) return nn.Sequential(*layers) @@ -196,8 +188,7 @@ def _make_branches(self, num_branches, block, num_blocks, num_channels): branches = [] for i in range(num_branches): - branches.append( - self._make_one_branch(i, block, num_blocks, num_channels)) + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) return nn.ModuleList(branches) @@ -340,13 +331,28 @@ def __init__(self, width=48, **kwargs): } # stem net - self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, - bias=False) - self.bn1 = nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM) - self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, - bias=False) - self.bn2 = nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM) - self.relu = nn.ReLU(inplace=True) + self.layer0 = nn.Sequential( + OrderedDict( + [ + ( + "conv1", + nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False + ), + ), + ("bn1", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), + ("relu", nn.ReLU(inplace=True)), + ( + "conv2", + nn.Conv2d( + 64, 64, kernel_size=3, stride=2, padding=1, bias=False + ), + ), + ("bn2", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), + ("relu2", nn.ReLU(inplace=True)), + ] + ) + ) self.layer1 = self._make_layer(HRNetBottleneck, 64, 64, 4) @@ -367,8 +373,7 @@ def __init__(self, width=48, **kwargs): num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels)) ] - self.transition2 = self._make_transition_layer(pre_stage_channels, - num_channels) + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage3, pre_stage_channels = self._make_stage( self.stage3_cfg, num_channels ) @@ -379,14 +384,12 @@ def __init__(self, width=48, **kwargs): num_channels = [ num_channels[i] * block.expansion for i in range(len(num_channels)) ] - self.transition3 = self._make_transition_layer(pre_stage_channels, - num_channels) + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) self.stage4, pre_stage_channels = self._make_stage( self.stage4_cfg, num_channels, multi_scale_output=True ) - def _make_transition_layer(self, num_channels_pre_layer, - num_channels_cur_layer): + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) num_branches_pre = len(num_channels_pre_layer) @@ -405,8 +408,7 @@ def _make_transition_layer(self, num_channels_pre_layer, bias=False, ), nn.BatchNorm2d( - num_channels_cur_layer[i], - momentum=HRNETV2_BN_MOMENTUM + num_channels_cur_layer[i], momentum=HRNETV2_BN_MOMENTUM ), nn.ReLU(inplace=True), ) @@ -424,10 +426,8 @@ def _make_transition_layer(self, num_channels_pre_layer, ) conv3x3s.append( nn.Sequential( - nn.Conv2d(inchannels, outchannels, 3, 2, 1, - bias=False), - nn.BatchNorm2d(outchannels, - momentum=HRNETV2_BN_MOMENTUM), + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), ) ) @@ -446,8 +446,7 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1): stride=stride, bias=False, ), - nn.BatchNorm2d(planes * block.expansion, - momentum=HRNETV2_BN_MOMENTUM), + nn.BatchNorm2d(planes * block.expansion, momentum=HRNETV2_BN_MOMENTUM), ) layers = [] @@ -458,8 +457,7 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1): return nn.Sequential(*layers) - def _make_stage(self, layer_config, num_inchannels, - multi_scale_output=True): + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} @@ -493,12 +491,7 @@ def _make_stage(self, layer_config, num_inchannels, return nn.Sequential(*modules), num_inchannels def forward(self, x, return_feature_maps=False): - x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.relu(x) + x = self.layer0(x) x = self.layer1(x) x_list = [] @@ -543,21 +536,6 @@ def forward(self, x, return_feature_maps=False): return [x] -def hrnetv2(pretrained=False, **kwargs): +def hrnetv2(**kwargs): model = HRNetV2(**kwargs) - if pretrained: - - def load_url(url, model_dir="./pretrained", map_location=None): - if not os.path.exists(model_dir): - os.makedirs(model_dir) - filename = url.split("/")[-1] - cached_file = os.path.join(model_dir, filename) - if not os.path.exists(cached_file): - sys.stderr.write( - 'Downloading: "{}" to {}\n'.format(url, cached_file)) - urlretrieve(url, cached_file) - return torch.load(cached_file, map_location=map_location) - - model.load_state_dict(load_url(model_urls["hrnetv2_48"]), strict=False) - return model diff --git a/pytorch_toolbelt/modules/backbone/inceptionv4.py b/pytorch_toolbelt/modules/backbone/inceptionv4.py index ebb6d3a7a..f8293a433 100644 --- a/pytorch_toolbelt/modules/backbone/inceptionv4.py +++ b/pytorch_toolbelt/modules/backbone/inceptionv4.py @@ -167,22 +167,16 @@ def __init__(self): self.branch1 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, - padding=(0, 3)), - BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, - padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0)), ) self.branch2 = nn.Sequential( BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, - padding=(3, 0)), - BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, - padding=(0, 3)), - BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, - padding=(3, 0)), - BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, - padding=(0, 3)), + BasicConv2d(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)), + BasicConv2d(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), ) self.branch3 = nn.Sequential( @@ -210,10 +204,8 @@ def __init__(self): self.branch1 = nn.Sequential( BasicConv2d(1024, 256, kernel_size=1, stride=1), - BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, - padding=(0, 3)), - BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, - padding=(3, 0)), + BasicConv2d(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)), + BasicConv2d(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)), BasicConv2d(320, 320, kernel_size=3, stride=2), ) @@ -293,7 +285,7 @@ def __init__(self, num_classes=1001): self.features = nn.Sequential( BasicConv2d(3, 32, kernel_size=3, stride=2), # 0, layer0 BasicConv2d(32, 32, kernel_size=3, stride=1), # 1, layer0 - BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), # 2 + BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), # 2 Mixed_3a(), # 3 Mixed_4a(), # 4 Mixed_5a(), # 5 @@ -333,10 +325,11 @@ def forward(self, input): def inceptionv4(num_classes=1000, pretrained="imagenet"): if pretrained: settings = pretrained_settings["inceptionv4"][pretrained] - assert (num_classes == settings["num_classes"]), \ - "num_classes should be {}, but is {}".format( - settings["num_classes"], num_classes - ) + assert ( + num_classes == settings["num_classes"] + ), "num_classes should be {}, but is {}".format( + settings["num_classes"], num_classes + ) # both 'imagenet'&'imagenet+background' are loaded from same parameters model = InceptionV4(num_classes=1001) diff --git a/pytorch_toolbelt/modules/backbone/mobilenetv3.py b/pytorch_toolbelt/modules/backbone/mobilenetv3.py index 45f8292a0..6de1037eb 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenetv3.py +++ b/pytorch_toolbelt/modules/backbone/mobilenetv3.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + # from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D from pytorch_toolbelt.modules.activations import HardSwish, HardSigmoid from pytorch_toolbelt.modules.identity import Identity diff --git a/pytorch_toolbelt/modules/coord_conv.py b/pytorch_toolbelt/modules/coord_conv.py index ac70c41bf..4e6a0ed22 100644 --- a/pytorch_toolbelt/modules/coord_conv.py +++ b/pytorch_toolbelt/modules/coord_conv.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +__all__ = ["append_coords", "AddCoords", "CoordConv"] + def append_coords(input_tensor, with_r=False): batch_size, _, x_dim, y_dim = input_tensor.size() @@ -40,13 +42,12 @@ def append_coords(input_tensor, with_r=False): return ret -""" -An alternative implementation for PyTorch with auto-infering the x-y dimensions. -https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py -""" - - class AddCoords(nn.Module): + """ + An alternative implementation for PyTorch with auto-infering the x-y dimensions. + https://github.com/mkocabas/CoordConv-pytorch/blob/master/CoordConv.py + """ + def __init__(self, with_r=False): super().__init__() self.with_r = with_r diff --git a/pytorch_toolbelt/modules/decoders/__init__.py b/pytorch_toolbelt/modules/decoders/__init__.py index 9ed669c49..f50125464 100644 --- a/pytorch_toolbelt/modules/decoders/__init__.py +++ b/pytorch_toolbelt/modules/decoders/__init__.py @@ -1,8 +1,12 @@ from __future__ import absolute_import -from .fpn_sum import * -from .fpn_cat import * +from .common import * from .deeplab import * -from .upernet import * +from .fpn import * +from .fpn_cat import * +from .fpn_sum import * +from .hrnet import * from .pyramid_pooling import * -from .unet_decoder import * \ No newline at end of file +from .unet import * +from .unet_v2 import * +from .upernet import * diff --git a/pytorch_toolbelt/modules/decoders/common.py b/pytorch_toolbelt/modules/decoders/common.py new file mode 100644 index 000000000..62ecbc18a --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/common.py @@ -0,0 +1,15 @@ +__all__ = ["DecoderModule"] + +from torch import nn + + +class DecoderModule(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, features): + raise NotImplementedError + + def set_trainable(self, trainable): + for param in self.parameters(): + param.requires_grad = bool(trainable) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 490c3d8ee..caf5378ae 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -1,21 +1,20 @@ -import math +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F +from .common import DecoderModule __all__ = ["DeeplabV3Decoder"] -class DeeplabV3Decoder(nn.Module): - def __init__( - self, - high_level_features: int, - low_level_features: int, - num_classes: int, - dropout=0.5, - ): +class DeeplabV3Decoder(DecoderModule): + def __init__(self, feature_maps: List[int], num_classes: int, dropout=0.5): super(DeeplabV3Decoder, self).__init__() + low_level_features = feature_maps[0] + high_level_features = feature_maps[-1] + self.conv1 = nn.Conv2d(low_level_features, 48, 1, bias=False) self.bn1 = nn.BatchNorm2d(48) self.relu = nn.ReLU(inplace=True) @@ -40,18 +39,24 @@ def __init__( ) self.reset_parameters() - def forward(self, x, low_level_feat): + def forward(self, feature_maps): + high_level_features = feature_maps[-1] + low_level_feat = feature_maps[0] + low_level_feat = self.conv1(low_level_feat) low_level_feat = self.bn1(low_level_feat) low_level_feat = self.relu(low_level_feat) - x = F.interpolate( - x, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True + high_level_features = F.interpolate( + high_level_features, + size=low_level_feat.size()[2:], + mode="bilinear", + align_corners=True, ) - x = torch.cat((x, low_level_feat), dim=1) - x = self.last_conv(x) + high_level_features = torch.cat((high_level_features, low_level_feat), dim=1) + high_level_features = self.last_conv(high_level_features) - return x + return high_level_features def reset_parameters(self): for m in self.modules(): diff --git a/pytorch_toolbelt/modules/decoders.py b/pytorch_toolbelt/modules/decoders/fpn.py similarity index 58% rename from pytorch_toolbelt/modules/decoders.py rename to pytorch_toolbelt/modules/decoders/fpn.py index c6fab2fd7..924af153d 100644 --- a/pytorch_toolbelt/modules/decoders.py +++ b/pytorch_toolbelt/modules/decoders/fpn.py @@ -1,63 +1,6 @@ from torch import nn - -from .fpn import FPNBottleneckBlock, FPNPredictionBlock, UpsampleAddConv -from .unet import UnetCentralBlock, UnetDecoderBlock - - -class DecoderModule(nn.Module): - def __init__(self): - super().__init__() - - def forward(self, features): - raise NotImplementedError - - def set_trainable(self, trainable): - for param in self.parameters(): - param.requires_grad = bool(trainable) - - -class UNetDecoder(DecoderModule): - def __init__( - self, features, start_features: int, dilation_factors=[1, 1, 1, 1], **kwargs - ): - super().__init__() - decoder_features = start_features - reversed_features = list(reversed(features)) - - output_filters = [decoder_features] - self.center = UnetCentralBlock(reversed_features[0], decoder_features) - - if dilation_factors is None: - dilation_factors = [1] * len(reversed_features) - - blocks = [] - for block_index, encoder_features in enumerate(reversed_features): - blocks.append( - UnetDecoderBlock( - output_filters[-1], - encoder_features, - decoder_features, - dilation=dilation_factors[block_index], - ) - ) - output_filters.append(decoder_features) - # print(block_index, decoder_features, encoder_features, decoder_features) - decoder_features = decoder_features // 2 - - self.blocks = nn.ModuleList(blocks) - self.output_filters = output_filters - - def forward(self, features): - reversed_features = list(reversed(features)) - decoder_outputs = [self.center(reversed_features[0])] - - for block_index, decoder_block, encoder_output in zip( - range(len(self.blocks)), self.blocks, reversed_features - ): - # print(block_index, decoder_outputs[-1].size(), encoder_output.size()) - decoder_outputs.append(decoder_block(decoder_outputs[-1], encoder_output)) - - return decoder_outputs +from .common import DecoderModule +from ..fpn import FPNBottleneckBlock, UpsampleAddConv, FPNPredictionBlock class FPNDecoder(DecoderModule): diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 77e9e06aa..936e7d580 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -7,17 +7,10 @@ from torch import nn, Tensor from torch.nn import functional as F -from pytorch_toolbelt.modules.decoders import ( - FPNDecoder, - FPNBottleneckBlock, - FPNPredictionBlock, -) from pytorch_toolbelt.modules.fpn import FPNFuse, UpsampleAdd __all__ = ["FPNCatDecoder"] -from ..modules import DoubleConvBNRelu - class FPNCatDecoder(DecoderModule): """ diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index f08817c33..9c68944c3 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -3,7 +3,7 @@ import torch from pytorch_toolbelt.modules import Identity, ABN -from pytorch_toolbelt.modules.decoders import DecoderModule +from .modules.decoders import DecoderModule from pytorch_toolbelt.utils.torch_utils import count_parameters @@ -67,9 +67,9 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: x = torch.cat( [ x, - F.interpolate(p2, size=x_size, mode="bilinear", align_corners=True), - F.interpolate(p4, size=x_size, mode="bilinear", align_corners=True), - F.interpolate(p8, size=x_size, mode="bilinear", align_corners=True), + F.interpolate(p2, size=x_size, mode="bilinear", align_corners=False), + F.interpolate(p4, size=x_size, mode="bilinear", align_corners=False), + F.interpolate(p8, size=x_size, mode="bilinear", align_corners=False), ], dim=1, ) diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index a582274f3..04b10e728 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -1,6 +1,13 @@ +from torch import nn + +from .common import DecoderModule +from ..backbone.hrnet import HRNETV2_BN_MOMENTUM + +__all__ = ["HRNetDecoder"] + class HRNetDecoder(DecoderModule): - def __init__(self, features: int, num_classes: int, dropout=0.): + def __init__(self, features: int, num_classes: int, dropout=0.0): super().__init__() self.last_layer = nn.Sequential( @@ -9,7 +16,8 @@ def __init__(self, features: int, num_classes: int, dropout=0.): out_channels=features, kernel_size=1, stride=1, - padding=0), + padding=0, + ), nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Dropout(dropout), @@ -18,9 +26,9 @@ def __init__(self, features: int, num_classes: int, dropout=0.): out_channels=num_classes, kernel_size=3, stride=1, - padding=1) + padding=1, + ), ) def forward(self, features): return self.last_layer(features[-1]) - diff --git a/pytorch_toolbelt/modules/decoders/pyramid_pooling.py b/pytorch_toolbelt/modules/decoders/pyramid_pooling.py index c15d421cf..023858fcd 100644 --- a/pytorch_toolbelt/modules/decoders/pyramid_pooling.py +++ b/pytorch_toolbelt/modules/decoders/pyramid_pooling.py @@ -5,9 +5,13 @@ import torch.nn.functional as F from torch import nn +from .common import DecoderModule -class PPMDecoder(nn.Module): + +class PPMDecoder(DecoderModule): """ + Pyramid pooling decoder module + https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/42b7567a43b1dab568e2bbfcbc8872778fbda92a/models/models.py """ diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py new file mode 100644 index 000000000..d3e14599a --- /dev/null +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -0,0 +1,183 @@ +from typing import List + +import torch +import torch.nn.functional as F +from torch import nn + +from ..abn import ABN +from .common import DecoderModule + +__all__ = ["UnetCentralBlock", "UnetDecoderBlock", "UNetDecoder"] + + +class UnetCentralBlock(nn.Module): + def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): + super().__init__() + self.conv1 = nn.Conv2d( + in_dec_filters, + out_filters, + kernel_size=3, + padding=1, + stride=2, + bias=False, + **kwargs + ) + self.bn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d( + out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs + ) + self.bn2 = abn_block(out_filters) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + return x + + +class UnetDecoderBlock(nn.Module): + """ + """ + + def __init__( + self, + in_dec_filters, + in_enc_filters, + out_filters, + abn_block=ABN, + pre_dropout_rate=0.0, + post_dropout_rate=0.0, + **kwargs + ): + super(UnetDecoderBlock, self).__init__() + + self.pre_drop = nn.Dropout(pre_dropout_rate, inplace=True) + + self.conv1 = nn.Conv2d( + in_dec_filters + in_enc_filters, + out_filters, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **kwargs + ) + self.bn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d( + out_filters, + out_filters, + kernel_size=3, + stride=1, + padding=1, + bias=False, + **kwargs + ) + self.bn2 = abn_block(out_filters) + + self.post_drop = nn.Dropout(post_dropout_rate, inplace=True) + + def forward(self, x, enc): + lat_size = enc.size()[2:] + x = F.interpolate(x, size=lat_size, mode="bilinear", align_corners=False) + + x = torch.cat([x, enc], 1) + + x = self.pre_drop(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.post_drop(x) + return x + + +# class UNetDecoder(DecoderModule): +# def __init__( +# self, features, start_features: int, dilation_factors=[1, 1, 1, 1], **kwargs +# ): +# super().__init__() +# decoder_features = start_features +# reversed_features = list(reversed(features)) +# +# output_filters = [decoder_features] +# self.center = UnetCentralBlock(reversed_features[0], decoder_features) +# +# if dilation_factors is None: +# dilation_factors = [1] * len(reversed_features) +# +# blocks = [] +# for block_index, encoder_features in enumerate(reversed_features): +# blocks.append( +# UnetDecoderBlock( +# output_filters[-1], +# encoder_features, +# decoder_features, +# dilation=dilation_factors[block_index], +# ) +# ) +# output_filters.append(decoder_features) +# # print(block_index, decoder_features, encoder_features, decoder_features) +# decoder_features = decoder_features // 2 +# +# self.blocks = nn.ModuleList(blocks) +# self.output_filters = output_filters +# +# def forward(self, features): +# reversed_features = list(reversed(features)) +# decoder_outputs = [self.center(reversed_features[0])] +# +# for block_index, decoder_block, encoder_output in zip( +# range(len(self.blocks)), self.blocks, reversed_features +# ): +# # print(block_index, decoder_outputs[-1].size(), encoder_output.size()) +# decoder_outputs.append(decoder_block(decoder_outputs[-1], encoder_output)) +# +# return decoder_outputs + + +class UNetDecoder(DecoderModule): + def __init__( + self, feature_maps: List[int], decoder_features: int, mask_channels: int + ): + super().__init__() + + if not isinstance(decoder_features, list): + decoder_features = [ + decoder_features * (2 ** i) for i in range(len(feature_maps)) + ] + + blocks = [] + for block_index, in_enc_features in enumerate(feature_maps[:-1]): + blocks.append( + UnetDecoderBlock( + decoder_features[block_index + 1], + in_enc_features, + decoder_features[block_index], + mask_channels, + ) + ) + + self.center = UnetCentralBlock( + feature_maps[-1], decoder_features[-1], mask_channels + ) + self.blocks = nn.ModuleList(blocks) + self.output_filters = decoder_features + + def forward(self, feature_maps): + + output, dsv = self.center(feature_maps[-1]) + decoder_outputs = [output] + dsv_list = [dsv] + + for decoder_block, encoder_output in zip( + reversed(self.blocks), reversed(feature_maps[:-1]) + ): + output, dsv = decoder_block(output, encoder_output) + decoder_outputs.append(output) + dsv_list.append(dsv) + + dsv_list = list(reversed(dsv_list)) + decoder_outputs = list(reversed(decoder_outputs)) + + return decoder_outputs, dsv_list diff --git a/pytorch_toolbelt/modules/decoders/unet_decoder.py b/pytorch_toolbelt/modules/decoders/unet_v2.py similarity index 83% rename from pytorch_toolbelt/modules/decoders/unet_decoder.py rename to pytorch_toolbelt/modules/decoders/unet_v2.py index 5eb611b4f..b0f1053a2 100644 --- a/pytorch_toolbelt/modules/decoders/unet_decoder.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -2,12 +2,11 @@ import torch import torch.nn.functional as F -from pytorch_toolbelt.modules import ABN -from pytorch_toolbelt.modules.decoders import DecoderModule -from pytorch_toolbelt.modules.encoders import SEResnet101Encoder -from pytorch_toolbelt.utils.torch_utils import count_parameters from torch import nn +from ..abn import ABN +from .common import DecoderModule + __all__ = ["UNetDecoderV2", "UnetCentralBlockV2", "UnetDecoderBlockV2"] @@ -114,7 +113,9 @@ def __init__(self, features: List[int], decoder_features: int, mask_channels: in ) ) - self.center = UnetCentralBlockV2(features[-1], decoder_features[-1], mask_channels) + self.center = UnetCentralBlockV2( + features[-1], decoder_features[-1], mask_channels + ) self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features @@ -135,22 +136,3 @@ def forward(self, feature_maps): decoder_outputs = list(reversed(decoder_outputs)) return decoder_outputs, dsv_list - - -@torch.no_grad() -def test_unetv2(): - encoder = SEResnet101Encoder().cuda().eval() - decoder = ( - UNetDecoderV2(encoder.output_filters, [128, 192, 256, 512], 5).cuda().eval() - ) - - print(count_parameters(encoder)) - print(count_parameters(decoder)) - print(decoder) - - x = torch.rand((1, 3, 256, 512)).cuda() - fm = encoder(x) - fm2 = decoder(fm) - - for fm, dsv in fm2: - print(fm.size(), dsv.size()) diff --git a/pytorch_toolbelt/modules/dropblock.py b/pytorch_toolbelt/modules/dropblock.py index 8e809fb4f..5a196743d 100644 --- a/pytorch_toolbelt/modules/dropblock.py +++ b/pytorch_toolbelt/modules/dropblock.py @@ -2,6 +2,8 @@ import torch.functional as F from torch import nn +__all__ = ["DropBlock2D", "DropBlock3D", "DropBlockScheduled"] + class DropBlock2D(nn.Module): r"""Randomly zeroes 2D spatial blocks of the input tensor. diff --git a/pytorch_toolbelt/modules/encoders/__init__.py b/pytorch_toolbelt/modules/encoders/__init__.py index 9b4f975e3..25ec0c0c1 100644 --- a/pytorch_toolbelt/modules/encoders/__init__.py +++ b/pytorch_toolbelt/modules/encoders/__init__.py @@ -11,4 +11,5 @@ from .resnet import * from .seresnet import * from .squeezenet import * +from .unet import * from .wide_resnet import * diff --git a/pytorch_toolbelt/modules/encoders/resnet.py b/pytorch_toolbelt/modules/encoders/resnet.py index cacbf9d14..e32af6efd 100644 --- a/pytorch_toolbelt/modules/encoders/resnet.py +++ b/pytorch_toolbelt/modules/encoders/resnet.py @@ -6,8 +6,7 @@ from collections import OrderedDict from torch import nn -from torchvision.models import resnet50, resnet34, resnet18, resnet101, \ - resnet152 +from torchvision.models import resnet50, resnet34, resnet18, resnet101, resnet152 from .common import EncoderModule, _take diff --git a/pytorch_toolbelt/modules/encoders/unet.py b/pytorch_toolbelt/modules/encoders/unet.py new file mode 100644 index 000000000..28703834b --- /dev/null +++ b/pytorch_toolbelt/modules/encoders/unet.py @@ -0,0 +1,66 @@ +from torch import nn + + +from ..abn import ABN + +from .common import EncoderModule, _take + +__all__ = ["UnetEncoderBlock", "UnetEncoder"] + + +class UnetEncoderBlock(nn.Module): + def __init__(self, in_dec_filters, out_filters, abn_block=ABN, stride=1, **kwargs): + super().__init__() + self.conv1 = nn.Conv2d( + in_dec_filters, + out_filters, + kernel_size=3, + padding=1, + stride=1, + bias=False, + **kwargs, + ) + self.bn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d( + out_filters, + out_filters, + kernel_size=3, + padding=1, + stride=stride, + bias=False, + **kwargs, + ) + self.bn2 = abn_block(out_filters) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.conv2(x) + x = self.bn2(x) + return x + + +class UnetEncoder(EncoderModule): + def __init__( + self, + input_channels=3, + features=32, + num_layers=4, + growth_factor=2, + abn_block=ABN, + ): + feature_maps = [features * growth_factor * (i + 1) for i in range(num_layers)] + strides = [2 * (i + 1) for i in range(num_layers)] + super().__init__(feature_maps, strides, layers=list(range(num_layers))) + + input_filters = input_channels + output_filters = feature_maps[0] + self.num_layers = num_layers + for layer in range(num_layers): + block = UnetEncoderBlock(input_filters, output_filters, abn_block=abn_block) + + self.add_module(f"layer{layer}", block) + + @property + def encoder_layers(self): + return [self[f"layer{layer}"] for layer in range(self.num_layers)] diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py index 84579df7e..3b7e35a25 100644 --- a/pytorch_toolbelt/modules/encoders/wide_resnet.py +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -1,8 +1,7 @@ from typing import List from pytorch_toolbelt.modules.abn import ABN -from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, \ - WiderResNetA2 +from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, WiderResNetA2 from .common import EncoderModule, _take diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py deleted file mode 100644 index b0e8b65d7..000000000 --- a/pytorch_toolbelt/modules/unet.py +++ /dev/null @@ -1,131 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .abn import ABN, ACT_RELU - -__all__ = ["UnetEncoderBlock", "UnetDecoderBlock", "UnetCentralBlock"] - - -class UnetEncoderBlock(nn.Module): - def __init__( - self, - in_dec_filters, - out_filters, - abn_block=ABN, - activation=ACT_RELU, - stride=1, - **kwargs - ): - super().__init__() - self.conv1 = nn.Conv2d( - in_dec_filters, - out_filters, - kernel_size=3, - padding=1, - stride=1, - bias=False, - **kwargs - ) - self.bn1 = abn_block(out_filters, activation=activation) - self.conv2 = nn.Conv2d( - out_filters, - out_filters, - kernel_size=3, - padding=1, - stride=stride, - bias=False, - **kwargs - ) - self.bn2 = abn_block(out_filters, activation=activation) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - return x - - -class UnetCentralBlock(nn.Module): - def __init__( - self, in_dec_filters, out_filters, abn_block=ABN, activation=ACT_RELU, **kwargs - ): - super().__init__() - self.conv1 = nn.Conv2d( - in_dec_filters, - out_filters, - kernel_size=3, - padding=1, - stride=2, - bias=False, - **kwargs - ) - self.bn1 = abn_block(out_filters, activation=activation) - self.conv2 = nn.Conv2d( - out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs - ) - self.bn2 = abn_block(out_filters, activation=activation) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - return x - - -class UnetDecoderBlock(nn.Module): - """ - """ - - def __init__( - self, - in_dec_filters, - in_enc_filters, - out_filters, - abn_block=ABN, - activation=ACT_RELU, - pre_dropout_rate=0.0, - post_dropout_rate=0.0, - **kwargs - ): - super(UnetDecoderBlock, self).__init__() - - self.conv1 = nn.Conv2d( - in_dec_filters + in_enc_filters, - out_filters, - kernel_size=3, - stride=1, - padding=1, - bias=False, - **kwargs - ) - self.bn1 = abn_block(out_filters, activation=activation) - self.conv2 = nn.Conv2d( - out_filters, - out_filters, - kernel_size=3, - stride=1, - padding=1, - bias=False, - **kwargs - ) - self.bn2 = abn_block(out_filters, activation=activation) - - self.pre_drop = nn.Dropout(pre_dropout_rate, inplace=True) - self.post_drop = nn.Dropout(post_dropout_rate, inplace=True) - - def forward(self, x, enc): - lat_size = enc.size()[2:] - x = F.interpolate(x, size=lat_size, mode="bilinear", align_corners=False) - - x = torch.cat([x, enc], 1) - - x = self.pre_drop(x) - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.post_drop(x) - return x From b6bc4c9d55c6fc436237c9c1cde50687d58940bf Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 26 Oct 2019 21:53:24 +0300 Subject: [PATCH 03/79] Refactor encoders & decoders --- .../modules/backbone/efficient_net.py | 7 +- .../modules/backbone/mobilenetv3.py | 25 ----- pytorch_toolbelt/modules/decoders/fpn.py | 4 +- pytorch_toolbelt/modules/decoders/fpn_cat.py | 15 --- pytorch_toolbelt/modules/decoders/fpn_sum.py | 15 --- .../modules/encoders/efficientnet.py | 2 +- .../modules/encoders/wide_resnet.py | 2 +- pytorch_toolbelt/modules/fpn.py | 13 +-- pytorch_toolbelt/modules/scse.py | 9 -- pytorch_toolbelt/utils/catalyst/metrics.py | 4 +- tests/test_decoders.py | 42 +++++++++ tests/test_encoders.py | 92 +++++++++++++++++++ tests/test_modules.py | 81 ---------------- 13 files changed, 147 insertions(+), 164 deletions(-) create mode 100644 tests/test_decoders.py create mode 100644 tests/test_encoders.py diff --git a/pytorch_toolbelt/modules/backbone/efficient_net.py b/pytorch_toolbelt/modules/backbone/efficient_net.py index c5f3768ae..fc7ee6669 100644 --- a/pytorch_toolbelt/modules/backbone/efficient_net.py +++ b/pytorch_toolbelt/modules/backbone/efficient_net.py @@ -4,13 +4,14 @@ from typing import List import torch -from pytorch_toolbelt.modules import ABN, SpatialGate2d -from pytorch_toolbelt.modules.activations import ACT_HARD_SWISH -from pytorch_toolbelt.modules.agn import AGN from torch import nn from torch.nn import functional as F from torch.nn.init import kaiming_normal_ +from pytorch_toolbelt.modules import ABN, SpatialGate2d +from pytorch_toolbelt.modules.activations import ACT_HARD_SWISH +from pytorch_toolbelt.modules.agn import AGN + def round_filters(filters, width_coefficient, depth_divisor, min_depth): """ diff --git a/pytorch_toolbelt/modules/backbone/mobilenetv3.py b/pytorch_toolbelt/modules/backbone/mobilenetv3.py index 6de1037eb..320d70991 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenetv3.py +++ b/pytorch_toolbelt/modules/backbone/mobilenetv3.py @@ -364,28 +364,3 @@ def forward(self, x): return x -if __name__ == "__main__": - """Testing - """ - from pytorch_toolbelt.utils.torch_utils import count_parameters - - model1 = MobileNetV3() - print(model1, count_parameters(model1)) - - model2 = MobileNetV3(scale=0.35) - print(model2, count_parameters(model2)) - - model3 = MobileNetV3(in_channels=2, num_classes=10) - print(model3, count_parameters(model3)) - - x = torch.randn(1, 2, 224, 224) - print(model3(x)) - - model4_size = 32 * 10 - model4 = MobileNetV3(num_classes=10) - print(model4, count_parameters(model4)) - x2 = torch.randn(1, 3, model4_size, model4_size) - print(model4(x2)) - - model5 = MobileNetV3(scale=0.35, small=True) - print(model5, count_parameters(model5)) diff --git a/pytorch_toolbelt/modules/decoders/fpn.py b/pytorch_toolbelt/modules/decoders/fpn.py index 924af153d..732a26bea 100644 --- a/pytorch_toolbelt/modules/decoders/fpn.py +++ b/pytorch_toolbelt/modules/decoders/fpn.py @@ -1,6 +1,6 @@ from torch import nn from .common import DecoderModule -from ..fpn import FPNBottleneckBlock, UpsampleAddConv, FPNPredictionBlock +from ..fpn import FPNBottleneckBlock, UpsampleAdd, FPNPredictionBlock class FPNDecoder(DecoderModule): @@ -8,7 +8,7 @@ def __init__( self, features, bottleneck=FPNBottleneckBlock, - upsample_add_block=UpsampleAddConv, + upsample_add_block=UpsampleAdd, prediction_block=FPNPredictionBlock, fpn_features=128, prediction_features=128, diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 936e7d580..7ced2d88c 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -76,18 +76,3 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: x = self.final_block(fused) return x, dsv_masks - -@torch.no_grad() -def test_fpn_cat(): - channels = [256, 512, 1024, 2048] - sizes = [64, 32, 16, 8] - - net = FPNCatDecoder(channels, 5).eval() - - input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] - output, dsv_masks = net(input) - - print(output.size(), output.mean(), output.std()) - for dsv in dsv_masks: - print(dsv.size(), dsv.mean(), dsv.std()) - print(count_parameters(net)) diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index 9c68944c3..d706ae2f9 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -194,18 +194,3 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor]: x = self.final_block(x) return x, dsv_masks - -@torch.no_grad() -def test_fpn_sum(): - channels = [256, 512, 1024, 2048] - sizes = [64, 32, 16, 8] - - net = FPNSumDecoder(channels, 5).eval() - - input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] - output, dsv_masks = net(input) - - print(output.size(), output.mean(), output.std()) - for dsv in dsv_masks: - print(dsv.size(), dsv.mean(), dsv.std()) - print(count_parameters(net)) diff --git a/pytorch_toolbelt/modules/encoders/efficientnet.py b/pytorch_toolbelt/modules/encoders/efficientnet.py index 42a886c8f..a8b02dc7d 100644 --- a/pytorch_toolbelt/modules/encoders/efficientnet.py +++ b/pytorch_toolbelt/modules/encoders/efficientnet.py @@ -1,4 +1,4 @@ -from pytorch_toolbelt.modules.backbone.efficient_net import ( +from ..backbone.efficient_net import ( efficient_net_b0, efficient_net_b6, efficient_net_b1, diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py index 3b7e35a25..8ccdb1ad0 100644 --- a/pytorch_toolbelt/modules/encoders/wide_resnet.py +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -1,6 +1,6 @@ from typing import List -from pytorch_toolbelt.modules.abn import ABN +from ..modules.abn import ABN from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, WiderResNetA2 from .common import EncoderModule, _take diff --git a/pytorch_toolbelt/modules/fpn.py b/pytorch_toolbelt/modules/fpn.py index 9ccec94a1..91b722fbe 100644 --- a/pytorch_toolbelt/modules/fpn.py +++ b/pytorch_toolbelt/modules/fpn.py @@ -39,7 +39,7 @@ def forward(self, x): class FPNPredictionBlock(nn.Module): - def __init__(self, input_channels, output_channels, mode="nearest"): + def __init__(self, input_channels, output_channels, mode="nearest", align_corners=None): super().__init__() self.input_channels = input_channels self.output_channels = output_channels @@ -47,16 +47,9 @@ def __init__(self, input_channels, output_channels, mode="nearest"): self.input_channels, self.output_channels, kernel_size=3, padding=1 ) self.mode = mode + self.align_corners = align_corners - def forward(self, x, y=None): - if y is not None: - x = x + F.interpolate( - y, - size=x.size()[2:], - mode=self.mode, - align_corners=False if self.mode == "bilinear" else None, - ) - + def forward(self, x): x = self.conv(x) return x diff --git a/pytorch_toolbelt/modules/scse.py b/pytorch_toolbelt/modules/scse.py index 3ce563cd1..d81bbdc16 100644 --- a/pytorch_toolbelt/modules/scse.py +++ b/pytorch_toolbelt/modules/scse.py @@ -5,8 +5,6 @@ from torch import nn, Tensor from torch.nn import functional as F -from torch.nn.init import kaiming_normal_ - __all__ = [ "ChannelGate2d", @@ -60,12 +58,6 @@ def __init__(self, channels, reduction=None, squeeze_channels=None): self.squeeze = nn.Conv2d(channels, squeeze_channels, kernel_size=1) self.expand = nn.Conv2d(squeeze_channels, channels, kernel_size=1) - self.reset_parameters() - - def reset_parameters(self): - kaiming_normal_(self.squeeze.weight, nonlinearity="relu") - kaiming_normal_(self.expand.weight, nonlinearity="sigmoid") - def forward(self, x: Tensor): module_input = x x = self.avg_pool(x) @@ -73,7 +65,6 @@ def forward(self, x: Tensor): x = F.relu(x, inplace=True) x = self.expand(x) x = x.sigmoid() - # print(module_input.mean().item(), module_input.std().item(), x.mean().item(), x.std().item()) return module_input * x diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 36beec60d..3a0eacb04 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -3,8 +3,8 @@ import numpy as np import torch from catalyst.dl import Callback, RunnerState, MetricCallback, CallbackOrder -from pytorch_toolbelt.utils.catalyst.visualization import get_tensorboard_logger -from pytorch_toolbelt.utils.torch_utils import to_numpy +from .visualization import get_tensorboard_logger +from ..torch_utils import to_numpy from pytorch_toolbelt.utils.visualization import ( render_figure_to_tensor, plot_confusion_matrix, diff --git a/tests/test_decoders.py b/tests/test_decoders.py new file mode 100644 index 000000000..b850ecb7f --- /dev/null +++ b/tests/test_decoders.py @@ -0,0 +1,42 @@ +import pytest +import torch + +import pytorch_toolbelt.modules.encoders as E +from pytorch_toolbelt.modules.backbone.inceptionv4 import inceptionv4 +from pytorch_toolbelt.modules.decoders import FPNSumDecoder, FPNCatDecoder +from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters + +skip_if_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="Cuda is not available" +) + +@torch.no_grad() +def test_fpn_sum(): + channels = [256, 512, 1024, 2048] + sizes = [64, 32, 16, 8] + + net = FPNSumDecoder(channels, 5).eval() + + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] + output, dsv_masks = net(input) + + print(output.size(), output.mean(), output.std()) + for dsv in dsv_masks: + print(dsv.size(), dsv.mean(), dsv.std()) + print(count_parameters(net)) + + +@torch.no_grad() +def test_fpn_cat(): + channels = [256, 512, 1024, 2048] + sizes = [64, 32, 16, 8] + + net = FPNCatDecoder(channels, 5).eval() + + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] + output, dsv_masks = net(input) + + print(output.size(), output.mean(), output.std()) + for dsv in dsv_masks: + print(dsv.size(), dsv.mean(), dsv.std()) + print(count_parameters(net)) diff --git a/tests/test_encoders.py b/tests/test_encoders.py new file mode 100644 index 000000000..e45e402db --- /dev/null +++ b/tests/test_encoders.py @@ -0,0 +1,92 @@ +import pytest +import torch + +import pytorch_toolbelt.modules.encoders as E +from pytorch_toolbelt.modules.backbone.inceptionv4 import inceptionv4 +from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters + +skip_if_no_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), reason="Cuda is not available" +) + +@pytest.mark.parametrize( + ["encoder", "encoder_params"], + [ + [E.Resnet34Encoder, {"pretrained": False}], + [E.Resnet50Encoder, {"pretrained": False}], + [E.SEResNeXt50Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}], + [E.SEResnet50Encoder, {"pretrained": False}], + [E.Resnet152Encoder, {"pretrained": False}], + [E.Resnet101Encoder, {"pretrained": False}], + [E.SEResnet152Encoder, {"pretrained": False}], + [E.SEResNeXt101Encoder, {"pretrained": False}], + [E.SEResnet101Encoder, {"pretrained": False}], + [E.SENet154Encoder, {"pretrained": False}], + [E.WiderResnet16Encoder, {}], + [E.WiderResnet20Encoder, {}], + [E.WiderResnet38Encoder, {}], + [E.WiderResnet16A2Encoder, {}], + [E.WiderResnet20A2Encoder, {}], + [E.WiderResnet38A2Encoder, {}], + [E.EfficientNetB0Encoder, {}], + [E.EfficientNetB1Encoder, {}], + [E.EfficientNetB2Encoder, {}], + [E.EfficientNetB3Encoder, {}], + [E.EfficientNetB4Encoder, {}], + [E.EfficientNetB5Encoder, {}], + [E.EfficientNetB6Encoder, {}], + [E.EfficientNetB7Encoder, {}], + [E.DenseNet121Encoder, {}], + [E.DenseNet161Encoder, {}], + [E.DenseNet169Encoder, {}], + [E.DenseNet201Encoder, {}], + ], +) +@torch.no_grad() +@skip_if_no_cuda +def test_encoders(encoder: E.EncoderModule, encoder_params): + net = encoder(**encoder_params).eval() + print(net.__class__.__name__, count_parameters(net)) + print(net.output_strides) + print(net.output_filters) + input = torch.rand((4, 3, 256, 256)) + input = maybe_cuda(input) + net = maybe_cuda(net) + output = net(input) + assert len(output) == len(net.output_filters) + for feature_map, expected_stride, expected_channels in zip( + output, net.output_strides, net.output_filters + ): + assert feature_map.size(1) == expected_channels + assert feature_map.size(2) * expected_stride == 256 + assert feature_map.size(3) * expected_stride == 256 + + +@torch.no_grad() +@skip_if_no_cuda +def test_inceptionv4_encoder(): + backbone = inceptionv4(pretrained=False) + backbone.last_linear = None + + net = E.InceptionV4Encoder(backbone, layers=[0, 1, 2, 3, 4]).cuda() + + print(count_parameters(backbone)) + print(count_parameters(net)) + + x = torch.randn((4, 3, 512, 512)).cuda() + + out = net(x) + for fm in out: + print(fm.size()) + + +@torch.no_grad() +@skip_if_no_cuda +def test_densenet(): + from torchvision.models import densenet121 + + net1 = E.DenseNet121Encoder() + net2 = densenet121(pretrained=False) + net2.classifier = None + + print(count_parameters(net1), count_parameters(net2)) diff --git a/tests/test_modules.py b/tests/test_modules.py index 8c7104680..5b3321e72 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -11,59 +11,6 @@ ) -@pytest.mark.parametrize( - ["encoder", "encoder_params"], - [ - [E.Resnet34Encoder, {"pretrained": False}], - [E.Resnet50Encoder, {"pretrained": False}], - [E.SEResNeXt50Encoder, - {"pretrained": False, "layers": [0, 1, 2, 3, 4]}], - [E.SEResnet50Encoder, {"pretrained": False}], - [E.Resnet152Encoder, {"pretrained": False}], - [E.Resnet101Encoder, {"pretrained": False}], - [E.SEResnet152Encoder, {"pretrained": False}], - [E.SEResNeXt101Encoder, {"pretrained": False}], - [E.SEResnet101Encoder, {"pretrained": False}], - [E.SENet154Encoder, {"pretrained": False}], - [E.WiderResnet16Encoder, {}], - [E.WiderResnet20Encoder, {}], - [E.WiderResnet38Encoder, {}], - [E.WiderResnet16A2Encoder, {}], - [E.WiderResnet20A2Encoder, {}], - [E.WiderResnet38A2Encoder, {}], - [E.EfficientNetB0Encoder, {}], - [E.EfficientNetB1Encoder, {}], - [E.EfficientNetB2Encoder, {}], - [E.EfficientNetB3Encoder, {}], - [E.EfficientNetB4Encoder, {}], - [E.EfficientNetB5Encoder, {}], - [E.EfficientNetB6Encoder, {}], - [E.EfficientNetB7Encoder, {}], - [E.DenseNet121Encoder, {}], - [E.DenseNet161Encoder, {}], - [E.DenseNet169Encoder, {}], - [E.DenseNet201Encoder, {}], - ], -) -@torch.no_grad() -@skip_if_no_cuda -def test_encoders(encoder: E.EncoderModule, encoder_params): - net = encoder(**encoder_params).eval() - print(net.__class__.__name__, count_parameters(net)) - print(net.output_strides) - print(net.output_filters) - input = torch.rand((4, 3, 256, 256)) - input = maybe_cuda(input) - net = maybe_cuda(net) - output = net(input) - assert len(output) == len(net.output_filters) - for feature_map, expected_stride, expected_channels in zip( - output, net.output_strides, net.output_filters - ): - assert feature_map.size(1) == expected_channels - assert feature_map.size(2) * expected_stride == 256 - assert feature_map.size(3) * expected_stride == 256 - def test_hff_dynamic_size(): feature_maps = [ @@ -93,31 +40,3 @@ def test_hff_static_size(): assert output.size(2) == 512 assert output.size(3) == 512 - -@torch.no_grad() -@skip_if_no_cuda -def test_inceptionv4_encoder(): - backbone = inceptionv4(pretrained=False) - backbone.last_linear = None - - net = E.InceptionV4Encoder(backbone, layers=[0, 1, 2, 3, 4]).cuda() - - print(count_parameters(backbone)) - print(count_parameters(net)) - - x = torch.randn((4, 3, 512, 512)).cuda() - - out = net(x) - for fm in out: - print(fm.size()) - - -@torch.no_grad() -@skip_if_no_cuda -def test_densenet(): - from torchvision.models import densenet121 - net1 = E.DenseNet121Encoder() - net2 = densenet121(pretrained=False) - net2.classifier = None - - print(count_parameters(net1), count_parameters(net2)) \ No newline at end of file From a9a0e1101615c51af48aa266ef178d5419cc0548 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Sun, 27 Oct 2019 19:18:12 +0200 Subject: [PATCH 04/79] Global refactoring of encoders & decoders --- black.toml | 25 +++ demo/demo_losses.py | 6 +- pytorch_toolbelt/inference/functional.py | 10 +- pytorch_toolbelt/inference/tiles.py | 56 +---- pytorch_toolbelt/inference/tta.py | 31 +-- pytorch_toolbelt/losses/dice.py | 18 +- pytorch_toolbelt/losses/focal.py | 24 +-- pytorch_toolbelt/losses/functional.py | 36 +--- pytorch_toolbelt/losses/jaccard.py | 18 +- pytorch_toolbelt/losses/lovasz.py | 21 +- pytorch_toolbelt/losses/wing_loss.py | 4 +- pytorch_toolbelt/modules/abn.py | 24 +-- pytorch_toolbelt/modules/activations.py | 11 + pytorch_toolbelt/modules/agn.py | 12 +- .../modules/backbone/efficient_net.py | 88 +++----- pytorch_toolbelt/modules/backbone/hrnet.py | 191 ++++-------------- .../modules/backbone/inceptionv4.py | 49 ++--- .../modules/backbone/mobilenet.py | 52 +---- .../modules/backbone/mobilenetv3.py | 70 ++----- pytorch_toolbelt/modules/backbone/senet.py | 71 +------ .../modules/backbone/wider_resnet.py | 162 +++------------ pytorch_toolbelt/modules/coord_conv.py | 12 +- pytorch_toolbelt/modules/decoders/common.py | 12 +- pytorch_toolbelt/modules/decoders/deeplab.py | 14 +- pytorch_toolbelt/modules/decoders/fpn.py | 20 +- pytorch_toolbelt/modules/decoders/fpn_cat.py | 69 ++++--- pytorch_toolbelt/modules/decoders/fpn_sum.py | 92 +++------ pytorch_toolbelt/modules/decoders/hrnet.py | 16 +- .../modules/decoders/pyramid_pooling.py | 20 +- pytorch_toolbelt/modules/decoders/unet.py | 55 +---- pytorch_toolbelt/modules/decoders/unet_v2.py | 37 +--- pytorch_toolbelt/modules/decoders/upernet.py | 41 +--- pytorch_toolbelt/modules/dropblock.py | 12 +- pytorch_toolbelt/modules/dsconv.py | 16 +- pytorch_toolbelt/modules/encoders/densenet.py | 55 +---- .../modules/encoders/efficientnet.py | 10 +- .../modules/encoders/mobilenet.py | 29 +-- pytorch_toolbelt/modules/encoders/resnet.py | 41 +--- .../modules/encoders/squeezenet.py | 5 +- pytorch_toolbelt/modules/encoders/unet.py | 29 +-- .../modules/encoders/wide_resnet.py | 32 +-- pytorch_toolbelt/modules/fpn.py | 63 ++---- pytorch_toolbelt/modules/pooling.py | 12 +- pytorch_toolbelt/modules/scse.py | 20 +- pytorch_toolbelt/modules/srm.py | 4 +- pytorch_toolbelt/optimization/functional.py | 3 +- pytorch_toolbelt/optimization/lr_schedules.py | 4 +- pytorch_toolbelt/utils/catalyst/criterions.py | 8 +- pytorch_toolbelt/utils/catalyst/metrics.py | 41 +--- .../utils/catalyst/visualization.py | 44 +--- pytorch_toolbelt/utils/dataset_utils.py | 28 +-- pytorch_toolbelt/utils/fs.py | 7 +- pytorch_toolbelt/utils/torch_utils.py | 8 +- pytorch_toolbelt/utils/visualization.py | 14 +- setup.py | 16 +- tests/test_decoders.py | 5 +- tests/test_encoders.py | 10 +- tests/test_modules.py | 6 +- tests/test_tiles.py | 38 +--- tests/test_tta.py | 47 +---- 60 files changed, 448 insertions(+), 1526 deletions(-) create mode 100644 black.toml diff --git a/black.toml b/black.toml new file mode 100644 index 000000000..d6600cc5b --- /dev/null +++ b/black.toml @@ -0,0 +1,25 @@ +# Example configuration for Black. + +# NOTE: you have to use single-quoted strings in TOML for regular expressions. +# It's the equivalent of r-strings in Python. Multiline strings are treated as +# verbose regular expressions by Black. Use [ ] to denote a significant space +# character. + +[tool.black] +line-length = 119 +target-version = ['py35', 'py36', 'py37', 'py38'] +include = '\.pyi?$' +exclude = ''' +/( + \.eggs + | \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' diff --git a/demo/demo_losses.py b/demo/demo_losses.py index 6020c04fd..0bf1ffefa 100644 --- a/demo/demo_losses.py +++ b/demo/demo_losses.py @@ -16,18 +16,14 @@ def main(): # "dice_log": L.BinaryDiceLogLoss(), # "sdice": L.BinarySymmetricDiceLoss(), # "sdice_log": L.BinarySymmetricDiceLoss(log_loss=True), - "bce+lovasz": L.JointLoss(BCEWithLogitsLoss(), L.BinaryLovaszLoss()), # "lovasz": L.BinaryLovaszLoss(), # "bce+jaccard": L.JointLoss(BCEWithLogitsLoss(), # L.BinaryJaccardLoss(), 1, 0.5), - # "bce+log_jaccard": L.JointLoss(BCEWithLogitsLoss(), # L.BinaryJaccardLogLoss(), 1, 0.5), - # "bce+log_dice": L.JointLoss(BCEWithLogitsLoss(), # L.BinaryDiceLogLoss(), 1, 0.5) - # "reduced_focal": L.BinaryFocalLoss(reduced=True) } @@ -55,5 +51,5 @@ def main(): f.show() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/pytorch_toolbelt/inference/functional.py b/pytorch_toolbelt/inference/functional.py index cfeb0fe4b..7df0194aa 100644 --- a/pytorch_toolbelt/inference/functional.py +++ b/pytorch_toolbelt/inference/functional.py @@ -62,11 +62,7 @@ def pad_image_tensor(image_tensor: Tensor, pad_size: int = 32): :return: Tuple of output tensor and pad params. Second argument can be used to reverse pad operation of model output """ rows, cols = image_tensor.size(2), image_tensor.size(3) - if ( - isinstance(pad_size, Sized) - and isinstance(pad_size, Iterable) - and len(pad_size) == 2 - ): + if isinstance(pad_size, Sized) and isinstance(pad_size, Iterable) and len(pad_size) == 2: pad_height, pad_width = [int(val) for val in pad_size] elif isinstance(pad_size, int): pad_height = pad_width = pad_size @@ -109,9 +105,7 @@ def unpad_image_tensor(image_tensor, pad): def unpad_xyxy_bboxes(bboxes_tensor: torch.Tensor, pad, dim=-1): pad_left, pad_right, pad_top, pad_btm = pad - pad = torch.tensor( - [pad_left, pad_top, pad_left, pad_top], dtype=bboxes_tensor.dtype - ).to(bboxes_tensor.device) + pad = torch.tensor([pad_left, pad_top, pad_left, pad_top], dtype=bboxes_tensor.dtype).to(bboxes_tensor.device) if dim == -1: dim = len(bboxes_tensor.size()) - 1 diff --git a/pytorch_toolbelt/inference/tiles.py b/pytorch_toolbelt/inference/tiles.py index 2c07a4e73..a4d307473 100644 --- a/pytorch_toolbelt/inference/tiles.py +++ b/pytorch_toolbelt/inference/tiles.py @@ -47,9 +47,7 @@ class ImageSlicer: Helper class to slice image into tiles and merge them back """ - def __init__( - self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean" - ): + def __init__(self, image_shape, tile_size, tile_step=0, image_margin=0, weight="mean"): """ :param image_shape: Shape of the source image (H, W) @@ -75,21 +73,14 @@ def __init__( weights = {"mean": self._mean, "pyramid": self._pyramid} - self.weight = ( - weight - if isinstance(weight, np.ndarray) - else weights[weight](self.tile_size) - ) + self.weight = weight if isinstance(weight, np.ndarray) else weights[weight](self.tile_size) if self.tile_step[0] < 1 or self.tile_step[0] > self.tile_size[0]: raise ValueError() if self.tile_step[1] < 1 or self.tile_step[1] > self.tile_size[1]: raise ValueError() - overlap = [ - self.tile_size[0] - self.tile_step[0], - self.tile_size[1] - self.tile_step[1], - ] + overlap = [self.tile_size[0] - self.tile_step[0], self.tile_size[1] - self.tile_step[1]] self.margin_left = 0 self.margin_right = 0 @@ -111,14 +102,10 @@ def __init__( self.margin_bottom = extra_h - self.margin_top else: - if (self.image_width - overlap[1] + 2 * image_margin) % self.tile_step[ - 1 - ] != 0: + if (self.image_width - overlap[1] + 2 * image_margin) % self.tile_step[1] != 0: raise ValueError() - if (self.image_height - overlap[0] + 2 * image_margin) % self.tile_step[ - 0 - ] != 0: + if (self.image_height - overlap[0] + 2 * image_margin) % self.tile_step[0] != 0: raise ValueError() self.margin_left = image_margin @@ -130,32 +117,13 @@ def __init__( bbox_crops = [] for y in range( - 0, - self.image_height - + self.margin_top - + self.margin_bottom - - self.tile_size[0] - + 1, - self.tile_step[0], + 0, self.image_height + self.margin_top + self.margin_bottom - self.tile_size[0] + 1, self.tile_step[0] ): for x in range( - 0, - self.image_width - + self.margin_left - + self.margin_right - - self.tile_size[1] - + 1, - self.tile_step[1], + 0, self.image_width + self.margin_left + self.margin_right - self.tile_size[1] + 1, self.tile_step[1] ): crops.append((x, y, self.tile_size[1], self.tile_size[0])) - bbox_crops.append( - ( - x - self.margin_left, - y - self.margin_top, - self.tile_size[1], - self.tile_size[0], - ) - ) + bbox_crops.append((x - self.margin_left, y - self.margin_top, self.tile_size[1], self.tile_size[0])) self.crops = np.array(crops) self.bbox_crops = np.array(bbox_crops) @@ -189,9 +157,7 @@ def split(self, image, border_type=cv2.BORDER_CONSTANT, value=0): return tiles - def cut_patch( - self, image: np.ndarray, slice_index, border_type=cv2.BORDER_CONSTANT, value=0 - ): + def cut_patch(self, image: np.ndarray, slice_index, border_type=cv2.BORDER_CONSTANT, value=0): assert image.shape[0] == self.image_height assert image.shape[1] == self.image_width @@ -298,9 +264,7 @@ def integrate_batch(self, batch: torch.Tensor, crop_coords): :param crop_coords: Corresponding tile crops w.r.t to original image """ if len(batch) != len(crop_coords): - raise ValueError( - "Number of images in batch does not correspond to number of coordinates" - ) + raise ValueError("Number of images in batch does not correspond to number of coordinates") for tile, (x, y, tile_width, tile_height) in zip(batch, crop_coords): self.image[:, y : y + tile_height, x : x + tile_width] += tile * self.weight diff --git a/pytorch_toolbelt/inference/tta.py b/pytorch_toolbelt/inference/tta.py index 5e718457b..4e8380172 100644 --- a/pytorch_toolbelt/inference/tta.py +++ b/pytorch_toolbelt/inference/tta.py @@ -71,21 +71,11 @@ def fivecrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> T center_crop_y = (image_height - crop_height) // 2 center_crop_x = (image_width - crop_width) // 2 - crop_cc = image[ - ..., - center_crop_y : center_crop_y + crop_height, - center_crop_x : center_crop_x + crop_width, - ] + crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width] assert crop_cc.size(2) == crop_height assert crop_cc.size(3) == crop_width - output = ( - model(crop_tl) - + model(crop_tr) - + model(crop_bl) - + model(crop_br) - + model(crop_cc) - ) + output = model(crop_tl) + model(crop_tr) + model(crop_bl) + model(crop_br) + model(crop_cc) one_over_5 = float(1.0 / 5.0) return output * one_over_5 @@ -125,11 +115,7 @@ def tencrop_image2label(model: nn.Module, image: Tensor, crop_size: Tuple) -> Te center_crop_y = (image_height - crop_height) // 2 center_crop_x = (image_width - crop_width) // 2 - crop_cc = image[ - ..., - center_crop_y : center_crop_y + crop_height, - center_crop_x : center_crop_x + crop_width, - ] + crop_cc = image[..., center_crop_y : center_crop_y + crop_height, center_crop_x : center_crop_x + crop_width] assert crop_cc.size(2) == crop_height assert crop_cc.size(3) == crop_width @@ -202,8 +188,7 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor: output = model(image) for aug, deaug in zip( - [F.torch_rot90, F.torch_rot180, F.torch_rot270], - [F.torch_rot270, F.torch_rot180, F.torch_rot90], + [F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_rot270, F.torch_rot180, F.torch_rot90] ): x = deaug(model(aug(image))) output = output + x @@ -258,13 +243,9 @@ def forward(self, input: Tensor) -> Tensor: for scale in self.scale_levels: dst_size = int(h * scale), int(w * scale) - input_scaled = interpolate( - input, dst_size, mode="bilinear", align_corners=False - ) + input_scaled = interpolate(input, dst_size, mode="bilinear", align_corners=False) output_scaled = self.model(input_scaled) - output_scaled = interpolate( - output_scaled, out_size, mode="bilinear", align_corners=False - ) + output_scaled = interpolate(output_scaled, out_size, mode="bilinear", align_corners=False) output += output_scaled return output / (1 + len(self.scale_levels)) diff --git a/pytorch_toolbelt/losses/dice.py b/pytorch_toolbelt/losses/dice.py index bf772675a..6df861060 100644 --- a/pytorch_toolbelt/losses/dice.py +++ b/pytorch_toolbelt/losses/dice.py @@ -21,15 +21,7 @@ class DiceLoss(_Loss): It supports binary, multiclass and multilabel cases """ - def __init__( - self, - mode: str, - classes: List[int] = None, - log_loss=False, - from_logits=True, - smooth=0, - eps=1e-7, - ): + def __init__(self, mode: str, classes: List[int] = None, log_loss=False, from_logits=True, smooth=0, eps=1e-7): """ :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} @@ -43,9 +35,7 @@ def __init__( super(DiceLoss, self).__init__() self.mode = mode if classes is not None: - assert ( - mode != BINARY_MODE - ), "Masking classes is not supported with mode=binary" + assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" classes = to_tensor(classes, dtype=torch.long) self.classes = classes @@ -89,9 +79,7 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) - scores = soft_dice_score( - y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims - ) + scores = soft_dice_score(y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims) if self.log_loss: loss = -torch.log(scores) diff --git a/pytorch_toolbelt/losses/focal.py b/pytorch_toolbelt/losses/focal.py index aaa73e364..2e3b4d965 100644 --- a/pytorch_toolbelt/losses/focal.py +++ b/pytorch_toolbelt/losses/focal.py @@ -8,15 +8,7 @@ class BinaryFocalLoss(_Loss): - def __init__( - self, - alpha=0.5, - gamma=2, - ignore_index=None, - reduction="mean", - reduced=False, - threshold=0.5, - ): + def __init__(self, alpha=0.5, gamma=2, ignore_index=None, reduction="mean", reduced=False, threshold=0.5): """ :param alpha: @@ -31,16 +23,10 @@ def __init__( self.ignore_index = ignore_index if reduced: self.focal_loss = partial( - focal_loss_with_logits, - alpha=None, - gamma=gamma, - threshold=threshold, - reduction=reduction, + focal_loss_with_logits, alpha=None, gamma=gamma, threshold=threshold, reduction=reduction ) else: - self.focal_loss = partial( - focal_loss_with_logits, alpha=alpha, gamma=gamma, reduction=reduction - ) + self.focal_loss = partial(focal_loss_with_logits, alpha=alpha, gamma=gamma, reduction=reduction) def forward(self, label_input, label_target): """Compute focal loss for binary classification problem. @@ -88,7 +74,5 @@ def forward(self, label_input, label_target): cls_label_target = cls_label_target[not_ignored] cls_label_input = cls_label_input[not_ignored] - loss += focal_loss_with_logits( - cls_label_input, cls_label_target, gamma=self.gamma, alpha=self.alpha - ) + loss += focal_loss_with_logits(cls_label_input, cls_label_target, gamma=self.gamma, alpha=self.alpha) return loss diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index b4f6cca80..1d9d3a00a 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -4,13 +4,7 @@ import torch import torch.nn.functional as F -__all__ = [ - "focal_loss_with_logits", - "sigmoid_focal_loss", - "soft_jaccard_score", - "soft_dice_score", - "wing_loss", -] +__all__ = ["focal_loss_with_logits", "sigmoid_focal_loss", "soft_jaccard_score", "soft_dice_score", "wing_loss"] def focal_loss_with_logits( @@ -78,21 +72,11 @@ def focal_loss_with_logits( # TODO: Mark as deprecated and emit warning -def reduced_focal_loss( - input: torch.Tensor, - target: torch.Tensor, - threshold=0.5, - gamma=2.0, - reduction="mean", -): - return focal_loss_with_logits( - input, target, alpha=None, gamma=gamma, reduction=reduction, threshold=threshold - ) +def reduced_focal_loss(input: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"): + return focal_loss_with_logits(input, target, alpha=None, gamma=gamma, reduction=reduction, threshold=threshold) -def soft_jaccard_score( - y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, eps=1e-7, dims=None -) -> torch.Tensor: +def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: """ :param y_pred: @@ -122,9 +106,7 @@ def soft_jaccard_score( return jaccard_score -def soft_dice_score( - y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e-7, dims=None -) -> torch.Tensor: +def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e-7, dims=None) -> torch.Tensor: """ :param y_pred: @@ -151,13 +133,7 @@ def soft_dice_score( return dice_score -def wing_loss( - prediction: torch.Tensor, - target: torch.Tensor, - width=5, - curvature=0.5, - reduction="mean", -): +def wing_loss(prediction: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"): """ https://arxiv.org/pdf/1711.06753.pdf :param prediction: diff --git a/pytorch_toolbelt/losses/jaccard.py b/pytorch_toolbelt/losses/jaccard.py index 7a6847abd..1207d9138 100644 --- a/pytorch_toolbelt/losses/jaccard.py +++ b/pytorch_toolbelt/losses/jaccard.py @@ -21,15 +21,7 @@ class JaccardLoss(_Loss): It supports binary, multi-class and multi-label cases. """ - def __init__( - self, - mode: str, - classes: List[int] = None, - log_loss=False, - from_logits=True, - smooth=0, - eps=1e-7, - ): + def __init__(self, mode: str, classes: List[int] = None, log_loss=False, from_logits=True, smooth=0, eps=1e-7): """ :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} @@ -43,9 +35,7 @@ def __init__( super(JaccardLoss, self).__init__() self.mode = mode if classes is not None: - assert ( - mode != BINARY_MODE - ), "Masking classes is not supported with mode=binary" + assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" classes = to_tensor(classes, dtype=torch.long) self.classes = classes @@ -89,9 +79,7 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) - scores = soft_jaccard_score( - y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims - ) + scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims) if self.log_loss: loss = -torch.log(scores) diff --git a/pytorch_toolbelt/losses/lovasz.py b/pytorch_toolbelt/losses/lovasz.py index 68faeb704..fb02b202e 100644 --- a/pytorch_toolbelt/losses/lovasz.py +++ b/pytorch_toolbelt/losses/lovasz.py @@ -42,9 +42,7 @@ def _lovasz_hinge(logits, labels, per_image=True, ignore=None): """ if per_image: loss = mean( - _lovasz_hinge_flat( - *_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore) - ) + _lovasz_hinge_flat(*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) for log, lab in zip(logits, labels) ) else: @@ -101,16 +99,11 @@ def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore=N """ if per_image: loss = mean( - _lovasz_softmax_flat( - *_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), - classes=classes - ) + _lovasz_softmax_flat(*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) for prob, lab in zip(probas, labels) ) else: - loss = _lovasz_softmax_flat( - *_flatten_probas(probas, labels, ignore), classes=classes - ) + loss = _lovasz_softmax_flat(*_flatten_probas(probas, labels, ignore), classes=classes) return loss @@ -195,9 +188,7 @@ def __init__(self, per_image=False, ignore=None): self.per_image = per_image def forward(self, logits, target): - return _lovasz_hinge( - logits, target, per_image=self.per_image, ignore=self.ignore - ) + return _lovasz_hinge(logits, target, per_image=self.per_image, ignore=self.ignore) class LovaszLoss(_Loss): @@ -207,6 +198,4 @@ def __init__(self, per_image=False, ignore=None): self.per_image = per_image def forward(self, logits, target): - return _lovasz_softmax( - logits, target, per_image=self.per_image, ignore=self.ignore - ) + return _lovasz_softmax(logits, target, per_image=self.per_image, ignore=self.ignore) diff --git a/pytorch_toolbelt/losses/wing_loss.py b/pytorch_toolbelt/losses/wing_loss.py index 4e2666814..910e7b415 100644 --- a/pytorch_toolbelt/losses/wing_loss.py +++ b/pytorch_toolbelt/losses/wing_loss.py @@ -12,6 +12,4 @@ def __init__(self, width=5, curvature=0.5, reduction="mean"): self.curvature = curvature def forward(self, prediction, target): - return F.wing_loss( - prediction, target, self.width, self.curvature, self.reduction - ) + return F.wing_loss(prediction, target, self.width, self.curvature, self.reduction) diff --git a/pytorch_toolbelt/modules/abn.py b/pytorch_toolbelt/modules/abn.py index 1183460be..edc6c6d03 100644 --- a/pytorch_toolbelt/modules/abn.py +++ b/pytorch_toolbelt/modules/abn.py @@ -25,15 +25,7 @@ class ABN(nn.Module): This gathers a `BatchNorm2d` and an activation function in a single module """ - def __init__( - self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - activation="leaky_relu", - slope=0.01, - ): + def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): """Create an Activated Batch Normalization module Parameters ---------- @@ -76,14 +68,7 @@ def reset_parameters(self): def forward(self, x): x = functional.batch_norm( - x, - self.running_mean, - self.running_var, - self.weight, - self.bias, - self.training, - self.momentum, - self.eps, + x, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps ) if self.activation == ACT_RELU: @@ -108,10 +93,7 @@ def forward(self, x): raise KeyError(self.activation) def __repr__(self): - rep = ( - "{name}({num_features}, eps={eps}, momentum={momentum}," - " affine={affine}, activation={activation}" - ) + rep = "{name}({num_features}, eps={eps}, momentum={momentum}," " affine={affine}, activation={activation}" if self.activation == "leaky_relu": rep += ", slope={slope})" else: diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index 2adc00c48..b0feec1d4 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -20,6 +20,7 @@ "HardSwish", "Swish", "get_activation_module", + "sanitize_activation_name" ] # Activation names @@ -107,3 +108,13 @@ def get_activation_module(activation_name: str, **kwargs) -> nn.Module: return partial(HardSwish, **kwargs) raise ValueError(f"Activation '{activation_name}' is not supported") + + +def sanitize_activation_name(activation_name): + """ + Return reasonable activation name for initialization in `kaiming_uniform_` for hipster activations + """ + if activation_name in {"swish", "mish"}: + return "leaky_relu" + + return activation_name diff --git a/pytorch_toolbelt/modules/agn.py b/pytorch_toolbelt/modules/agn.py index 3b99c56b0..d23b116f1 100644 --- a/pytorch_toolbelt/modules/agn.py +++ b/pytorch_toolbelt/modules/agn.py @@ -26,13 +26,7 @@ class AGN(nn.Module): """ def __init__( - self, - num_features: int, - num_groups: int, - eps=1e-5, - momentum=0.1, - activation=ACT_LEAKY_RELU, - slope=0.01, + self, num_features: int, num_groups: int, eps=1e-5, momentum=0.1, activation=ACT_LEAKY_RELU, slope=0.01 ): """Create an Activated Batch Normalization module Parameters @@ -93,9 +87,7 @@ def forward(self, x): raise KeyError(self.activation) def __repr__(self): - rep = ( - "{name}({num_features},{num_groups}, eps={eps}" ", activation={activation}" - ) + rep = "{name}({num_features},{num_groups}, eps={eps}" ", activation={activation}" if self.activation == "leaky_relu": rep += ", slope={slope})" else: diff --git a/pytorch_toolbelt/modules/backbone/efficient_net.py b/pytorch_toolbelt/modules/backbone/efficient_net.py index fc7ee6669..ba5245bd6 100644 --- a/pytorch_toolbelt/modules/backbone/efficient_net.py +++ b/pytorch_toolbelt/modules/backbone/efficient_net.py @@ -6,11 +6,12 @@ import torch from torch import nn from torch.nn import functional as F -from torch.nn.init import kaiming_normal_ +from torch.nn.init import kaiming_normal_, kaiming_uniform_ -from pytorch_toolbelt.modules import ABN, SpatialGate2d -from pytorch_toolbelt.modules.activations import ACT_HARD_SWISH -from pytorch_toolbelt.modules.agn import AGN +from ..abn import ABN +from ..agn import AGN +from ..activations import ACT_HARD_SWISH, sanitize_activation_name +from ..scse import SpatialGate2d def round_filters(filters, width_coefficient, depth_divisor, min_depth): @@ -19,9 +20,7 @@ def round_filters(filters, width_coefficient, depth_divisor, min_depth): """ filters *= width_coefficient min_depth = min_depth or depth_divisor - new_filters = max( - min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor - ) + new_filters = max(min_depth, int(filters + depth_divisor / 2) // depth_divisor * depth_divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += depth_divisor return int(new_filters) @@ -45,9 +44,7 @@ def drop_connect(inputs, p, training): batch_size = inputs.shape[0] keep_prob = 1 - p random_tensor = keep_prob - random_tensor += torch.rand( - [batch_size, 1, 1, 1], dtype=inputs.dtype - ) # uniform [0,1) + random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype) # uniform [0,1) binary_tensor = torch.floor(random_tensor) output = inputs / keep_prob * binary_tensor return output @@ -82,19 +79,11 @@ def __init__( self.id_skip = id_skip def scale( - self, - width_coefficient: float, - depth_coefficient: float, - depth_divisor: float = 8.0, - min_filters: int = None, + self, width_coefficient: float, depth_coefficient: float, depth_divisor: float = 8.0, min_filters: int = None ): copy = deepcopy(self) - copy.input_filters = round_filters( - self.input_filters, width_coefficient, depth_divisor, min_filters - ) - copy.output_filters = round_filters( - self.output_filters, width_coefficient, depth_divisor, min_filters - ) + copy.input_filters = round_filters(self.input_filters, width_coefficient, depth_divisor, min_filters) + copy.output_filters = round_filters(self.output_filters, width_coefficient, depth_divisor, min_filters) copy.num_repeat = round_repeats(self.num_repeat, depth_coefficient) copy.width_coefficient = width_coefficient copy.depth_coefficient = depth_coefficient @@ -123,14 +112,10 @@ def __init__(self, block_args: EfficientNetBlockArgs, abn_block: ABN, abn_params # Expansion phase inp = block_args.input_filters # number of input channels - oup = ( - block_args.input_filters * block_args.expand_ratio - ) # number of output channels + oup = block_args.input_filters * block_args.expand_ratio # number of output channels if block_args.expand_ratio != 1: - self.expand_conv = nn.Conv2d( - in_channels=inp, out_channels=oup, kernel_size=1, bias=False - ) + self.expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self.abn0 = abn_block(oup, **abn_params) # Depthwise convolution phase @@ -152,9 +137,7 @@ def __init__(self, block_args: EfficientNetBlockArgs, abn_block: ABN, abn_params # Output phase final_oup = self._block_args.output_filters - self.project_conv = nn.Conv2d( - in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False - ) + self.project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self.abn2 = abn_block(final_oup, **abn_params) self.input_filters = block_args.input_filters @@ -164,22 +147,15 @@ def __init__(self, block_args: EfficientNetBlockArgs, abn_block: ABN, abn_params def reset_parameters(self): if hasattr(self, "expand_conv"): - kaiming_normal_( - self.expand_conv.weight, - a=self.abn0.slope, - nonlinearity=self.abn0.activation, + kaiming_uniform_( + self.expand_conv.weight, a=self.abn0.slope, nonlinearity=sanitize_activation_name(self.abn0.activation) ) - kaiming_normal_( - self.depthwise_conv.weight, - a=self.abn1.slope, - nonlinearity=self.abn1.activation, + kaiming_uniform_( + self.depthwise_conv.weight, a=self.abn1.slope, nonlinearity=sanitize_activation_name(self.abn1.activation) ) - - kaiming_normal_( - self.project_conv.weight, - a=self.abn1.slope, - nonlinearity=self.abn2.activation, + kaiming_uniform_( + self.project_conv.weight, a=self.abn1.slope, nonlinearity=sanitize_activation_name(self.abn2.activation) ) def forward(self, inputs, drop_connect_rate=None): @@ -202,11 +178,7 @@ def forward(self, inputs, drop_connect_rate=None): x = self.abn2(self.project_conv(x)) # Skip connection and drop connect - if ( - self.id_skip - and self._block_args.stride == 1 - and self.input_filters == self.output_filters - ): + if self.id_skip and self._block_args.stride == 1 and self.input_filters == self.output_filters: if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection @@ -333,12 +305,7 @@ def __init__( ( "conv", nn.Conv2d( - in_channels, - first_block_args.input_filters, - kernel_size=3, - padding=1, - stride=2, - bias=False, + in_channels, first_block_args.input_filters, kernel_size=3, padding=1, stride=2, bias=False ), ), ("abn", abn_block(first_block_args.input_filters, **abn_params)), @@ -365,15 +332,10 @@ def __init__( # Head out_channels = round_filters( - 1280, - last_block_args.width_coefficient, - last_block_args.depth_divisor, - last_block_args.min_filters, + 1280, last_block_args.width_coefficient, last_block_args.depth_divisor, last_block_args.min_filters ) - self.conv_head = nn.Conv2d( - last_block_args.output_filters, out_channels, kernel_size=1, bias=False - ) + self.conv_head = nn.Conv2d(last_block_args.output_filters, out_channels, kernel_size=1, bias=False) self.abn_head = abn_block(out_channels, **abn_params) # Final linear layer @@ -503,9 +465,7 @@ def test_efficient_net_group_norm(): ]: print("=======", model_fn.__name__, "=======") agn_params = {"num_groups": 8, "activation": ACT_HARD_SWISH} - model = ( - model_fn(num_classes, abn_block=AGN, abn_params=agn_params).eval().cuda() - ) + model = model_fn(num_classes, abn_block=AGN, abn_params=agn_params).eval().cuda() print(count_parameters(model)) # print(model) print() diff --git a/pytorch_toolbelt/modules/backbone/hrnet.py b/pytorch_toolbelt/modules/backbone/hrnet.py index 90a3bb536..cd0a91c20 100644 --- a/pytorch_toolbelt/modules/backbone/hrnet.py +++ b/pytorch_toolbelt/modules/backbone/hrnet.py @@ -3,24 +3,18 @@ https://github.com/HRNet/HRNet-Semantic-Segmentation """ -import os -import sys from collections import OrderedDict -from urllib.request import urlretrieve import torch import torch.nn as nn import torch.nn.functional as F - HRNETV2_BN_MOMENTUM = 0.1 def hrnet_conv3x3(in_planes, out_planes, stride=1): """3x3 convolution with padding""" - return nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False - ) + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class HRNetBasicBlock(nn.Module): @@ -62,13 +56,9 @@ def __init__(self, inplanes, planes, stride=1, downsample=None): super(HRNetBottleneck, self).__init__() self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, stride=stride, padding=1, bias=False - ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) - self.conv3 = nn.Conv2d( - planes, planes * self.expansion, kernel_size=1, bias=False - ) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=HRNETV2_BN_MOMENTUM) self.relu = nn.ReLU(inplace=True) self.downsample = downsample @@ -99,19 +89,10 @@ def forward(self, x): class HighResolutionModule(nn.Module): def __init__( - self, - num_branches, - blocks, - num_blocks, - num_inchannels, - num_channels, - fuse_method, - multi_scale_output=True, + self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True ): super(HighResolutionModule, self).__init__() - self._check_branches( - num_branches, blocks, num_blocks, num_inchannels, num_channels - ) + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) self.num_inchannels = num_inchannels self.fuse_method = fuse_method @@ -119,40 +100,26 @@ def __init__( self.multi_scale_output = multi_scale_output - self.branches = self._make_branches( - num_branches, blocks, num_blocks, num_channels - ) + self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) self.fuse_layers = self._make_fuse_layers() self.relu = nn.ReLU(inplace=True) - def _check_branches( - self, num_branches, blocks, num_blocks, num_inchannels, num_channels - ): + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): if num_branches != len(num_blocks): - error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format( - num_branches, len(num_blocks) - ) + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks)) raise ValueError(error_msg) if num_branches != len(num_channels): - error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format( - num_branches, len(num_channels) - ) + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(num_branches, len(num_channels)) raise ValueError(error_msg) if num_branches != len(num_inchannels): - error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format( - num_branches, len(num_inchannels) - ) + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(num_branches, len(num_inchannels)) raise ValueError(error_msg) def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): downsample = None - if ( - stride != 1 - or self.num_inchannels[branch_index] - != num_channels[branch_index] * block.expansion - ): + if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: downsample = nn.Sequential( nn.Conv2d( self.num_inchannels[branch_index], @@ -161,26 +128,14 @@ def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride stride=stride, bias=False, ), - nn.BatchNorm2d( - num_channels[branch_index] * block.expansion, - momentum=HRNETV2_BN_MOMENTUM, - ), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=HRNETV2_BN_MOMENTUM), ) layers = [] - layers.append( - block( - self.num_inchannels[branch_index], - num_channels[branch_index], - stride, - downsample, - ) - ) + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion for i in range(1, num_blocks[branch_index]): - layers.append( - block(self.num_inchannels[branch_index], num_channels[branch_index]) - ) + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) return nn.Sequential(*layers) @@ -205,17 +160,8 @@ def _make_fuse_layers(self): if j > i: fuse_layer.append( nn.Sequential( - nn.Conv2d( - num_inchannels[j], - num_inchannels[i], - 1, - 1, - 0, - bias=False, - ), - nn.BatchNorm2d( - num_inchannels[i], momentum=HRNETV2_BN_MOMENTUM - ), + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=HRNETV2_BN_MOMENTUM), ) ) elif j == i: @@ -227,36 +173,16 @@ def _make_fuse_layers(self): num_outchannels_conv3x3 = num_inchannels[i] conv3x3s.append( nn.Sequential( - nn.Conv2d( - num_inchannels[j], - num_outchannels_conv3x3, - 3, - 2, - 1, - bias=False, - ), - nn.BatchNorm2d( - num_outchannels_conv3x3, - momentum=HRNETV2_BN_MOMENTUM, - ), + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=HRNETV2_BN_MOMENTUM), ) ) else: num_outchannels_conv3x3 = num_inchannels[j] conv3x3s.append( nn.Sequential( - nn.Conv2d( - num_inchannels[j], - num_outchannels_conv3x3, - 3, - 2, - 1, - bias=False, - ), - nn.BatchNorm2d( - num_outchannels_conv3x3, - momentum=HRNETV2_BN_MOMENTUM, - ), + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), ) ) @@ -334,20 +260,10 @@ def __init__(self, width=48, **kwargs): self.layer0 = nn.Sequential( OrderedDict( [ - ( - "conv1", - nn.Conv2d( - 3, 64, kernel_size=3, stride=2, padding=1, bias=False - ), - ), + ("conv1", nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)), ("bn1", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), ("relu", nn.ReLU(inplace=True)), - ( - "conv2", - nn.Conv2d( - 64, 64, kernel_size=3, stride=2, padding=1, bias=False - ), - ), + ("conv2", nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)), ("bn2", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), ("relu2", nn.ReLU(inplace=True)), ] @@ -359,35 +275,23 @@ def __init__(self, width=48, **kwargs): self.stage2_cfg = extra["STAGE2"] num_channels = self.stage2_cfg["NUM_CHANNELS"] block = blocks_dict[self.stage2_cfg["BLOCK"]] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition1 = self._make_transition_layer([256], num_channels) - self.stage2, pre_stage_channels = self._make_stage( - self.stage2_cfg, num_channels - ) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) self.stage3_cfg = extra["STAGE3"] num_channels = self.stage3_cfg["NUM_CHANNELS"] block = blocks_dict[self.stage3_cfg["BLOCK"]] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) - self.stage3, pre_stage_channels = self._make_stage( - self.stage3_cfg, num_channels - ) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) self.stage4_cfg = extra["STAGE4"] num_channels = self.stage4_cfg["NUM_CHANNELS"] block = blocks_dict[self.stage4_cfg["BLOCK"]] - num_channels = [ - num_channels[i] * block.expansion for i in range(len(num_channels)) - ] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) - self.stage4, pre_stage_channels = self._make_stage( - self.stage4_cfg, num_channels, multi_scale_output=True - ) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): num_branches_cur = len(num_channels_cur_layer) @@ -399,17 +303,8 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer) if num_channels_cur_layer[i] != num_channels_pre_layer[i]: transition_layers.append( nn.Sequential( - nn.Conv2d( - num_channels_pre_layer[i], - num_channels_cur_layer[i], - 3, - 1, - 1, - bias=False, - ), - nn.BatchNorm2d( - num_channels_cur_layer[i], momentum=HRNETV2_BN_MOMENTUM - ), + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), ) ) @@ -419,11 +314,7 @@ def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer) conv3x3s = [] for j in range(i + 1 - num_branches_pre): inchannels = num_channels_pre_layer[-1] - outchannels = ( - num_channels_cur_layer[i] - if j == i - num_branches_pre - else inchannels - ) + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels conv3x3s.append( nn.Sequential( nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), @@ -439,13 +330,7 @@ def _make_layer(self, block, inplanes, planes, blocks, stride=1): downsample = None if stride != 1 or inplanes != planes * block.expansion: downsample = nn.Sequential( - nn.Conv2d( - inplanes, - planes * block.expansion, - kernel_size=1, - stride=stride, - bias=False, - ), + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(planes * block.expansion, momentum=HRNETV2_BN_MOMENTUM), ) @@ -520,15 +405,9 @@ def forward(self, x, return_feature_maps=False): # Upsampling x0_h, x0_w = x[0].size(2), x[0].size(3) - x1 = F.interpolate( - x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=False - ) - x2 = F.interpolate( - x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=False - ) - x3 = F.interpolate( - x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=False - ) + x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=False) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=False) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=False) x = torch.cat([x[0], x1, x2, x3], 1) diff --git a/pytorch_toolbelt/modules/backbone/inceptionv4.py b/pytorch_toolbelt/modules/backbone/inceptionv4.py index f8293a433..7cc87c29b 100644 --- a/pytorch_toolbelt/modules/backbone/inceptionv4.py +++ b/pytorch_toolbelt/modules/backbone/inceptionv4.py @@ -38,18 +38,10 @@ class BasicConv2d(nn.Module): def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0): super(BasicConv2d, self).__init__() self.conv = nn.Conv2d( - in_planes, - out_planes, - kernel_size=kernel_size, - stride=stride, - padding=padding, - bias=False, + in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=False ) # verify bias false self.bn = nn.BatchNorm2d( - out_planes, - eps=0.001, # value found in tensorflow - momentum=0.1, # default pytorch value - affine=True, + out_planes, eps=0.001, momentum=0.1, affine=True # value found in tensorflow # default pytorch value ) self.relu = nn.ReLU(inplace=True) @@ -78,8 +70,7 @@ def __init__(self): super(Mixed_4a, self).__init__() self.branch0 = nn.Sequential( - BasicConv2d(160, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1), + BasicConv2d(160, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1) ) self.branch1 = nn.Sequential( @@ -115,8 +106,7 @@ def __init__(self): self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1) self.branch1 = nn.Sequential( - BasicConv2d(384, 64, kernel_size=1, stride=1), - BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1), + BasicConv2d(384, 64, kernel_size=1, stride=1), BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1) ) self.branch2 = nn.Sequential( @@ -198,8 +188,7 @@ def __init__(self): super(Reduction_B, self).__init__() self.branch0 = nn.Sequential( - BasicConv2d(1024, 192, kernel_size=1, stride=1), - BasicConv2d(192, 192, kernel_size=3, stride=2), + BasicConv2d(1024, 192, kernel_size=1, stride=1), BasicConv2d(192, 192, kernel_size=3, stride=2) ) self.branch1 = nn.Sequential( @@ -226,26 +215,14 @@ def __init__(self): self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1) self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch1_1a = BasicConv2d( - 384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch1_1b = BasicConv2d( - 384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) + self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1) - self.branch2_1 = BasicConv2d( - 384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) - self.branch2_2 = BasicConv2d( - 448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch2_3a = BasicConv2d( - 512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1) - ) - self.branch2_3b = BasicConv2d( - 512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0) - ) + self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0)) + self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1)) + self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0)) self.branch3 = nn.Sequential( nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False), @@ -325,9 +302,7 @@ def forward(self, input): def inceptionv4(num_classes=1000, pretrained="imagenet"): if pretrained: settings = pretrained_settings["inceptionv4"][pretrained] - assert ( - num_classes == settings["num_classes"] - ), "num_classes should be {}, but is {}".format( + assert num_classes == settings["num_classes"], "num_classes should be {}, but is {}".format( settings["num_classes"], num_classes ) diff --git a/pytorch_toolbelt/modules/backbone/mobilenet.py b/pytorch_toolbelt/modules/backbone/mobilenet.py index 4110e3292..b7006fe93 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenet.py +++ b/pytorch_toolbelt/modules/backbone/mobilenet.py @@ -8,19 +8,11 @@ def conv_bn(inp, oup, stride, activation: nn.Module): - return nn.Sequential( - nn.Conv2d(inp, oup, 3, stride, 1, bias=False), - nn.BatchNorm2d(oup), - activation(inplace=True), - ) + return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), activation(inplace=True)) def conv_1x1_bn(inp, oup, activation: nn.Module): - return nn.Sequential( - nn.Conv2d(inp, oup, 1, 1, 0, bias=False), - nn.BatchNorm2d(oup), - activation(inplace=True), - ) + return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), activation(inplace=True)) class InvertedResidual(nn.Module): @@ -35,9 +27,7 @@ def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module): if expand_ratio == 1: self.conv = nn.Sequential( # dw - nn.Conv2d( - hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False - ), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), activation(inplace=True), # pw-linear @@ -51,9 +41,7 @@ def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module): nn.BatchNorm2d(hidden_dim), activation(inplace=True), # dw - nn.Conv2d( - hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False - ), + nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), activation(inplace=True), # pw-linear @@ -69,9 +57,7 @@ def forward(self, x): class MobileNetV2(nn.Module): - def __init__( - self, n_class=1000, input_size=224, width_mult=1.0, activation="relu6" - ): + def __init__(self, n_class=1000, input_size=224, width_mult=1.0, activation="relu6"): super(MobileNetV2, self).__init__() act = get_activation_module(activation) @@ -93,9 +79,7 @@ def __init__( # building first layer assert input_size % 32 == 0 input_channel = int(input_channel * width_mult) - self.last_channel = ( - int(last_channel * width_mult) if width_mult > 1.0 else last_channel - ) + self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel self.layer0 = conv_bn(3, input_channel, 2, act) # building inverted residual blocks @@ -105,25 +89,9 @@ def __init__( blocks = [] for i in range(n): if i == 0: - blocks.append( - block( - input_channel, - output_channel, - s, - expand_ratio=t, - activation=act, - ) - ) + blocks.append(block(input_channel, output_channel, s, expand_ratio=t, activation=act)) else: - blocks.append( - block( - input_channel, - output_channel, - 1, - expand_ratio=t, - activation=act, - ) - ) + blocks.append(block(input_channel, output_channel, 1, expand_ratio=t, activation=act)) input_channel = output_channel @@ -133,9 +101,7 @@ def __init__( self.final_layer = conv_1x1_bn(input_channel, self.last_channel, activation=act) # building classifier - self.classifier = nn.Sequential( - nn.Dropout(0.2), nn.Linear(self.last_channel, n_class) - ) + self.classifier = nn.Sequential(nn.Dropout(0.2), nn.Linear(self.last_channel, n_class)) self._initialize_weights() diff --git a/pytorch_toolbelt/modules/backbone/mobilenetv3.py b/pytorch_toolbelt/modules/backbone/mobilenetv3.py index 320d70991..f29ffd447 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenetv3.py +++ b/pytorch_toolbelt/modules/backbone/mobilenetv3.py @@ -2,13 +2,11 @@ from collections import OrderedDict -import torch import torch.nn as nn import torch.nn.functional as F -# from pytorch_toolbelt.modules.dropblock import DropBlockScheduled, DropBlock2D -from pytorch_toolbelt.modules.activations import HardSwish, HardSigmoid -from pytorch_toolbelt.modules.identity import Identity +from ..activations import HardSwish, HardSigmoid +from ..identity import Identity def _make_divisible(v, divisor, min_value=None): @@ -42,13 +40,9 @@ def __init__(self, n_features, reduction=4): if n_features % reduction != 0: raise ValueError("n_features must be divisible by reduction (default = 4)") - self.linear1 = nn.Conv2d( - n_features, n_features // reduction, kernel_size=1, bias=True - ) + self.linear1 = nn.Conv2d(n_features, n_features // reduction, kernel_size=1, bias=True) self.nonlin1 = nn.ReLU(inplace=True) - self.linear2 = nn.Conv2d( - n_features // reduction, n_features, kernel_size=1, bias=True - ) + self.linear2 = nn.Conv2d(n_features // reduction, n_features, kernel_size=1, bias=True) self.nonlin2 = HardSigmoid(inplace=True) def forward(self, x): @@ -80,18 +74,10 @@ def __init__( self.db1 = nn.Dropout2d(drop_prob) # self.db1 = DropBlockScheduled(DropBlock2D(drop_prob=drop_prob, block_size=7), start_value=0., # stop_value=drop_prob, nr_steps=num_steps, start_step=start_step) - self.act1 = activation( - **act_params - ) # first does have act according to MobileNetV2 + self.act1 = activation(**act_params) # first does have act according to MobileNetV2 self.conv2 = nn.Conv2d( - expplanes, - expplanes, - kernel_size=k, - stride=stride, - padding=k // 2, - bias=False, - groups=expplanes, + expplanes, expplanes, kernel_size=k, stride=stride, padding=k // 2, bias=False, groups=expplanes ) self.bn2 = nn.BatchNorm2d(expplanes) self.db2 = nn.Dropout2d(drop_prob) @@ -185,9 +171,7 @@ def __init__(self, inplanes, num_classes, expplanes1, expplanes2): self.avgpool = nn.AdaptiveAvgPool2d(1) - self.conv2 = nn.Conv2d( - expplanes1, expplanes2, kernel_size=1, stride=1, bias=False - ) + self.conv2 = nn.Conv2d(expplanes1, expplanes2, kernel_size=1, stride=1, bias=False) self.act2 = HardSwish(inplace=True) self.dropout = nn.Dropout(p=0.2, inplace=True) @@ -221,14 +205,7 @@ class MobileNetV3(nn.Module): """ def __init__( - self, - num_classes=1000, - scale=1.0, - in_channels=3, - drop_prob=0.0, - num_steps=3e5, - start_step=0, - small=False, + self, num_classes=1000, scale=1.0, in_channels=3, drop_prob=0.0, num_steps=3e5, start_step=0, small=False ): super(MobileNetV3, self).__init__() @@ -272,47 +249,30 @@ def __init__( [96, 576, 96, 1, 5, drop_prob, True, HardSwish], # -> 7x7 ] - self.bottlenecks_setting = ( - self.bottlenecks_setting_small if small else self.bottlenecks_setting_large - ) + self.bottlenecks_setting = self.bottlenecks_setting_small if small else self.bottlenecks_setting_large for l in self.bottlenecks_setting: l[0] = _make_divisible(l[0] * self.scale, 8) l[1] = _make_divisible(l[1] * self.scale, 8) l[2] = _make_divisible(l[2] * self.scale, 8) self.conv1 = nn.Conv2d( - in_channels, - self.bottlenecks_setting[0][0], - kernel_size=3, - bias=False, - stride=2, - padding=1, + in_channels, self.bottlenecks_setting[0][0], kernel_size=3, bias=False, stride=2, padding=1 ) self.bn1 = nn.BatchNorm2d(self.bottlenecks_setting[0][0]) self.act1 = HardSwish(inplace=True) - self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = ( - self._make_bottlenecks() - ) + self.layer0, self.layer1, self.layer2, self.layer3, self.layer4 = self._make_bottlenecks() # Last convolution has 1280 output channels for scale <= 1 - self.last_exp2 = ( - 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8) - ) + self.last_exp2 = 1280 if self.scale <= 1 else _make_divisible(1280 * self.scale, 8) if small: self.last_exp1 = _make_divisible(576 * self.scale, 8) self.last_block = LastBlockSmall( - self.bottlenecks_setting[-1][2], - num_classes, - self.last_exp1, - self.last_exp2, + self.bottlenecks_setting[-1][2], num_classes, self.last_exp1, self.last_exp2 ) else: self.last_exp1 = _make_divisible(960 * self.scale, 8) self.last_block = LastBlockLarge( - self.bottlenecks_setting[-1][2], - num_classes, - self.last_exp1, - self.last_exp2, + self.bottlenecks_setting[-1][2], num_classes, self.last_exp1, self.last_exp2 ) def _make_bottlenecks(self): @@ -362,5 +322,3 @@ def forward(self, x): x = self.last_block(x) return x - - diff --git a/pytorch_toolbelt/modules/backbone/senet.py b/pytorch_toolbelt/modules/backbone/senet.py index 6643bcb36..df0c175ad 100644 --- a/pytorch_toolbelt/modules/backbone/senet.py +++ b/pytorch_toolbelt/modules/backbone/senet.py @@ -149,13 +149,7 @@ def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=Non self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(planes * 2) self.conv2 = nn.Conv2d( - planes * 2, - planes * 4, - kernel_size=3, - stride=stride, - padding=1, - groups=groups, - bias=False, + planes * 2, planes * 4, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False ) self.bn2 = nn.BatchNorm2d(planes * 4) self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, bias=False) @@ -177,13 +171,9 @@ class SEResNetBottleneck(Bottleneck): def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None): super(SEResNetBottleneck, self).__init__() - self.conv1 = nn.Conv2d( - inplanes, planes, kernel_size=1, bias=False, stride=stride - ) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, stride=stride) self.bn1 = nn.BatchNorm2d(planes) - self.conv2 = nn.Conv2d( - planes, planes, kernel_size=3, padding=1, groups=groups, bias=False - ) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(planes) self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -200,29 +190,12 @@ class SEResNeXtBottleneck(Bottleneck): expansion = 4 - def __init__( - self, - inplanes, - planes, - groups, - reduction, - stride=1, - downsample=None, - base_width=4, - ): + def __init__(self, inplanes, planes, groups, reduction, stride=1, downsample=None, base_width=4): super(SEResNeXtBottleneck, self).__init__() width = math.floor(planes * (base_width / 64)) * groups self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, stride=1) self.bn1 = nn.BatchNorm2d(width) - self.conv2 = nn.Conv2d( - width, - width, - kernel_size=3, - stride=stride, - padding=1, - groups=groups, - bias=False, - ) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, padding=1, groups=groups, bias=False) self.bn2 = nn.BatchNorm2d(width) self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(planes * 4) @@ -305,12 +278,7 @@ def __init__( ] else: layer0_modules = [ - ( - "conv1", - nn.Conv2d( - 3, inplanes, kernel_size=7, stride=2, padding=3, bias=False - ), - ), + ("conv1", nn.Conv2d(3, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), ("bn1", nn.BatchNorm2d(inplanes)), ("relu1", nn.ReLU(inplace=True)), ] @@ -362,15 +330,7 @@ def __init__( self.last_linear = nn.Linear(512 * block.expansion, num_classes) def _make_layer( - self, - block, - planes, - blocks, - groups, - reduction, - stride=1, - downsample_kernel_size=1, - downsample_padding=0, + self, block, planes, blocks, groups, reduction, stride=1, downsample_kernel_size=1, downsample_padding=0 ): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: @@ -387,9 +347,7 @@ def _make_layer( ) layers = [] - layers.append( - block(self.inplanes, planes, groups, reduction, stride, downsample) - ) + layers.append(block(self.inplanes, planes, groups, reduction, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, blocks): layers.append(block(self.inplanes, planes, groups, reduction)) @@ -419,9 +377,7 @@ def forward(self, x): def initialize_pretrained_model(model, num_classes, settings): - assert ( - num_classes == settings["num_classes"] - ), "num_classes should be {}, but is {}".format( + assert num_classes == settings["num_classes"], "num_classes should be {}, but is {}".format( settings["num_classes"], num_classes ) model.load_state_dict(model_zoo.load_url(settings["url"])) @@ -433,14 +389,7 @@ def initialize_pretrained_model(model, num_classes, settings): def senet154(num_classes=1000, pretrained="imagenet"): - model = SENet( - SEBottleneck, - [3, 8, 36, 3], - groups=64, - reduction=16, - dropout_p=0.2, - num_classes=num_classes, - ) + model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, dropout_p=0.2, num_classes=num_classes) if pretrained is not None: settings = pretrained_settings["senet154"][pretrained] initialize_pretrained_model(model, num_classes, settings) diff --git a/pytorch_toolbelt/modules/backbone/wider_resnet.py b/pytorch_toolbelt/modules/backbone/wider_resnet.py index 445c366d6..9647894a4 100644 --- a/pytorch_toolbelt/modules/backbone/wider_resnet.py +++ b/pytorch_toolbelt/modules/backbone/wider_resnet.py @@ -2,23 +2,13 @@ from functools import partial import torch -from pytorch_toolbelt.modules.abn import ABN -from pytorch_toolbelt.modules.pooling import GlobalAvgPool2d -from pytorch_toolbelt.utils.torch_utils import count_parameters +from ..abn import ABN +from ..pooling import GlobalAvgPool2d from torch import nn class IdentityResidualBlock(nn.Module): - def __init__( - self, - in_channels, - channels, - stride=1, - dilation=1, - groups=1, - norm_act=ABN, - dropout=None, - ): + def __init__(self, in_channels, channels, stride=1, dilation=1, groups=1, norm_act=ABN, dropout=None): """Identity-mapping residual block Parameters ---------- @@ -57,44 +47,20 @@ def __init__( ( "conv1", nn.Conv2d( - in_channels, - channels[0], - 3, - stride=stride, - padding=dilation, - bias=False, - dilation=dilation, + in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False, dilation=dilation ), ), ("bn2", norm_act(channels[0])), ( "conv2", - nn.Conv2d( - channels[0], - channels[1], - 3, - stride=1, - padding=dilation, - bias=False, - dilation=dilation, - ), + nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False, dilation=dilation), ), ] if dropout is not None: layers = layers[0:2] + [("dropout", dropout())] + layers[2:] else: layers = [ - ( - "conv1", - nn.Conv2d( - in_channels, - channels[0], - 1, - stride=stride, - padding=0, - bias=False, - ), - ), + ("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)), ("bn2", norm_act(channels[0])), ( "conv2", @@ -110,21 +76,14 @@ def __init__( ), ), ("bn3", norm_act(channels[1])), - ( - "conv3", - nn.Conv2d( - channels[1], channels[2], 1, stride=1, padding=0, bias=False - ), - ), + ("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)), ] if dropout is not None: layers = layers[0:4] + [("dropout", dropout())] + layers[4:] self.convs = nn.Sequential(OrderedDict(layers)) if need_proj_conv: - self.proj_conv = nn.Conv2d( - in_channels, channels[-1], 1, stride=stride, padding=0, bias=False - ) + self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) def forward(self, x): if hasattr(self, "proj_conv"): @@ -161,22 +120,11 @@ def __init__(self, structure, norm_act=ABN, classes=0): raise ValueError("Expected a structure with six values") # Initial layers - self.mod1 = nn.Sequential( - OrderedDict( - [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))] - ) - ) + self.mod1 = nn.Sequential(OrderedDict([("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))])) # Groups of residual blocks in_channels = 64 - channels = [ - (128, 128), - (256, 256), - (512, 512), - (512, 1024), - (512, 1024, 2048), - (1024, 2048, 4096), - ] + channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), (1024, 2048, 4096)] for mod_id, num in enumerate(structure): # Create blocks for module blocks = [] @@ -184,9 +132,7 @@ def __init__(self, structure, norm_act=ABN, classes=0): blocks.append( ( "block%d" % (block_id + 1), - IdentityResidualBlock( - in_channels, channels[mod_id], norm_act=norm_act - ), + IdentityResidualBlock(in_channels, channels[mod_id], norm_act=norm_act), ) ) @@ -195,21 +141,14 @@ def __init__(self, structure, norm_act=ABN, classes=0): # Create module if mod_id <= 4: - self.add_module( - "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1) - ) + self.add_module("pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) # Pooling and predictor self.bn_out = norm_act(in_channels) if classes != 0: self.classifier = nn.Sequential( - OrderedDict( - [ - ("avg_pool", GlobalAvgPool2d()), - ("fc", nn.Linear(in_channels, classes)), - ] - ) + OrderedDict([("avg_pool", GlobalAvgPool2d()), ("fc", nn.Linear(in_channels, classes))]) ) def forward(self, img): @@ -253,22 +192,11 @@ def __init__(self, structure, norm_act=ABN, classes=0, dilation=False): raise ValueError("Expected a structure with six values") # Initial layers - self.mod1 = nn.Sequential( - OrderedDict( - [("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))] - ) - ) + self.mod1 = nn.Sequential(OrderedDict([("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))])) # Groups of residual blocks in_channels = 64 - channels = [ - (128, 128), - (256, 256), - (512, 512), - (512, 1024), - (512, 1024, 2048), - (1024, 2048, 4096), - ] + channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), (1024, 2048, 4096)] for mod_id, num in enumerate(structure): # Create blocks for module blocks = [] @@ -296,12 +224,7 @@ def __init__(self, structure, norm_act=ABN, classes=0, dilation=False): ( "block%d" % (block_id + 1), IdentityResidualBlock( - in_channels, - channels[mod_id], - norm_act=norm_act, - stride=stride, - dilation=dil, - dropout=drop, + in_channels, channels[mod_id], norm_act=norm_act, stride=stride, dilation=dil, dropout=drop ), ) ) @@ -311,21 +234,14 @@ def __init__(self, structure, norm_act=ABN, classes=0, dilation=False): # Create module if mod_id < 2: - self.add_module( - "pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1) - ) + self.add_module("pool%d" % (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) # Pooling and predictor self.bn_out = norm_act(in_channels) if classes != 0: self.classifier = nn.Sequential( - OrderedDict( - [ - ("avg_pool", GlobalAvgPool2d()), - ("fc", nn.Linear(in_channels, classes)), - ] - ) + OrderedDict([("avg_pool", GlobalAvgPool2d()), ("fc", nn.Linear(in_channels, classes))]) ) def forward(self, img): @@ -345,53 +261,25 @@ def forward(self, img): def wider_resnet_16(num_classes=0, norm_act=ABN): - return WiderResNet( - structure=[1, 1, 1, 1, 1, 1], norm_act=norm_act, classes=num_classes - ) + return WiderResNet(structure=[1, 1, 1, 1, 1, 1], norm_act=norm_act, classes=num_classes) def wider_resnet_20(num_classes=0, norm_act=ABN): - return WiderResNet( - structure=[1, 1, 1, 3, 1, 1], norm_act=norm_act, classes=num_classes - ) + return WiderResNet(structure=[1, 1, 1, 3, 1, 1], norm_act=norm_act, classes=num_classes) def wider_resnet_38(num_classes=0, norm_act=ABN): - return WiderResNet( - structure=[3, 3, 6, 3, 1, 1], norm_act=norm_act, classes=num_classes - ) + return WiderResNet(structure=[3, 3, 6, 3, 1, 1], norm_act=norm_act, classes=num_classes) def wider_resnet_16_a2(num_classes=0, norm_act=ABN): - return WiderResNetA2( - structure=[1, 1, 1, 1, 1, 1], norm_act=norm_act, classes=num_classes - ) + return WiderResNetA2(structure=[1, 1, 1, 1, 1, 1], norm_act=norm_act, classes=num_classes) def wider_resnet_20_a2(num_classes=0, norm_act=ABN): - return WiderResNetA2( - structure=[1, 1, 1, 3, 1, 1], norm_act=norm_act, classes=num_classes - ) + return WiderResNetA2(structure=[1, 1, 1, 3, 1, 1], norm_act=norm_act, classes=num_classes) def wider_resnet_38_a2(num_classes=0, norm_act=ABN): - return WiderResNetA2( - structure=[3, 3, 6, 3, 1, 1], norm_act=norm_act, classes=num_classes - ) - - -def test_wider_resnet(): - for fn in [ - wider_resnet_16_a2, - wider_resnet_16, - wider_resnet_20, - wider_resnet_20_a2, - wider_resnet_38, - wider_resnet_38_a2, - ]: - net = fn().eval() - print(count_parameters(net)) - x = torch.randn((1, 3, 512, 512)) - out = net(x) - for o in out: - print(o.size()) + return WiderResNetA2(structure=[3, 3, 6, 3, 1, 1], norm_act=norm_act, classes=num_classes) + diff --git a/pytorch_toolbelt/modules/coord_conv.py b/pytorch_toolbelt/modules/coord_conv.py index 4e6a0ed22..0ee570c10 100644 --- a/pytorch_toolbelt/modules/coord_conv.py +++ b/pytorch_toolbelt/modules/coord_conv.py @@ -23,19 +23,11 @@ def append_coords(input_tensor, with_r=False): xx_channel = xx_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) yy_channel = yy_channel.repeat(batch_size, 1, 1, 1).transpose(2, 3) - ret = torch.cat( - [ - input_tensor, - xx_channel.type_as(input_tensor), - yy_channel.type_as(input_tensor), - ], - dim=1, - ) + ret = torch.cat([input_tensor, xx_channel.type_as(input_tensor), yy_channel.type_as(input_tensor)], dim=1) if with_r: rr = torch.sqrt( - torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) - + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(xx_channel.type_as(input_tensor) - 0.5, 2) + torch.pow(yy_channel.type_as(input_tensor) - 0.5, 2) ) ret = torch.cat([ret, rr], dim=1) diff --git a/pytorch_toolbelt/modules/decoders/common.py b/pytorch_toolbelt/modules/decoders/common.py index 62ecbc18a..d42e61134 100644 --- a/pytorch_toolbelt/modules/decoders/common.py +++ b/pytorch_toolbelt/modules/decoders/common.py @@ -1,7 +1,7 @@ -__all__ = ["DecoderModule"] - from torch import nn +__all__ = ["DecoderModule", "SegmentationDecoderModule"] + class DecoderModule(nn.Module): def __init__(self): @@ -13,3 +13,11 @@ def forward(self, features): def set_trainable(self, trainable): for param in self.parameters(): param.requires_grad = bool(trainable) + + +class SegmentationDecoderModule(DecoderModule): + """ + A placeholder for future. Indicates sub-class decoders are suitable for segmentation tasks + """ + + pass diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index caf5378ae..79ae89c81 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -20,14 +20,7 @@ def __init__(self, feature_maps: List[int], num_classes: int, dropout=0.5): self.relu = nn.ReLU(inplace=True) self.last_conv = nn.Sequential( - nn.Conv2d( - high_level_features + 48, - 256, - kernel_size=3, - stride=1, - padding=1, - bias=False, - ), + nn.Conv2d(high_level_features + 48, 256, kernel_size=3, stride=1, padding=1, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Dropout(dropout), @@ -48,10 +41,7 @@ def forward(self, feature_maps): low_level_feat = self.relu(low_level_feat) high_level_features = F.interpolate( - high_level_features, - size=low_level_feat.size()[2:], - mode="bilinear", - align_corners=True, + high_level_features, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True ) high_level_features = torch.cat((high_level_features, low_level_feat), dim=1) high_level_features = self.last_conv(high_level_features) diff --git a/pytorch_toolbelt/modules/decoders/fpn.py b/pytorch_toolbelt/modules/decoders/fpn.py index 732a26bea..d04030ec8 100644 --- a/pytorch_toolbelt/modules/decoders/fpn.py +++ b/pytorch_toolbelt/modules/decoders/fpn.py @@ -34,9 +34,7 @@ def __init__( if isinstance(fpn_features, list) and len(fpn_features) != len(features): raise ValueError() - if isinstance(prediction_features, list) and len(prediction_features) != len( - features - ): + if isinstance(prediction_features, list) and len(prediction_features) != len(features): raise ValueError() if not isinstance(fpn_features, list): @@ -51,19 +49,12 @@ def __init__( ] integrators = [ - upsample_add_block( - output_channels, - upsample_scale=upsample_scale, - mode=mode, - align_corners=align_corners, - ) + upsample_add_block(output_channels, upsample_scale=upsample_scale, mode=mode, align_corners=align_corners) for output_channels in fpn_features ] predictors = [ prediction_block(input_channels, output_channels) - for input_channels, output_channels in zip( - fpn_features, prediction_features - ) + for input_channels, output_channels in zip(fpn_features, prediction_features) ] self.bottlenecks = nn.ModuleList(bottlenecks) @@ -76,10 +67,7 @@ def forward(self, features): fpn_outputs = [] prev_fpn = None for feature_map, bottleneck_module, upsample_add, output_module in zip( - reversed(features), - reversed(self.bottlenecks), - reversed(self.integrators), - reversed(self.predictors), + reversed(features), reversed(self.bottlenecks), reversed(self.integrators), reversed(self.predictors) ): curr_fpn = bottleneck_module(feature_map) curr_fpn = upsample_add(curr_fpn, prev_fpn) diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 7ced2d88c..febf629e2 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -1,36 +1,58 @@ from typing import List, Tuple -import torch -from pytorch_toolbelt.modules import ABN -from pytorch_toolbelt.modules.decoders import DecoderModule -from pytorch_toolbelt.utils.torch_utils import count_parameters from torch import nn, Tensor -from torch.nn import functional as F -from pytorch_toolbelt.modules.fpn import FPNFuse, UpsampleAdd +from .common import SegmentationDecoderModule +from .fpn import FPNDecoder +from ..abn import ABN +from ..fpn import FPNFuse, UpsampleAdd __all__ = ["FPNCatDecoder"] -class FPNCatDecoder(DecoderModule): +class FPNSumDecoderBlock(nn.Module): + """ + Simple prediction block composed of (Conv + BN + Activation) repeated twice + """ + + def __init__(self, input_features: int, output_features: int, abn_block=ABN, dropout=0.0): + super().__init__() + self.conv1 = nn.Conv2d(input_features, output_features, kernel_size=3, padding=1, bias=False) + self.abn1 = abn_block(output_features) + self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, padding=1, bias=False) + self.abn2 = abn_block(output_features) + self.drop2 = nn.Dropout(dropout) + + def forward(self, x: Tensor) -> Tensor: + x = self.conv1(x) + x = self.abn1(x) + x = self.conv2(x) + x = self.abn2(x) + x = self.drop2(x) + return x + + +class FPNCatDecoder(SegmentationDecoderModule): """ """ def __init__( - self, - feature_maps: List[int], - num_classes: int, - fpn_channels=128, - dropout=0.0, - abn_block=ABN, + self, + feature_maps: List[int], + num_classes: int, + fpn_channels=128, + dropout=0.0, + abn_block=ABN, + upsample_add=UpsampleAdd, + prediction_block=FPNSumDecoderBlock, ): super().__init__() self.fpn = FPNDecoder( feature_maps, - upsample_add_block=UpsampleAdd, - prediction_block=DoubleConvBNRelu, + upsample_add_block=upsample_add, + prediction_block=prediction_block, fpn_features=fpn_channels, prediction_features=fpn_channels, ) @@ -38,6 +60,7 @@ def __init__( self.fuse = FPNFuse() self.dropout = nn.Dropout2d(dropout, inplace=True) + # dsv blocks are for deep supervision self.dsv = nn.ModuleList( [ nn.Conv2d(fpn_features, num_classes, kernel_size=1) @@ -49,16 +72,11 @@ def __init__( self.final_block = nn.Sequential( nn.Conv2d(features, features // 2, kernel_size=1), - nn.BatchNorm2d(features // 2), - nn.Conv2d( - features // 2, features // 4, kernel_size=3, padding=1, bias=True - ), - nn.LeakyReLU(inplace=True), - nn.BatchNorm2d(features // 4), - nn.Conv2d( - features // 4, features // 4, kernel_size=3, padding=1, bias=False - ), - nn.LeakyReLU(inplace=True), + abn_block(features // 2), + nn.Conv2d(features // 2, features // 4, kernel_size=3, padding=1, bias=False), + abn_block(features // 2), + nn.Conv2d(features // 4, features // 4, kernel_size=3, padding=1, bias=False), + abn_block(features // 4), nn.Conv2d(features // 4, num_classes, kernel_size=1, bias=True), ) @@ -75,4 +93,3 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: x = self.final_block(fused) return x, dsv_masks - diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index d706ae2f9..ab5126ecb 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -2,56 +2,35 @@ from typing import List, Tuple import torch -from pytorch_toolbelt.modules import Identity, ABN -from .modules.decoders import DecoderModule -from pytorch_toolbelt.utils.torch_utils import count_parameters +from ..abn import ABN +from ..identity import Identity +from .common import SegmentationDecoderModule from torch import Tensor, nn import torch.nn.functional as F -__all__ = ["FPNSumDecoder", "FPNSumTransitionBlock", "FPNSumCenterBlock"] +__all__ = ["FPNSumDecoder", "FPNSumDecoderBlock", "FPNSumCenterBlock"] class FPNSumCenterBlock(nn.Module): - def __init__( - self, - encoder_features: int, - decoder_features: int, - num_classes: int, - abn_block=ABN, - dropout=0.0, - ): + def __init__(self, encoder_features: int, decoder_features: int, num_classes: int, abn_block=ABN, dropout=0.0): super().__init__() - self.bottleneck = nn.Conv2d( - encoder_features, encoder_features // 2, kernel_size=1 - ) + self.bottleneck = nn.Conv2d(encoder_features, encoder_features // 2, kernel_size=1) self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) - self.proj2 = nn.Conv2d( - encoder_features // 2, encoder_features // 8, kernel_size=1 - ) + self.proj2 = nn.Conv2d(encoder_features // 2, encoder_features // 8, kernel_size=1) self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4) - self.proj4 = nn.Conv2d( - encoder_features // 2, encoder_features // 8, kernel_size=1 - ) + self.proj4 = nn.Conv2d(encoder_features // 2, encoder_features // 8, kernel_size=1) self.pool8 = nn.AvgPool2d(kernel_size=8, stride=8) - self.proj8 = nn.Conv2d( - encoder_features // 2, encoder_features // 8, kernel_size=1 - ) + self.proj8 = nn.Conv2d(encoder_features // 2, encoder_features // 8, kernel_size=1) - self.blend = nn.Conv2d( - encoder_features // 2 + 3 * encoder_features // 8, - decoder_features, - kernel_size=1, - ) + self.blend = nn.Conv2d(encoder_features // 2 + 3 * encoder_features // 8, decoder_features, kernel_size=1) self.dropout = nn.Dropout2d(dropout, inplace=True) - self.conv1 = nn.Conv2d( - decoder_features, decoder_features, kernel_size=3, padding=1, bias=False - ) + self.conv1 = nn.Conv2d(decoder_features, decoder_features, kernel_size=3, padding=1, bias=False) self.abn1 = abn_block(decoder_features) self.dsv = nn.Conv2d(decoder_features, num_classes, kernel_size=1) @@ -85,7 +64,7 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: return x, dsv -class FPNSumTransitionBlock(nn.Module): +class FPNSumDecoderBlock(nn.Module): def __init__( self, encoder_features: int, @@ -103,9 +82,7 @@ def __init__( self.reduction = nn.Conv2d(decoder_features, output_features, kernel_size=1) self.dropout = nn.Dropout2d(dropout, inplace=True) - self.conv1 = nn.Conv2d( - output_features, output_features, kernel_size=3, padding=1, bias=False - ) + self.conv1 = nn.Conv2d(output_features, output_features, kernel_size=3, padding=1, bias=False) self.abn1 = abn_block(output_features) self.dsv = nn.Conv2d(output_features, num_classes, kernel_size=1) @@ -117,9 +94,7 @@ def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Tuple[Tensor, Tenso :param encoder_fm: :return: """ - decoder_fm = F.interpolate( - decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=True - ) + decoder_fm = F.interpolate(decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=True) encoder_fm = self.skip(encoder_fm) x = decoder_fm + encoder_fm @@ -135,48 +110,30 @@ def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Tuple[Tensor, Tenso return x, dsv -class FPNSumDecoder(DecoderModule): +class FPNSumDecoder(SegmentationDecoderModule): """ """ - def __init__( - self, - feature_maps: List[int], - num_classes: int, - fpn_channels=256, - dropout=0.0, - abn_block=ABN, - ): + def __init__(self, feature_maps: List[int], num_classes: int, fpn_channels=256, dropout=0.0, abn_block=ABN, + center_block=FPNSumCenterBlock, + decoder_block=FPNSumDecoderBlock): super().__init__() - self.center = FPNSumCenterBlock( - feature_maps[-1], - fpn_channels, - num_classes=num_classes, - dropout=dropout, - abn_block=abn_block, + self.center = center_block( + feature_maps[-1], fpn_channels, num_classes=num_classes, dropout=dropout, abn_block=abn_block ) self.fpn_modules = nn.ModuleList( [ - FPNSumTransitionBlock( - encoder_fm, - decoder_fm, - decoder_fm, - num_classes=num_classes, - dropout=dropout, - abn_block=abn_block, - ) - for decoder_fm, encoder_fm in zip( - repeat(fpn_channels), reversed(feature_maps[:-1]) + decoder_block( + encoder_fm, decoder_fm, decoder_fm, num_classes=num_classes, dropout=dropout, abn_block=abn_block ) + for decoder_fm, encoder_fm in zip(repeat(fpn_channels), reversed(feature_maps[:-1])) ] ) - self.final_block = nn.Sequential( - nn.Conv2d(fpn_channels, num_classes, kernel_size=1) - ) + self.final_block = nn.Sequential(nn.Conv2d(fpn_channels, num_classes, kernel_size=1)) def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor]: last_feature_map = feature_maps[-1] @@ -193,4 +150,3 @@ def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor]: x = self.final_block(x) return x, dsv_masks - diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index 04b10e728..909606659 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -11,23 +11,11 @@ def __init__(self, features: int, num_classes: int, dropout=0.0): super().__init__() self.last_layer = nn.Sequential( - nn.Conv2d( - in_channels=features, - out_channels=features, - kernel_size=1, - stride=1, - padding=0, - ), + nn.Conv2d(in_channels=features, out_channels=features, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Dropout(dropout), - nn.Conv2d( - in_channels=features, - out_channels=num_classes, - kernel_size=3, - stride=1, - padding=1, - ), + nn.Conv2d(in_channels=features, out_channels=num_classes, kernel_size=3, stride=1, padding=1), ) def forward(self, features): diff --git a/pytorch_toolbelt/modules/decoders/pyramid_pooling.py b/pytorch_toolbelt/modules/decoders/pyramid_pooling.py index 023858fcd..131eb1626 100644 --- a/pytorch_toolbelt/modules/decoders/pyramid_pooling.py +++ b/pytorch_toolbelt/modules/decoders/pyramid_pooling.py @@ -15,13 +15,7 @@ class PPMDecoder(DecoderModule): https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/42b7567a43b1dab568e2bbfcbc8872778fbda92a/models/models.py """ - def __init__( - self, - feature_maps: List[int], - num_classes=150, - channels=512, - pool_scales=(1, 2, 3, 6), - ): + def __init__(self, feature_maps: List[int], num_classes=150, channels=512, pool_scales=(1, 2, 3, 6)): super(PPMDecoder, self).__init__() fc_dim = feature_maps[-1] @@ -38,13 +32,7 @@ def __init__( self.ppm = nn.ModuleList(self.ppm) self.conv_last = nn.Sequential( - nn.Conv2d( - fc_dim + len(pool_scales) * channels, - channels, - kernel_size=3, - padding=1, - bias=False, - ), + nn.Conv2d(fc_dim + len(pool_scales) * channels, channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(channels), nn.ReLU(inplace=True), nn.Dropout2d(0.1), @@ -58,9 +46,7 @@ def forward(self, feature_maps: List[torch.Tensor]): ppm_out = [last_fm] for pool_scale in self.ppm: input_pooled = pool_scale(last_fm) - input_pooled = F.interpolate( - input_pooled, size=input_size[2:], mode="bilinear", align_corners=False - ) + input_pooled = F.interpolate(input_pooled, size=input_size[2:], mode="bilinear", align_corners=False) ppm_out.append(input_pooled) ppm_out = torch.cat(ppm_out, dim=1) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index d3e14599a..b8137feb2 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -13,19 +13,9 @@ class UnetCentralBlock(nn.Module): def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): super().__init__() - self.conv1 = nn.Conv2d( - in_dec_filters, - out_filters, - kernel_size=3, - padding=1, - stride=2, - bias=False, - **kwargs - ) + self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False, **kwargs) self.bn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d( - out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs - ) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) self.bn2 = abn_block(out_filters) def forward(self, x): @@ -48,31 +38,17 @@ def __init__( abn_block=ABN, pre_dropout_rate=0.0, post_dropout_rate=0.0, - **kwargs + **kwargs, ): super(UnetDecoderBlock, self).__init__() self.pre_drop = nn.Dropout(pre_dropout_rate, inplace=True) self.conv1 = nn.Conv2d( - in_dec_filters + in_enc_filters, - out_filters, - kernel_size=3, - stride=1, - padding=1, - bias=False, - **kwargs + in_dec_filters + in_enc_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False, **kwargs ) self.bn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d( - out_filters, - out_filters, - kernel_size=3, - stride=1, - padding=1, - bias=False, - **kwargs - ) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False, **kwargs) self.bn2 = abn_block(out_filters) self.post_drop = nn.Dropout(post_dropout_rate, inplace=True) @@ -137,30 +113,21 @@ def forward(self, x, enc): class UNetDecoder(DecoderModule): - def __init__( - self, feature_maps: List[int], decoder_features: int, mask_channels: int - ): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int): super().__init__() if not isinstance(decoder_features, list): - decoder_features = [ - decoder_features * (2 ** i) for i in range(len(feature_maps)) - ] + decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] blocks = [] for block_index, in_enc_features in enumerate(feature_maps[:-1]): blocks.append( UnetDecoderBlock( - decoder_features[block_index + 1], - in_enc_features, - decoder_features[block_index], - mask_channels, + decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels ) ) - self.center = UnetCentralBlock( - feature_maps[-1], decoder_features[-1], mask_channels - ) + self.center = UnetCentralBlock(feature_maps[-1], decoder_features[-1], mask_channels) self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features @@ -170,9 +137,7 @@ def forward(self, feature_maps): decoder_outputs = [output] dsv_list = [dsv] - for decoder_block, encoder_output in zip( - reversed(self.blocks), reversed(feature_maps[:-1]) - ): + for decoder_block, encoder_output in zip(reversed(self.blocks), reversed(feature_maps[:-1])): output, dsv = decoder_block(output, encoder_output) decoder_outputs.append(output) dsv_list.append(dsv) diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index b0f1053a2..a4698c0a1 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -15,13 +15,9 @@ def __init__(self, in_dec_filters, out_filters, mask_channels, abn_block=ABN): super().__init__() self.bottleneck = nn.Conv2d(in_dec_filters, out_filters, kernel_size=1) - self.conv1 = nn.Conv2d( - out_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False - ) + self.conv1 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False) self.abn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d( - out_filters, out_filters, kernel_size=3, padding=1, bias=False - ) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False) self.abn2 = abn_block(out_filters) self.dsv = nn.Conv2d(out_filters, mask_channels, kernel_size=1) @@ -54,17 +50,11 @@ def __init__( ): super(UnetDecoderBlockV2, self).__init__() - self.bottleneck = nn.Conv2d( - in_dec_filters + in_enc_filters, out_filters, kernel_size=1 - ) + self.bottleneck = nn.Conv2d(in_dec_filters + in_enc_filters, out_filters, kernel_size=1) - self.conv1 = nn.Conv2d( - out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False - ) + self.conv1 = nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False) self.abn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d( - out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False - ) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False) self.abn2 = abn_block(out_filters) self.pre_drop = nn.Dropout2d(pre_dropout_rate, inplace=True) @@ -98,24 +88,17 @@ def __init__(self, features: List[int], decoder_features: int, mask_channels: in super().__init__() if not isinstance(decoder_features, list): - decoder_features = [ - decoder_features * (2 ** i) for i in range(len(features)) - ] + decoder_features = [decoder_features * (2 ** i) for i in range(len(features))] blocks = [] for block_index, in_enc_features in enumerate(features[:-1]): blocks.append( UnetDecoderBlockV2( - decoder_features[block_index + 1], - in_enc_features, - decoder_features[block_index], - mask_channels, + decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels ) ) - self.center = UnetCentralBlockV2( - features[-1], decoder_features[-1], mask_channels - ) + self.center = UnetCentralBlockV2(features[-1], decoder_features[-1], mask_channels) self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features @@ -125,9 +108,7 @@ def forward(self, feature_maps): decoder_outputs = [output] dsv_list = [dsv] - for decoder_block, encoder_output in zip( - reversed(self.blocks), reversed(feature_maps[:-1]) - ): + for decoder_block, encoder_output in zip(reversed(self.blocks), reversed(feature_maps[:-1])): output, dsv = decoder_block(output, encoder_output) decoder_outputs.append(output) dsv_list.append(dsv) diff --git a/pytorch_toolbelt/modules/decoders/upernet.py b/pytorch_toolbelt/modules/decoders/upernet.py index 1c23c1c73..9506e3fbc 100644 --- a/pytorch_toolbelt/modules/decoders/upernet.py +++ b/pytorch_toolbelt/modules/decoders/upernet.py @@ -9,22 +9,14 @@ def conv3x3_bn_relu(in_planes, out_planes, stride=1): "3x3 convolution + BN + relu" return nn.Sequential( - nn.Conv2d( - in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False - ), + nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_planes), nn.ReLU(inplace=True), ) class UPerNet(nn.Module): - def __init__( - self, - output_filters: List[int], - num_classes=150, - pool_scales=(1, 2, 3, 6), - fpn_dim=256, - ): + def __init__(self, output_filters: List[int], num_classes=150, pool_scales=(1, 2, 3, 6), fpn_dim=256): super(UPerNet, self).__init__() last_fm_dim = output_filters[-1] @@ -37,16 +29,12 @@ def __init__( self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) self.ppm_conv.append( nn.Sequential( - nn.Conv2d(last_fm_dim, 512, kernel_size=1, bias=False), - nn.BatchNorm2d(512), - nn.ReLU(inplace=True), + nn.Conv2d(last_fm_dim, 512, kernel_size=1, bias=False), nn.BatchNorm2d(512), nn.ReLU(inplace=True) ) ) self.ppm_pooling = nn.ModuleList(self.ppm_pooling) self.ppm_conv = nn.ModuleList(self.ppm_conv) - self.ppm_last_conv = conv3x3_bn_relu( - last_fm_dim + len(pool_scales) * 512, fpn_dim, 1 - ) + self.ppm_last_conv = conv3x3_bn_relu(last_fm_dim + len(pool_scales) * 512, fpn_dim, 1) # FPN Module self.fpn_in = [] @@ -66,8 +54,7 @@ def __init__( self.fpn_out = nn.ModuleList(self.fpn_out) self.conv_last = nn.Sequential( - conv3x3_bn_relu(len(output_filters) * fpn_dim, fpn_dim, 1), - nn.Conv2d(fpn_dim, num_classes, kernel_size=1), + conv3x3_bn_relu(len(output_filters) * fpn_dim, fpn_dim, 1), nn.Conv2d(fpn_dim, num_classes, kernel_size=1) ) def forward(self, feature_maps): @@ -79,10 +66,7 @@ def forward(self, feature_maps): ppm_out.append( pool_conv( F.interpolate( - pool_scale(last_fm), - (input_size[2], input_size[3]), - mode="bilinear", - align_corners=False, + pool_scale(last_fm), (input_size[2], input_size[3]), mode="bilinear", align_corners=False ) ) ) @@ -94,9 +78,7 @@ def forward(self, feature_maps): conv_x = feature_maps[i] conv_x = self.fpn_in[i](conv_x) # lateral branch - f = F.interpolate( - f, size=conv_x.size()[2:], mode="bilinear", align_corners=False - ) # top-down branch + f = F.interpolate(f, size=conv_x.size()[2:], mode="bilinear", align_corners=False) # top-down branch f = conv_x + f fpn_feature_list.append(self.fpn_out[i](f)) @@ -105,14 +87,7 @@ def forward(self, feature_maps): output_size = fpn_feature_list[0].size()[2:] fusion_list = [fpn_feature_list[0]] for i in range(1, len(fpn_feature_list)): - fusion_list.append( - F.interpolate( - fpn_feature_list[i], - output_size, - mode="bilinear", - align_corners=False, - ) - ) + fusion_list.append(F.interpolate(fpn_feature_list[i], output_size, mode="bilinear", align_corners=False)) fusion_out = torch.cat(fusion_list, 1) x = self.conv_last(fusion_out) diff --git a/pytorch_toolbelt/modules/dropblock.py b/pytorch_toolbelt/modules/dropblock.py index 5a196743d..812f8584c 100644 --- a/pytorch_toolbelt/modules/dropblock.py +++ b/pytorch_toolbelt/modules/dropblock.py @@ -30,9 +30,7 @@ def __init__(self, drop_prob, block_size): def forward(self, x): # shape: (bsize, channels, height, width) - assert ( - x.dim() == 4 - ), "Expected input with 4 dimensions (bsize, channels, height, width)" + assert x.dim() == 4, "Expected input with 4 dimensions (bsize, channels, height, width)" if not self.training or self.drop_prob == 0.0: return x @@ -64,9 +62,7 @@ def _compute_block_mask(self, mask): if self.block_size % 2 == 0: block_mask = block_mask[:, :, :-1, :-1] - keeped = block_mask.numel() - block_mask.sum().to( - torch.float32 - ) # prevent overflow in float16 + keeped = block_mask.numel() - block_mask.sum().to(torch.float32) # prevent overflow in float16 block_mask = 1 - block_mask.squeeze(1) return block_mask, keeped @@ -97,9 +93,7 @@ def __init__(self, drop_prob, block_size): def forward(self, x): # shape: (bsize, channels, depth, height, width) - assert ( - x.dim() == 5 - ), "Expected input with 5 dimensions (bsize, channels, depth, height, width)" + assert x.dim() == 5, "Expected input with 5 dimensions (bsize, channels, depth, height, width)" if not self.training or self.drop_prob == 0.0: return x diff --git a/pytorch_toolbelt/modules/dsconv.py b/pytorch_toolbelt/modules/dsconv.py index 3cf254373..febdc29ec 100644 --- a/pytorch_toolbelt/modules/dsconv.py +++ b/pytorch_toolbelt/modules/dsconv.py @@ -4,17 +4,7 @@ class DepthwiseSeparableConv2d(nn.Module): - def __init__( - self, - in_channels, - out_channels, - kernel_size, - stride=1, - padding=0, - dilation=1, - groups=1, - bias=True, - ): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): super(DepthwiseSeparableConv2d, self).__init__() self.depthwise = nn.Conv2d( in_channels, @@ -26,9 +16,7 @@ def __init__( bias=bias, groups=in_channels, ) - self.pointwise = nn.Conv2d( - in_channels, out_channels, kernel_size=1, groups=groups, bias=bias - ) + self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, groups=groups, bias=bias) def forward(self, x): out = self.depthwise(x) diff --git a/pytorch_toolbelt/modules/encoders/densenet.py b/pytorch_toolbelt/modules/encoders/densenet.py index 537ea0539..9583a04e6 100644 --- a/pytorch_toolbelt/modules/encoders/densenet.py +++ b/pytorch_toolbelt/modules/encoders/densenet.py @@ -1,33 +1,16 @@ from typing import List from torch import nn -from torchvision.models import ( - densenet121, - densenet161, - densenet169, - densenet201, - DenseNet, -) +from torchvision.models import densenet121, densenet161, densenet169, densenet201, DenseNet from .common import EncoderModule, _take -__all__ = [ - "DenseNetEncoder", - "DenseNet121Encoder", - "DenseNet169Encoder", - "DenseNet161Encoder", - "DenseNet201Encoder", -] +__all__ = ["DenseNetEncoder", "DenseNet121Encoder", "DenseNet169Encoder", "DenseNet161Encoder", "DenseNet201Encoder"] class DenseNetEncoder(EncoderModule): def __init__( - self, - densenet: DenseNet, - strides: List[int], - channels: List[int], - layers: List[int], - first_avg_pool=False, + self, densenet: DenseNet, strides: List[int], channels: List[int], layers: List[int], first_avg_pool=False ): if layers is None: layers = [1, 2, 3, 4] @@ -38,24 +21,16 @@ def except_pool(block: nn.Module): del block.pool return block - self.layer0 = nn.Sequential( - densenet.features.conv0, densenet.features.norm0, densenet.features.relu0 - ) + self.layer0 = nn.Sequential(densenet.features.conv0, densenet.features.norm0, densenet.features.relu0) self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) self.pool0 = self.avg_pool if first_avg_pool else densenet.features.pool0 - self.layer1 = nn.Sequential( - densenet.features.denseblock1, except_pool(densenet.features.transition1) - ) + self.layer1 = nn.Sequential(densenet.features.denseblock1, except_pool(densenet.features.transition1)) - self.layer2 = nn.Sequential( - densenet.features.denseblock2, except_pool(densenet.features.transition2) - ) + self.layer2 = nn.Sequential(densenet.features.denseblock2, except_pool(densenet.features.transition2)) - self.layer3 = nn.Sequential( - densenet.features.denseblock3, except_pool(densenet.features.transition3) - ) + self.layer3 = nn.Sequential(densenet.features.denseblock3, except_pool(densenet.features.transition3)) self.layer4 = nn.Sequential(densenet.features.denseblock4) @@ -94,9 +69,7 @@ def forward(self, x): class DenseNet121Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False - ): + def __init__(self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False): densenet = densenet121(pretrained=pretrained, memory_efficient=memory_efficient) strides = [2, 4, 8, 16, 32] channels = [64, 128, 256, 512, 1024] @@ -104,9 +77,7 @@ def __init__( class DenseNet161Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False - ): + def __init__(self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False): densenet = densenet161(pretrained=pretrained, memory_efficient=memory_efficient) strides = [2, 4, 8, 16, 32] channels = [96, 192, 384, 1056, 2208] @@ -114,9 +85,7 @@ def __init__( class DenseNet169Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False - ): + def __init__(self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False): densenet = densenet169(pretrained=pretrained, memory_efficient=memory_efficient) strides = [2, 4, 8, 16, 32] channels = [64, 128, 256, 640, 1664] @@ -124,9 +93,7 @@ def __init__( class DenseNet201Encoder(DenseNetEncoder): - def __init__( - self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False - ): + def __init__(self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False): densenet = densenet201(pretrained=pretrained, memory_efficient=memory_efficient) strides = [2, 4, 8, 16, 32] channels = [64, 128, 256, 896, 1920] diff --git a/pytorch_toolbelt/modules/encoders/efficientnet.py b/pytorch_toolbelt/modules/encoders/efficientnet.py index a8b02dc7d..6767063a4 100644 --- a/pytorch_toolbelt/modules/encoders/efficientnet.py +++ b/pytorch_toolbelt/modules/encoders/efficientnet.py @@ -43,15 +43,7 @@ def __init__(self, efficientnet, filters, strides, layers): @property def encoder_layers(self): - return [ - self.block0, - self.block1, - self.block2, - self.block3, - self.block4, - self.block5, - self.block6, - ] + return [self.block0, self.block1, self.block2, self.block3, self.block4, self.block5, self.block6] def forward(self, x): input = self.stem(x) diff --git a/pytorch_toolbelt/modules/encoders/mobilenet.py b/pytorch_toolbelt/modules/encoders/mobilenet.py index b78eb0bbe..f94660bf5 100644 --- a/pytorch_toolbelt/modules/encoders/mobilenet.py +++ b/pytorch_toolbelt/modules/encoders/mobilenet.py @@ -7,9 +7,7 @@ class MobilenetV2Encoder(EncoderModule): def __init__(self, layers=[2, 3, 5, 7], activation="relu6"): - super().__init__( - [32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], layers - ) + super().__init__([32, 16, 24, 32, 64, 96, 160, 320], [2, 2, 4, 8, 16, 16, 32, 32], layers) encoder = MobileNetV2(activation=activation) self.layer0 = encoder.layer0 @@ -23,30 +21,13 @@ def __init__(self, layers=[2, 3, 5, 7], activation="relu6"): @property def encoder_layers(self): - return [ - self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4, - self.layer5, - self.layer6, - self.layer7, - ] + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4, self.layer5, self.layer6, self.layer7] class MobilenetV3Encoder(EncoderModule): - def __init__( - self, input_channels=3, small=False, drop_prob=0.0, layers=[1, 2, 3, 4] - ): - super().__init__( - [24, 24, 40, 96, 96] if small else [16, 40, 80, 160, 160], - [4, 8, 16, 32, 32], - layers, - ) - encoder = MobileNetV3( - in_channels=input_channels, small=small, drop_prob=drop_prob - ) + def __init__(self, input_channels=3, small=False, drop_prob=0.0, layers=[1, 2, 3, 4]): + super().__init__([24, 24, 40, 96, 96] if small else [16, 40, 80, 160, 160], [4, 8, 16, 32, 32], layers) + encoder = MobileNetV3(in_channels=input_channels, small=small, drop_prob=drop_prob) self.conv1 = encoder.conv1 self.bn1 = encoder.bn1 diff --git a/pytorch_toolbelt/modules/encoders/resnet.py b/pytorch_toolbelt/modules/encoders/resnet.py index e32af6efd..33a56eb0b 100644 --- a/pytorch_toolbelt/modules/encoders/resnet.py +++ b/pytorch_toolbelt/modules/encoders/resnet.py @@ -26,11 +26,7 @@ def __init__(self, resnet, filters, strides, layers=None): layers = [1, 2, 3, 4] super().__init__(filters, strides, layers) - self.layer0 = nn.Sequential( - OrderedDict( - [("conv1", resnet.conv1), ("bn1", resnet.bn1), ("relu", resnet.relu)] - ) - ) + self.layer0 = nn.Sequential(OrderedDict([("conv1", resnet.conv1), ("bn1", resnet.bn1), ("relu", resnet.relu)])) self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 @@ -60,49 +56,24 @@ def forward(self, x): class Resnet18Encoder(ResnetEncoder): def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet18(pretrained=pretrained), - [64, 64, 128, 256, 512], - [2, 4, 8, 16, 32], - layers, - ) + super().__init__(resnet18(pretrained=pretrained), [64, 64, 128, 256, 512], [2, 4, 8, 16, 32], layers) class Resnet34Encoder(ResnetEncoder): def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet34(pretrained=pretrained), - [64, 64, 128, 256, 512], - [2, 4, 8, 16, 32], - layers, - ) + super().__init__(resnet34(pretrained=pretrained), [64, 64, 128, 256, 512], [2, 4, 8, 16, 32], layers) class Resnet50Encoder(ResnetEncoder): def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet50(pretrained=pretrained), - [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], - layers, - ) + super().__init__(resnet50(pretrained=pretrained), [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) class Resnet101Encoder(ResnetEncoder): def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet101(pretrained=pretrained), - [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], - layers, - ) + super().__init__(resnet101(pretrained=pretrained), [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) class Resnet152Encoder(ResnetEncoder): def __init__(self, pretrained=True, layers=None): - super().__init__( - resnet152(pretrained=pretrained), - [64, 256, 512, 1024, 2048], - [2, 4, 8, 16, 32], - layers, - ) + super().__init__(resnet152(pretrained=pretrained), [64, 256, 512, 1024, 2048], [2, 4, 8, 16, 32], layers) diff --git a/pytorch_toolbelt/modules/encoders/squeezenet.py b/pytorch_toolbelt/modules/encoders/squeezenet.py index 9b66e5baf..7e5814a4a 100644 --- a/pytorch_toolbelt/modules/encoders/squeezenet.py +++ b/pytorch_toolbelt/modules/encoders/squeezenet.py @@ -46,10 +46,7 @@ def __init__(self, pretrained=True, layers=[1, 2, 3]): # Fire(384, 64, 256, 256), # Fire(512, 64, 256, 256), self.layer3 = nn.Sequential( - squeezenet.features[9], - squeezenet.features[10], - squeezenet.features[11], - squeezenet.features[12], + squeezenet.features[9], squeezenet.features[10], squeezenet.features[11], squeezenet.features[12] ) @property diff --git a/pytorch_toolbelt/modules/encoders/unet.py b/pytorch_toolbelt/modules/encoders/unet.py index 28703834b..fd04cf97e 100644 --- a/pytorch_toolbelt/modules/encoders/unet.py +++ b/pytorch_toolbelt/modules/encoders/unet.py @@ -11,25 +11,9 @@ class UnetEncoderBlock(nn.Module): def __init__(self, in_dec_filters, out_filters, abn_block=ABN, stride=1, **kwargs): super().__init__() - self.conv1 = nn.Conv2d( - in_dec_filters, - out_filters, - kernel_size=3, - padding=1, - stride=1, - bias=False, - **kwargs, - ) + self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=1, bias=False, **kwargs) self.bn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d( - out_filters, - out_filters, - kernel_size=3, - padding=1, - stride=stride, - bias=False, - **kwargs, - ) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, stride=stride, bias=False, **kwargs) self.bn2 = abn_block(out_filters) def forward(self, x): @@ -41,14 +25,7 @@ def forward(self, x): class UnetEncoder(EncoderModule): - def __init__( - self, - input_channels=3, - features=32, - num_layers=4, - growth_factor=2, - abn_block=ABN, - ): + def __init__(self, input_channels=3, features=32, num_layers=4, growth_factor=2, abn_block=ABN): feature_maps = [features * growth_factor * (i + 1) for i in range(num_layers)] strides = [2 * (i + 1) for i in range(num_layers)] super().__init__(feature_maps, strides, layers=list(range(num_layers))) diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py index 8ccdb1ad0..d9d611ff3 100644 --- a/pytorch_toolbelt/modules/encoders/wide_resnet.py +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -1,7 +1,7 @@ from typing import List -from ..modules.abn import ABN -from pytorch_toolbelt.modules.backbone.wider_resnet import WiderResNet, WiderResNetA2 +from ..abn import ABN +from ..backbone.wider_resnet import WiderResNet, WiderResNetA2 from .common import EncoderModule, _take @@ -19,9 +19,7 @@ class WiderResnetEncoder(EncoderModule): def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): - super().__init__( - [64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers - ) + super().__init__([64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers) encoder = WiderResNet(structure, classes=0, norm_act=norm_act) self.layer0 = encoder.mod1 @@ -40,15 +38,7 @@ def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): @property def encoder_layers(self): - return [ - self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4, - self.layer5, - self.layer6, - ] + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4, self.layer5, self.layer6] def forward(self, input): output_features = [] @@ -101,9 +91,7 @@ def __init__(self, layers=None): class WiderResnetA2Encoder(EncoderModule): def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): - super().__init__( - [64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers - ) + super().__init__([64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers) encoder = WiderResNetA2(structure=structure, classes=0, norm_act=norm_act) self.layer0 = encoder.mod1 @@ -119,15 +107,7 @@ def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): @property def encoder_layers(self): - return [ - self.layer0, - self.layer1, - self.layer2, - self.layer3, - self.layer4, - self.layer5, - self.layer6, - ] + return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4, self.layer5, self.layer6] def forward(self, input): output_features = [] diff --git a/pytorch_toolbelt/modules/fpn.py b/pytorch_toolbelt/modules/fpn.py index 91b722fbe..43fc91cbd 100644 --- a/pytorch_toolbelt/modules/fpn.py +++ b/pytorch_toolbelt/modules/fpn.py @@ -28,9 +28,7 @@ def forward(self, x): class FPNBottleneckBlockBN(nn.Module): def __init__(self, input_channels, output_channels): super().__init__() - self.conv = nn.Conv2d( - input_channels, output_channels, kernel_size=1, bias=False - ) + self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=1, bias=False) self.bn = nn.BatchNorm2d(output_channels) def forward(self, x): @@ -39,15 +37,11 @@ def forward(self, x): class FPNPredictionBlock(nn.Module): - def __init__(self, input_channels, output_channels, mode="nearest", align_corners=None): + def __init__(self, input_channels: int, output_channels: int): super().__init__() self.input_channels = input_channels self.output_channels = output_channels - self.conv = nn.Conv2d( - self.input_channels, self.output_channels, kernel_size=3, padding=1 - ) - self.mode = mode - self.align_corners = align_corners + self.conv = nn.Conv2d(self.input_channels, self.output_channels, kernel_size=3, padding=1) def forward(self, x): x = self.conv(x) @@ -59,9 +53,7 @@ class UpsampleAdd(nn.Module): Compute pixelwise sum of first tensor and upsampled second tensor. """ - def __init__( - self, filters: int, upsample_scale=None, mode="nearest", align_corners=None - ): + def __init__(self, filters: int, upsample_scale=None, mode="nearest", align_corners=None): super().__init__() self.interpolation_mode = mode self.upsample_scale = upsample_scale @@ -71,17 +63,11 @@ def forward(self, x, y=None): if y is not None: if self.upsample_scale is not None: y = F.interpolate( - y, - scale_factor=self.upsample_scale, - mode=self.interpolation_mode, - align_corners=self.align_corners, + y, scale_factor=self.upsample_scale, mode=self.interpolation_mode, align_corners=self.align_corners ) else: y = F.interpolate( - y, - size=(x.size(2), x.size(3)), - mode=self.interpolation_mode, - align_corners=self.align_corners, + y, size=(x.size(2), x.size(3)), mode=self.interpolation_mode, align_corners=self.align_corners ) x = x + y @@ -95,9 +81,7 @@ class UpsampleAddConv(nn.Module): to smooth aliasing artifacts """ - def __init__( - self, filters: int, upsample_scale=None, mode="nearest", align_corners=None - ): + def __init__(self, filters: int, upsample_scale=None, mode="nearest", align_corners=None): super().__init__() self.interpolation_mode = mode self.upsample_scale = upsample_scale @@ -108,17 +92,11 @@ def forward(self, x, y=None): if y is not None: if self.upsample_scale is not None: y = F.interpolate( - y, - scale_factor=self.upsample_scale, - mode=self.interpolation_mode, - align_corners=self.align_corners, + y, scale_factor=self.upsample_scale, mode=self.interpolation_mode, align_corners=self.align_corners ) else: y = F.interpolate( - y, - size=(x.size(2), x.size(3)), - mode=self.interpolation_mode, - align_corners=self.align_corners, + y, size=(x.size(2), x.size(3)), mode=self.interpolation_mode, align_corners=self.align_corners ) x = x + y @@ -138,11 +116,7 @@ def forward(self, features): dst_size = features[0].size()[-2:] for f in features: - layers.append( - F.interpolate( - f, size=dst_size, mode=self.mode, align_corners=self.align_corners - ) - ) + layers.append(F.interpolate(f, size=dst_size, mode=self.mode, align_corners=self.align_corners)) return torch.cat(layers, dim=1) @@ -160,9 +134,7 @@ def forward(self, features): dst_size = features[0].size()[-2:] for f in features[1:]: - output = output + F.interpolate( - f, size=dst_size, mode=self.mode, align_corners=self.align_corners - ) + output = output + F.interpolate(f, size=dst_size, mode=self.mode, align_corners=self.align_corners) return output @@ -180,9 +152,7 @@ class HFF(nn.Module): >>> feature_map = feature_map_0 + up(feature_map[1] + up(feature_map[2] + up(feature_map[3] + ...)))) """ - def __init__( - self, sizes=None, upsample_scale=2, mode="nearest", align_corners=None - ): + def __init__(self, sizes=None, upsample_scale=2, mode="nearest", align_corners=None): super().__init__() self.sizes = sizes self.interpolation_mode = mode @@ -195,9 +165,7 @@ def forward(self, features): current_map = features[-1] for feature_map_index in reversed(range(num_feature_maps - 1)): if self.sizes is not None: - prev_upsampled = self._upsample( - current_map, self.sizes[feature_map_index] - ) + prev_upsampled = self._upsample(current_map, self.sizes[feature_map_index]) else: prev_upsampled = self._upsample(current_map) @@ -215,9 +183,6 @@ def _upsample(self, x, output_size=None): ) else: x = F.interpolate( - x, - scale_factor=self.upsample_scale, - mode=self.interpolation_mode, - align_corners=self.align_corners, + x, scale_factor=self.upsample_scale, mode=self.interpolation_mode, align_corners=self.align_corners ) return x diff --git a/pytorch_toolbelt/modules/pooling.py b/pytorch_toolbelt/modules/pooling.py index ec845e569..47b7d405c 100644 --- a/pytorch_toolbelt/modules/pooling.py +++ b/pytorch_toolbelt/modules/pooling.py @@ -5,13 +5,7 @@ import torch.nn as nn import torch.nn.functional as F -__all__ = [ - "GlobalAvgPool2d", - "GlobalMaxPool2d", - "GWAP", - "RMSPool", - "MILCustomPoolingModule", -] +__all__ = ["GlobalAvgPool2d", "GlobalMaxPool2d", "GWAP", "RMSPool", "MILCustomPoolingModule"] class GlobalAvgPool2d(nn.Module): @@ -97,7 +91,5 @@ def __init__(self, in_channels, out_channels, reduction=4): def forward(self, x): weight = self.weight_generator(x) loss = self.classifier(x) - logits = torch.sum(weight * loss, dim=[2, 3]) / ( - torch.sum(weight, dim=[2, 3]) + 1e-6 - ) + logits = torch.sum(weight * loss, dim=[2, 3]) / (torch.sum(weight, dim=[2, 3]) + 1e-6) return logits diff --git a/pytorch_toolbelt/modules/scse.py b/pytorch_toolbelt/modules/scse.py index d81bbdc16..742e11fd5 100644 --- a/pytorch_toolbelt/modules/scse.py +++ b/pytorch_toolbelt/modules/scse.py @@ -6,13 +6,7 @@ from torch import nn, Tensor from torch.nn import functional as F -__all__ = [ - "ChannelGate2d", - "SpatialGate2d", - "ChannelSpatialGate2d", - "SpatialGate2dV2", - "ChannelSpatialGate2dV2", -] +__all__ = ["ChannelGate2d", "SpatialGate2d", "ChannelSpatialGate2d", "SpatialGate2dV2", "ChannelSpatialGate2dV2"] class ChannelGate2d(nn.Module): @@ -45,12 +39,8 @@ def __init__(self, channels, reduction=None, squeeze_channels=None): :param squeeze_channels: Number of channels in squeeze block. """ super().__init__() - assert ( - reduction or squeeze_channels - ), "One of 'reduction' and 'squeeze_channels' must be set" - assert not ( - reduction and squeeze_channels - ), "'reduction' and 'squeeze_channels' are mutually exclusive" + assert reduction or squeeze_channels, "One of 'reduction' and 'squeeze_channels' must be set" + assert not (reduction and squeeze_channels), "'reduction' and 'squeeze_channels' are mutually exclusive" if squeeze_channels is None: squeeze_channels = max(1, channels // reduction) @@ -91,9 +81,7 @@ def __init__(self, channels, reduction=4): super().__init__() squeeze_channels = max(1, channels // reduction) self.squeeze = nn.Conv2d(channels, squeeze_channels, kernel_size=1, padding=0) - self.conv = nn.Conv2d( - squeeze_channels, squeeze_channels, kernel_size=7, dilation=3, padding=3 * 3 - ) + self.conv = nn.Conv2d(squeeze_channels, squeeze_channels, kernel_size=7, dilation=3, padding=3 * 3) self.expand = nn.Conv2d(squeeze_channels, channels, kernel_size=1, padding=0) def forward(self, x: Tensor): diff --git a/pytorch_toolbelt/modules/srm.py b/pytorch_toolbelt/modules/srm.py index 9ce462107..82545cd53 100644 --- a/pytorch_toolbelt/modules/srm.py +++ b/pytorch_toolbelt/modules/srm.py @@ -12,9 +12,7 @@ def __init__(self, channels: int): super(SRMLayer, self).__init__() # Equal to torch.einsum('bck,ck->bc', A, B) - self.cfc = nn.Conv1d( - channels, channels, kernel_size=2, bias=False, groups=channels - ) + self.cfc = nn.Conv1d(channels, channels, kernel_size=2, bias=False, groups=channels) self.bn = nn.BatchNorm1d(channels) def forward(self, x): diff --git a/pytorch_toolbelt/optimization/functional.py b/pytorch_toolbelt/optimization/functional.py index 975347210..5f55e02ed 100644 --- a/pytorch_toolbelt/optimization/functional.py +++ b/pytorch_toolbelt/optimization/functional.py @@ -1,7 +1,6 @@ def get_lr_decay_parameters(parameters, learning_rate, groups: dict): custom_lr_parameters = dict( - (group_name, {"params": [], "lr": learning_rate * lr_factor}) - for (group_name, lr_factor) in groups.items() + (group_name, {"params": [], "lr": learning_rate * lr_factor}) for (group_name, lr_factor) in groups.items() ) custom_lr_parameters["default"] = {"params": [], "lr": learning_rate} diff --git a/pytorch_toolbelt/optimization/lr_schedules.py b/pytorch_toolbelt/optimization/lr_schedules.py index e37a37a28..913d15abc 100644 --- a/pytorch_toolbelt/optimization/lr_schedules.py +++ b/pytorch_toolbelt/optimization/lr_schedules.py @@ -22,9 +22,7 @@ def __init__(self, optimizer, epochs, min_lr_factor=0.05, max_lr=1.0): super().__init__(optimizer) def get_lr(self): - return [ - base_lr * self.learning_rates[self.last_epoch] for base_lr in self.base_lrs - ] + return [base_lr * self.learning_rates[self.last_epoch] for base_lr in self.base_lrs] class CosineAnnealingLRWithDecay(_LRScheduler): diff --git a/pytorch_toolbelt/utils/catalyst/criterions.py b/pytorch_toolbelt/utils/catalyst/criterions.py index f1bbc49a7..e5fd3a83b 100644 --- a/pytorch_toolbelt/utils/catalyst/criterions.py +++ b/pytorch_toolbelt/utils/catalyst/criterions.py @@ -67,15 +67,11 @@ def get_multiplier(self, training_progress, schedule, start, end): def on_loader_start(self, state: RunnerState): self.is_needed = not self.on_train_only or state.loader_name.startswith("train") if self.is_needed: - state.metrics.epoch_values[state.loader_name][ - f"l{self.p}_weight_decay" - ] = self.multiplier + state.metrics.epoch_values[state.loader_name][f"l{self.p}_weight_decay"] = self.multiplier def on_epoch_start(self, state: RunnerState): training_progress = float(state.epoch) / float(state.num_epochs) - self.multiplier = self.get_multiplier( - training_progress, self.schedule, self.start_wd, self.end_wd - ) + self.multiplier = self.get_multiplier(training_progress, self.schedule, self.start_wd, self.end_wd) def on_batch_end(self, state: RunnerState): if not self.is_needed: diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 3a0eacb04..8cd1ad017 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -5,10 +5,7 @@ from catalyst.dl import Callback, RunnerState, MetricCallback, CallbackOrder from .visualization import get_tensorboard_logger from ..torch_utils import to_numpy -from pytorch_toolbelt.utils.visualization import ( - render_figure_to_tensor, - plot_confusion_matrix, -) +from pytorch_toolbelt.utils.visualization import render_figure_to_tensor, plot_confusion_matrix from sklearn.metrics import f1_score, confusion_matrix __all__ = [ @@ -50,11 +47,7 @@ class PixelAccuracyCallback(MetricCallback): """ def __init__( - self, - input_key: str = "targets", - output_key: str = "logits", - prefix: str = "accuracy", - ignore_index=None, + self, input_key: str = "targets", output_key: str = "logits", prefix: str = "accuracy", ignore_index=None ): """ :param input_key: input key to use for iou calculation; @@ -150,11 +143,7 @@ class MacroF1Callback(Callback): """ def __init__( - self, - input_key: str = "targets", - output_key: str = "logits", - prefix: str = "macro_f1", - ignore_index=None, + self, input_key: str = "targets", output_key: str = "logits", prefix: str = "macro_f1", ignore_index=None ): """ :param input_key: input key to use for precision calculation; @@ -163,9 +152,7 @@ def __init__( specifies our `y_pred`. """ super().__init__(CallbackOrder.Metric) - self.metric_fn = lambda outputs, targets: f1_score( - targets, outputs, average="macro" - ) + self.metric_fn = lambda outputs, targets: f1_score(targets, outputs, average="macro") self.prefix = prefix self.output_key = output_key self.input_key = input_key @@ -208,12 +195,7 @@ def on_loader_end(self, state): def binary_dice_iou_score( - y_pred: torch.Tensor, - y_true: torch.Tensor, - mode="dice", - threshold=None, - nan_score_on_empty=False, - eps=1e-7, + y_pred: torch.Tensor, y_true: torch.Tensor, mode="dice", threshold=None, nan_score_on_empty=False, eps=1e-7 ) -> float: """ Compute IoU score between two image tensors @@ -360,10 +342,7 @@ def __init__( if self.mode == BINARY_MODE: self.score_fn = partial( - binary_dice_iou_score, - threshold=0.0, - nan_score_on_empty=nan_score_on_empty, - mode=metric, + binary_dice_iou_score, threshold=0.0, nan_score_on_empty=nan_score_on_empty, mode=metric ) if self.mode == MULTICLASS_MODE: @@ -395,9 +374,7 @@ def on_batch_end(self, state: RunnerState): batch_size = targets.size(0) score_per_image = [] for image_index in range(batch_size): - score_per_class = self.score_fn( - y_pred=outputs[image_index], y_true=targets[image_index] - ) + score_per_class = self.score_fn(y_pred=outputs[image_index], y_true=targets[image_index]) score_per_image.append(score_per_class) mean_score = np.nanmean(score_per_image) @@ -419,6 +396,4 @@ def on_loader_end(self, state): scores_per_class = np.nanmean(scores, axis=0) for class_name, score_per_class in zip(class_names, scores_per_class): - state.metrics.epoch_values[state.loader_name][ - self.prefix + "_" + class_name - ] = float(score_per_class) + state.metrics.epoch_values[state.loader_name][self.prefix + "_" + class_name] = float(score_per_class) diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index 40640d78a..3163fa461 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -96,9 +96,7 @@ def on_loader_start(self, state): def on_batch_end(self, state: RunnerState): value = state.metrics.batch_values.get(self.target_metric, None) if value is None: - warnings.warn( - f"Metric value for {self.target_metric} is not available in state.metrics.batch_values" - ) + warnings.warn(f"Metric value for {self.target_metric} is not available in state.metrics.batch_values") return if self.best_score is None or self.is_better(value, self.best_score): @@ -125,11 +123,7 @@ def on_loader_end(self, state: RunnerState) -> None: def _log_samples(self, samples, name, logger, step): if "tensorboard" in self.targets: for i, image in enumerate(samples): - logger.add_image( - f"{self.target_metric}/{name}/{i}", - tensor_from_rgb_image(image), - step, - ) + logger.add_image(f"{self.target_metric}/{name}/{i}", tensor_from_rgb_image(image), step) if "matplotlib" in self.targets: for i, image in enumerate(samples): @@ -151,11 +145,7 @@ def draw_binary_segmentation_predictions( std=(0.229, 0.224, 0.225), ): images = [] - image_id_input = ( - input[image_id_key] - if image_id_key is not None - else [None] * len(input[image_key]) - ) + image_id_input = input[image_id_key] if image_id_key is not None else [None] * len(input[image_key]) for image, target, image_id, logits in zip( input[image_key], input[targets_key], image_id_input, output[outputs_key] @@ -171,23 +161,14 @@ def draw_binary_segmentation_predictions( overlay[true_mask & pred_mask] = np.array( [0, 250, 0], dtype=overlay.dtype ) # Correct predictions (Hits) painted with green - overlay[true_mask & ~pred_mask] = np.array( - [250, 0, 0], dtype=overlay.dtype - ) # Misses painted with red + overlay[true_mask & ~pred_mask] = np.array([250, 0, 0], dtype=overlay.dtype) # Misses painted with red overlay[~true_mask & pred_mask] = np.array( [250, 250, 0], dtype=overlay.dtype ) # False alarm painted with yellow overlay = cv2.addWeighted(image, 0.5, overlay, 0.5, 0, dtype=cv2.CV_8U) if image_id is not None: - cv2.putText( - overlay, - str(image_id), - (10, 15), - cv2.FONT_HERSHEY_PLAIN, - 1, - (250, 250, 250), - ) + cv2.putText(overlay, str(image_id), (10, 15), cv2.FONT_HERSHEY_PLAIN, 1, (250, 250, 250)) images.append(overlay) return images @@ -208,11 +189,7 @@ def draw_semantic_segmentation_predictions( assert mode in {"overlay", "side-by-side"} images = [] - image_id_input = ( - input[image_id_key] - if image_id_key is not None - else [None] * len(input[image_key]) - ) + image_id_input = input[image_id_key] if image_id_key is not None else [None] * len(input[image_key]) for image, target, image_id, logits in zip( input[image_key], input[targets_key], image_id_input, output[outputs_key] @@ -229,14 +206,7 @@ def draw_semantic_segmentation_predictions( overlay = cv2.addWeighted(image, 0.5, overlay, 0.5, 0, dtype=cv2.CV_8U) if image_id is not None: - cv2.putText( - overlay, - str(image_id), - (10, 15), - cv2.FONT_HERSHEY_PLAIN, - 1, - (250, 250, 250), - ) + cv2.putText(overlay, str(image_id), (10, 15), cv2.FONT_HERSHEY_PLAIN, 1, (250, 250, 250)) elif mode == "side-by-side": true_mask = np.zeros_like(image) diff --git a/pytorch_toolbelt/utils/dataset_utils.py b/pytorch_toolbelt/utils/dataset_utils.py index d19d17906..ee849ab96 100644 --- a/pytorch_toolbelt/utils/dataset_utils.py +++ b/pytorch_toolbelt/utils/dataset_utils.py @@ -2,27 +2,16 @@ from pytorch_toolbelt.inference.tiles import ImageSlicer from pytorch_toolbelt.utils.fs import id_from_fname -from pytorch_toolbelt.utils.torch_utils import ( - tensor_from_rgb_image, - tensor_from_mask_image, -) +from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, tensor_from_mask_image from torch.utils.data import Dataset, ConcatDataset class ImageMaskDataset(Dataset): def __init__( - self, - image_filenames, - target_filenames, - image_loader, - target_loader, - transform=None, - keep_in_mem=False, + self, image_filenames, target_filenames, image_loader, target_loader, transform=None, keep_in_mem=False ): if len(image_filenames) != len(target_filenames): - raise ValueError( - "Number of images does not corresponds to number of targets" - ) + raise ValueError("Number of images does not corresponds to number of targets") self.image_ids = [id_from_fname(fname) for fname in image_filenames] @@ -97,8 +86,7 @@ def __init__( self.transform = transform self.image_ids = [ - id_from_fname(image_fname) + f" [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]" - for crop in self.slicer.crops + id_from_fname(image_fname) + f" [{crop[0]};{crop[1]};{crop[2]};{crop[3]};]" for crop in self.slicer.crops ] def _get_image(self, index): @@ -142,14 +130,10 @@ def __init__( **kwargs, ): if len(image_filenames) != len(target_filenames): - raise ValueError( - "Number of images does not corresponds to number of targets" - ) + raise ValueError("Number of images does not corresponds to number of targets") datasets = [] for image, mask in zip(image_filenames, target_filenames): - dataset = TiledSingleImageDataset( - image, mask, image_loader, target_loader, **kwargs - ) + dataset = TiledSingleImageDataset(image, mask, image_loader, target_loader, **kwargs) datasets.append(dataset) super().__init__(datasets) diff --git a/pytorch_toolbelt/utils/fs.py b/pytorch_toolbelt/utils/fs.py index c77ff7446..59443d2b7 100644 --- a/pytorch_toolbelt/utils/fs.py +++ b/pytorch_toolbelt/utils/fs.py @@ -51,14 +51,11 @@ def auto_file(filename: str, where: str = ".") -> str: files = list(glob.iglob(os.path.join(where, "**", filename), recursive=True)) if len(files) == 0: - raise FileNotFoundError( - "Given file could not be found with recursive search:" + filename - ) + raise FileNotFoundError("Given file could not be found with recursive search:" + filename) if len(files) > 1: raise FileNotFoundError( - "More than one file matches given filename. Please specify it explicitly:\n" - + "\n".join(files) + "More than one file matches given filename. Please specify it explicitly:\n" + "\n".join(files) ) return files[0] diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index 840bfae4d..c0950d870 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -108,9 +108,7 @@ def tensor_from_mask_image(mask: np.ndarray) -> torch.Tensor: return tensor_from_rgb_image(mask) -def rgb_image_from_tensor( - image: torch.Tensor, mean, std, max_pixel_value=255.0, dtype=np.uint8 -) -> np.ndarray: +def rgb_image_from_tensor(image: torch.Tensor, mean, std, max_pixel_value=255.0, dtype=np.uint8) -> np.ndarray: image = np.moveaxis(to_numpy(image), 0, -1) mean = to_numpy(mean) std = to_numpy(std) @@ -144,8 +142,6 @@ def transfer_weights(model: nn.Module, model_state_dict: collections.OrderedDict """ for name, value in model_state_dict.items(): try: - model.load_state_dict( - collections.OrderedDict([(name, value)]), strict=False - ) + model.load_state_dict(collections.OrderedDict([(name, value)]), strict=False) except Exception as e: print(e) diff --git a/pytorch_toolbelt/utils/visualization.py b/pytorch_toolbelt/utils/visualization.py index df2fabaf5..e1c45e42e 100644 --- a/pytorch_toolbelt/utils/visualization.py +++ b/pytorch_toolbelt/utils/visualization.py @@ -7,13 +7,7 @@ def plot_confusion_matrix( - cm, - class_names, - figsize=(16, 16), - normalize=False, - title="Confusion matrix", - fname=None, - noshow=False, + cm, class_names, figsize=(16, 16), normalize=False, title="Confusion matrix", fname=None, noshow=False ): """Render the confusion matrix and return matplotlib's figure with it. Normalization can be applied by setting `normalize=True`. @@ -44,11 +38,7 @@ def plot_confusion_matrix( thresh = cm.max() / 2.0 for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text( - j, - i, - format(cm[i, j], fmt), - horizontalalignment="center", - color="white" if cm[i, j] > thresh else "black", + j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black" ) plt.tight_layout() diff --git a/setup.py b/setup.py index 153565584..c2ae21ef2 100644 --- a/setup.py +++ b/setup.py @@ -46,9 +46,9 @@ def load_readme(): def get_test_requirements(): - requirements = ['pytest', 'catalyst>=19.6.4'] + requirements = ["pytest", "catalyst>=19.6.4"] if sys.version_info < (3, 3): - requirements.append('mock') + requirements.append("mock") return requirements @@ -64,15 +64,15 @@ def get_test_requirements(): packages=find_packages(exclude=EXCLUDE_FROM_PACKAGES), install_requires=DEPENDENCIES, python_requires=REQUIRES_PYTHON, - extras_require={'tests': get_test_requirements()}, + extras_require={"tests": get_test_requirements()}, include_package_data=True, keywords=["PyTorch", "Kaggle", "Deep Learning", "Machine Learning", "ResNet", "VGG", "ResNext", "Unet", "Focal"], scripts=[], license="License :: OSI Approved :: MIT License", classifiers=[ - 'Development Status :: 4 - Beta', - 'Intended Audience :: Developers', - 'Intended Audience :: Science/Research', + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.6", @@ -81,8 +81,8 @@ def get_test_requirements(): "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Scientific/Engineering :: Image Recognition", "Topic :: Scientific/Engineering :: Artificial Intelligence", - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: Software Development :: Libraries :: Application Frameworks" # "Private :: Do Not Upload" ], diff --git a/tests/test_decoders.py b/tests/test_decoders.py index b850ecb7f..8c5b115cb 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -6,9 +6,8 @@ from pytorch_toolbelt.modules.decoders import FPNSumDecoder, FPNCatDecoder from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="Cuda is not available" -) +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") + @torch.no_grad() def test_fpn_sum(): diff --git a/tests/test_encoders.py b/tests/test_encoders.py index e45e402db..460fed2c6 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -5,9 +5,8 @@ from pytorch_toolbelt.modules.backbone.inceptionv4 import inceptionv4 from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="Cuda is not available" -) +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") + @pytest.mark.parametrize( ["encoder", "encoder_params"], @@ -54,9 +53,7 @@ def test_encoders(encoder: E.EncoderModule, encoder_params): net = maybe_cuda(net) output = net(input) assert len(output) == len(net.output_filters) - for feature_map, expected_stride, expected_channels in zip( - output, net.output_strides, net.output_filters - ): + for feature_map, expected_stride, expected_channels in zip(output, net.output_strides, net.output_filters): assert feature_map.size(1) == expected_channels assert feature_map.size(2) * expected_stride == 256 assert feature_map.size(3) * expected_stride == 256 @@ -90,3 +87,4 @@ def test_densenet(): net2.classifier = None print(count_parameters(net1), count_parameters(net2)) + diff --git a/tests/test_modules.py b/tests/test_modules.py index 5b3321e72..2840bf6c5 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -6,10 +6,7 @@ from pytorch_toolbelt.modules.fpn import HFF from pytorch_toolbelt.utils.torch_utils import maybe_cuda, count_parameters -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="Cuda is not available" -) - +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") def test_hff_dynamic_size(): @@ -39,4 +36,3 @@ def test_hff_static_size(): output = hff(feature_maps) assert output.size(2) == 512 assert output.size(3) == 512 - diff --git a/tests/test_tiles.py b/tests/test_tiles.py index 42e2083b3..1f0935f74 100644 --- a/tests/test_tiles.py +++ b/tests/test_tiles.py @@ -1,19 +1,13 @@ import numpy as np import torch from pytorch_toolbelt.inference.tiles import ImageSlicer, CudaTileMerger -from pytorch_toolbelt.utils.torch_utils import ( - tensor_from_rgb_image, - rgb_image_from_tensor, - to_numpy, -) +from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, rgb_image_from_tensor, to_numpy from torch import nn from torch.utils.data import DataLoader import pytest -skip_if_no_cuda = pytest.mark.skipif( - not torch.cuda.is_available(), reason="Cuda is not available" -) +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") def test_tiles_split_merge(): @@ -26,9 +20,7 @@ def test_tiles_split_merge(): def test_tiles_split_merge_non_dividable(): image = np.random.random((563, 512, 3)).astype(np.uint8) - tiler = ImageSlicer( - image.shape, tile_size=(128, 128), tile_step=(128, 128), weight="mean" - ) + tiler = ImageSlicer(image.shape, tile_size=(128, 128), tile_step=(128, 128), weight="mean") tiles = tiler.split(image) merged = tiler.merge(tiles, dtype=np.uint8) np.testing.assert_equal(merged, image) @@ -38,19 +30,13 @@ def test_tiles_split_merge_non_dividable(): def test_tiles_split_merge_non_dividable_cuda(): image = np.random.random((5632, 5120, 3)).astype(np.uint8) - tiler = ImageSlicer( - image.shape, tile_size=(1280, 1280), tile_step=(1280, 1280), weight="mean" - ) + tiler = ImageSlicer(image.shape, tile_size=(1280, 1280), tile_step=(1280, 1280), weight="mean") tiles = tiler.split(image) - merger = CudaTileMerger( - tiler.target_shape, channels=image.shape[2], weight=tiler.weight - ) + merger = CudaTileMerger(tiler.target_shape, channels=image.shape[2], weight=tiler.weight) for tile, coordinates in zip(tiles, tiler.crops): # Integrate as batch of size 1 - merger.integrate_batch( - tensor_from_rgb_image(tile).unsqueeze(0).float().cuda(), [coordinates] - ) + merger.integrate_batch(tensor_from_rgb_image(tile).unsqueeze(0).float().cuda(), [coordinates]) merged = merger.merge() merged = rgb_image_from_tensor(merged, mean=0, std=1, max_pixel_value=1) @@ -61,9 +47,7 @@ def test_tiles_split_merge_non_dividable_cuda(): def test_tiles_split_merge_2(): image = np.random.random((5000, 5000, 3)).astype(np.uint8) - tiler = ImageSlicer( - image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid" - ) + tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid") np.testing.assert_allclose(tiler.weight, tiler.weight.T) @@ -83,17 +67,13 @@ def forward(self, input): return max_channel image = np.random.random((5000, 5000, 3)).astype(np.uint8) - tiler = ImageSlicer( - image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid" - ) + tiler = ImageSlicer(image.shape, tile_size=(512, 512), tile_step=(256, 256), weight="pyramid") tiles = [tensor_from_rgb_image(tile) for tile in tiler.split(image)] model = MaxChannelIntensity().eval().cuda() merger = CudaTileMerger(tiler.target_shape, 1, tiler.weight) - for tiles_batch, coords_batch in DataLoader( - list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True - ): + for tiles_batch, coords_batch in DataLoader(list(zip(tiles, tiler.crops)), batch_size=8, pin_memory=True): tiles_batch = tiles_batch.float().cuda() pred_batch = model(tiles_batch) diff --git a/tests/test_tta.py b/tests/test_tta.py index 46de97033..4a4fcf42a 100644 --- a/tests/test_tta.py +++ b/tests/test_tta.py @@ -38,12 +38,7 @@ def test_fliplr_image2mask(): def test_d4_image2label(): - input = ( - torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]) - .unsqueeze(0) - .unsqueeze(0) - .float() - ) + input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float() model = SumAll() output = tta.d4_image2label(model, input) @@ -53,12 +48,7 @@ def test_d4_image2label(): def test_fliplr_image2label(): - input = ( - torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]) - .unsqueeze(0) - .unsqueeze(0) - .float() - ) + input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float() model = SumAll() output = tta.fliplr_image2label(model, input) @@ -68,45 +58,20 @@ def test_fliplr_image2label(): def test_fivecrop_image2label(): - input = ( - torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]) - .unsqueeze(0) - .unsqueeze(0) - .float() - ) + input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float() model = SumAll() output = tta.fivecrop_image2label(model, input, (2, 2)) - expected = ( - (1 + 2 + 5 + 6) - + (3 + 4 + 7 + 8) - + (9 + 0 + 3 + 4) - + (1 + 2 + 5 + 6) - + (6 + 7 + 0 + 1) - ) / 5 + expected = ((1 + 2 + 5 + 6) + (3 + 4 + 7 + 8) + (9 + 0 + 3 + 4) + (1 + 2 + 5 + 6) + (6 + 7 + 0 + 1)) / 5 assert int(output) == expected def test_tencrop_image2label(): - input = ( - torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]) - .unsqueeze(0) - .unsqueeze(0) - .float() - ) + input = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 0, 1, 2], [3, 4, 5, 6]]).unsqueeze(0).unsqueeze(0).float() model = SumAll() output = tta.tencrop_image2label(model, input, (2, 2)) - expected = ( - 2 - * ( - (1 + 2 + 5 + 6) - + (3 + 4 + 7 + 8) - + (9 + 0 + 3 + 4) - + (1 + 2 + 5 + 6) - + (6 + 7 + 0 + 1) - ) - ) / 10 + expected = (2 * ((1 + 2 + 5 + 6) + (3 + 4 + 7 + 8) + (9 + 0 + 3 + 4) + (1 + 2 + 5 + 6) + (6 + 7 + 0 + 1))) / 10 assert int(output) == expected From 65ecc2a28736e683ae77a4fe918efc8faffbc418 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Sun, 27 Oct 2019 21:08:33 +0200 Subject: [PATCH 05/79] Fix wrong channels size --- pytorch_toolbelt/modules/decoders/fpn_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index febf629e2..5cd7283a6 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -74,7 +74,7 @@ def __init__( nn.Conv2d(features, features // 2, kernel_size=1), abn_block(features // 2), nn.Conv2d(features // 2, features // 4, kernel_size=3, padding=1, bias=False), - abn_block(features // 2), + abn_block(features // 4), nn.Conv2d(features // 4, features // 4, kernel_size=3, padding=1, bias=False), abn_block(features // 4), nn.Conv2d(features // 4, num_classes, kernel_size=1, bias=True), From 6a01dc1c3f4d22b4a30e0bf2499bf781d27cd43d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 9 Nov 2019 21:33:47 +0200 Subject: [PATCH 06/79] Adding flexible input channels support --- pytorch_toolbelt/losses/jaccard.py | 2 +- pytorch_toolbelt/modules/backbone/hrnet.py | 4 +- pytorch_toolbelt/modules/encoders/common.py | 43 ++++++++++++++++++- pytorch_toolbelt/modules/encoders/densenet.py | 16 ++++++- pytorch_toolbelt/modules/encoders/hrnet.py | 2 +- .../modules/encoders/inception.py | 5 ++- pytorch_toolbelt/modules/encoders/resnet.py | 9 +++- pytorch_toolbelt/modules/encoders/seresnet.py | 5 ++- 8 files changed, 75 insertions(+), 11 deletions(-) diff --git a/pytorch_toolbelt/losses/jaccard.py b/pytorch_toolbelt/losses/jaccard.py index 1207d9138..864a4c3e1 100644 --- a/pytorch_toolbelt/losses/jaccard.py +++ b/pytorch_toolbelt/losses/jaccard.py @@ -8,7 +8,7 @@ from .functional import soft_jaccard_score -__all__ = ["JaccardLoss"] +__all__ = ["JaccardLoss", "BINARY_MODE", "MULTICLASS_MODE", "MULTILABEL_MODE"] BINARY_MODE = "binary" MULTICLASS_MODE = "multiclass" diff --git a/pytorch_toolbelt/modules/backbone/hrnet.py b/pytorch_toolbelt/modules/backbone/hrnet.py index cd0a91c20..39f7c2bcc 100644 --- a/pytorch_toolbelt/modules/backbone/hrnet.py +++ b/pytorch_toolbelt/modules/backbone/hrnet.py @@ -224,7 +224,7 @@ def forward(self, x): class HRNetV2(nn.Module): - def __init__(self, width=48, **kwargs): + def __init__(self, input_channels=3, width=48, **kwargs): super(HRNetV2, self).__init__() blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} @@ -260,7 +260,7 @@ def __init__(self, width=48, **kwargs): self.layer0 = nn.Sequential( OrderedDict( [ - ("conv1", nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)), + ("conv1", nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1, bias=False)), ("bn1", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), ("relu", nn.ReLU(inplace=True)), ("conv2", nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)), diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index 7f35c20a5..91e4bfae2 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -7,11 +7,43 @@ from torch import nn +import warnings +import torch.nn.functional as F + +__all__ = ["EncoderModule", "_take", "make_n_channel_input"] + def _take(elements, indexes): return list([elements[i] for i in indexes]) +def make_n_channel_input(conv: nn.Conv2d, in_channels: int, mode="auto"): + if conv.in_channels == in_channels: + warnings.warn("make_n_channel_input call is spurious") + return conv + + new_conv = nn.Conv2d( + in_channels, + out_channels=conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + groups=conv.groups, + bias=conv.bias is not None, + padding_mode=conv.padding_mode, + ) + + w = conv.weight + if in_channels > conv.in_channels: + w = F.pad(w, pad=[0, 0, 0, in_channels-conv.in_channels], mode='circular') + else: + w = w[:, 0:in_channels, ...] + + new_conv.weight = nn.Parameter(w[:, 0:1, ...], requires_grad=True) + return new_conv + + class EncoderModule(nn.Module): def __init__(self, channels: List[int], strides: List[int], layers: List[int]): super().__init__() @@ -41,9 +73,18 @@ def output_filters(self) -> List[int]: return self._output_filters @property - def encoder_layers(self): + def encoder_layers(self) -> List[nn.Module]: raise NotImplementedError def set_trainable(self, trainable): for param in self.parameters(): param.requires_grad = bool(trainable) + + def change_input_channels(self, input_channels: int, mode="auto"): + """ + Change number of channels expected in the input tensor. By default, + all encoders assume 3-channel image in BCHW notation with C=3. + This method changes first convolution to have user-defined number of + channels as input. + """ + raise NotImplementedError diff --git a/pytorch_toolbelt/modules/encoders/densenet.py b/pytorch_toolbelt/modules/encoders/densenet.py index 9583a04e6..9098b38b4 100644 --- a/pytorch_toolbelt/modules/encoders/densenet.py +++ b/pytorch_toolbelt/modules/encoders/densenet.py @@ -1,9 +1,10 @@ +from collections import OrderedDict from typing import List from torch import nn from torchvision.models import densenet121, densenet161, densenet169, densenet201, DenseNet -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input __all__ = ["DenseNetEncoder", "DenseNet121Encoder", "DenseNet169Encoder", "DenseNet161Encoder", "DenseNet201Encoder"] @@ -21,7 +22,15 @@ def except_pool(block: nn.Module): del block.pool return block - self.layer0 = nn.Sequential(densenet.features.conv0, densenet.features.norm0, densenet.features.relu0) + self.layer0 = nn.Sequential( + OrderedDict( + [ + ("conv0", densenet.features.conv0), + ("bn0", densenet.features.norm0), + ("act0", densenet.features.relu0), + ] + ) + ) self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) self.pool0 = self.avg_pool if first_avg_pool else densenet.features.pool0 @@ -67,6 +76,9 @@ def forward(self, x): # Return only features that were requested return _take(output_features, self._layers) + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv0 = make_n_channel_input(self.layer0.conv0, input_channels, mode) + class DenseNet121Encoder(DenseNetEncoder): def __init__(self, layers=None, pretrained=True, memory_efficient=False, first_avg_pool=False): diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index a639bcb93..c05ec1e69 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -26,7 +26,7 @@ def forward(self, x): class HRNetV2Encoder48(EncoderModule): def __init__(self, pretrained=False): super().__init__([720], [4], [0]) - self.hrnet = hrnetv2(pretrained=False) + self.hrnet = hrnetv2(width=48, pretrained=False) def forward(self, x): return self.hrnet(x) diff --git a/pytorch_toolbelt/modules/encoders/inception.py b/pytorch_toolbelt/modules/encoders/inception.py index d6c91122c..90a2de8e9 100644 --- a/pytorch_toolbelt/modules/encoders/inception.py +++ b/pytorch_toolbelt/modules/encoders/inception.py @@ -1,4 +1,4 @@ -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input from ..backbone.inceptionv4 import inceptionv4 __all__ = ["InceptionV4Encoder"] @@ -37,3 +37,6 @@ def forward(self, x): @property def encoder_layers(self): return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] + + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0[0] = make_n_channel_input(self.layer0[0], input_channels, mode) diff --git a/pytorch_toolbelt/modules/encoders/resnet.py b/pytorch_toolbelt/modules/encoders/resnet.py index 33a56eb0b..3c86a2552 100644 --- a/pytorch_toolbelt/modules/encoders/resnet.py +++ b/pytorch_toolbelt/modules/encoders/resnet.py @@ -8,7 +8,7 @@ from torch import nn from torchvision.models import resnet50, resnet34, resnet18, resnet101, resnet152 -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input __all__ = [ "ResnetEncoder", @@ -26,7 +26,9 @@ def __init__(self, resnet, filters, strides, layers=None): layers = [1, 2, 3, 4] super().__init__(filters, strides, layers) - self.layer0 = nn.Sequential(OrderedDict([("conv1", resnet.conv1), ("bn1", resnet.bn1), ("relu", resnet.relu)])) + self.layer0 = nn.Sequential(OrderedDict([("conv0", resnet.conv1), + ("bn0", resnet.bn1), + ("act0", resnet.relu)])) self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 @@ -53,6 +55,9 @@ def forward(self, x): # Return only features that were requested return _take(output_features, self._layers) + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv0 = make_n_channel_input(self.layer0.conv0, input_channels, mode) + class Resnet18Encoder(ResnetEncoder): def __init__(self, pretrained=True, layers=None): diff --git a/pytorch_toolbelt/modules/encoders/seresnet.py b/pytorch_toolbelt/modules/encoders/seresnet.py index b18dec903..0de69c3f7 100644 --- a/pytorch_toolbelt/modules/encoders/seresnet.py +++ b/pytorch_toolbelt/modules/encoders/seresnet.py @@ -3,7 +3,7 @@ Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. """ -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input from ..backbone.senet import ( SENet, @@ -76,6 +76,9 @@ def forward(self, x): # Return only features that were requested return _take(output_features, self._layers) + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + class SEResnet50Encoder(SEResnetEncoder): def __init__(self, pretrained=True, layers=None): From ea6b0a70bb8478b0c1b8db683ed1d024a49f6f32 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 16 Nov 2019 21:19:23 +0200 Subject: [PATCH 07/79] Global code cleanup/refactoring --- CREDITS.md | 7 + pytorch_toolbelt/__init__.py | 2 +- pytorch_toolbelt/modules/__init__.py | 3 +- pytorch_toolbelt/modules/abn.py | 101 ----------- .../modules/activated_batch_norm.py | 168 ++++++++++++++++++ .../{agn.py => activated_group_norm.py} | 35 ++-- pytorch_toolbelt/modules/activations.py | 136 +++++++++----- .../modules/backbone/efficient_net.py | 4 +- .../modules/backbone/wider_resnet.py | 3 +- pytorch_toolbelt/modules/decoders/fpn.py | 6 +- pytorch_toolbelt/modules/decoders/fpn_cat.py | 18 +- pytorch_toolbelt/modules/decoders/fpn_sum.py | 15 +- pytorch_toolbelt/modules/decoders/unet.py | 107 +---------- pytorch_toolbelt/modules/decoders/unet_v2.py | 2 +- pytorch_toolbelt/modules/encoders/common.py | 2 +- .../modules/encoders/mobilenet.py | 8 +- pytorch_toolbelt/modules/encoders/resnet.py | 4 +- pytorch_toolbelt/modules/encoders/seresnet.py | 4 +- .../modules/encoders/squeezenet.py | 18 +- pytorch_toolbelt/modules/encoders/unet.py | 34 ++-- .../modules/encoders/wide_resnet.py | 12 +- pytorch_toolbelt/modules/pooling.py | 46 ++++- pytorch_toolbelt/modules/unet.py | 93 ++++++++++ tests/test_activations.py | 31 ++++ 24 files changed, 522 insertions(+), 337 deletions(-) create mode 100644 CREDITS.md delete mode 100644 pytorch_toolbelt/modules/abn.py create mode 100644 pytorch_toolbelt/modules/activated_batch_norm.py rename pytorch_toolbelt/modules/{agn.py => activated_group_norm.py} (75%) create mode 100644 pytorch_toolbelt/modules/unet.py create mode 100644 tests/test_activations.py diff --git a/CREDITS.md b/CREDITS.md new file mode 100644 index 000000000..e6e5f87c7 --- /dev/null +++ b/CREDITS.md @@ -0,0 +1,7 @@ +This file contains links to repositories, source code of which may be partially used in this repository. Mind giving them kudos on GitHub! + +1. https://github.com/Cadene/pretrained-models.pytorch +1. https://blog.ceshine.net/post/pytorch-memory-swish/ +1. https://github.com/digantamisra98/Mish +1. https://github.com/mapillary/inplace_abn + diff --git a/pytorch_toolbelt/__init__.py b/pytorch_toolbelt/__init__.py index 5f013188d..348e738b4 100644 --- a/pytorch_toolbelt/__init__.py +++ b/pytorch_toolbelt/__init__.py @@ -1,3 +1,3 @@ from __future__ import absolute_import -__version__ = "0.2.2-alpha" +__version__ = "0.3.0" diff --git a/pytorch_toolbelt/modules/__init__.py b/pytorch_toolbelt/modules/__init__.py index bc41cd69e..37d819a7b 100644 --- a/pytorch_toolbelt/modules/__init__.py +++ b/pytorch_toolbelt/modules/__init__.py @@ -1,6 +1,7 @@ from __future__ import absolute_import -from .abn import * +from .activated_batch_norm import * +from .activated_group_norm import * from .dsconv import * from .fpn import * from .hypercolumn import * diff --git a/pytorch_toolbelt/modules/abn.py b/pytorch_toolbelt/modules/abn.py deleted file mode 100644 index edc6c6d03..000000000 --- a/pytorch_toolbelt/modules/abn.py +++ /dev/null @@ -1,101 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as functional - -from .activations import ( - ACT_LEAKY_RELU, - ACT_NONE, - ACT_HARD_SIGMOID, - ACT_HARD_SWISH, - ACT_SWISH, - ACT_SELU, - ACT_ELU, - ACT_RELU6, - ACT_RELU, - hard_swish, - hard_sigmoid, - swish, -) - -__all__ = ["ABN"] - - -class ABN(nn.Module): - """Activated Batch Normalization - This gathers a `BatchNorm2d` and an activation function in a single module - """ - - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): - """Create an Activated Batch Normalization module - Parameters - ---------- - num_features : int - Number of feature channels in the input and output. - eps : float - Small constant to prevent numerical issues. - momentum : float - Momentum factor applied to compute running statistics as. - affine : bool - If `True` apply learned scale and shift transformation after normalization. - activation : str - Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. - slope : float - Negative slope for the `leaky_relu` activation. - """ - super(ABN, self).__init__() - self.num_features = num_features - self.affine = affine - self.eps = eps - self.momentum = momentum - self.activation = activation - self.slope = slope - if self.affine: - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) - else: - self.register_parameter("weight", None) - self.register_parameter("bias", None) - self.register_buffer("running_mean", torch.zeros(num_features)) - self.register_buffer("running_var", torch.ones(num_features)) - self.reset_parameters() - - def reset_parameters(self): - nn.init.zeros_(self.running_mean) - nn.init.ones_(self.running_var) - if self.affine: - nn.init.ones_(self.weight) - nn.init.zeros_(self.bias) - - def forward(self, x): - x = functional.batch_norm( - x, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps - ) - - if self.activation == ACT_RELU: - return functional.relu(x, inplace=True) - elif self.activation == ACT_RELU6: - return functional.relu6(x, inplace=True) - elif self.activation == ACT_LEAKY_RELU: - return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) - elif self.activation == ACT_ELU: - return functional.elu(x, inplace=True) - elif self.activation == ACT_SELU: - return functional.selu(x, inplace=True) - elif self.activation == ACT_SWISH: - return swish(x) - elif self.activation == ACT_HARD_SWISH: - return hard_swish(x, inplace=True) - elif self.activation == ACT_HARD_SIGMOID: - return hard_sigmoid(x, inplace=True) - elif self.activation == ACT_NONE: - return x - else: - raise KeyError(self.activation) - - def __repr__(self): - rep = "{name}({num_features}, eps={eps}, momentum={momentum}," " affine={affine}, activation={activation}" - if self.activation == "leaky_relu": - rep += ", slope={slope})" - else: - rep += ")" - return rep.format(name=self.__class__.__name__, **self.__dict__) diff --git a/pytorch_toolbelt/modules/activated_batch_norm.py b/pytorch_toolbelt/modules/activated_batch_norm.py new file mode 100644 index 000000000..9897fb680 --- /dev/null +++ b/pytorch_toolbelt/modules/activated_batch_norm.py @@ -0,0 +1,168 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter +from torch.nn import init +from .activations import * + +__all__ = ["ABN"] + + +class ABN(nn.Module): + _version = 2 + __constants__ = [ + "track_running_stats", + "momentum", + "eps", + "weight", + "bias", + "running_mean", + "running_var", + "num_batches_tracked", + "num_features", + "affine", + ] + + """Activated Batch Normalization + This gathers a `BatchNorm` and an activation function in a single module + """ + + def __init__( + self, + num_features: int, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + activation="leaky_relu", + slope=0.01, + ): + """Create an Activated Batch Normalization module + Parameters + ---------- + num_features : int + Number of feature channels in the input and output. + eps : float + Small constant to prevent numerical issues. + momentum : float + Momentum factor applied to compute running statistics as. + affine : bool + If `True` apply learned scale and shift transformation after normalization. + activation : str + Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. + slope : float + Negative slope for the `leaky_relu` activation. + """ + super(ABN, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + if self.track_running_stats: + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + self.register_buffer("num_batches_tracked", torch.tensor(0, dtype=torch.long)) + else: + self.register_parameter("running_mean", None) + self.register_parameter("running_var", None) + self.register_parameter("num_batches_tracked", None) + self.reset_parameters() + + self.activation = activation + self.slope = slope + + self.reset_parameters() + + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() + + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + + # exponential_average_factor is set to self.momentum + # (when it is available) only so that if gets updated + # in ONNX graph when this node is exported to ONNX. + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum + + if self.training and self.track_running_stats: + # TODO: if statement only here to tell the jit to skip emitting this when it is None + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float(self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum + + x = F.batch_norm( + input, + self.running_mean, + self.running_var, + self.weight, + self.bias, + self.training or not self.track_running_stats, + exponential_average_factor, + self.eps, + ) + + if self.activation == ACT_RELU: + return F.relu(x, inplace=True) + elif self.activation == ACT_RELU6: + return F.relu6(x, inplace=True) + elif self.activation == ACT_LEAKY_RELU: + return F.leaky_relu(x, negative_slope=self.slope, inplace=True) + elif self.activation == ACT_ELU: + return F.elu(x, inplace=True) + elif self.activation == ACT_SELU: + return F.selu(x, inplace=True) + elif self.activation == ACT_SWISH: + return swish(x) + elif self.activation == ACT_MISH: + return mish(x) + elif self.activation == ACT_HARD_SWISH: + return hard_swish(x, inplace=True) + elif self.activation == ACT_HARD_SIGMOID: + return hard_sigmoid(x, inplace=True) + elif self.activation == ACT_NONE: + return x + else: + raise KeyError(self.activation) + + def __repr__(self): + return ( + "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, " + "track_running_stats={track_running_stats}, activation={activation}".format(**self.__dict__) + ) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + + if (version is None or version < 2) and self.track_running_stats: + # at version 2: added num_batches_tracked buffer + # this should have a default value of 0 + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key not in state_dict: + state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/pytorch_toolbelt/modules/agn.py b/pytorch_toolbelt/modules/activated_group_norm.py similarity index 75% rename from pytorch_toolbelt/modules/agn.py rename to pytorch_toolbelt/modules/activated_group_norm.py index d23b116f1..c0dbed846 100644 --- a/pytorch_toolbelt/modules/agn.py +++ b/pytorch_toolbelt/modules/activated_group_norm.py @@ -1,21 +1,8 @@ import torch import torch.nn as nn -import torch.nn.functional as functional +import torch.nn.functional as F -from .activations import ( - ACT_LEAKY_RELU, - ACT_NONE, - ACT_HARD_SIGMOID, - ACT_HARD_SWISH, - ACT_SWISH, - ACT_SELU, - ACT_ELU, - ACT_RELU6, - ACT_RELU, - hard_swish, - hard_sigmoid, - swish, -) +from .activations import * __all__ = ["AGN"] @@ -54,8 +41,8 @@ def __init__( self.activation = activation self.slope = slope - self.weight = nn.Parameter(torch.ones(num_features)) - self.bias = nn.Parameter(torch.zeros(num_features)) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) self.reset_parameters() def reset_parameters(self): @@ -63,20 +50,22 @@ def reset_parameters(self): nn.init.zeros_(self.bias) def forward(self, x): - x = functional.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) if self.activation == ACT_RELU: - return functional.relu(x, inplace=True) + return F.relu(x, inplace=True) elif self.activation == ACT_RELU6: - return functional.relu6(x, inplace=True) + return F.relu6(x, inplace=True) elif self.activation == ACT_LEAKY_RELU: - return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) + return F.leaky_relu(x, negative_slope=self.slope, inplace=True) elif self.activation == ACT_ELU: - return functional.elu(x, inplace=True) + return F.elu(x, inplace=True) elif self.activation == ACT_SELU: - return functional.selu(x, inplace=True) + return F.selu(x, inplace=True) elif self.activation == ACT_SWISH: return swish(x) + elif self.activation == ACT_MISH: + return mish(x) elif self.activation == ACT_HARD_SWISH: return hard_swish(x, inplace=True) elif self.activation == ACT_HARD_SIGMOID: diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index b0feec1d4..600be993f 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -1,7 +1,9 @@ from functools import partial +import torch from torch import nn from torch.nn import functional as F +from .identity import Identity __all__ = [ "ACT_ELU", @@ -13,6 +15,8 @@ "ACT_RELU6", "ACT_SELU", "ACT_SWISH", + "ACT_MISH", + "mish", "swish", "hard_sigmoid", "hard_swish", @@ -20,7 +24,7 @@ "HardSwish", "Swish", "get_activation_module", - "sanitize_activation_name" + "sanitize_activation_name", ] # Activation names @@ -31,12 +35,72 @@ ACT_NONE = "none" ACT_SELU = "selu" ACT_SWISH = "swish" +ACT_MISH = "mish" ACT_HARD_SWISH = "hard_swish" ACT_HARD_SIGMOID = "hard_sigmoid" +class SwishFunction(torch.autograd.Function): + """ + Memory efficient Swish implementation. + + Credit: https://blog.ceshine.net/post/pytorch-memory-swish/ + """ + + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +def mish(input): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + See additional documentation for mish class. + Credit: https://github.com/digantamisra98/Mish + """ + return input * torch.tanh(F.softplus(input)) + + +class Mish(nn.Module): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + Examples: + >>> m = Mish() + >>> input = torch.randn(2) + >>> output = m(input) + + Credit: https://github.com/digantamisra98/Mish + """ + + def __init__(self): + """ + Init method. + """ + super().__init__() + + def forward(self, input): + """ + Forward pass of the function. + """ + return mish(input) + + def swish(x): - return x * x.sigmoid() + return SwishFunction.apply(x) def hard_sigmoid(x, inplace=False): @@ -57,11 +121,8 @@ def forward(self, x): class Swish(nn.Module): - def __init__(self, inplace=False): - super(Swish, self).__init__() - - def forward(self, x): - return swish(x) + def forward(self, input_tensor): + return SwishFunction.apply(input_tensor) class HardSwish(nn.Module): @@ -74,47 +135,30 @@ def forward(self, x): def get_activation_module(activation_name: str, **kwargs) -> nn.Module: - if activation_name.lower() == "relu": - return partial(nn.ReLU, **kwargs) - - if activation_name.lower() == "relu6": - return partial(nn.ReLU6, **kwargs) - - if activation_name.lower() == "leaky_relu": - return partial(nn.LeakyReLU, **kwargs) - - if activation_name.lower() == "elu": - return partial(nn.ELU, **kwargs) - - if activation_name.lower() == "selu": - return partial(nn.SELU, **kwargs) - - if activation_name.lower() == "celu": - return partial(nn.CELU, **kwargs) - - if activation_name.lower() == "glu": - return partial(nn.GLU, **kwargs) - - if activation_name.lower() == "prelu": - return partial(nn.PReLU, **kwargs) - - if activation_name.lower() == "hard_sigmoid": - return partial(HardSigmoid, **kwargs) - - if activation_name.lower() == "swish": - return partial(Swish, **kwargs) - - if activation_name.lower() == "hard_swish": - return partial(HardSwish, **kwargs) - - raise ValueError(f"Activation '{activation_name}' is not supported") - - -def sanitize_activation_name(activation_name): + ACTIVATIONS = { + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "leaky_rely": nn.LeakyReLU, + "elu": nn.ELU, + "selu": nn.SELU, + "celu": nn.CELU, + "glu": nn.GLU, + "prelu": nn.PReLU, + "swish": Swish, + "mish": Mish, + "hard_sigmoid": HardSigmoid, + "hard_swish": HardSwish, + "none": Identity, + } + + return ACTIVATIONS[activation_name.lower()](**kwargs) + + +def sanitize_activation_name(activation_name: str) -> str: """ Return reasonable activation name for initialization in `kaiming_uniform_` for hipster activations """ - if activation_name in {"swish", "mish"}: - return "leaky_relu" + if activation_name in {ACT_MISH, ACT_SWISH}: + return ACT_LEAKY_RELU return activation_name diff --git a/pytorch_toolbelt/modules/backbone/efficient_net.py b/pytorch_toolbelt/modules/backbone/efficient_net.py index ba5245bd6..c09a92971 100644 --- a/pytorch_toolbelt/modules/backbone/efficient_net.py +++ b/pytorch_toolbelt/modules/backbone/efficient_net.py @@ -8,8 +8,8 @@ from torch.nn import functional as F from torch.nn.init import kaiming_normal_, kaiming_uniform_ -from ..abn import ABN -from ..agn import AGN +from ..activated_batch_norm import ABN +from ..activated_group_norm import AGN from ..activations import ACT_HARD_SWISH, sanitize_activation_name from ..scse import SpatialGate2d diff --git a/pytorch_toolbelt/modules/backbone/wider_resnet.py b/pytorch_toolbelt/modules/backbone/wider_resnet.py index 9647894a4..393d34ed1 100644 --- a/pytorch_toolbelt/modules/backbone/wider_resnet.py +++ b/pytorch_toolbelt/modules/backbone/wider_resnet.py @@ -2,7 +2,7 @@ from functools import partial import torch -from ..abn import ABN +from ..activated_batch_norm import ABN from ..pooling import GlobalAvgPool2d from torch import nn @@ -282,4 +282,3 @@ def wider_resnet_20_a2(num_classes=0, norm_act=ABN): def wider_resnet_38_a2(num_classes=0, norm_act=ABN): return WiderResNetA2(structure=[3, 3, 6, 3, 1, 1], norm_act=norm_act, classes=num_classes) - diff --git a/pytorch_toolbelt/modules/decoders/fpn.py b/pytorch_toolbelt/modules/decoders/fpn.py index d04030ec8..1a58e4b01 100644 --- a/pytorch_toolbelt/modules/decoders/fpn.py +++ b/pytorch_toolbelt/modules/decoders/fpn.py @@ -1,4 +1,6 @@ from torch import nn +from typing import List + from .common import DecoderModule from ..fpn import FPNBottleneckBlock, UpsampleAdd, FPNPredictionBlock @@ -6,7 +8,7 @@ class FPNDecoder(DecoderModule): def __init__( self, - features, + features: List[int], bottleneck=FPNBottleneckBlock, upsample_add_block=UpsampleAdd, prediction_block=FPNPredictionBlock, @@ -18,7 +20,7 @@ def __init__( ): """ - :param features: + :param features: Number of features for feature maps from encoder :param prediction_block: :param bottleneck: :param fpn_features: diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 5cd7283a6..779a319d9 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -4,7 +4,7 @@ from .common import SegmentationDecoderModule from .fpn import FPNDecoder -from ..abn import ABN +from ..activated_batch_norm import ABN from ..fpn import FPNFuse, UpsampleAdd __all__ = ["FPNCatDecoder"] @@ -38,14 +38,14 @@ class FPNCatDecoder(SegmentationDecoderModule): """ def __init__( - self, - feature_maps: List[int], - num_classes: int, - fpn_channels=128, - dropout=0.0, - abn_block=ABN, - upsample_add=UpsampleAdd, - prediction_block=FPNSumDecoderBlock, + self, + feature_maps: List[int], + num_classes: int, + fpn_channels=128, + dropout=0.0, + abn_block=ABN, + upsample_add=UpsampleAdd, + prediction_block=FPNSumDecoderBlock, ): super().__init__() diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index ab5126ecb..5bd927a84 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -2,7 +2,7 @@ from typing import List, Tuple import torch -from ..abn import ABN +from ..activated_batch_norm import ABN from ..identity import Identity from .common import SegmentationDecoderModule @@ -115,9 +115,16 @@ class FPNSumDecoder(SegmentationDecoderModule): """ - def __init__(self, feature_maps: List[int], num_classes: int, fpn_channels=256, dropout=0.0, abn_block=ABN, - center_block=FPNSumCenterBlock, - decoder_block=FPNSumDecoderBlock): + def __init__( + self, + feature_maps: List[int], + num_classes: int, + fpn_channels=256, + dropout=0.0, + abn_block=ABN, + center_block=FPNSumCenterBlock, + decoder_block=FPNSumDecoderBlock, + ): super().__init__() self.center = center_block( diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index b8137feb2..7fb6dd364 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -4,112 +4,11 @@ import torch.nn.functional as F from torch import nn -from ..abn import ABN +from ..activated_batch_norm import ABN from .common import DecoderModule +from ..unet import UnetCentralBlock, UnetDecoderBlock -__all__ = ["UnetCentralBlock", "UnetDecoderBlock", "UNetDecoder"] - - -class UnetCentralBlock(nn.Module): - def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): - super().__init__() - self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False, **kwargs) - self.bn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) - self.bn2 = abn_block(out_filters) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - return x - - -class UnetDecoderBlock(nn.Module): - """ - """ - - def __init__( - self, - in_dec_filters, - in_enc_filters, - out_filters, - abn_block=ABN, - pre_dropout_rate=0.0, - post_dropout_rate=0.0, - **kwargs, - ): - super(UnetDecoderBlock, self).__init__() - - self.pre_drop = nn.Dropout(pre_dropout_rate, inplace=True) - - self.conv1 = nn.Conv2d( - in_dec_filters + in_enc_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False, **kwargs - ) - self.bn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False, **kwargs) - self.bn2 = abn_block(out_filters) - - self.post_drop = nn.Dropout(post_dropout_rate, inplace=True) - - def forward(self, x, enc): - lat_size = enc.size()[2:] - x = F.interpolate(x, size=lat_size, mode="bilinear", align_corners=False) - - x = torch.cat([x, enc], 1) - - x = self.pre_drop(x) - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - x = self.post_drop(x) - return x - - -# class UNetDecoder(DecoderModule): -# def __init__( -# self, features, start_features: int, dilation_factors=[1, 1, 1, 1], **kwargs -# ): -# super().__init__() -# decoder_features = start_features -# reversed_features = list(reversed(features)) -# -# output_filters = [decoder_features] -# self.center = UnetCentralBlock(reversed_features[0], decoder_features) -# -# if dilation_factors is None: -# dilation_factors = [1] * len(reversed_features) -# -# blocks = [] -# for block_index, encoder_features in enumerate(reversed_features): -# blocks.append( -# UnetDecoderBlock( -# output_filters[-1], -# encoder_features, -# decoder_features, -# dilation=dilation_factors[block_index], -# ) -# ) -# output_filters.append(decoder_features) -# # print(block_index, decoder_features, encoder_features, decoder_features) -# decoder_features = decoder_features // 2 -# -# self.blocks = nn.ModuleList(blocks) -# self.output_filters = output_filters -# -# def forward(self, features): -# reversed_features = list(reversed(features)) -# decoder_outputs = [self.center(reversed_features[0])] -# -# for block_index, decoder_block, encoder_output in zip( -# range(len(self.blocks)), self.blocks, reversed_features -# ): -# # print(block_index, decoder_outputs[-1].size(), encoder_output.size()) -# decoder_outputs.append(decoder_block(decoder_outputs[-1], encoder_output)) -# -# return decoder_outputs +__all__ = ["UNetDecoder"] class UNetDecoder(DecoderModule): diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index a4698c0a1..33068226f 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch import nn -from ..abn import ABN +from ..activated_batch_norm import ABN from .common import DecoderModule __all__ = ["UNetDecoderV2", "UnetCentralBlockV2", "UnetDecoderBlockV2"] diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index 91e4bfae2..829ffee45 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -36,7 +36,7 @@ def make_n_channel_input(conv: nn.Conv2d, in_channels: int, mode="auto"): w = conv.weight if in_channels > conv.in_channels: - w = F.pad(w, pad=[0, 0, 0, in_channels-conv.in_channels], mode='circular') + w = F.pad(w, pad=[0, 0, 0, in_channels - conv.in_channels], mode="circular") else: w = w[:, 0:in_channels, ...] diff --git a/pytorch_toolbelt/modules/encoders/mobilenet.py b/pytorch_toolbelt/modules/encoders/mobilenet.py index f94660bf5..769770a61 100644 --- a/pytorch_toolbelt/modules/encoders/mobilenet.py +++ b/pytorch_toolbelt/modules/encoders/mobilenet.py @@ -1,4 +1,4 @@ -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input from ..backbone.mobilenet import MobileNetV2 from ..backbone.mobilenetv3 import MobileNetV3 @@ -23,6 +23,9 @@ def __init__(self, layers=[2, 3, 5, 7], activation="relu6"): def encoder_layers(self): return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4, self.layer5, self.layer6, self.layer7] + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + class MobilenetV3Encoder(EncoderModule): def __init__(self, input_channels=3, small=False, drop_prob=0.0, layers=[1, 2, 3, 4]): @@ -67,3 +70,6 @@ def forward(self, x): @property def encoder_layers(self): return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] + + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) diff --git a/pytorch_toolbelt/modules/encoders/resnet.py b/pytorch_toolbelt/modules/encoders/resnet.py index 3c86a2552..e75f61a60 100644 --- a/pytorch_toolbelt/modules/encoders/resnet.py +++ b/pytorch_toolbelt/modules/encoders/resnet.py @@ -26,9 +26,7 @@ def __init__(self, resnet, filters, strides, layers=None): layers = [1, 2, 3, 4] super().__init__(filters, strides, layers) - self.layer0 = nn.Sequential(OrderedDict([("conv0", resnet.conv1), - ("bn0", resnet.bn1), - ("act0", resnet.relu)])) + self.layer0 = nn.Sequential(OrderedDict([("conv0", resnet.conv1), ("bn0", resnet.bn1), ("act0", resnet.relu)])) self.maxpool = resnet.maxpool self.layer1 = resnet.layer1 diff --git a/pytorch_toolbelt/modules/encoders/seresnet.py b/pytorch_toolbelt/modules/encoders/seresnet.py index 0de69c3f7..370abfd5b 100644 --- a/pytorch_toolbelt/modules/encoders/seresnet.py +++ b/pytorch_toolbelt/modules/encoders/seresnet.py @@ -2,6 +2,8 @@ Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. """ +from torch import Tensor +from typing import List from .common import EncoderModule, _take, make_n_channel_input @@ -60,7 +62,7 @@ def output_strides(self): def output_filters(self): return self._output_filters - def forward(self, x): + def forward(self, x: Tensor) -> List[Tensor]: input = x output_features = [] for layer in self.encoder_layers: diff --git a/pytorch_toolbelt/modules/encoders/squeezenet.py b/pytorch_toolbelt/modules/encoders/squeezenet.py index 7e5814a4a..ba09a7a7d 100644 --- a/pytorch_toolbelt/modules/encoders/squeezenet.py +++ b/pytorch_toolbelt/modules/encoders/squeezenet.py @@ -1,7 +1,9 @@ +from collections import OrderedDict + from torch import nn from torchvision.models import squeezenet1_1 -from .common import EncoderModule +from .common import EncoderModule, make_n_channel_input __all__ = ["SqueezenetEncoder"] @@ -15,10 +17,13 @@ def __init__(self, pretrained=True, layers=[1, 2, 3]): # nn.ReLU(inplace=True), # nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True), self.layer0 = nn.Sequential( - squeezenet.features[0], - squeezenet.features[1], - # squeezenet.features[2], - nn.MaxPool2d(kernel_size=3, stride=2, padding=1), + OrderedDict( + [ + ("conv1", squeezenet.features[0]), + ("relu1", nn.ReLU(inplace=True)), + ("pool1", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), + ] + ) ) # Fire(64, 16, 64, 64), @@ -52,3 +57,6 @@ def __init__(self, pretrained=True, layers=[1, 2, 3]): @property def encoder_layers(self): return [self.layer0, self.layer1, self.layer2, self.layer3] + + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) diff --git a/pytorch_toolbelt/modules/encoders/unet.py b/pytorch_toolbelt/modules/encoders/unet.py index fd04cf97e..5f810c7ca 100644 --- a/pytorch_toolbelt/modules/encoders/unet.py +++ b/pytorch_toolbelt/modules/encoders/unet.py @@ -1,30 +1,15 @@ -from torch import nn +from .common import EncoderModule, make_n_channel_input +from ..activated_batch_norm import ABN +from ..unet import UnetEncoderBlock - -from ..abn import ABN - -from .common import EncoderModule, _take - -__all__ = ["UnetEncoderBlock", "UnetEncoder"] - - -class UnetEncoderBlock(nn.Module): - def __init__(self, in_dec_filters, out_filters, abn_block=ABN, stride=1, **kwargs): - super().__init__() - self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=1, bias=False, **kwargs) - self.bn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, stride=stride, bias=False, **kwargs) - self.bn2 = abn_block(out_filters) - - def forward(self, x): - x = self.conv1(x) - x = self.bn1(x) - x = self.conv2(x) - x = self.bn2(x) - return x +__all__ = ["UnetEncoder"] class UnetEncoder(EncoderModule): + """ + Vanilla U-Net encoder + """ + def __init__(self, input_channels=3, features=32, num_layers=4, growth_factor=2, abn_block=ABN): feature_maps = [features * growth_factor * (i + 1) for i in range(num_layers)] strides = [2 * (i + 1) for i in range(num_layers)] @@ -41,3 +26,6 @@ def __init__(self, input_channels=3, features=32, num_layers=4, growth_factor=2, @property def encoder_layers(self): return [self[f"layer{layer}"] for layer in range(self.num_layers)] + + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py index d9d611ff3..fa733438e 100644 --- a/pytorch_toolbelt/modules/encoders/wide_resnet.py +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -1,9 +1,9 @@ from typing import List -from ..abn import ABN +from ..activated_batch_norm import ABN from ..backbone.wider_resnet import WiderResNet, WiderResNetA2 -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input __all__ = [ "WiderResnetEncoder", @@ -21,7 +21,7 @@ class WiderResnetEncoder(EncoderModule): def __init__(self, structure: List[int], layers: List[int], norm_act=ABN): super().__init__([64, 128, 256, 512, 1024, 2048, 4096], [1, 2, 4, 8, 16, 32, 32], layers) - encoder = WiderResNet(structure, classes=0, norm_act=norm_act) + encoder: WiderResNet = WiderResNet(structure, classes=0, norm_act=norm_act) self.layer0 = encoder.mod1 self.layer1 = encoder.mod2 self.layer2 = encoder.mod3 @@ -67,6 +67,9 @@ def forward(self, input): # Return only features that were requested return _take(output_features, self._layers) + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + class WiderResnet16Encoder(WiderResnetEncoder): def __init__(self, layers=None): @@ -136,6 +139,9 @@ def forward(self, input): # Return only features that were requested return _take(output_features, self._layers) + def change_input_channels(self, input_channels: int, mode="auto"): + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + class WiderResnet16A2Encoder(WiderResnetA2Encoder): def __init__(self, layers=None): diff --git a/pytorch_toolbelt/modules/pooling.py b/pytorch_toolbelt/modules/pooling.py index 47b7d405c..755c90846 100644 --- a/pytorch_toolbelt/modules/pooling.py +++ b/pytorch_toolbelt/modules/pooling.py @@ -5,7 +5,15 @@ import torch.nn as nn import torch.nn.functional as F -__all__ = ["GlobalAvgPool2d", "GlobalMaxPool2d", "GWAP", "RMSPool", "MILCustomPoolingModule"] +__all__ = [ + "GlobalAvgPool2d", + "GlobalMaxPool2d", + "GlobalWeightedAvgPool2d", + "GWAP", + "RMSPool", + "MILCustomPoolingModule", + "GlobalRankPooling", +] class GlobalAvgPool2d(nn.Module): @@ -34,15 +42,16 @@ def forward(self, x): return x -class GWAP(nn.Module): +class GlobalWeightedAvgPool2d(nn.Module): """ Global Weighted Average Pooling from paper "Global Weighted Average Pooling Bridges Pixel-level Localization and Image-level Classification" """ - def __init__(self, features): + def __init__(self, features: int, flatten=False): super().__init__() self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True) + self.flatten = flatten def fscore(self, x): m = self.conv(x) @@ -57,10 +66,13 @@ def forward(self, x): x = self.fscore(x) x = self.norm(x) x = x * input_x - x = x.sum(dim=[2, 3], keepdim=True) + x = x.sum(dim=[2, 3], keepdim=not self.flatten) return x +GWAP = GlobalWeightedAvgPool2d + + class RMSPool(nn.Module): """ Root mean square pooling @@ -93,3 +105,29 @@ def forward(self, x): loss = self.classifier(x) logits = torch.sum(weight * loss, dim=[2, 3]) / (torch.sum(weight, dim=[2, 3]) + 1e-6) return logits + + +class GlobalRankPooling(nn.Module): + """ + https://arxiv.org/abs/1704.02112 + """ + + def __init__(self, num_features, spatial_size, flatten=False): + super().__init__() + self.conv = nn.Conv1d(num_features, num_features, spatial_size, groups=num_features) + self.flatten = flatten + + def forward(self, x: torch.Tensor): + spatial_size = x.size(2) * x.size(3) + assert spatial_size == self.conv.kernel_size[0], ( + f"Expected spatial size {self.conv.kernel_size[0]}, " f"got {x.size(2)}x{x.size(3)}" + ) + + x = x.view(x.size(0), x.size(1), -1) # Flatten spatial dimensions + x_sorted, index = x.topk(spatial_size, dim=2) + + x = self.conv(x_sorted) # [B, C, 1] + + if self.flatten: + x = x.squeeze(2) + return x diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py new file mode 100644 index 000000000..c95141791 --- /dev/null +++ b/pytorch_toolbelt/modules/unet.py @@ -0,0 +1,93 @@ +import torch +from torch import nn +import torch.nn.functional as F + +from .activated_batch_norm import ABN + +__all__ = ["UnetEncoderBlock", "UnetCentralBlock", "UnetDecoderBlock"] + + +class UnetEncoderBlock(nn.Module): + def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): + super().__init__() + self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False, **kwargs) + self.abn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, stride=1, bias=False, **kwargs) + self.abn2 = abn_block(out_filters) + + def forward(self, x): + x = self.conv1(x) + x = self.abn1(x) + x = self.conv2(x) + x = self.abn2(x) + return x + + +class UnetCentralBlock(nn.Module): + def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): + super().__init__() + self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False, **kwargs) + self.abn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) + self.abn2 = abn_block(out_filters) + + def forward(self, x): + x = self.conv1(x) + x = self.abn1(x) + x = self.conv2(x) + x = self.abn2(x) + return x + + +class UnetDecoderBlock(nn.Module): + """ + """ + + def __init__( + self, + in_dec_filters, + in_enc_filters, + out_filters, + abn_block=ABN, + pre_dropout_rate=0.0, + post_dropout_rate=0.0, + scale_factor=None, + scale_mode="nearest", + align_corners=None, + **kwargs, + ): + super(UnetDecoderBlock, self).__init__() + + self.pre_drop = nn.Dropout2d(pre_dropout_rate, inplace=True) + + self.conv1 = nn.Conv2d( + in_dec_filters + in_enc_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs + ) + self.abn1 = abn_block(out_filters) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) + self.abn2 = abn_block(out_filters) + + self.post_drop = nn.Dropout2d(post_dropout_rate, inplace=False) + + self.scale_factor = scale_factor + self.scale_mode = scale_mode + self.align_corners = align_corners + + def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: + if self.scale_factor is not None: + x = F.interpolate( + x, scale_factor=self.scale_factor, mode=self.scale_mode, align_corners=self.align_corners + ) + else: + lat_size = enc.size()[2:] + x = F.interpolate(x, size=lat_size, mode=self.scale_mode, align_corners=self.align_corners) + + x = torch.cat([x, enc], 1) + + x = self.pre_drop(x) + x = self.conv1(x) + x = self.abn1(x) + x = self.conv2(x) + x = self.abn2(x) + x = self.post_drop(x) + return x diff --git a/tests/test_activations.py b/tests/test_activations.py new file mode 100644 index 000000000..8c86456e0 --- /dev/null +++ b/tests/test_activations.py @@ -0,0 +1,31 @@ +import torch +import pytest + +from pytorch_toolbelt.modules.activations import get_activation_module + + +@pytest.mark.parametrize( + "activation_name", + ["none", "relu", "relu6", "leaky_relu", "elu", "selu", "celu", "mish", "swish", "hard_sigmoid", "hard_swish"], +) +def test_activations(activation_name): + act = get_activation_module(activation_name) + x = torch.randn(128).float() + y = act(x) + assert y.dtype == torch.float32 + + +@pytest.mark.parametrize( + "activation_name", + ["none", "relu", "relu6", "leaky_relu", "elu", "selu", "celu", "mish", "swish", "hard_sigmoid", "hard_swish"], +) +def test_activations_cuda(activation_name): + act = get_activation_module(activation_name) + x = torch.randn(128).float().cuda() + y = act(x) + assert y.dtype == torch.float32 + + act = get_activation_module(activation_name) + x = torch.randn(128).half().cuda() + y = act(x) + assert y.dtype == torch.float16 From 5137fd6f55ccccef055761b8a832688c166017b2 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 16 Nov 2019 21:25:38 +0200 Subject: [PATCH 08/79] Fix typo --- pytorch_toolbelt/modules/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index 600be993f..a5f32c5b1 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -138,7 +138,7 @@ def get_activation_module(activation_name: str, **kwargs) -> nn.Module: ACTIVATIONS = { "relu": nn.ReLU, "relu6": nn.ReLU6, - "leaky_rely": nn.LeakyReLU, + "leaky_relu": nn.LeakyReLU, "elu": nn.ELU, "selu": nn.SELU, "celu": nn.CELU, From 682c055e75e2368536c7b6831736395263116c4d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 16 Nov 2019 21:25:54 +0200 Subject: [PATCH 09/79] Do not use pretrained weights in tests --- tests/test_encoders.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 460fed2c6..5017781c2 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -35,10 +35,10 @@ [E.EfficientNetB5Encoder, {}], [E.EfficientNetB6Encoder, {}], [E.EfficientNetB7Encoder, {}], - [E.DenseNet121Encoder, {}], - [E.DenseNet161Encoder, {}], - [E.DenseNet169Encoder, {}], - [E.DenseNet201Encoder, {}], + [E.DenseNet121Encoder, {"pretrained": False}], + [E.DenseNet161Encoder, {"pretrained": False}], + [E.DenseNet169Encoder, {"pretrained": False}], + [E.DenseNet201Encoder, {"pretrained": False}], ], ) @torch.no_grad() @@ -65,7 +65,8 @@ def test_inceptionv4_encoder(): backbone = inceptionv4(pretrained=False) backbone.last_linear = None - net = E.InceptionV4Encoder(backbone, layers=[0, 1, 2, 3, 4]).cuda() + net = E.InceptionV4Encoder(pretrained=False, + layers=[0, 1, 2, 3, 4]).cuda() print(count_parameters(backbone)) print(count_parameters(net)) @@ -82,7 +83,7 @@ def test_inceptionv4_encoder(): def test_densenet(): from torchvision.models import densenet121 - net1 = E.DenseNet121Encoder() + net1 = E.DenseNet121Encoder(pretrained=False) net2 = densenet121(pretrained=False) net2.classifier = None From bd4005002f20edf47a0b2a8ece59b7eb7e6f6d27 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Mon, 18 Nov 2019 16:37:45 +0200 Subject: [PATCH 10/79] Updates for Deeplab decoder --- pytorch_toolbelt/modules/decoders/deeplab.py | 86 ++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 79ae89c81..9b11830b8 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -4,10 +4,96 @@ import torch.nn as nn import torch.nn.functional as F from .common import DecoderModule +from ..activated_batch_norm import ABN +from ..encoders import EncoderModule __all__ = ["DeeplabV3Decoder"] +class ASPPModule(nn.Module): + def __init__(self, inplanes, planes, kernel_size, padding, dilation): + super(ASPPModule, self).__init__() + self.atrous_conv = nn.Conv2d( + inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False + ) + self.bn = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + + self.reset_parameters() + + def forward(self, x): + x = self.atrous_conv(x) + x = self.bn(x) + + return self.relu(x) + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +class ASPP(nn.Module): + def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5): + super(ASPP, self).__init__() + + if output_stride == 32: + dilations = [1, 3, 6, 9] + elif output_stride == 16: + dilations = [1, 6, 12, 18] + elif output_stride == 8: + dilations = [1, 12, 24, 36] + else: + raise NotImplementedError + + self.aspp1 = ASPPModule(inplanes, output_features, 1, padding=0, dilation=dilations[0]) + self.aspp2 = ASPPModule(inplanes, output_features, 3, padding=dilations[1], dilation=dilations[1]) + self.aspp3 = ASPPModule(inplanes, output_features, 3, padding=dilations[2], dilation=dilations[2]) + self.aspp4 = ASPPModule(inplanes, output_features, 3, padding=dilations[3], dilation=dilations[3]) + + self.global_avg_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(inplanes, output_features, 1, stride=1, bias=False), + nn.BatchNorm2d(output_features), + nn.ReLU(inplace=True), + ) + self.conv1 = nn.Conv2d(1280, output_features, 1, bias=False) + self.bn1 = nn.BatchNorm2d(output_features) + self.relu = nn.ReLU(inplace=True) + self.dropout = nn.Dropout(dropout) + self.reset_parameters() + + def forward(self, x): + x1 = self.aspp1(x) + x2 = self.aspp2(x) + x3 = self.aspp3(x) + x4 = self.aspp4(x) + x5 = self.global_avg_pool(x) + x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=False) + x = torch.cat((x1, x2, x3, x4, x5), dim=1) + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + return self.dropout(x) + + def reset_parameters(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + # m.weight.data.normal_(0, math.sqrt(2. / n)) + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + + + class DeeplabV3Decoder(DecoderModule): def __init__(self, feature_maps: List[int], num_classes: int, dropout=0.5): super(DeeplabV3Decoder, self).__init__() From ffd0bd2bbb202e72072d6ed5591cf33fbaf8dc80 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 19 Nov 2019 13:59:25 +0200 Subject: [PATCH 11/79] Fix import of SummaryWriter --- pytorch_toolbelt/utils/catalyst/visualization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index 3163fa461..436c4e83d 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -7,7 +7,7 @@ import torch from catalyst.dl import Callback, RunnerState, CallbackOrder from catalyst.dl.callbacks import TensorboardLogger -from tensorboardX import SummaryWriter +from catalyst.utils.tensorboard import SummaryWriter from pytorch_toolbelt.utils.torch_utils import rgb_image_from_tensor, to_numpy from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image From d2c95dd620a62b859e276bb2f1640f847bc9a79d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 19 Nov 2019 16:49:43 +0200 Subject: [PATCH 12/79] Add more callbacks --- pytorch_toolbelt/utils/catalyst/__init__.py | 1 + pytorch_toolbelt/utils/catalyst/criterions.py | 75 ++++++++++++- pytorch_toolbelt/utils/catalyst/opl.py | 105 ++++++++++++++++++ .../utils/catalyst/visualization.py | 46 ++++++++ 4 files changed, 225 insertions(+), 2 deletions(-) create mode 100644 pytorch_toolbelt/utils/catalyst/opl.py diff --git a/pytorch_toolbelt/utils/catalyst/__init__.py b/pytorch_toolbelt/utils/catalyst/__init__.py index 65c35e98e..13e769b10 100644 --- a/pytorch_toolbelt/utils/catalyst/__init__.py +++ b/pytorch_toolbelt/utils/catalyst/__init__.py @@ -3,3 +3,4 @@ from .metrics import * from .visualization import * from .criterions import * +from .opl import * \ No newline at end of file diff --git a/pytorch_toolbelt/utils/catalyst/criterions.py b/pytorch_toolbelt/utils/catalyst/criterions.py index e5fd3a83b..f569be7ef 100644 --- a/pytorch_toolbelt/utils/catalyst/criterions.py +++ b/pytorch_toolbelt/utils/catalyst/criterions.py @@ -1,10 +1,12 @@ import math +import torch from catalyst.dl import RunnerState, CriterionCallback from catalyst.dl.callbacks.criterion import _add_loss_to_state from torch import nn +from torch.nn import functional as F -__all__ = ["LPRegularizationCallback"] +__all__ = ["LPRegularizationCallback", "TSACriterionCallback"] class LPRegularizationCallback(CriterionCallback): @@ -91,4 +93,73 @@ def on_batch_end(self, state: RunnerState): lp_reg = param.norm(self.p) * self.multiplier + lp_reg state.metrics.add_batch_value(metrics_dict={self.prefix: lp_reg.item()}) - _add_loss_to_state(state, lp_reg) + _add_loss_to_state(self.prefix, state, lp_reg) + + + +class TSACriterionCallback(CriterionCallback): + """ + Criterion callback with training signal annealing support. + + This callback requires that criterion key returns loss per each element in batch + + Reference: + Unsupervised Data Augmentation for Consistency Training + https://arxiv.org/abs/1904.12848 + """ + def __init__(self, num_classes, num_epochs, + input_key: str = "targets", + output_key: str = "logits", + prefix: str = "loss", + criterion_key: str = None, + loss_key: str = None, multiplier: float = 1.0, unsupervised_label=-100): + super().__init__(input_key, output_key, prefix, criterion_key, loss_key, multiplier) + self.num_epochs = num_epochs + self.num_classes = num_classes + self.tsa_threshold = None + self.unsupervised_label = unsupervised_label + + def get_tsa_threshold(self, current_epoch, schedule, start, end) -> float: + training_progress = float(current_epoch) / float(self.num_epochs) + + if schedule == "linear_schedule": + threshold = training_progress + elif schedule == "exp_schedule": + scale = 5 + threshold = math.exp((training_progress - 1) * scale) + # [exp(-5), exp(0)] = [1e-2, 1] + elif schedule == "log_schedule": + scale = 5 + # [1 - exp(0), 1 - exp(-5)] = [0, 0.99] + threshold = 1 - math.exp((-training_progress) * scale) + return threshold * (end - start) + start + + def on_epoch_start(self, state: RunnerState): + if state.loader_name == "train": + self.tsa_threshold = self.get_tsa_threshold(state.epoch, 'exp_schedule', 1. / self.num_classes, 1.0) + state.metrics.epoch_values['train']['tsa_threshold'] = self.tsa_threshold + + def _compute_loss(self, state: RunnerState, criterion): + + logits = state.output[self.output_key] + targets = state.input[self.input_key] + supervised_mask = targets != self.unsupervised_label # Mask indicating labeled samples + + targets = targets[supervised_mask] + logits = logits[supervised_mask] + + if not len(targets): + return torch.tensor(0, dtype=logits.dtype, device=logits.device) + + with torch.no_grad(): + one_hot_targets = F.one_hot(targets, num_classes=self.num_classes).float() + sup_probs = logits.detach().softmax(dim=1) + correct_label_probs = torch.sum(one_hot_targets * sup_probs, dim=1) + larger_than_threshold = correct_label_probs > self.tsa_threshold + loss_mask = 1. - larger_than_threshold.float() + + loss = criterion(logits, targets) + loss = loss * loss_mask + + loss = loss.sum() / loss_mask.sum().clamp_min(1) + return loss \ No newline at end of file diff --git a/pytorch_toolbelt/utils/catalyst/opl.py b/pytorch_toolbelt/utils/catalyst/opl.py new file mode 100644 index 000000000..31ae0d7ea --- /dev/null +++ b/pytorch_toolbelt/utils/catalyst/opl.py @@ -0,0 +1,105 @@ +from catalyst.dl import Callback, CallbackOrder, RunnerState +import numpy as np + +from ..torch_utils import to_numpy + +__all__ = ["MulticlassOnlinePseudolabelingCallback", "BCEOnlinePseudolabelingCallback", "PseudolabelDatasetMixin"] + + +class PseudolabelDatasetMixin: + def set_target(self, index: int, value): + raise NotImplementedError + + +class MulticlassOnlinePseudolabelingCallback(Callback): + """ + Online pseudo-labeling callback for multi-class problem. + + >>> unlabeled_train = get_test_dataset( + >>> data_dir, image_size=image_size, augmentation=augmentations + >>> ) + >>> unlabeled_eval = get_test_dataset( + >>> data_dir, image_size=image_size + >>> ) + >>> + >>> callbacks += [ + >>> MulticlassOnlinePseudolabelingCallback( + >>> unlabeled_train.targets, + >>> pseudolabel_loader="label", + >>> prob_threshold=0.9) + >>> ] + >>> train_ds = train_ds + unlabeled_train + >>> + >>> loaders = collections.OrderedDict() + >>> loaders["train"] = DataLoader(train_ds) + >>> loaders["valid"] = DataLoader(valid_ds) + >>> loaders["label"] = DataLoader(unlabeled_eval, shuffle=False) # ! shuffle=False is important ! + """ + + def __init__( + self, + unlabeled_ds: PseudolabelDatasetMixin, + pseudolabel_loader="label", + prob_threshold=0.9, + prob_ratio=None, + output_key="logits", + unlabeled_class=-100, + ): + super().__init__(CallbackOrder.Other) + self.unlabeled_ds = unlabeled_ds + self.pseudolabel_loader = pseudolabel_loader + self.prob_threshold = prob_threshold + self.prob_ratio = prob_ratio + self.predictions = [] + self.output_key = output_key + self.unlabeled_class = unlabeled_class + + def on_epoch_start(self, state: RunnerState): + pass + + def on_loader_start(self, state: RunnerState): + if state.loader_name == self.pseudolabel_loader: + self.predictions = [] + + def get_probabilities(self, state: RunnerState): + probs = state.output[self.output_key].detach().softmax(dim=1) + return to_numpy(probs) + + def on_batch_end(self, state: RunnerState): + if state.loader_name == self.pseudolabel_loader: + probs = self.get_probabilities(state) + self.predictions.extend(probs) + + def on_loader_end(self, state: RunnerState): + if state.loader_name == self.pseudolabel_loader: + predictions = np.array(self.predictions) + max_pred = np.argmax(predictions, axis=1) + max_score = np.amax(predictions, axis=1) + confident_mask = max_score > self.prob_threshold + num_samples = len(predictions) + + for index, predicted_target, score in zip(range(num_samples, max_pred, max_score)): + target = predicted_target if score > self.prob_threshold else self.unlabeled_class + self.unlabeled_ds.set_target(index, target) + + num_confident_samples = confident_mask.sum() + state.metrics.epoch_values[state.loader_name]["pseudolabeling/confident_samples"] = num_confident_samples + state.metrics.epoch_values[state.loader_name]["pseudolabeling/confident_samples_mean_score"] = max_score[ + confident_mask + ].mean() + + state.metrics.epoch_values[state.loader_name]["pseudolabeling/unconfident_samples"] = ( + len(predictions) - num_confident_samples + ) + state.metrics.epoch_values[state.loader_name]["pseudolabeling/unconfident_samples_mean_score"] = max_score[ + ~confident_mask + ].mean() + + def on_epoch_end(self, state: RunnerState): + pass + + +class BCEOnlinePseudolabelingCallback(MulticlassOnlinePseudolabelingCallback): + def get_probabilities(self, state: RunnerState): + probs = state.output[self.output_key].detach().sigmoid() + return to_numpy(probs) diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index 436c4e83d..9a46a2da5 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -15,6 +15,7 @@ __all__ = [ "get_tensorboard_logger", "ShowPolarBatchesCallback", + "ShowEmbeddingsCallback", "draw_binary_segmentation_predictions", "draw_semantic_segmentation_predictions", ] @@ -134,6 +135,51 @@ def _log_samples(self, samples, name, logger, step): plt.show() + +class ShowEmbeddingsCallback(Callback): + def __init__(self, embedding_key, input_key, targets_key, prefix='embedding', + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225) + ): + super().__init__(CallbackOrder.Other) + self.prefix = prefix + self.embedding_key = embedding_key + self.input_key = input_key + self.targets_key = targets_key + self.mean = torch.tensor(mean).view((1, 3, 1, 1)) + self.std = torch.tensor(std).view((1, 3, 1, 1)) + + self.embeddings = [] + self.images = [] + self.targets = [] + + def on_loader_start(self, state: RunnerState): + self.embeddings = [] + self.images = [] + self.targets = [] + + def on_loader_end(self, state: RunnerState): + logger = get_tensorboard_logger(state) + logger.add_embedding(mat=torch.cat(self.embeddings, dim=0), + metadata=self.targets, + label_img=torch.cat(self.images, dim=0), + global_step=state.epoch, + tag=self.prefix + ) + + def on_batch_end(self, state: RunnerState): + embedding = state.output[self.embedding_key].detach().cpu() + image = state.input[self.input_key].detach().cpu() + targets = state.input[self.targets_key].detach().cpu().tolist() + + image = F.interpolate(image, size=(256, 256)) + image = image * self.std + self.mean + + self.images.append(image) + self.embeddings.append(embedding) + self.targets.extend(targets) + + def draw_binary_segmentation_predictions( input: dict, output: dict, From 49f114a67e4c67cb3bf1380ae8dcef33cf5e861e Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 20 Nov 2019 21:22:18 +0200 Subject: [PATCH 13/79] Soft CE & BCE losses --- pytorch_toolbelt/losses/__init__.py | 2 ++ pytorch_toolbelt/losses/soft_bce.py | 30 +++++++++++++++++++++++++++++ pytorch_toolbelt/losses/soft_ce.py | 24 +++++++++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 pytorch_toolbelt/losses/soft_bce.py create mode 100644 pytorch_toolbelt/losses/soft_ce.py diff --git a/pytorch_toolbelt/losses/__init__.py b/pytorch_toolbelt/losses/__init__.py index 8335d8574..35b3b33b6 100644 --- a/pytorch_toolbelt/losses/__init__.py +++ b/pytorch_toolbelt/losses/__init__.py @@ -6,3 +6,5 @@ from .lovasz import * from .joint_loss import * from .wing_loss import * +from .soft_bce import * +from .soft_ce import * \ No newline at end of file diff --git a/pytorch_toolbelt/losses/soft_bce.py b/pytorch_toolbelt/losses/soft_bce.py new file mode 100644 index 000000000..75d65d7d5 --- /dev/null +++ b/pytorch_toolbelt/losses/soft_bce.py @@ -0,0 +1,30 @@ +import torch +from torch import nn +import torch.nn.functional as F + +__all__ = ["SoftBCELoss"] + + +class SoftBCELoss(nn.Module): + def __init__(self, smooth_factor=1e-4, ignore_index=None, reduction="mean"): + super().__init__() + self.smooth_factor = smooth_factor + self.ignore_index = ignore_index + self.reduction = reduction + + def forward(self, label_input, label_target): + not_ignored_mask = label_target != self.ignore_index + + label_target = (1 - label_target) * self.smooth_factor + label_target * (1 - self.smooth_factor) + + loss = F.binary_cross_entropy_with_logits(label_input, label_target, reduction="none") + + loss = loss * not_ignored_mask.float() + + if self.reduction == "mean": + loss = loss.mean() + + if self.reduction == "sum": + loss = loss.sum() + + return loss diff --git a/pytorch_toolbelt/losses/soft_ce.py b/pytorch_toolbelt/losses/soft_ce.py new file mode 100644 index 000000000..b7711fd7b --- /dev/null +++ b/pytorch_toolbelt/losses/soft_ce.py @@ -0,0 +1,24 @@ +import torch +import torch.nn.functional as F +from torch import nn + +__all__ = ["SoftCrossEntropyLoss"] + + +class SoftCrossEntropyLoss(nn.Module): + def __init__(self, smooth_factor=1e-4, ignore_index=None): + super().__init__() + self.smooth_factor = smooth_factor + self.ignore_index = ignore_index + + def forward(self, label_input, label_target): + not_ignored = label_target != self.ignore_index + + num_classes = label_input.size(1) + one_hot_target = F.one_hot(label_target.masked_fill(~not_ignored, 0), num_classes).float() + one_hot_target = one_hot_target * (1 - self.smooth_factor) + (1 - one_hot_target) * self.smooth_factor / ( + num_classes - 1 + ) + log_prb = F.log_softmax(label_input, dim=1) + loss = -(one_hot_target * log_prb).sum(dim=1) + return torch.mean(loss * not_ignored.float()) From b94fad7cf6ae68bf2a5a831460c193f0ac269212 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 21 Nov 2019 13:48:21 +0200 Subject: [PATCH 14/79] Fix SoftBCELoss --- pytorch_toolbelt/losses/soft_bce.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pytorch_toolbelt/losses/soft_bce.py b/pytorch_toolbelt/losses/soft_bce.py index 75d65d7d5..d815d6033 100644 --- a/pytorch_toolbelt/losses/soft_bce.py +++ b/pytorch_toolbelt/losses/soft_bce.py @@ -6,20 +6,23 @@ class SoftBCELoss(nn.Module): - def __init__(self, smooth_factor=1e-4, ignore_index=None, reduction="mean"): + def __init__(self, smooth_factor=None, ignore_index=None, reduction="mean"): super().__init__() self.smooth_factor = smooth_factor self.ignore_index = ignore_index self.reduction = reduction def forward(self, label_input, label_target): - not_ignored_mask = label_target != self.ignore_index + if self.ignore_index is not None: + not_ignored_mask = (label_target != self.ignore_index).float() - label_target = (1 - label_target) * self.smooth_factor + label_target * (1 - self.smooth_factor) + if self.smooth_factor is not None: + label_target = (1 - label_target) * self.smooth_factor + label_target * (1 - self.smooth_factor) loss = F.binary_cross_entropy_with_logits(label_input, label_target, reduction="none") - loss = loss * not_ignored_mask.float() + if self.ignore_index is not None: + loss = loss * not_ignored_mask.float() if self.reduction == "mean": loss = loss.mean() From a9bb29257f75ec36162f41893813760c0e3a6d2f Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 21 Nov 2019 14:49:55 +0200 Subject: [PATCH 15/79] Code formatting --- pytorch_toolbelt/losses/__init__.py | 2 +- pytorch_toolbelt/modules/decoders/deeplab.py | 2 -- pytorch_toolbelt/utils/catalyst/__init__.py | 2 +- pytorch_toolbelt/utils/catalyst/criterions.py | 28 +++++++++++-------- .../utils/catalyst/visualization.py | 27 ++++++++++-------- tests/test_encoders.py | 4 +-- 6 files changed, 36 insertions(+), 29 deletions(-) diff --git a/pytorch_toolbelt/losses/__init__.py b/pytorch_toolbelt/losses/__init__.py index 35b3b33b6..47a28d4bc 100644 --- a/pytorch_toolbelt/losses/__init__.py +++ b/pytorch_toolbelt/losses/__init__.py @@ -7,4 +7,4 @@ from .joint_loss import * from .wing_loss import * from .soft_bce import * -from .soft_ce import * \ No newline at end of file +from .soft_ce import * diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 9b11830b8..509049a29 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -92,8 +92,6 @@ def reset_parameters(self): m.bias.data.zero_() - - class DeeplabV3Decoder(DecoderModule): def __init__(self, feature_maps: List[int], num_classes: int, dropout=0.5): super(DeeplabV3Decoder, self).__init__() diff --git a/pytorch_toolbelt/utils/catalyst/__init__.py b/pytorch_toolbelt/utils/catalyst/__init__.py index 13e769b10..7ac9ec9af 100644 --- a/pytorch_toolbelt/utils/catalyst/__init__.py +++ b/pytorch_toolbelt/utils/catalyst/__init__.py @@ -3,4 +3,4 @@ from .metrics import * from .visualization import * from .criterions import * -from .opl import * \ No newline at end of file +from .opl import * diff --git a/pytorch_toolbelt/utils/catalyst/criterions.py b/pytorch_toolbelt/utils/catalyst/criterions.py index f569be7ef..8867c53b4 100644 --- a/pytorch_toolbelt/utils/catalyst/criterions.py +++ b/pytorch_toolbelt/utils/catalyst/criterions.py @@ -96,7 +96,6 @@ def on_batch_end(self, state: RunnerState): _add_loss_to_state(self.prefix, state, lp_reg) - class TSACriterionCallback(CriterionCallback): """ Criterion callback with training signal annealing support. @@ -107,12 +106,19 @@ class TSACriterionCallback(CriterionCallback): Unsupervised Data Augmentation for Consistency Training https://arxiv.org/abs/1904.12848 """ - def __init__(self, num_classes, num_epochs, - input_key: str = "targets", - output_key: str = "logits", - prefix: str = "loss", - criterion_key: str = None, - loss_key: str = None, multiplier: float = 1.0, unsupervised_label=-100): + + def __init__( + self, + num_classes, + num_epochs, + input_key: str = "targets", + output_key: str = "logits", + prefix: str = "loss", + criterion_key: str = None, + loss_key: str = None, + multiplier: float = 1.0, + unsupervised_label=-100, + ): super().__init__(input_key, output_key, prefix, criterion_key, loss_key, multiplier) self.num_epochs = num_epochs self.num_classes = num_classes @@ -136,8 +142,8 @@ def get_tsa_threshold(self, current_epoch, schedule, start, end) -> float: def on_epoch_start(self, state: RunnerState): if state.loader_name == "train": - self.tsa_threshold = self.get_tsa_threshold(state.epoch, 'exp_schedule', 1. / self.num_classes, 1.0) - state.metrics.epoch_values['train']['tsa_threshold'] = self.tsa_threshold + self.tsa_threshold = self.get_tsa_threshold(state.epoch, "exp_schedule", 1.0 / self.num_classes, 1.0) + state.metrics.epoch_values["train"]["tsa_threshold"] = self.tsa_threshold def _compute_loss(self, state: RunnerState, criterion): @@ -156,10 +162,10 @@ def _compute_loss(self, state: RunnerState, criterion): sup_probs = logits.detach().softmax(dim=1) correct_label_probs = torch.sum(one_hot_targets * sup_probs, dim=1) larger_than_threshold = correct_label_probs > self.tsa_threshold - loss_mask = 1. - larger_than_threshold.float() + loss_mask = 1.0 - larger_than_threshold.float() loss = criterion(logits, targets) loss = loss * loss_mask loss = loss.sum() / loss_mask.sum().clamp_min(1) - return loss \ No newline at end of file + return loss diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index 9a46a2da5..c5d87e039 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -135,12 +135,16 @@ def _log_samples(self, samples, name, logger, step): plt.show() - class ShowEmbeddingsCallback(Callback): - def __init__(self, embedding_key, input_key, targets_key, prefix='embedding', - mean=(0.485, 0.456, 0.406), - std=(0.229, 0.224, 0.225) - ): + def __init__( + self, + embedding_key, + input_key, + targets_key, + prefix="embedding", + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + ): super().__init__(CallbackOrder.Other) self.prefix = prefix self.embedding_key = embedding_key @@ -160,12 +164,13 @@ def on_loader_start(self, state: RunnerState): def on_loader_end(self, state: RunnerState): logger = get_tensorboard_logger(state) - logger.add_embedding(mat=torch.cat(self.embeddings, dim=0), - metadata=self.targets, - label_img=torch.cat(self.images, dim=0), - global_step=state.epoch, - tag=self.prefix - ) + logger.add_embedding( + mat=torch.cat(self.embeddings, dim=0), + metadata=self.targets, + label_img=torch.cat(self.images, dim=0), + global_step=state.epoch, + tag=self.prefix, + ) def on_batch_end(self, state: RunnerState): embedding = state.output[self.embedding_key].detach().cpu() diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 5017781c2..8e32213a5 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -65,8 +65,7 @@ def test_inceptionv4_encoder(): backbone = inceptionv4(pretrained=False) backbone.last_linear = None - net = E.InceptionV4Encoder(pretrained=False, - layers=[0, 1, 2, 3, 4]).cuda() + net = E.InceptionV4Encoder(pretrained=False, layers=[0, 1, 2, 3, 4]).cuda() print(count_parameters(backbone)) print(count_parameters(net)) @@ -88,4 +87,3 @@ def test_densenet(): net2.classifier = None print(count_parameters(net1), count_parameters(net2)) - From bacc9d5c05511bffa4e825f363af267f55bbbd0f Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 21 Nov 2019 14:58:27 +0200 Subject: [PATCH 16/79] Code formatting --- .../modules/activated_batch_norm.py | 17 ++++++++++++++++- .../modules/activated_group_norm.py | 17 ++++++++++++++++- pytorch_toolbelt/modules/activations.py | 2 +- pytorch_toolbelt/modules/decoders/upernet.py | 4 +++- tests/test_activations.py | 3 +++ 5 files changed, 39 insertions(+), 4 deletions(-) diff --git a/pytorch_toolbelt/modules/activated_batch_norm.py b/pytorch_toolbelt/modules/activated_batch_norm.py index 9897fb680..d6c16f8f9 100644 --- a/pytorch_toolbelt/modules/activated_batch_norm.py +++ b/pytorch_toolbelt/modules/activated_batch_norm.py @@ -3,7 +3,22 @@ import torch.nn.functional as F from torch.nn.parameter import Parameter from torch.nn import init -from .activations import * +from .activations import ( + ACT_LEAKY_RELU, + ACT_HARD_SWISH, + ACT_MISH, + ACT_SWISH, + ACT_SELU, + ACT_ELU, + ACT_RELU6, + ACT_RELU, + ACT_HARD_SIGMOID, + ACT_NONE, + hard_sigmoid, + hard_swish, + mish, + swish, +) __all__ = ["ABN"] diff --git a/pytorch_toolbelt/modules/activated_group_norm.py b/pytorch_toolbelt/modules/activated_group_norm.py index c0dbed846..6604b8d1d 100644 --- a/pytorch_toolbelt/modules/activated_group_norm.py +++ b/pytorch_toolbelt/modules/activated_group_norm.py @@ -2,7 +2,22 @@ import torch.nn as nn import torch.nn.functional as F -from .activations import * +from .activations import ( + ACT_LEAKY_RELU, + ACT_HARD_SWISH, + ACT_MISH, + ACT_SWISH, + ACT_SELU, + ACT_ELU, + ACT_RELU6, + ACT_RELU, + ACT_HARD_SIGMOID, + ACT_NONE, + hard_sigmoid, + hard_swish, + mish, + swish, +) __all__ = ["AGN"] diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index a5f32c5b1..530b2926a 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -62,7 +62,7 @@ def backward(ctx, grad_output): def mish(input): """ - Applies the mish function element-wise: + Apply the mish function element-wise: mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) See additional documentation for mish class. Credit: https://github.com/digantamisra98/Mish diff --git a/pytorch_toolbelt/modules/decoders/upernet.py b/pytorch_toolbelt/modules/decoders/upernet.py index 9506e3fbc..200e8027f 100644 --- a/pytorch_toolbelt/modules/decoders/upernet.py +++ b/pytorch_toolbelt/modules/decoders/upernet.py @@ -7,7 +7,9 @@ def conv3x3_bn_relu(in_planes, out_planes, stride=1): - "3x3 convolution + BN + relu" + """ + 3x3 convolution + BN + relu + """ return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), nn.BatchNorm2d(out_planes), diff --git a/tests/test_activations.py b/tests/test_activations.py index 8c86456e0..fafde46a2 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -3,6 +3,8 @@ from pytorch_toolbelt.modules.activations import get_activation_module +skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") + @pytest.mark.parametrize( "activation_name", @@ -19,6 +21,7 @@ def test_activations(activation_name): "activation_name", ["none", "relu", "relu6", "leaky_relu", "elu", "selu", "celu", "mish", "swish", "hard_sigmoid", "hard_swish"], ) +@skip_if_no_cuda def test_activations_cuda(activation_name): act = get_activation_module(activation_name) x = torch.randn(128).float().cuda() From 36bfd82cbef8144da3072df1f45412317c13d541 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 21 Nov 2019 15:04:36 +0200 Subject: [PATCH 17/79] Fix missing import --- pytorch_toolbelt/utils/catalyst/visualization.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index c5d87e039..26c6afced 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -5,6 +5,8 @@ import matplotlib.pyplot as plt import numpy as np import torch +import torch.nn.functional as F + from catalyst.dl import Callback, RunnerState, CallbackOrder from catalyst.dl.callbacks import TensorboardLogger from catalyst.utils.tensorboard import SummaryWriter From 6777415fb9fdee55b378c731d525ad30ab05eb07 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 21 Nov 2019 16:47:41 +0200 Subject: [PATCH 18/79] Use inplace accumulation --- pytorch_toolbelt/inference/tta.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/inference/tta.py b/pytorch_toolbelt/inference/tta.py index 4e8380172..84a3c4b7d 100644 --- a/pytorch_toolbelt/inference/tta.py +++ b/pytorch_toolbelt/inference/tta.py @@ -188,10 +188,11 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor: output = model(image) for aug, deaug in zip( - [F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_rot270, F.torch_rot180, F.torch_rot90] + [F.torch_rot90, F.torch_rot180, F.torch_rot270], + [F.torch_rot270, F.torch_rot180, F.torch_rot90] ): x = deaug(model(aug(image))) - output = output + x + output += x image = F.torch_transpose(image) @@ -200,7 +201,7 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor: [F.torch_none, F.torch_rot270, F.torch_rot180, F.torch_rot90], ): x = deaug(model(aug(image))) - output = output + F.torch_transpose(x) + output += F.torch_transpose(x) one_over_8 = float(1.0 / 8.0) return output * one_over_8 From b37146842724faf5ff39743f4a7233400ba11ec3 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 22 Nov 2019 13:31:39 +0200 Subject: [PATCH 19/79] Fix UNet implementation --- pytorch_toolbelt/losses/soft_bce.py | 33 ++++++++++++++--- pytorch_toolbelt/modules/decoders/unet.py | 24 ++++++------- pytorch_toolbelt/modules/decoders/unet_v2.py | 37 +++++++++++++------- pytorch_toolbelt/modules/unet.py | 11 +++--- 4 files changed, 71 insertions(+), 34 deletions(-) diff --git a/pytorch_toolbelt/losses/soft_bce.py b/pytorch_toolbelt/losses/soft_bce.py index d815d6033..a1b9650bd 100644 --- a/pytorch_toolbelt/losses/soft_bce.py +++ b/pytorch_toolbelt/losses/soft_bce.py @@ -1,14 +1,38 @@ import torch from torch import nn import torch.nn.functional as F +from typing import Optional -__all__ = ["SoftBCELoss"] +__all__ = ["BCELoss", "SoftBCELoss"] + + +class BCELoss(nn.Module): + def __init__(self, ignore_index: Optional[int] = -100, reduction="mean"): + super().__init__() + self.ignore_index = ignore_index + self.reduction = reduction + + def forward(self, label_input, label_target): + if self.ignore_index is not None: + not_ignored_mask = (label_target != self.ignore_index).float() + + loss = F.binary_cross_entropy_with_logits(label_input, label_target, reduction="none") + if self.ignore_index is not None: + loss = loss * not_ignored_mask.float() + + if self.reduction == "mean": + loss = loss.mean() + + if self.reduction == "sum": + loss = loss.sum() + + return loss class SoftBCELoss(nn.Module): - def __init__(self, smooth_factor=None, ignore_index=None, reduction="mean"): + def __init__(self, smooth_factor=0, ignore_index: Optional[int] = -100, reduction="mean"): super().__init__() - self.smooth_factor = smooth_factor + self.smooth_factor = float(smooth_factor) self.ignore_index = ignore_index self.reduction = reduction @@ -16,8 +40,7 @@ def forward(self, label_input, label_target): if self.ignore_index is not None: not_ignored_mask = (label_target != self.ignore_index).float() - if self.smooth_factor is not None: - label_target = (1 - label_target) * self.smooth_factor + label_target * (1 - self.smooth_factor) + label_target = (1 - label_target) * self.smooth_factor + label_target * (1 - self.smooth_factor) loss = F.binary_cross_entropy_with_logits(label_input, label_target, reduction="none") diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index 7fb6dd364..ab7c2b746 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -18,30 +18,28 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels if not isinstance(decoder_features, list): decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] + self.center = UnetCentralBlock(in_dec_filters=feature_maps[-1], out_filters=decoder_features[-1]) + blocks = [] for block_index, in_enc_features in enumerate(feature_maps[:-1]): blocks.append( UnetDecoderBlock( - decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels + in_dec_filters=decoder_features[block_index + 1], + in_enc_filters=in_enc_features, + out_filters=decoder_features[block_index], ) ) - self.center = UnetCentralBlock(feature_maps[-1], decoder_features[-1], mask_channels) self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features - def forward(self, feature_maps): + self.final = nn.Conv2d(decoder_features[0], mask_channels, kernel_size=1) - output, dsv = self.center(feature_maps[-1]) - decoder_outputs = [output] - dsv_list = [dsv] + def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor: + output = self.center(feature_maps[-1]) for decoder_block, encoder_output in zip(reversed(self.blocks), reversed(feature_maps[:-1])): - output, dsv = decoder_block(output, encoder_output) - decoder_outputs.append(output) - dsv_list.append(dsv) - - dsv_list = list(reversed(dsv_list)) - decoder_outputs = list(reversed(decoder_outputs)) + output = decoder_block(output, encoder_output) - return decoder_outputs, dsv_list + output = self.final(output) + return output diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index 33068226f..e3bccb8a2 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Tuple, List import torch import torch.nn.functional as F @@ -47,9 +47,16 @@ def __init__( abn_block=ABN, pre_dropout_rate=0.0, post_dropout_rate=0.0, + scale_factor=None, + scale_mode="nearest", + align_corners=None, ): super(UnetDecoderBlockV2, self).__init__() + self.scale_factor = scale_factor + self.scale_mode = scale_mode + self.align_corners = align_corners + self.bottleneck = nn.Conv2d(in_dec_filters + in_enc_filters, out_filters, kernel_size=1) self.conv1 = nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=1, padding=1, bias=False) @@ -63,11 +70,17 @@ def __init__( self.dsv = nn.Conv2d(out_filters, mask_channels, kernel_size=1) - def forward(self, x, enc): - lat_size = enc.size()[2:] - x = F.interpolate(x, size=lat_size, mode="bilinear", align_corners=True) + def forward(self, x: torch.Tensor, enc: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: + if self.scale_factor is not None: + x = F.interpolate( + x, scale_factor=self.scale_factor, mode=self.scale_mode, align_corners=self.align_corners + ) + else: + lat_size = enc.size()[2:] + x = F.interpolate(x, size=lat_size, mode=self.scale_mode, align_corners=self.align_corners) x = torch.cat([x, enc], 1) + x = self.bottleneck(x) x = self.pre_drop(x) @@ -84,36 +97,36 @@ def forward(self, x, enc): class UNetDecoderV2(DecoderModule): - def __init__(self, features: List[int], decoder_features: int, mask_channels: int): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int): super().__init__() if not isinstance(decoder_features, list): - decoder_features = [decoder_features * (2 ** i) for i in range(len(features))] + decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] blocks = [] - for block_index, in_enc_features in enumerate(features[:-1]): + for block_index, in_enc_features in enumerate(feature_maps[:-1]): blocks.append( UnetDecoderBlockV2( decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels ) ) - self.center = UnetCentralBlockV2(features[-1], decoder_features[-1], mask_channels) + self.center = UnetCentralBlockV2(feature_maps[-1], decoder_features[-1], mask_channels) self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features + self.final = nn.Conv2d(decoder_features[0], mask_channels, kernel_size=1) + def forward(self, feature_maps): output, dsv = self.center(feature_maps[-1]) - decoder_outputs = [output] dsv_list = [dsv] for decoder_block, encoder_output in zip(reversed(self.blocks), reversed(feature_maps[:-1])): output, dsv = decoder_block(output, encoder_output) - decoder_outputs.append(output) dsv_list.append(dsv) dsv_list = list(reversed(dsv_list)) - decoder_outputs = list(reversed(decoder_outputs)) - return decoder_outputs, dsv_list + output = self.final(output) + return output, dsv_list diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py index c95141791..728f9b3cf 100644 --- a/pytorch_toolbelt/modules/unet.py +++ b/pytorch_toolbelt/modules/unet.py @@ -58,6 +58,10 @@ def __init__( ): super(UnetDecoderBlock, self).__init__() + self.scale_factor = scale_factor + self.scale_mode = scale_mode + self.align_corners = align_corners + self.pre_drop = nn.Dropout2d(pre_dropout_rate, inplace=True) self.conv1 = nn.Conv2d( @@ -69,10 +73,6 @@ def __init__( self.post_drop = nn.Dropout2d(post_dropout_rate, inplace=False) - self.scale_factor = scale_factor - self.scale_mode = scale_mode - self.align_corners = align_corners - def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: if self.scale_factor is not None: x = F.interpolate( @@ -85,9 +85,12 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: x = torch.cat([x, enc], 1) x = self.pre_drop(x) + x = self.conv1(x) x = self.abn1(x) + x = self.conv2(x) x = self.abn2(x) + x = self.post_drop(x) return x From 25e8a0216c0f0ac7663f0850b2fcd121cea70904 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 22 Nov 2019 18:13:26 +0200 Subject: [PATCH 20/79] Fix FPN sum --- pytorch_toolbelt/modules/decoders/fpn_sum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index 5bd927a84..3c6203452 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -94,7 +94,7 @@ def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Tuple[Tensor, Tenso :param encoder_fm: :return: """ - decoder_fm = F.interpolate(decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=True) + decoder_fm = F.interpolate(decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=False) encoder_fm = self.skip(encoder_fm) x = decoder_fm + encoder_fm @@ -140,7 +140,7 @@ def __init__( ] ) - self.final_block = nn.Sequential(nn.Conv2d(fpn_channels, num_classes, kernel_size=1)) + self.final_block = nn.Conv2d(fpn_channels, num_classes, kernel_size=1) def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor]: last_feature_map = feature_maps[-1] From 8eb43cc6b94eba67389ed75b3406a8dc83b28085 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 22 Nov 2019 22:04:44 +0200 Subject: [PATCH 21/79] DeeplabV3+ decoder --- pytorch_toolbelt/modules/decoders/deeplab.py | 95 +++++++------------- 1 file changed, 31 insertions(+), 64 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 509049a29..4eac99c48 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -5,39 +5,26 @@ import torch.nn.functional as F from .common import DecoderModule from ..activated_batch_norm import ABN -from ..encoders import EncoderModule __all__ = ["DeeplabV3Decoder"] class ASPPModule(nn.Module): - def __init__(self, inplanes, planes, kernel_size, padding, dilation): + def __init__(self, inplanes, planes, kernel_size, padding, dilation, abn_block=ABN): super(ASPPModule, self).__init__() self.atrous_conv = nn.Conv2d( inplanes, planes, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, bias=False ) - self.bn = nn.BatchNorm2d(planes) - self.relu = nn.ReLU(inplace=True) - - self.reset_parameters() + self.abn = abn_block(planes) def forward(self, x): x = self.atrous_conv(x) - x = self.bn(x) - - return self.relu(x) - - def reset_parameters(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - torch.nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() + x = self.abn(x) + return x class ASPP(nn.Module): - def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5): + def __init__(self, inplanes: int, output_stride: int, output_features: int, dropout=0.5, abn_block=ABN): super(ASPP, self).__init__() if output_stride == 32: @@ -57,14 +44,11 @@ def __init__(self, inplanes: int, output_stride: int, output_features: int, drop self.global_avg_pool = nn.Sequential( nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(inplanes, output_features, 1, stride=1, bias=False), - nn.BatchNorm2d(output_features), - nn.ReLU(inplace=True), + abn_block(output_features), ) - self.conv1 = nn.Conv2d(1280, output_features, 1, bias=False) - self.bn1 = nn.BatchNorm2d(output_features) - self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(output_features * 5, output_features, 1, bias=False) + self.abn1 = abn_block(output_features) self.dropout = nn.Dropout(dropout) - self.reset_parameters() def forward(self, x): x1 = self.aspp1(x) @@ -76,66 +60,49 @@ def forward(self, x): x = torch.cat((x1, x2, x3, x4, x5), dim=1) x = self.conv1(x) - x = self.bn1(x) - x = self.relu(x) + x = self.abn1(x) return self.dropout(x) - def reset_parameters(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels - # m.weight.data.normal_(0, math.sqrt(2. / n)) - torch.nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() - class DeeplabV3Decoder(DecoderModule): - def __init__(self, feature_maps: List[int], num_classes: int, dropout=0.5): + def __init__(self, + feature_maps: List[int], + num_classes: int, + output_stride=32, + high_level_bottleneck=256, + low_level_bottleneck=32, + dropout=0.5, + abn_block=ABN): super(DeeplabV3Decoder, self).__init__() - low_level_features = feature_maps[0] - high_level_features = feature_maps[-1] + self.aspp = ASPP(feature_maps[-1], output_stride, high_level_bottleneck, dropout=dropout, abn_block=abn_block) - self.conv1 = nn.Conv2d(low_level_features, 48, 1, bias=False) - self.bn1 = nn.BatchNorm2d(48) - self.relu = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d(feature_maps[0], low_level_bottleneck, 1, bias=False) + self.abn1 = abn_block(48) self.last_conv = nn.Sequential( - nn.Conv2d(high_level_features + 48, 256, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), + nn.Conv2d(high_level_bottleneck + low_level_bottleneck, high_level_bottleneck, kernel_size=3, stride=1, padding=1, bias=False), + abn_block(high_level_bottleneck), nn.Dropout(dropout), - nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), + nn.Conv2d(high_level_bottleneck, high_level_bottleneck, kernel_size=3, stride=1, padding=1, bias=False), + abn_block(high_level_bottleneck), nn.Dropout(dropout * 0.2), # 5 times smaller dropout rate - nn.Conv2d(256, num_classes, kernel_size=1, stride=1), + nn.Conv2d(high_level_bottleneck, num_classes, kernel_size=1, stride=1), ) - self.reset_parameters() def forward(self, feature_maps): - high_level_features = feature_maps[-1] low_level_feat = feature_maps[0] - low_level_feat = self.conv1(low_level_feat) - low_level_feat = self.bn1(low_level_feat) - low_level_feat = self.relu(low_level_feat) + low_level_feat = self.abn1(low_level_feat) + + high_level_features = feature_maps[-1] + high_level_features = self.aspp(high_level_features) high_level_features = F.interpolate( - high_level_features, size=low_level_feat.size()[2:], mode="bilinear", align_corners=True + high_level_features, size=low_level_feat.size()[2:], mode="bilinear", align_corners=False ) - high_level_features = torch.cat((high_level_features, low_level_feat), dim=1) + high_level_features = torch.cat([high_level_features, low_level_feat], dim=1) high_level_features = self.last_conv(high_level_features) return high_level_features - - def reset_parameters(self): - for m in self.modules(): - if isinstance(m, nn.Conv2d): - torch.nn.init.kaiming_normal_(m.weight) - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1) - m.bias.data.zero_() From 21ae4d97c83a0a0dac4d9c85fa6070a6058d4d2b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 22 Nov 2019 22:09:35 +0200 Subject: [PATCH 22/79] DeeplabV3+ decoder --- pytorch_toolbelt/modules/decoders/deeplab.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 4eac99c48..6765495f4 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -82,15 +82,17 @@ def __init__(self, self.abn1 = abn_block(48) self.last_conv = nn.Sequential( - nn.Conv2d(high_level_bottleneck + low_level_bottleneck, high_level_bottleneck, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d(high_level_bottleneck + low_level_bottleneck, high_level_bottleneck, kernel_size=3, padding=1, bias=False), abn_block(high_level_bottleneck), nn.Dropout(dropout), - nn.Conv2d(high_level_bottleneck, high_level_bottleneck, kernel_size=3, stride=1, padding=1, bias=False), + nn.Conv2d(high_level_bottleneck, high_level_bottleneck, kernel_size=3, padding=1, bias=False), abn_block(high_level_bottleneck), nn.Dropout(dropout * 0.2), # 5 times smaller dropout rate - nn.Conv2d(high_level_bottleneck, num_classes, kernel_size=1, stride=1), + nn.Conv2d(high_level_bottleneck, num_classes, kernel_size=1), ) + self.dsv = nn.Conv2d(high_level_bottleneck, num_classes, kernel_size=1) + def forward(self, feature_maps): low_level_feat = feature_maps[0] low_level_feat = self.conv1(low_level_feat) @@ -99,10 +101,12 @@ def forward(self, feature_maps): high_level_features = feature_maps[-1] high_level_features = self.aspp(high_level_features) + dsv = self.dsv(high_level_features) + high_level_features = F.interpolate( high_level_features, size=low_level_feat.size()[2:], mode="bilinear", align_corners=False ) high_level_features = torch.cat([high_level_features, low_level_feat], dim=1) high_level_features = self.last_conv(high_level_features) - return high_level_features + return high_level_features, dsv From 0daa4ddbd0ce4b37ccd22fe339c01861980ede55 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 22 Nov 2019 22:21:14 +0200 Subject: [PATCH 23/79] Fix deeplab decoder --- pytorch_toolbelt/modules/decoders/deeplab.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 6765495f4..c985dff59 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -79,7 +79,7 @@ def __init__(self, self.aspp = ASPP(feature_maps[-1], output_stride, high_level_bottleneck, dropout=dropout, abn_block=abn_block) self.conv1 = nn.Conv2d(feature_maps[0], low_level_bottleneck, 1, bias=False) - self.abn1 = abn_block(48) + self.abn1 = abn_block(low_level_bottleneck) self.last_conv = nn.Sequential( nn.Conv2d(high_level_bottleneck + low_level_bottleneck, high_level_bottleneck, kernel_size=3, padding=1, bias=False), From 90935c76ff13ea3ad40e81703c9155d671e28ca3 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 27 Nov 2019 11:13:26 +0200 Subject: [PATCH 24/79] Fix typo in proj blocks --- pytorch_toolbelt/modules/decoders/fpn_sum.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index 3c6203452..d6701b4cb 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -39,8 +39,8 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: x = self.bottleneck(x) p2 = self.proj2(self.pool2(x)) - p4 = self.proj2(self.pool4(x)) - p8 = self.proj2(self.pool8(x)) + p4 = self.proj4(self.pool4(x)) + p8 = self.proj8(self.pool8(x)) x_size = x.size()[2:] x = torch.cat( From 50c62500f4a565cbc6f25414432da9d6357b0f47 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 27 Nov 2019 11:13:54 +0200 Subject: [PATCH 25/79] Fix block names --- pytorch_toolbelt/modules/decoders/fpn_cat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 779a319d9..43c0777a8 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -10,7 +10,7 @@ __all__ = ["FPNCatDecoder"] -class FPNSumDecoderBlock(nn.Module): +class FPNCatDecoderBlock(nn.Module): """ Simple prediction block composed of (Conv + BN + Activation) repeated twice """ @@ -45,7 +45,7 @@ def __init__( dropout=0.0, abn_block=ABN, upsample_add=UpsampleAdd, - prediction_block=FPNSumDecoderBlock, + prediction_block=FPNCatDecoderBlock, ): super().__init__() From 8b32cae742ebe018b4f7bec9d5eb814006deb53e Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 27 Nov 2019 11:14:07 +0200 Subject: [PATCH 26/79] Use Dropout2d --- pytorch_toolbelt/modules/decoders/fpn_cat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 43c0777a8..f57bfb139 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -21,7 +21,7 @@ def __init__(self, input_features: int, output_features: int, abn_block=ABN, dro self.abn1 = abn_block(output_features) self.conv2 = nn.Conv2d(output_features, output_features, kernel_size=3, padding=1, bias=False) self.abn2 = abn_block(output_features) - self.drop2 = nn.Dropout(dropout) + self.drop2 = nn.Dropout2d(dropout) def forward(self, x: Tensor) -> Tensor: x = self.conv1(x) From 8ae94adaf8a37536d55caaea85cded67034b7406 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 27 Nov 2019 16:40:22 +0200 Subject: [PATCH 27/79] Refactor FPN decoders --- pytorch_toolbelt/modules/decoders/common.py | 2 + pytorch_toolbelt/modules/decoders/fpn_cat.py | 64 ++++++---- pytorch_toolbelt/modules/decoders/fpn_sum.py | 124 ++++++++++++++----- pytorch_toolbelt/modules/fpn.py | 70 +++++++++++ 4 files changed, 202 insertions(+), 58 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/common.py b/pytorch_toolbelt/modules/decoders/common.py index d42e61134..45319f9e6 100644 --- a/pytorch_toolbelt/modules/decoders/common.py +++ b/pytorch_toolbelt/modules/decoders/common.py @@ -2,6 +2,8 @@ __all__ = ["DecoderModule", "SegmentationDecoderModule"] +from typing import List + class DecoderModule(nn.Module): def __init__(self): diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index f57bfb139..3fe5b89b2 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -1,4 +1,5 @@ -from typing import List, Tuple +from functools import partial +from typing import List, Tuple, Optional, Union from torch import nn, Tensor @@ -7,7 +8,7 @@ from ..activated_batch_norm import ABN from ..fpn import FPNFuse, UpsampleAdd -__all__ = ["FPNCatDecoder"] +__all__ = ["FPNCatDecoderBlock", "FPNCatDecoder"] class FPNCatDecoderBlock(nn.Module): @@ -40,13 +41,28 @@ class FPNCatDecoder(SegmentationDecoderModule): def __init__( self, feature_maps: List[int], - num_classes: int, + output_channels: int, fpn_channels=128, + dsv_channels: Optional[int] = None, dropout=0.0, abn_block=ABN, upsample_add=UpsampleAdd, prediction_block=FPNCatDecoderBlock, + final_block=partial(nn.Conv2d, kernel_size=1), ): + """ + + Args: + feature_maps: + output_channels: + fpn_channels: + dsv_channels: + dropout: + abn_block: + upsample_add: + prediction_block: + final_block: + """ super().__init__() self.fpn = FPNDecoder( @@ -61,35 +77,33 @@ def __init__( self.dropout = nn.Dropout2d(dropout, inplace=True) # dsv blocks are for deep supervision - self.dsv = nn.ModuleList( - [ - nn.Conv2d(fpn_features, num_classes, kernel_size=1) - for fpn_features in [fpn_channels] * len(feature_maps) - ] - ) + if dsv_channels is not None: + self.dsv = nn.ModuleList( + [ + nn.Conv2d(fpn_features, dsv_channels, kernel_size=1) + for fpn_features in [fpn_channels] * len(feature_maps) + ] + ) + else: + self.dsv = None features = sum(self.fpn.output_filters) - self.final_block = nn.Sequential( - nn.Conv2d(features, features // 2, kernel_size=1), - abn_block(features // 2), - nn.Conv2d(features // 2, features // 4, kernel_size=3, padding=1, bias=False), - abn_block(features // 4), - nn.Conv2d(features // 4, features // 4, kernel_size=3, padding=1, bias=False), - abn_block(features // 4), - nn.Conv2d(features // 4, num_classes, kernel_size=1, bias=True), - ) + self.final_block = final_block(features, output_channels) - def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, List[Tensor]]: + def forward(self, feature_maps: List[Tensor]) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: fpn_maps = self.fpn(feature_maps) fused = self.fuse(fpn_maps) fused = self.dropout(fused) + x = self.final_block(fused) - dsv_masks = [] - for dsv_block, fpn in zip(self.dsv, fpn_maps): - dsv = dsv_block(fpn) - dsv_masks.append(dsv) + if self.dsv is not None: + dsv_masks = [] + for dsv_block, fpn in zip(self.dsv, fpn_maps): + dsv = dsv_block(fpn) + dsv_masks.append(dsv) - x = self.final_block(fused) - return x, dsv_masks + return x, dsv_masks + + return x diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index d6701b4cb..5d67f5a53 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -1,5 +1,6 @@ +from functools import partial from itertools import repeat -from typing import List, Tuple +from typing import List, Tuple, Optional, Union import torch from ..activated_batch_norm import ABN @@ -14,7 +15,24 @@ class FPNSumCenterBlock(nn.Module): - def __init__(self, encoder_features: int, decoder_features: int, num_classes: int, abn_block=ABN, dropout=0.0): + def __init__( + self, + encoder_features: int, + decoder_features: int, + dsv_channels: Optional[int] = None, + abn_block=ABN, + dropout=0.0, + ): + """ + Center FPN block that aggregates multi-scale context using strided average poolings + + Args: + encoder_features: Number of input features + decoder_features: Number of output features + dsv_channels: Number of output features for deep supervision (usually number of channels in final mask) + abn_block: Block for Activation + BatchNorm2d + dropout: Dropout rate after context fusion + """ super().__init__() self.bottleneck = nn.Conv2d(encoder_features, encoder_features // 2, kernel_size=1) @@ -33,9 +51,12 @@ def __init__(self, encoder_features: int, decoder_features: int, num_classes: in self.conv1 = nn.Conv2d(decoder_features, decoder_features, kernel_size=3, padding=1, bias=False) self.abn1 = abn_block(decoder_features) - self.dsv = nn.Conv2d(decoder_features, num_classes, kernel_size=1) + if dsv_channels is not None: + self.dsv = nn.Conv2d(decoder_features, dsv_channels, kernel_size=1) + else: + self.dsv = None - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = self.bottleneck(x) p2 = self.proj2(self.pool2(x)) @@ -59,9 +80,11 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: x = self.conv1(x) x = self.abn1(x) - dsv = self.dsv(x) + if self.dsv is not None: + dsv = self.dsv(x) + return x, dsv - return x, dsv + return x class FPNSumDecoderBlock(nn.Module): @@ -70,10 +93,20 @@ def __init__( encoder_features: int, decoder_features: int, output_features: int, - num_classes: int, + dsv_channels: Optional[int] = None, abn_block=ABN, dropout=0.0, ): + """ + + Args: + encoder_features: + decoder_features: + output_features: + dsv_channels: + abn_block: + dropout: + """ super().__init__() self.skip = nn.Conv2d(encoder_features, decoder_features, kernel_size=1) if decoder_features == output_features: @@ -81,79 +114,104 @@ def __init__( else: self.reduction = nn.Conv2d(decoder_features, output_features, kernel_size=1) - self.dropout = nn.Dropout2d(dropout, inplace=True) self.conv1 = nn.Conv2d(output_features, output_features, kernel_size=3, padding=1, bias=False) self.abn1 = abn_block(output_features) + self.drop1 = nn.Dropout2d(dropout, inplace=True) - self.dsv = nn.Conv2d(output_features, num_classes, kernel_size=1) - - def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Tuple[Tensor, Tensor]: - """ + if dsv_channels is not None: + self.dsv = nn.Conv2d(decoder_features, dsv_channels, kernel_size=1) + else: + self.dsv = None - :param decoder_fm: - :param encoder_fm: - :return: - """ + def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: decoder_fm = F.interpolate(decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=False) encoder_fm = self.skip(encoder_fm) x = decoder_fm + encoder_fm x = self.reduction(x) - x = self.dropout(x) x = self.conv1(x) x = self.abn1(x) + x = self.drop1(x) - dsv = self.dsv(x) + if self.dsv is not None: + dsv = self.dsv(x) + return x, dsv - return x, dsv + return x class FPNSumDecoder(SegmentationDecoderModule): - """ - - """ - def __init__( self, feature_maps: List[int], - num_classes: int, + output_channels: int, + dsv_channels: Optional[int] = None, fpn_channels=256, dropout=0.0, abn_block=ABN, center_block=FPNSumCenterBlock, decoder_block=FPNSumDecoderBlock, + final_block=partial(nn.Conv2d, kernel_size=1), ): + """ + + Args: + feature_maps: + output_channels: + dsv_channels: + fpn_channels: + dropout: + abn_block: + center_block: + decoder_block: + final_block: + """ super().__init__() self.center = center_block( - feature_maps[-1], fpn_channels, num_classes=num_classes, dropout=dropout, abn_block=abn_block + feature_maps[-1], fpn_channels, dsv_channels=dsv_channels, dropout=dropout, abn_block=abn_block ) self.fpn_modules = nn.ModuleList( [ decoder_block( - encoder_fm, decoder_fm, decoder_fm, num_classes=num_classes, dropout=dropout, abn_block=abn_block + encoder_fm, decoder_fm, decoder_fm, dsv_channels=dsv_channels, dropout=dropout, abn_block=abn_block ) for decoder_fm, encoder_fm in zip(repeat(fpn_channels), reversed(feature_maps[:-1])) ] ) - self.final_block = nn.Conv2d(fpn_channels, num_classes, kernel_size=1) + self.final_block = final_block(fpn_channels, output_channels) + self.dsv_channels = dsv_channels - def forward(self, feature_maps: List[Tensor]) -> Tuple[Tensor, Tensor]: + def forward(self, feature_maps: List[Tensor]) -> Union[Tensor, Tuple[Tensor, List[Tensor]]]: last_feature_map = feature_maps[-1] feature_maps = reversed(feature_maps[:-1]) dsv_masks = [] - x, dsv = self.center(last_feature_map) - dsv_masks.append(dsv) + output = self.center(last_feature_map) - for transition_unit, encoder_fm in zip(self.fpn_modules, feature_maps): - x, dsv = transition_unit(x, encoder_fm) + if self.dsv_channels: + x, dsv = output dsv_masks.append(dsv) + else: + x = output + + for fpn_block, encoder_fm in zip(self.fpn_modules, feature_maps): + output = fpn_block(x, encoder_fm) + + if self.dsv_channels: + x, dsv = output + dsv_masks.append(dsv) + else: + x = output x = self.final_block(x) - return x, dsv_masks + + if self.dsv_channels: + return x, dsv_masks + + return x diff --git a/pytorch_toolbelt/modules/fpn.py b/pytorch_toolbelt/modules/fpn.py index 43fc91cbd..86eea6a24 100644 --- a/pytorch_toolbelt/modules/fpn.py +++ b/pytorch_toolbelt/modules/fpn.py @@ -4,12 +4,16 @@ from torch import nn from torch.nn import functional as F +from ..modules.activated_batch_norm import ABN + __all__ = [ "FPNBottleneckBlock", "FPNBottleneckBlockBN", "FPNPredictionBlock", "FPNFuse", "FPNFuseSum", + "FPNFinalBottleneckBlock", + "FPNFinalTransposeConvBlock", "UpsampleAdd", "UpsampleAddConv", ] @@ -48,6 +52,72 @@ def forward(self, x): return x +class FPNFinalBottleneckBlock(nn.Module): + def __init__(self, input_channels: int, output_channels: int, reduction=4, abn_block=ABN): + super().__init__() + + features = input_channels // reduction + + self.bottleneck = nn.Conv2d(input_channels, features, kernel_size=1, bias=False) + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False) + self.abn1 = abn_block(features) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False) + self.abn2 = abn_block(features) + + self.final = nn.Conv2d(features, output_channels, kernel_size=1, bias=True) + + def forward(self, x): + x = self.bottleneck(x) + + x = self.conv1(x) + x = self.abn1(x) + + x = self.conv2(x) + x = self.abn2(x) + + x = self.final(x) + return x + + +class FPNFinalTransposeConvBlock(nn.Module): + def __init__(self, input_channels: int, output_channels: int, reduction=4, abn_block=ABN): + """ + + Args: + input_channels: + output_channels: + reduction: + abn_block: + """ + super().__init__() + + features = input_channels // reduction + + self.bottleneck = nn.Conv2d(input_channels, features, kernel_size=1, bias=False) + + self.conv1 = nn.ConvTranspose2d(features, features, kernel_size=3, stride=2, padding=1, bias=False) + self.abn1 = abn_block(features) + + self.conv2 = nn.ConvTranspose2d(features, features, kernel_size=3, stride=2, padding=1, bias=False) + self.abn2 = abn_block(features) + + self.final = nn.Conv2d(features, output_channels, kernel_size=1, bias=True) + + def forward(self, x): + x = self.bottleneck(x) + + x = self.conv1(x) + x = self.abn1(x) + + x = self.conv2(x) + x = self.abn2(x) + + x = self.final(x) + return x + + class UpsampleAdd(nn.Module): """ Compute pixelwise sum of first tensor and upsampled second tensor. From 6cfbc36354104df2ae510a69563b5e5fadeedf3d Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 27 Nov 2019 17:09:58 +0200 Subject: [PATCH 28/79] Bugfix --- pytorch_toolbelt/modules/encoders/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index 829ffee45..93eab773c 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -40,7 +40,7 @@ def make_n_channel_input(conv: nn.Conv2d, in_channels: int, mode="auto"): else: w = w[:, 0:in_channels, ...] - new_conv.weight = nn.Parameter(w[:, 0:1, ...], requires_grad=True) + new_conv.weight = nn.Parameter(w, requires_grad=True) return new_conv From 4dbb219b7794839807bfd03b5021ab29714bf281 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 27 Nov 2019 17:22:34 +0200 Subject: [PATCH 29/79] change_input_channels returns self --- pytorch_toolbelt/modules/encoders/common.py | 6 ++++-- pytorch_toolbelt/modules/encoders/densenet.py | 1 + pytorch_toolbelt/modules/encoders/inception.py | 1 + pytorch_toolbelt/modules/encoders/mobilenet.py | 2 ++ pytorch_toolbelt/modules/encoders/resnet.py | 1 + pytorch_toolbelt/modules/encoders/seresnet.py | 1 + pytorch_toolbelt/modules/encoders/squeezenet.py | 1 + pytorch_toolbelt/modules/encoders/unet.py | 1 + pytorch_toolbelt/modules/encoders/wide_resnet.py | 2 ++ 9 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index 93eab773c..df2639dd3 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -36,11 +36,13 @@ def make_n_channel_input(conv: nn.Conv2d, in_channels: int, mode="auto"): w = conv.weight if in_channels > conv.in_channels: - w = F.pad(w, pad=[0, 0, 0, in_channels - conv.in_channels], mode="circular") + # TODO: Figure out padding scheme + # w = F.pad(w, pad=[0, in_channels - conv.in_channels, 0, 0], mode="circular") + pass else: w = w[:, 0:in_channels, ...] + new_conv.weight = nn.Parameter(w, requires_grad=True) - new_conv.weight = nn.Parameter(w, requires_grad=True) return new_conv diff --git a/pytorch_toolbelt/modules/encoders/densenet.py b/pytorch_toolbelt/modules/encoders/densenet.py index 9098b38b4..b762ce3d3 100644 --- a/pytorch_toolbelt/modules/encoders/densenet.py +++ b/pytorch_toolbelt/modules/encoders/densenet.py @@ -78,6 +78,7 @@ def forward(self, x): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv0 = make_n_channel_input(self.layer0.conv0, input_channels, mode) + return self class DenseNet121Encoder(DenseNetEncoder): diff --git a/pytorch_toolbelt/modules/encoders/inception.py b/pytorch_toolbelt/modules/encoders/inception.py index 90a2de8e9..dbd842394 100644 --- a/pytorch_toolbelt/modules/encoders/inception.py +++ b/pytorch_toolbelt/modules/encoders/inception.py @@ -40,3 +40,4 @@ def encoder_layers(self): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0[0] = make_n_channel_input(self.layer0[0], input_channels, mode) + return self diff --git a/pytorch_toolbelt/modules/encoders/mobilenet.py b/pytorch_toolbelt/modules/encoders/mobilenet.py index 769770a61..bb74e6500 100644 --- a/pytorch_toolbelt/modules/encoders/mobilenet.py +++ b/pytorch_toolbelt/modules/encoders/mobilenet.py @@ -25,6 +25,7 @@ def encoder_layers(self): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self class MobilenetV3Encoder(EncoderModule): @@ -73,3 +74,4 @@ def encoder_layers(self): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self diff --git a/pytorch_toolbelt/modules/encoders/resnet.py b/pytorch_toolbelt/modules/encoders/resnet.py index e75f61a60..88be2b952 100644 --- a/pytorch_toolbelt/modules/encoders/resnet.py +++ b/pytorch_toolbelt/modules/encoders/resnet.py @@ -55,6 +55,7 @@ def forward(self, x): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv0 = make_n_channel_input(self.layer0.conv0, input_channels, mode) + return self class Resnet18Encoder(ResnetEncoder): diff --git a/pytorch_toolbelt/modules/encoders/seresnet.py b/pytorch_toolbelt/modules/encoders/seresnet.py index 370abfd5b..a71ddeb85 100644 --- a/pytorch_toolbelt/modules/encoders/seresnet.py +++ b/pytorch_toolbelt/modules/encoders/seresnet.py @@ -80,6 +80,7 @@ def forward(self, x: Tensor) -> List[Tensor]: def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self class SEResnet50Encoder(SEResnetEncoder): diff --git a/pytorch_toolbelt/modules/encoders/squeezenet.py b/pytorch_toolbelt/modules/encoders/squeezenet.py index ba09a7a7d..f10ea850a 100644 --- a/pytorch_toolbelt/modules/encoders/squeezenet.py +++ b/pytorch_toolbelt/modules/encoders/squeezenet.py @@ -60,3 +60,4 @@ def encoder_layers(self): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self diff --git a/pytorch_toolbelt/modules/encoders/unet.py b/pytorch_toolbelt/modules/encoders/unet.py index 5f810c7ca..6f8558356 100644 --- a/pytorch_toolbelt/modules/encoders/unet.py +++ b/pytorch_toolbelt/modules/encoders/unet.py @@ -29,3 +29,4 @@ def encoder_layers(self): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py index fa733438e..a77c48032 100644 --- a/pytorch_toolbelt/modules/encoders/wide_resnet.py +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -69,6 +69,7 @@ def forward(self, input): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self class WiderResnet16Encoder(WiderResnetEncoder): @@ -141,6 +142,7 @@ def forward(self, input): def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) + return self class WiderResnet16A2Encoder(WiderResnetA2Encoder): From 605f97c9a3cb36828ca817fb74a519b504741905 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 28 Nov 2019 17:24:52 +0200 Subject: [PATCH 30/79] Add HRNet model --- pytorch_toolbelt/modules/backbone/hrnet.py | 7 +----- pytorch_toolbelt/modules/decoders/hrnet.py | 7 ++++-- pytorch_toolbelt/modules/encoders/hrnet.py | 19 +++++++++++----- tests/test_encoders.py | 26 ++++++++++++++++++++++ 4 files changed, 46 insertions(+), 13 deletions(-) diff --git a/pytorch_toolbelt/modules/backbone/hrnet.py b/pytorch_toolbelt/modules/backbone/hrnet.py index 39f7c2bcc..7e154a74c 100644 --- a/pytorch_toolbelt/modules/backbone/hrnet.py +++ b/pytorch_toolbelt/modules/backbone/hrnet.py @@ -409,12 +409,7 @@ def forward(self, x, return_feature_maps=False): x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=False) x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=False) - x = torch.cat([x[0], x1, x2, x3], 1) + x = torch.cat([x[0], x1, x2, x3], dim=1) - # x = self.last_layer(x) return [x] - -def hrnetv2(**kwargs): - model = HRNetV2(**kwargs) - return model diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index 909606659..ac68a0635 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -1,4 +1,5 @@ from torch import nn +from typing import List from .common import DecoderModule from ..backbone.hrnet import HRNETV2_BN_MOMENTUM @@ -7,15 +8,17 @@ class HRNetDecoder(DecoderModule): - def __init__(self, features: int, num_classes: int, dropout=0.0): + def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0): super().__init__() + features = feature_maps[-1] + self.last_layer = nn.Sequential( nn.Conv2d(in_channels=features, out_channels=features, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Dropout(dropout), - nn.Conv2d(in_channels=features, out_channels=num_classes, kernel_size=3, stride=1, padding=1), + nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=3, stride=1, padding=1), ) def forward(self, features): diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index c05ec1e69..708f3ce37 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -1,6 +1,6 @@ -from pytorch_toolbelt.modules.backbone.hrnet import hrnetv2 +from pytorch_toolbelt.modules.backbone.hrnet import HRNetV2 -from .common import EncoderModule +from .common import EncoderModule, make_n_channel_input __all__ = ["HRNetV2Encoder48", "HRNetV2Encoder18", "HRNetV2Encoder34"] @@ -8,25 +8,34 @@ class HRNetV2Encoder18(EncoderModule): def __init__(self, pretrained=False): super().__init__([144 + 72 + 36 + 18], [4], [0]) - self.hrnet = hrnetv2(width=18, pretrained=False) + self.hrnet = HRNetV2(width=18, pretrained=False) def forward(self, x): return self.hrnet(x) + def change_input_channels(self, input_channels: int, mode="auto"): + self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) + class HRNetV2Encoder34(EncoderModule): def __init__(self, pretrained=False): super().__init__([34 * 8 + 34 * 4 + 34 * 2 + 34], [4], [0]) - self.hrnet = hrnetv2(width=34, pretrained=False) + self.hrnet = HRNetV2(width=34, pretrained=False) def forward(self, x): return self.hrnet(x) + def change_input_channels(self, input_channels: int, mode="auto"): + self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) + class HRNetV2Encoder48(EncoderModule): def __init__(self, pretrained=False): super().__init__([720], [4], [0]) - self.hrnet = hrnetv2(width=48, pretrained=False) + self.hrnet = HRNetV2(width=48, pretrained=False) def forward(self, x): return self.hrnet(x) + + def change_input_channels(self, input_channels: int, mode="auto"): + self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 8e32213a5..84aeb2307 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -87,3 +87,29 @@ def test_densenet(): net2.classifier = None print(count_parameters(net1), count_parameters(net2)) + + +@pytest.mark.parametrize( + ["encoder", "encoder_params"], + [ + [E.HRNetV2Encoder18, {"pretrained": False}], + [E.HRNetV2Encoder34, {"pretrained": False}], + [E.HRNetV2Encoder48, {"pretrained": False}], + ], +) +@torch.no_grad() +@skip_if_no_cuda +def test_hrnet_encoder(encoder: E.EncoderModule, encoder_params): + net = encoder(**encoder_params).eval() + print(net.__class__.__name__, count_parameters(net)) + print(net.output_strides) + print(net.output_filters) + input = torch.rand((4, 3, 256, 256)) + input = maybe_cuda(input) + net = maybe_cuda(net) + output = net(input) + assert len(output) == len(net.output_filters) + for feature_map, expected_stride, expected_channels in zip(output, net.output_strides, net.output_filters): + assert feature_map.size(1) == expected_channels + assert feature_map.size(2) * expected_stride == 256 + assert feature_map.size(3) * expected_stride == 256 From 009bfc57d1f1ac6d7a92abd29e6f1d9b38270b91 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 29 Nov 2019 21:25:31 +0200 Subject: [PATCH 31/79] Add LR schedules --- pytorch_toolbelt/optimization/lr_schedules.py | 43 +++++++++++++++---- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/pytorch_toolbelt/optimization/lr_schedules.py b/pytorch_toolbelt/optimization/lr_schedules.py index 913d15abc..a841187a9 100644 --- a/pytorch_toolbelt/optimization/lr_schedules.py +++ b/pytorch_toolbelt/optimization/lr_schedules.py @@ -2,7 +2,10 @@ import numpy as np from torch import nn -from torch.optim.lr_scheduler import _LRScheduler +from torch.optim.lr_scheduler import _LRScheduler, LambdaLR +from torch.optim.optimizer import Optimizer + +__all__ = ["OnceCycleLR", "CosineAnnealingLRWithDecay", "PolyLR"] def set_learning_rate(optimizer, lr): @@ -69,25 +72,49 @@ def compute_lr(base_lr): return [compute_lr(base_lr) for base_lr in self.base_lrs] +class PolyLR(LambdaLR): + def __init__(self, optimizer: Optimizer, max_epoch, gamma=0.9): + def poly_lr(epoch): + return (1.0 - float(epoch) / max_epoch) ** gamma + + super().__init__(optimizer, poly_lr) + + if __name__ == "__main__": import matplotlib as mpl mpl.use("module://backend_interagg") import matplotlib.pyplot as plt - from torch.optim import SGD + from torch.optim import SGD, Optimizer net = nn.Conv2d(1, 1, 1) - opt = SGD(net.parameters(), lr=1e-2) + opt = SGD(net.parameters(), lr=1e-3) - scheduler = OnceCycleLR(opt, 800, min_lr_factor=0.01) - # scheduler = CosineAnnealingLRWithDecay(opt, 80, gamma=0.999) + epochs = 100 + plt.figure() + + scheduler = OnceCycleLR(opt, epochs, min_lr_factor=0.01) lrs = [] - for epoch in range(800): + for epoch in range(epochs): scheduler.step(epoch) lrs.append(scheduler.get_lr()[0]) + plt.plot(range(epochs), lrs, label="1cycle") - plt.figure() - plt.plot(range(800), lrs) + scheduler = CosineAnnealingLRWithDecay(opt, epochs / 5, gamma=0.99) + lrs = [] + for epoch in range(epochs): + scheduler.step(epoch) + lrs.append(scheduler.get_lr()[0]) + plt.plot(range(epochs), lrs, label="cosine") + + scheduler = PolyLR(opt, epochs, gamma=0.9) + lrs = [] + for epoch in range(epochs): + scheduler.step(epoch) + lrs.append(scheduler.get_lr()[0]) + plt.plot(range(epochs), lrs, label="poly") + + plt.legend() plt.show() From b4fdaa25e3eb14400a3445579e7f59d423244dd2 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Sun, 1 Dec 2019 15:35:40 +0200 Subject: [PATCH 32/79] Add object context blocks --- CREDITS.md | 2 +- pytorch_toolbelt/modules/ocnet.py | 361 ++++++++++++++++++++++++++++++ 2 files changed, 362 insertions(+), 1 deletion(-) create mode 100644 pytorch_toolbelt/modules/ocnet.py diff --git a/CREDITS.md b/CREDITS.md index e6e5f87c7..a273204d6 100644 --- a/CREDITS.md +++ b/CREDITS.md @@ -4,4 +4,4 @@ This file contains links to repositories, source code of which may be partially 1. https://blog.ceshine.net/post/pytorch-memory-swish/ 1. https://github.com/digantamisra98/Mish 1. https://github.com/mapillary/inplace_abn - +1. https://github.com/PkuRainBow/OCNet.pytorch diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py new file mode 100644 index 000000000..ad1b7e31e --- /dev/null +++ b/pytorch_toolbelt/modules/ocnet.py @@ -0,0 +1,361 @@ +# Credit: https://github.com/PkuRainBow/OCNet.pytorch/blob/master/oc_module/asp_oc_block.py +import torch +from torch import nn +from .activated_batch_norm import ABN +import torch.nn.functional as F + +__all__ = ["ObjectContextBlock", "ASPObjectContextBlock", "PyramidObjectContextBlock"] + + +class _SelfAttentionBlock(nn.Module): + """ + The basic implementation for self-attention block/non-local block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + value_channels : the dimension after the value transform + scale : choose the scale to downsample the input feature maps (save memory cost) + Return: + N X C X H X W + position-aware context features.(w/o concate or add with the input) + """ + + def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1, abn_block=ABN): + super(_SelfAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.out_channels = out_channels + self.key_channels = key_channels + self.value_channels = value_channels + if out_channels is None: + self.out_channels = in_channels + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_key = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0 + ), + abn_block(self.key_channels), + ) + self.f_query = self.f_key + self.f_value = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.value_channels, kernel_size=1, stride=1, padding=0 + ) + self.W = nn.Conv2d( + in_channels=self.value_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0 + ) + nn.init.constant(self.W.weight, 0) + nn.init.constant(self.W.bias, 0) + + def forward(self, x): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + value = self.f_value(x).view(batch_size, self.value_channels, -1) + value = value.permute(0, 2, 1) + query = self.f_query(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_key(x).view(batch_size, self.key_channels, -1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels ** -0.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.value_channels, *x.size()[2:]) + context = self.W(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), mode="bilinear") + return context + + +class SelfAttentionBlock2D(_SelfAttentionBlock): + def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1): + super(SelfAttentionBlock2D, self).__init__(in_channels, key_channels, value_channels, out_channels, scale) + + +class BaseOC_Module(nn.Module): + """ + Implementation of the BaseOC module + Parameters: + in_features / out_features: the channels of the input / output feature maps. + dropout: we choose 0.05 as the default value. + size: you can apply multiple sizes. Here we only use one size. + Return: + features fused with Object context information. + """ + + def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1]), abn_block=ABN): + super(BaseOC_Module, self).__init__() + self.stages = [] + self.stages = nn.ModuleList( + [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes] + ) + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(2 * in_channels, out_channels, kernel_size=1, padding=0), + abn_block(out_channels), + nn.Dropout2d(dropout), + ) + + def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): + return SelfAttentionBlock2D(in_channels, key_channels, value_channels, output_channels, size) + + def forward(self, feats): + priors = [stage(feats) for stage in self.stages] + context = priors[0] + for i in range(1, len(priors)): + context += priors[i] + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + return output + + +class ObjectContextBlock(nn.Module): + """ + Output only the context features. + Parameters: + in_features / out_features: the channels of the input / output feature maps. + dropout: specify the dropout ratio + fusion: We provide two different fusion method, "concat" or "add" + size: we find that directly learn the attention weights on even 1/8 feature maps is hard. + Return: + features after "concat" or "add" + """ + + def __init__(self, in_channels, out_channels, key_channels, value_channels, dropout, sizes=([1]), abn_block=ABN): + super(ObjectContextBlock, self).__init__() + self.stages = [] + self.stages = nn.ModuleList( + [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes] + ) + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), abn_block(out_channels) + ) + + def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): + return SelfAttentionBlock2D(in_channels, key_channels, value_channels, output_channels, size) + + def forward(self, feats): + priors = [stage(feats) for stage in self.stages] + context = priors[0] + for i in range(1, len(priors)): + context += priors[i] + output = self.conv_bn_dropout(context) + return output + + +class ASPObjectContextBlock(nn.Module): + def __init__(self, features, out_features=256, dilations=(12, 24, 36), abn_block=ABN): + super(ASPObjectContextBlock, self).__init__() + self.context = nn.Sequential( + nn.Conv2d(features, out_features, kernel_size=3, padding=1, dilation=1, bias=True), + abn_block(out_features), + ObjectContextBlock( + in_channels=out_features, + out_channels=out_features, + key_channels=out_features // 2, + value_channels=out_features, + dropout=0, + sizes=([2]), + ), + ) + self.conv2 = nn.Sequential( + nn.Conv2d(features, out_features, kernel_size=1, padding=0, dilation=1, bias=False), + abn_block(out_features), + ) + self.conv3 = nn.Sequential( + nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[0], dilation=dilations[0], bias=False), + abn_block(out_features), + ) + self.conv4 = nn.Sequential( + nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[1], dilation=dilations[1], bias=False), + abn_block(out_features), + ) + self.conv5 = nn.Sequential( + nn.Conv2d(features, out_features, kernel_size=3, padding=dilations[2], dilation=dilations[2], bias=False), + abn_block(out_features), + ) + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(out_features * 5, out_features * 2, kernel_size=1, padding=0, dilation=1, bias=False), + abn_block(out_features * 2), + nn.Dropout2d(0.1), + ) + + def _cat_each(self, feat1, feat2, feat3, feat4, feat5): + assert len(feat1) == len(feat2) + z = [] + for i in range(len(feat1)): + z.append(torch.cat((feat1[i], feat2[i], feat3[i], feat4[i], feat5[i]), dim=1)) + return z + + def forward(self, x): + if isinstance(x, torch.Tensor): + _, _, h, w = x.size() + elif isinstance(x, tuple) or isinstance(x, list): + _, _, h, w = x[0].size() + else: + raise RuntimeError("unknown input type") + + feat1 = self.context(x) + feat2 = self.conv2(x) + feat3 = self.conv3(x) + feat4 = self.conv4(x) + feat5 = self.conv5(x) + + if isinstance(x, torch.Tensor): + out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) + elif isinstance(x, tuple) or isinstance(x, list): + out = self._cat_each(feat1, feat2, feat3, feat4, feat5) + else: + raise RuntimeError("unknown input type") + + output = self.conv_bn_dropout(out) + return output + + +class _PyramidSelfAttentionBlock(nn.Module): + """ + The basic implementation for self-attention block/non-local block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + value_channels : the dimension after the value transform + scale : choose the scale to downsample the input feature maps + Return: + N X C X H X W + position-aware context features.(w/o concate or add with the input) + """ + + def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1, abn_block=ABN): + super(_PyramidSelfAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.out_channels = out_channels + self.key_channels = key_channels + self.value_channels = value_channels + if out_channels == None: + self.out_channels = in_channels + self.f_key = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0 + ), + abn_block(self.key_channels), + ) + self.f_query = self.f_key + self.f_value = nn.Conv2d( + in_channels=self.in_channels, out_channels=self.value_channels, kernel_size=1, stride=1, padding=0 + ) + self.W = nn.Conv2d( + in_channels=self.value_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0 + ) + nn.init.constant(self.W.weight, 0) + nn.init.constant(self.W.bias, 0) + + def forward(self, x): + batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) + + local_x = [] + local_y = [] + step_h, step_w = h // self.scale, w // self.scale + for i in range(0, self.scale): + for j in range(0, self.scale): + start_x, start_y = i * step_h, j * step_w + end_x, end_y = min(start_x + step_h, h), min(start_y + step_w, w) + if i == (self.scale - 1): + end_x = h + if j == (self.scale - 1): + end_y = w + local_x += [start_x, end_x] + local_y += [start_y, end_y] + + value = self.f_value(x) + query = self.f_query(x) + key = self.f_key(x) + + local_list = [] + local_block_cnt = 2 * self.scale * self.scale + for i in range(0, local_block_cnt, 2): + value_local = value[:, :, local_x[i] : local_x[i + 1], local_y[i] : local_y[i + 1]] + query_local = query[:, :, local_x[i] : local_x[i + 1], local_y[i] : local_y[i + 1]] + key_local = key[:, :, local_x[i] : local_x[i + 1], local_y[i] : local_y[i + 1]] + + h_local, w_local = value_local.size(2), value_local.size(3) + value_local = value_local.contiguous().view(batch_size, self.value_channels, -1) + value_local = value_local.permute(0, 2, 1) + + query_local = query_local.contiguous().view(batch_size, self.key_channels, -1) + query_local = query_local.permute(0, 2, 1) + key_local = key_local.contiguous().view(batch_size, self.key_channels, -1) + + sim_map = torch.matmul(query_local, key_local) + sim_map = (self.key_channels ** -0.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + context_local = torch.matmul(sim_map, value_local) + context_local = context_local.permute(0, 2, 1).contiguous() + context_local = context_local.view(batch_size, self.value_channels, h_local, w_local) + local_list.append(context_local) + + context_list = [] + for i in range(0, self.scale): + row_tmp = [] + for j in range(0, self.scale): + row_tmp.append(local_list[j + i * self.scale]) + context_list.append(torch.cat(row_tmp, 3)) + + context = torch.cat(context_list, 2) + context = self.W(context) + + return context + + +class PyramidSelfAttentionBlock2D(_PyramidSelfAttentionBlock): + def __init__(self, in_channels, key_channels, value_channels, out_channels=None, scale=1): + super(PyramidSelfAttentionBlock2D, self).__init__( + in_channels, key_channels, value_channels, out_channels, scale + ) + + +class PyramidObjectContextBlock(nn.Module): + """ + Output the combination of the context features and the original features. + Parameters: + in_features / out_features: the channels of the input / output feature maps. + dropout: specify the dropout ratio + size: we find that directly learn the attention weights on even 1/8 feature maps is hard. + Return: + features after "concat" or "add" + """ + + def __init__(self, in_channels, out_channels, dropout=0.05, sizes=([1, 2, 3, 6]), abn_block=ABN): + super(PyramidObjectContextBlock, self).__init__() + self.group = len(sizes) + self.stages = [] + self.stages = nn.ModuleList( + [self._make_stage(in_channels, out_channels, in_channels // 2, in_channels, size) for size in sizes] + ) + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(2 * in_channels * self.group, out_channels, kernel_size=1, padding=0), + abn_block(out_channels), + nn.Dropout2d(dropout), + ) + self.up_dr = nn.Sequential( + nn.Conv2d(in_channels, in_channels * self.group, kernel_size=1, padding=0), + abn_block(in_channels * self.group), + ) + + def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): + return PyramidSelfAttentionBlock2D(in_channels, key_channels, value_channels, output_channels, size) + + def forward(self, feats): + priors = [stage(feats) for stage in self.stages] + context = [self.up_dr(feats)] + for i in range(0, len(priors)): + context += [priors[i]] + output = self.conv_bn_dropout(torch.cat(context, 1)) + return output From 9c3b9746864b70f7a6389ca4a01ce5f9704bf15e Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 1 Dec 2019 22:34:05 +0200 Subject: [PATCH 33/79] Update OCNet --- pytorch_toolbelt/modules/ocnet.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index ad1b7e31e..061fc9c63 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -45,8 +45,9 @@ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, self.W = nn.Conv2d( in_channels=self.value_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0 ) - nn.init.constant(self.W.weight, 0) - nn.init.constant(self.W.bias, 0) + # Eugene Khvedchenya: Original implementation initialized weight of context convolution with zeros, which does not make sense to me + # nn.init.constant(self.W.weight, 0) + nn.init.constant_(self.W.bias, 0) def forward(self, x): batch_size, h, w = x.size(0), x.size(2), x.size(3) @@ -68,7 +69,7 @@ def forward(self, x): context = context.view(batch_size, self.value_channels, *x.size()[2:]) context = self.W(context) if self.scale > 1: - context = F.interpolate(input=context, size=(h, w), mode="bilinear") + context = F.interpolate(input=context, size=(h, w), mode="bilinear", align_corners=False) return context From 668f93bdf6a22285941bbd51d81501a46ce8a498 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Mon, 2 Dec 2019 23:02:51 +0200 Subject: [PATCH 34/79] Add dropout to unet decoder --- pytorch_toolbelt/modules/decoders/unet_v2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index e3bccb8a2..cb746205f 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -66,7 +66,7 @@ def __init__( self.pre_drop = nn.Dropout2d(pre_dropout_rate, inplace=True) - self.post_drop = nn.Dropout2d(post_dropout_rate, inplace=True) + self.post_drop = nn.Dropout2d(post_dropout_rate) self.dsv = nn.Conv2d(out_filters, mask_channels, kernel_size=1) @@ -97,7 +97,7 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> Tuple[torch.Tensor, Lis class UNetDecoderV2(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0.): super().__init__() if not isinstance(decoder_features, list): @@ -107,7 +107,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels for block_index, in_enc_features in enumerate(feature_maps[:-1]): blocks.append( UnetDecoderBlockV2( - decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels + decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels, post_dropout_rate=dropout ) ) From 2a0c0910cdcd06029a25d1f5c734defade8d2950 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Mon, 2 Dec 2019 23:18:48 +0200 Subject: [PATCH 35/79] Use nn.Dropout2d --- pytorch_toolbelt/modules/decoders/hrnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index ac68a0635..ef5c810e6 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -17,7 +17,7 @@ def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0): nn.Conv2d(in_channels=features, out_channels=features, kernel_size=1, stride=1, padding=0), nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), - nn.Dropout(dropout), + nn.Dropout2d(dropout), nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=3, stride=1, padding=1), ) From 6bb150c069d6bf150770643415379557806210ed Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 3 Dec 2019 15:45:35 +0200 Subject: [PATCH 36/79] Add Dropout --- pytorch_toolbelt/modules/decoders/unet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index ab7c2b746..e580568ca 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -12,7 +12,7 @@ class UNetDecoder(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0.): super().__init__() if not isinstance(decoder_features, list): @@ -33,6 +33,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features + self.final_drop = nn.Dropout2d(dropout) self.final = nn.Conv2d(decoder_features[0], mask_channels, kernel_size=1) def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor: @@ -41,5 +42,6 @@ def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor: for decoder_block, encoder_output in zip(reversed(self.blocks), reversed(feature_maps[:-1])): output = decoder_block(output, encoder_output) + output = self.final_drop(output) output = self.final(output) return output From 4ae28b983cdd63658eebe7f415c54daea2c45c69 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 5 Dec 2019 16:21:41 +0200 Subject: [PATCH 37/79] Explicitly specify y_true/y_pred --- pytorch_toolbelt/utils/catalyst/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 8cd1ad017..fc20a72b4 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -122,7 +122,7 @@ def on_loader_end(self, state): class_names = self.class_names num_classes = len(class_names) - cm = confusion_matrix(targets, outputs, labels=range(num_classes)) + cm = confusion_matrix(y_true=targets, y_pred=outputs, labels=range(num_classes)) fig = plot_confusion_matrix( cm, From 6a1bfec729c1fb07925ee4b2b188dcf329513aa2 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 5 Dec 2019 16:24:28 +0200 Subject: [PATCH 38/79] Fix value of CallbackOrder --- pytorch_toolbelt/utils/catalyst/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 8cd1ad017..8e6b9f8e6 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -85,7 +85,7 @@ def __init__( specifies our `y_pred`. :param ignore_index: same meaning as in nn.CrossEntropyLoss """ - super().__init__(CallbackOrder.Logger) + super().__init__(CallbackOrder.Metric) self.prefix = prefix self.class_names = class_names self.output_key = output_key From 1446faaf414136eba7f799ae5bf234f5013a18d0 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 5 Dec 2019 16:24:55 +0200 Subject: [PATCH 39/79] Do not apply bias to conv layers followed by batchnorm --- pytorch_toolbelt/modules/ocnet.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index 061fc9c63..8a17eaa6b 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -132,7 +132,7 @@ def __init__(self, in_channels, out_channels, key_channels, value_channels, drop [self._make_stage(in_channels, out_channels, key_channels, value_channels, size) for size in sizes] ) self.conv_bn_dropout = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0), abn_block(out_channels) + nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0, bias=False), abn_block(out_channels) ) def _make_stage(self, in_channels, output_channels, key_channels, value_channels, size): @@ -148,18 +148,18 @@ def forward(self, feats): class ASPObjectContextBlock(nn.Module): - def __init__(self, features, out_features=256, dilations=(12, 24, 36), abn_block=ABN): + def __init__(self, features, out_features=256, dilations=(12, 24, 36), abn_block=ABN, dropout=0.1): super(ASPObjectContextBlock, self).__init__() self.context = nn.Sequential( - nn.Conv2d(features, out_features, kernel_size=3, padding=1, dilation=1, bias=True), + nn.Conv2d(features, out_features, kernel_size=3, padding=1, dilation=1, bias=False), abn_block(out_features), ObjectContextBlock( in_channels=out_features, out_channels=out_features, key_channels=out_features // 2, value_channels=out_features, - dropout=0, - sizes=([2]), + dropout=dropout, + sizes=([1]), ), ) self.conv2 = nn.Sequential( @@ -182,7 +182,7 @@ def __init__(self, features, out_features=256, dilations=(12, 24, 36), abn_block self.conv_bn_dropout = nn.Sequential( nn.Conv2d(out_features * 5, out_features * 2, kernel_size=1, padding=0, dilation=1, bias=False), abn_block(out_features * 2), - nn.Dropout2d(0.1), + nn.Dropout2d(dropout), ) def _cat_each(self, feat1, feat2, feat3, feat4, feat5): @@ -243,7 +243,7 @@ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, self.out_channels = in_channels self.f_key = nn.Sequential( nn.Conv2d( - in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0 + in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False ), abn_block(self.key_channels), ) @@ -254,7 +254,7 @@ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, self.W = nn.Conv2d( in_channels=self.value_channels, out_channels=self.out_channels, kernel_size=1, stride=1, padding=0 ) - nn.init.constant(self.W.weight, 0) + # nn.init.constant(self.W.weight, 0) nn.init.constant(self.W.bias, 0) def forward(self, x): @@ -341,12 +341,12 @@ def __init__(self, in_channels, out_channels, dropout=0.05, sizes=([1, 2, 3, 6]) [self._make_stage(in_channels, out_channels, in_channels // 2, in_channels, size) for size in sizes] ) self.conv_bn_dropout = nn.Sequential( - nn.Conv2d(2 * in_channels * self.group, out_channels, kernel_size=1, padding=0), + nn.Conv2d(2 * in_channels * self.group, out_channels, kernel_size=1, padding=0, bias=False), abn_block(out_channels), nn.Dropout2d(dropout), ) self.up_dr = nn.Sequential( - nn.Conv2d(in_channels, in_channels * self.group, kernel_size=1, padding=0), + nn.Conv2d(in_channels, in_channels * self.group, kernel_size=1, padding=0, bias=False), abn_block(in_channels * self.group), ) From 89fe46270beaa05f844763cb0ab7002e7bdcdc57 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 5 Dec 2019 16:27:53 +0200 Subject: [PATCH 40/79] Put back size=2 --- pytorch_toolbelt/modules/ocnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index 8a17eaa6b..c1cab6603 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -159,7 +159,7 @@ def __init__(self, features, out_features=256, dilations=(12, 24, 36), abn_block key_channels=out_features // 2, value_channels=out_features, dropout=dropout, - sizes=([1]), + sizes=([2]), ), ) self.conv2 = nn.Sequential( From 8a0c4d84a611a48f5bd6f63e9ce566484cd72366 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 5 Dec 2019 16:45:44 +0200 Subject: [PATCH 41/79] Memory optimization for confusion meter callback --- pytorch_toolbelt/utils/catalyst/metrics.py | 32 ++++++++++++---------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index a0ca3a1fb..4f6984f18 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -3,6 +3,9 @@ import numpy as np import torch from catalyst.dl import Callback, RunnerState, MetricCallback, CallbackOrder +from torchnet.meter import ConfusionMeter +from typing import List + from .visualization import get_tensorboard_logger from ..torch_utils import to_numpy from pytorch_toolbelt.utils.visualization import render_figure_to_tensor, plot_confusion_matrix @@ -75,7 +78,8 @@ def __init__( input_key: str = "targets", output_key: str = "logits", prefix: str = "confusion_matrix", - class_names=None, + class_names: List[str] = None, + num_classes: int = None, ignore_index=None, ): """ @@ -88,41 +92,41 @@ def __init__( super().__init__(CallbackOrder.Metric) self.prefix = prefix self.class_names = class_names + self.num_classes = num_classes \ + if class_names is None \ + else len(class_names) self.output_key = output_key self.input_key = input_key - self.outputs = [] - self.targets = [] self.ignore_index = ignore_index + self.confusion_matrix = None def on_loader_start(self, state): - self.outputs = [] - self.targets = [] + self.confusion_matrix = ConfusionMeter(self.num_classes) def on_batch_end(self, state: RunnerState): - outputs = to_numpy(state.output[self.output_key]) - targets = to_numpy(state.input[self.input_key]) + outputs = state.output[self.output_key].detach().argmax(dim=1).cpu() + targets = state.input[self.input_key].detach().cpu() - outputs = np.argmax(outputs, axis=1) + # Flatten + outputs = outputs.view(-1) + targets = targets.view(-1) if self.ignore_index is not None: mask = targets != self.ignore_index outputs = outputs[mask] targets = targets[mask] - self.outputs.extend(outputs) - self.targets.extend(targets) + self.confusion_matrix.add(predicted=outputs, target=targets) def on_loader_end(self, state): - targets = np.array(self.targets) - outputs = np.array(self.outputs) if self.class_names is None: - class_names = [str(i) for i in range(targets.shape[1])] + class_names = [str(i) for i in range(self.num_classes)] else: class_names = self.class_names num_classes = len(class_names) - cm = confusion_matrix(y_true=targets, y_pred=outputs, labels=range(num_classes)) + cm = self.confusion_matrix.value() fig = plot_confusion_matrix( cm, From ce042ecc4c63f26cfb3e2bdfa38af4f1294d725c Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 6 Dec 2019 12:56:13 +0200 Subject: [PATCH 42/79] Add weights initializaiton --- pytorch_toolbelt/modules/encoders/common.py | 11 ++++++----- tests/test_utils_functional.py | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index df2639dd3..4654dc350 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -2,13 +2,13 @@ Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. """ - +import math from typing import List +import torch from torch import nn import warnings -import torch.nn.functional as F __all__ = ["EncoderModule", "_take", "make_n_channel_input"] @@ -36,9 +36,10 @@ def make_n_channel_input(conv: nn.Conv2d, in_channels: int, mode="auto"): w = conv.weight if in_channels > conv.in_channels: - # TODO: Figure out padding scheme - # w = F.pad(w, pad=[0, in_channels - conv.in_channels, 0, 0], mode="circular") - pass + n = math.ceil(in_channels / float(conv.in_channels)) + w = torch.cat([w] * n, dim=1) + w = w[:, :in_channels, ...] + new_conv.weight = nn.Parameter(w, requires_grad=True) else: w = w[:, 0:in_channels, ...] new_conv.weight = nn.Parameter(w, requires_grad=True) diff --git a/tests/test_utils_functional.py b/tests/test_utils_functional.py index f5c0c2376..ef4a7e679 100644 --- a/tests/test_utils_functional.py +++ b/tests/test_utils_functional.py @@ -1,5 +1,8 @@ import torch +from torch import nn +import torch.nn.functional as F from pytorch_toolbelt.inference.functional import unpad_xyxy_bboxes +from pytorch_toolbelt.modules.encoders import make_n_channel_input def test_unpad_xyxy_bboxes(): @@ -17,3 +20,19 @@ def test_unpad_xyxy_bboxes(): assert bboxes2_unpad.size(1) == 32 assert bboxes2_unpad.size(2) == 4 assert bboxes2_unpad.size(3) == 20 + + +def test_make_n_channel_input(): + conv = nn.Conv2d(3, 16, kernel_size=3, padding=1) + + conv6 = make_n_channel_input(conv, in_channels=6) + assert conv6.weight.size(0) == conv.weight.size(0) + assert conv6.weight.size(1) == 6 + assert conv6.weight.size(2) == conv.weight.size(2) + assert conv6.weight.size(3) == conv.weight.size(3) + + conv5 = make_n_channel_input(conv, in_channels=5) + assert conv5.weight.size(0) == conv.weight.size(0) + assert conv5.weight.size(1) == 5 + assert conv5.weight.size(2) == conv.weight.size(2) + assert conv5.weight.size(3) == conv.weight.size(3) From 97bdc657468daf24eda95db7dd552e63af1b1c6d Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 6 Dec 2019 15:37:56 +0200 Subject: [PATCH 43/79] Remove unneeded bias parameter --- pytorch_toolbelt/modules/__init__.py | 7 +++++++ pytorch_toolbelt/modules/decoders/hrnet.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/__init__.py b/pytorch_toolbelt/modules/__init__.py index 37d819a7b..66accfcd8 100644 --- a/pytorch_toolbelt/modules/__init__.py +++ b/pytorch_toolbelt/modules/__init__.py @@ -2,8 +2,15 @@ from .activated_batch_norm import * from .activated_group_norm import * +from .activations import * +from .coord_conv import * +from .dropblock import * from .dsconv import * from .fpn import * from .hypercolumn import * from .identity import * +from .ocnet import * +from .pooling import * from .scse import * +from .srm import * +from .unet import * diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index ef5c810e6..ad522423a 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -14,7 +14,7 @@ def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0): features = feature_maps[-1] self.last_layer = nn.Sequential( - nn.Conv2d(in_channels=features, out_channels=features, kernel_size=1, stride=1, padding=0), + nn.Conv2d(in_channels=features, out_channels=features, kernel_size=1, stride=1, padding=0, bias=False), nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), nn.ReLU(inplace=True), nn.Dropout2d(dropout), From 083261706ae9c9c1a81b76a5e639d97c8f9d3b2b Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 6 Dec 2019 22:33:22 +0200 Subject: [PATCH 44/79] Fix commentary --- pytorch_toolbelt/losses/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/losses/dice.py b/pytorch_toolbelt/losses/dice.py index 6df861060..68da1b7e9 100644 --- a/pytorch_toolbelt/losses/dice.py +++ b/pytorch_toolbelt/losses/dice.py @@ -86,7 +86,7 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: else: loss = 1 - scores - # IoU loss is defined for non-empty classes + # Dice loss is undefined for non-empty classes # So we zero contribution of channel that does not have true pixels # NOTE: A better workaround would be to use loss term `mean(y_pred)` # for this case, however it will be a modified jaccard loss From a9af77698e42779d9a61b45b46afa20e5600412d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 6 Dec 2019 23:36:26 +0200 Subject: [PATCH 45/79] Refactor HRNet --- pytorch_toolbelt/modules/decoders/hrnet.py | 31 +- pytorch_toolbelt/modules/encoders/hrnet.py | 461 ++++++++++++++++++++- tests/test_encoders.py | 6 +- 3 files changed, 468 insertions(+), 30 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index ad522423a..a242376c0 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -1,3 +1,5 @@ +from collections import OrderedDict + from torch import nn from typing import List @@ -13,13 +15,28 @@ def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0): features = feature_maps[-1] - self.last_layer = nn.Sequential( - nn.Conv2d(in_channels=features, out_channels=features, kernel_size=1, stride=1, padding=0, bias=False), - nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM), - nn.ReLU(inplace=True), - nn.Dropout2d(dropout), - nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=3, stride=1, padding=1), + self.embedding = nn.Sequential( + OrderedDict( + [ + ( + "conv1", + nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False), + ), + ("bn1", nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM)), + ("relu", nn.ReLU(inplace=True)), + ] + ) + ) + + self.logits = nn.Sequential( + OrderedDict( + [ + ("drop", nn.Dropout2d(dropout)), + ("final", nn.Conv2d(in_channels=features, out_channels=output_channels, kernel_size=1)), + ] + ) ) def forward(self, features): - return self.last_layer(features[-1]) + embedding = self.embedding(features) + return self.logits(embedding) diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index 708f3ce37..837b8962d 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -1,41 +1,462 @@ -from pytorch_toolbelt.modules.backbone.hrnet import HRNetV2 +from collections import OrderedDict + +import torch +from torch import nn +import torch.nn.functional as F +from typing import List from .common import EncoderModule, make_n_channel_input __all__ = ["HRNetV2Encoder48", "HRNetV2Encoder18", "HRNetV2Encoder34"] -class HRNetV2Encoder18(EncoderModule): - def __init__(self, pretrained=False): - super().__init__([144 + 72 + 36 + 18], [4], [0]) - self.hrnet = HRNetV2(width=18, pretrained=False) +HRNETV2_BN_MOMENTUM = 0.1 + + +def hrnet_conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) + + +class HRNetBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(HRNetBasicBlock, self).__init__() + self.conv1 = hrnet_conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = hrnet_conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.downsample = downsample + self.stride = stride def forward(self, x): - return self.hrnet(x) + residual = x - def change_input_channels(self, input_channels: int, mode="auto"): - self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out += residual + out = self.relu(out) -class HRNetV2Encoder34(EncoderModule): - def __init__(self, pretrained=False): - super().__init__([34 * 8 + 34 * 4 + 34 * 2 + 34], [4], [0]) - self.hrnet = HRNetV2(width=34, pretrained=False) + return out + + +class HRNetBottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(HRNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=HRNETV2_BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride def forward(self, x): - return self.hrnet(x) + residual = x - def change_input_channels(self, input_channels: int, mode="auto"): - self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__( + self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True + ): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks)) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(num_branches, len(num_channels)) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(num_branches, len(num_inchannels)) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): + downsample = None + if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=HRNETV2_BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), + nn.BatchNorm2d(num_inchannels[i], momentum=HRNETV2_BN_MOMENTUM), + ) + ) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=HRNETV2_BN_MOMENTUM), + ) + ) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, momentum=HRNETV2_BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + ) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode="bilinear", + align_corners=False, + ) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +class HRNetEncoderBase(EncoderModule): + def __init__(self, input_channels=3, width=48, layers: List[int] = None): + if layers is None: + # By default return only last feature map + layers = [4] + + channels = [ + 64, + 256, + width * 2 + width, + width * 4 + width * 2 + width, + width * 8 + width * 4 + width * 2 + width, + ] + + strides = [4, 4, 4, 4, 4] + + super().__init__(channels=channels, strides=strides, layers=layers) + + blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} + + extra = { + "STAGE2": { + "NUM_MODULES": 1, + "NUM_BRANCHES": 2, + "BLOCK": "BASIC", + "NUM_BLOCKS": (4, 4), + "NUM_CHANNELS": (width, width * 2), + "FUSE_METHOD": "SUM", + }, + "STAGE3": { + "NUM_MODULES": 4, + "NUM_BRANCHES": 3, + "BLOCK": "BASIC", + "NUM_BLOCKS": (4, 4, 4), + "NUM_CHANNELS": (width, width * 2, width * 4), + "FUSE_METHOD": "SUM", + }, + "STAGE4": { + "NUM_MODULES": 3, + "NUM_BRANCHES": 4, + "BLOCK": "BASIC", + "NUM_BLOCKS": (4, 4, 4, 4), + "NUM_CHANNELS": (width, width * 2, width * 4, width * 8), + "FUSE_METHOD": "SUM", + }, + "FINAL_CONV_KERNEL": 1, + } + + # stem net + self.layer0 = nn.Sequential( + OrderedDict( + [ + ("conv1", nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1, bias=False)), + ("bn1", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), + ("relu", nn.ReLU(inplace=True)), + ("conv2", nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)), + ("bn2", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), + ("relu2", nn.ReLU(inplace=True)), + ] + ) + ) + + self.layer1 = self._make_layer(HRNetBottleneck, 64, 64, 4) + + self.stage2_cfg = extra["STAGE2"] + num_channels = self.stage2_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage2_cfg["BLOCK"]] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer([256], num_channels) + self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) + + self.stage3_cfg = extra["STAGE3"] + num_channels = self.stage3_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage3_cfg["BLOCK"]] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) + + self.stage4_cfg = extra["STAGE4"] + num_channels = self.stage4_cfg["NUM_CHANNELS"] + block = blocks_dict[self.stage4_cfg["BLOCK"]] + num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) + + def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), + nn.BatchNorm2d(num_channels_cur_layer[i], momentum=HRNETV2_BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + ) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=HRNETV2_BN_MOMENTUM), + nn.ReLU(inplace=True), + ) + ) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) -class HRNetV2Encoder48(EncoderModule): - def __init__(self, pretrained=False): - super().__init__([720], [4], [0]) - self.hrnet = HRNetV2(width=48, pretrained=False) + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=HRNETV2_BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): + + blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} + + num_modules = layer_config["NUM_MODULES"] + num_branches = layer_config["NUM_BRANCHES"] + num_blocks = layer_config["NUM_BLOCKS"] + num_channels = layer_config["NUM_CHANNELS"] + block = blocks_dict[layer_config["BLOCK"]] + fuse_method = layer_config["FUSE_METHOD"] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule( + num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + ) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels def forward(self, x): - return self.hrnet(x) + outputs = [] + x = self.layer0(x) + if 0 in self._layers: + outputs.append(x) + + x = self.layer1(x) + if 1 in self._layers: + outputs.append(x) + + x_list = [] + for i in range(self.stage2_cfg["NUM_BRANCHES"]): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + if 2 in self._layers: + x = self.resize_and_concatenate_input(y_list) + outputs.append(x) + + x_list = [] + for i in range(self.stage3_cfg["NUM_BRANCHES"]): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + if 3 in self._layers: + x = self.resize_and_concatenate_input(y_list) + outputs.append(x) + + x_list = [] + for i in range(self.stage4_cfg["NUM_BRANCHES"]): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + if 4 in self._layers: + x = self.resize_and_concatenate_input(y_list) + outputs.append(x) + + return outputs + + @staticmethod + def resize_and_concatenate_input(x: List[torch.Tensor]) -> torch.Tensor: + x0_h, x0_w = x[0].size(2), x[0].size(3) + x = [x[0]] + [F.interpolate(xi, size=(x0_h, x0_w), mode="bilinear", align_corners=False) for xi in x[1:]] + x = torch.cat(x, dim=1) + return x def change_input_channels(self, input_channels: int, mode="auto"): self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) + + +class HRNetV2Encoder18(HRNetEncoderBase): + def __init__(self, pretrained=None, layers=None): + super().__init__(width=18, layers=layers) + + +class HRNetV2Encoder34(HRNetEncoderBase): + def __init__(self, pretrained=None, layers=None): + super().__init__(width=34, layers=layers) + + +class HRNetV2Encoder48(HRNetEncoderBase): + def __init__(self, pretrained=None, layers=None): + super().__init__(width=48, layers=layers) diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 84aeb2307..7f28860a7 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -92,9 +92,9 @@ def test_densenet(): @pytest.mark.parametrize( ["encoder", "encoder_params"], [ - [E.HRNetV2Encoder18, {"pretrained": False}], - [E.HRNetV2Encoder34, {"pretrained": False}], - [E.HRNetV2Encoder48, {"pretrained": False}], + [E.HRNetV2Encoder18, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}], + [E.HRNetV2Encoder34, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}], + [E.HRNetV2Encoder48, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}], ], ) @torch.no_grad() From bf0223f1dc5a50a3246a6f3c4f0104708029bc7e Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 6 Dec 2019 23:53:02 +0200 Subject: [PATCH 46/79] Refactor HRNet --- pytorch_toolbelt/modules/backbone/hrnet.py | 415 --------------------- pytorch_toolbelt/modules/decoders/hrnet.py | 6 +- 2 files changed, 3 insertions(+), 418 deletions(-) delete mode 100644 pytorch_toolbelt/modules/backbone/hrnet.py diff --git a/pytorch_toolbelt/modules/backbone/hrnet.py b/pytorch_toolbelt/modules/backbone/hrnet.py deleted file mode 100644 index 7e154a74c..000000000 --- a/pytorch_toolbelt/modules/backbone/hrnet.py +++ /dev/null @@ -1,415 +0,0 @@ -""" -This HRNet implementation is modified from the following repository: -https://github.com/HRNet/HRNet-Semantic-Segmentation -""" - -from collections import OrderedDict - -import torch -import torch.nn as nn -import torch.nn.functional as F - -HRNETV2_BN_MOMENTUM = 0.1 - - -def hrnet_conv3x3(in_planes, out_planes, stride=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) - - -class HRNetBasicBlock(nn.Module): - expansion = 1 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(HRNetBasicBlock, self).__init__() - self.conv1 = hrnet_conv3x3(inplanes, planes, stride) - self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) - self.relu = nn.ReLU(inplace=True) - self.conv2 = hrnet_conv3x3(planes, planes) - self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class HRNetBottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, downsample=None): - super(HRNetBottleneck, self).__init__() - self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) - self.bn1 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) - self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes, momentum=HRNETV2_BN_MOMENTUM) - self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion, momentum=HRNETV2_BN_MOMENTUM) - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - - def forward(self, x): - residual = x - - out = self.conv1(x) - out = self.bn1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.bn2(out) - out = self.relu(out) - - out = self.conv3(out) - out = self.bn3(out) - - if self.downsample is not None: - residual = self.downsample(x) - - out += residual - out = self.relu(out) - - return out - - -class HighResolutionModule(nn.Module): - def __init__( - self, num_branches, blocks, num_blocks, num_inchannels, num_channels, fuse_method, multi_scale_output=True - ): - super(HighResolutionModule, self).__init__() - self._check_branches(num_branches, blocks, num_blocks, num_inchannels, num_channels) - - self.num_inchannels = num_inchannels - self.fuse_method = fuse_method - self.num_branches = num_branches - - self.multi_scale_output = multi_scale_output - - self.branches = self._make_branches(num_branches, blocks, num_blocks, num_channels) - self.fuse_layers = self._make_fuse_layers() - self.relu = nn.ReLU(inplace=True) - - def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, num_channels): - if num_branches != len(num_blocks): - error_msg = "NUM_BRANCHES({}) <> NUM_BLOCKS({})".format(num_branches, len(num_blocks)) - raise ValueError(error_msg) - - if num_branches != len(num_channels): - error_msg = "NUM_BRANCHES({}) <> NUM_CHANNELS({})".format(num_branches, len(num_channels)) - raise ValueError(error_msg) - - if num_branches != len(num_inchannels): - error_msg = "NUM_BRANCHES({}) <> NUM_INCHANNELS({})".format(num_branches, len(num_inchannels)) - raise ValueError(error_msg) - - def _make_one_branch(self, branch_index, block, num_blocks, num_channels, stride=1): - downsample = None - if stride != 1 or self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: - downsample = nn.Sequential( - nn.Conv2d( - self.num_inchannels[branch_index], - num_channels[branch_index] * block.expansion, - kernel_size=1, - stride=stride, - bias=False, - ), - nn.BatchNorm2d(num_channels[branch_index] * block.expansion, momentum=HRNETV2_BN_MOMENTUM), - ) - - layers = [] - layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index], stride, downsample)) - self.num_inchannels[branch_index] = num_channels[branch_index] * block.expansion - for i in range(1, num_blocks[branch_index]): - layers.append(block(self.num_inchannels[branch_index], num_channels[branch_index])) - - return nn.Sequential(*layers) - - def _make_branches(self, num_branches, block, num_blocks, num_channels): - branches = [] - - for i in range(num_branches): - branches.append(self._make_one_branch(i, block, num_blocks, num_channels)) - - return nn.ModuleList(branches) - - def _make_fuse_layers(self): - if self.num_branches == 1: - return None - - num_branches = self.num_branches - num_inchannels = self.num_inchannels - fuse_layers = [] - for i in range(num_branches if self.multi_scale_output else 1): - fuse_layer = [] - for j in range(num_branches): - if j > i: - fuse_layer.append( - nn.Sequential( - nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), - nn.BatchNorm2d(num_inchannels[i], momentum=HRNETV2_BN_MOMENTUM), - ) - ) - elif j == i: - fuse_layer.append(None) - else: - conv3x3s = [] - for k in range(i - j): - if k == i - j - 1: - num_outchannels_conv3x3 = num_inchannels[i] - conv3x3s.append( - nn.Sequential( - nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_outchannels_conv3x3, momentum=HRNETV2_BN_MOMENTUM), - ) - ) - else: - num_outchannels_conv3x3 = num_inchannels[j] - conv3x3s.append( - nn.Sequential( - nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), - nn.BatchNorm2d(num_outchannels_conv3x3, momentum=HRNETV2_BN_MOMENTUM), - nn.ReLU(inplace=True), - ) - ) - fuse_layer.append(nn.Sequential(*conv3x3s)) - fuse_layers.append(nn.ModuleList(fuse_layer)) - - return nn.ModuleList(fuse_layers) - - def get_num_inchannels(self): - return self.num_inchannels - - def forward(self, x): - if self.num_branches == 1: - return [self.branches[0](x[0])] - - for i in range(self.num_branches): - x[i] = self.branches[i](x[i]) - - x_fuse = [] - for i in range(len(self.fuse_layers)): - y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) - for j in range(1, self.num_branches): - if i == j: - y = y + x[j] - elif j > i: - width_output = x[i].shape[-1] - height_output = x[i].shape[-2] - y = y + F.interpolate( - self.fuse_layers[i][j](x[j]), - size=(height_output, width_output), - mode="bilinear", - align_corners=False, - ) - else: - y = y + self.fuse_layers[i][j](x[j]) - x_fuse.append(self.relu(y)) - - return x_fuse - - -class HRNetV2(nn.Module): - def __init__(self, input_channels=3, width=48, **kwargs): - super(HRNetV2, self).__init__() - blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} - - extra = { - "STAGE2": { - "NUM_MODULES": 1, - "NUM_BRANCHES": 2, - "BLOCK": "BASIC", - "NUM_BLOCKS": (4, 4), - "NUM_CHANNELS": (width, width * 2), - "FUSE_METHOD": "SUM", - }, - "STAGE3": { - "NUM_MODULES": 4, - "NUM_BRANCHES": 3, - "BLOCK": "BASIC", - "NUM_BLOCKS": (4, 4, 4), - "NUM_CHANNELS": (width, width * 2, width * 4), - "FUSE_METHOD": "SUM", - }, - "STAGE4": { - "NUM_MODULES": 3, - "NUM_BRANCHES": 4, - "BLOCK": "BASIC", - "NUM_BLOCKS": (4, 4, 4, 4), - "NUM_CHANNELS": (width, width * 2, width * 4, width * 8), - "FUSE_METHOD": "SUM", - }, - "FINAL_CONV_KERNEL": 1, - } - - # stem net - self.layer0 = nn.Sequential( - OrderedDict( - [ - ("conv1", nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1, bias=False)), - ("bn1", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), - ("relu", nn.ReLU(inplace=True)), - ("conv2", nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)), - ("bn2", nn.BatchNorm2d(64, momentum=HRNETV2_BN_MOMENTUM)), - ("relu2", nn.ReLU(inplace=True)), - ] - ) - ) - - self.layer1 = self._make_layer(HRNetBottleneck, 64, 64, 4) - - self.stage2_cfg = extra["STAGE2"] - num_channels = self.stage2_cfg["NUM_CHANNELS"] - block = blocks_dict[self.stage2_cfg["BLOCK"]] - num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] - self.transition1 = self._make_transition_layer([256], num_channels) - self.stage2, pre_stage_channels = self._make_stage(self.stage2_cfg, num_channels) - - self.stage3_cfg = extra["STAGE3"] - num_channels = self.stage3_cfg["NUM_CHANNELS"] - block = blocks_dict[self.stage3_cfg["BLOCK"]] - num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] - self.transition2 = self._make_transition_layer(pre_stage_channels, num_channels) - self.stage3, pre_stage_channels = self._make_stage(self.stage3_cfg, num_channels) - - self.stage4_cfg = extra["STAGE4"] - num_channels = self.stage4_cfg["NUM_CHANNELS"] - block = blocks_dict[self.stage4_cfg["BLOCK"]] - num_channels = [num_channels[i] * block.expansion for i in range(len(num_channels))] - self.transition3 = self._make_transition_layer(pre_stage_channels, num_channels) - self.stage4, pre_stage_channels = self._make_stage(self.stage4_cfg, num_channels, multi_scale_output=True) - - def _make_transition_layer(self, num_channels_pre_layer, num_channels_cur_layer): - num_branches_cur = len(num_channels_cur_layer) - num_branches_pre = len(num_channels_pre_layer) - - transition_layers = [] - for i in range(num_branches_cur): - if i < num_branches_pre: - if num_channels_cur_layer[i] != num_channels_pre_layer[i]: - transition_layers.append( - nn.Sequential( - nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), - nn.BatchNorm2d(num_channels_cur_layer[i], momentum=HRNETV2_BN_MOMENTUM), - nn.ReLU(inplace=True), - ) - ) - else: - transition_layers.append(None) - else: - conv3x3s = [] - for j in range(i + 1 - num_branches_pre): - inchannels = num_channels_pre_layer[-1] - outchannels = num_channels_cur_layer[i] if j == i - num_branches_pre else inchannels - conv3x3s.append( - nn.Sequential( - nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), - nn.BatchNorm2d(outchannels, momentum=HRNETV2_BN_MOMENTUM), - nn.ReLU(inplace=True), - ) - ) - transition_layers.append(nn.Sequential(*conv3x3s)) - - return nn.ModuleList(transition_layers) - - def _make_layer(self, block, inplanes, planes, blocks, stride=1): - downsample = None - if stride != 1 or inplanes != planes * block.expansion: - downsample = nn.Sequential( - nn.Conv2d(inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), - nn.BatchNorm2d(planes * block.expansion, momentum=HRNETV2_BN_MOMENTUM), - ) - - layers = [] - layers.append(block(inplanes, planes, stride, downsample)) - inplanes = planes * block.expansion - for i in range(1, blocks): - layers.append(block(inplanes, planes)) - - return nn.Sequential(*layers) - - def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): - - blocks_dict = {"BASIC": HRNetBasicBlock, "BOTTLENECK": HRNetBottleneck} - - num_modules = layer_config["NUM_MODULES"] - num_branches = layer_config["NUM_BRANCHES"] - num_blocks = layer_config["NUM_BLOCKS"] - num_channels = layer_config["NUM_CHANNELS"] - block = blocks_dict[layer_config["BLOCK"]] - fuse_method = layer_config["FUSE_METHOD"] - - modules = [] - for i in range(num_modules): - # multi_scale_output is only used last module - if not multi_scale_output and i == num_modules - 1: - reset_multi_scale_output = False - else: - reset_multi_scale_output = True - modules.append( - HighResolutionModule( - num_branches, - block, - num_blocks, - num_inchannels, - num_channels, - fuse_method, - reset_multi_scale_output, - ) - ) - num_inchannels = modules[-1].get_num_inchannels() - - return nn.Sequential(*modules), num_inchannels - - def forward(self, x, return_feature_maps=False): - x = self.layer0(x) - x = self.layer1(x) - - x_list = [] - for i in range(self.stage2_cfg["NUM_BRANCHES"]): - if self.transition1[i] is not None: - x_list.append(self.transition1[i](x)) - else: - x_list.append(x) - y_list = self.stage2(x_list) - - x_list = [] - for i in range(self.stage3_cfg["NUM_BRANCHES"]): - if self.transition2[i] is not None: - x_list.append(self.transition2[i](y_list[-1])) - else: - x_list.append(y_list[i]) - y_list = self.stage3(x_list) - - x_list = [] - for i in range(self.stage4_cfg["NUM_BRANCHES"]): - if self.transition3[i] is not None: - x_list.append(self.transition3[i](y_list[-1])) - else: - x_list.append(y_list[i]) - x = self.stage4(x_list) - - # Upsampling - x0_h, x0_w = x[0].size(2), x[0].size(3) - x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode="bilinear", align_corners=False) - x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode="bilinear", align_corners=False) - x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode="bilinear", align_corners=False) - - x = torch.cat([x[0], x1, x2, x3], dim=1) - - return [x] - diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index a242376c0..74a9273f8 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -1,6 +1,6 @@ from collections import OrderedDict -from torch import nn +from torch import nn, Tensor from typing import List from .common import DecoderModule @@ -37,6 +37,6 @@ def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0): ) ) - def forward(self, features): - embedding = self.embedding(features) + def forward(self, features: List[Tensor]): + embedding = self.embedding(features[-1]) return self.logits(embedding) From dbc6a503d4f8c2efc717d598f5555b38fe12cc5e Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sat, 7 Dec 2019 00:01:13 +0200 Subject: [PATCH 47/79] Refactor HRNet --- pytorch_toolbelt/modules/decoders/hrnet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index 74a9273f8..e36331841 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -4,7 +4,6 @@ from typing import List from .common import DecoderModule -from ..backbone.hrnet import HRNETV2_BN_MOMENTUM __all__ = ["HRNetDecoder"] @@ -22,7 +21,7 @@ def __init__(self, feature_maps: List[int], output_channels: int, dropout=0.0): "conv1", nn.Conv2d(in_channels=features, out_channels=features, kernel_size=3, padding=1, bias=False), ), - ("bn1", nn.BatchNorm2d(features, momentum=HRNETV2_BN_MOMENTUM)), + ("bn1", nn.BatchNorm2d(features)), ("relu", nn.ReLU(inplace=True)), ] ) From 268e85d8d5548f276a0165bef7102b3d0d0821f2 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 8 Dec 2019 13:36:56 +0200 Subject: [PATCH 48/79] Fix unexising module name --- pytorch_toolbelt/modules/encoders/hrnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index 837b8962d..d946d4014 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -444,7 +444,7 @@ def resize_and_concatenate_input(x: List[torch.Tensor]) -> torch.Tensor: return x def change_input_channels(self, input_channels: int, mode="auto"): - self.hrnet.layer0.conv1 = make_n_channel_input(self.hrnet.layer0.conv1, input_channels, mode) + self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) class HRNetV2Encoder18(HRNetEncoderBase): From 3bd9e2ce950c1d6d0733bf6851532d8af7959823 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 8 Dec 2019 22:54:09 +0200 Subject: [PATCH 49/79] Fix dice & jaccard losses --- pytorch_toolbelt/losses/dice.py | 8 ++++---- pytorch_toolbelt/losses/functional.py | 4 ++-- pytorch_toolbelt/losses/jaccard.py | 6 +++--- pytorch_toolbelt/utils/catalyst/metrics.py | 20 +++++++++++--------- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/pytorch_toolbelt/losses/dice.py b/pytorch_toolbelt/losses/dice.py index 68da1b7e9..9ff362bd0 100644 --- a/pytorch_toolbelt/losses/dice.py +++ b/pytorch_toolbelt/losses/dice.py @@ -79,10 +79,10 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: y_true = y_true.view(bs, num_classes, -1) y_pred = y_pred.view(bs, num_classes, -1) - scores = soft_dice_score(y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims) + scores = soft_dice_score(y_pred, y_true.type_as(y_pred), self.smooth, self.eps, dims=dims) if self.log_loss: - loss = -torch.log(scores) + loss = -torch.log(scores.clamp_min(self.eps)) else: loss = 1 - scores @@ -91,8 +91,8 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: # NOTE: A better workaround would be to use loss term `mean(y_pred)` # for this case, however it will be a modified jaccard loss - mask = (y_true.sum(dims) > 0).float() - loss = loss * mask + mask = y_true.sum(dims) > 0 + loss *= mask.float() if self.classes is not None: loss = loss[self.classes] diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 1d9d3a00a..f21dd8d7c 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -102,7 +102,7 @@ def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, e cardinality = torch.sum(y_pred + y_true) union = cardinality - intersection - jaccard_score = (intersection + smooth) / (union + smooth + eps) + jaccard_score = (intersection + smooth) / (union.clamp_min(eps) + smooth) return jaccard_score @@ -129,7 +129,7 @@ def soft_dice_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0, eps=1e else: intersection = torch.sum(y_pred * y_true) cardinality = torch.sum(y_pred + y_true) - dice_score = (2.0 * intersection + smooth) / (cardinality + smooth + eps) + dice_score = (2.0 * intersection + smooth) / (cardinality.clamp_min(eps) + smooth) return dice_score diff --git a/pytorch_toolbelt/losses/jaccard.py b/pytorch_toolbelt/losses/jaccard.py index 864a4c3e1..eeaed2986 100644 --- a/pytorch_toolbelt/losses/jaccard.py +++ b/pytorch_toolbelt/losses/jaccard.py @@ -82,7 +82,7 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: scores = soft_jaccard_score(y_pred, y_true.type(y_pred.dtype), self.smooth, self.eps, dims=dims) if self.log_loss: - loss = -torch.log(scores) + loss = -torch.log(scores.clamp_min(self.eps)) else: loss = 1 - scores @@ -91,8 +91,8 @@ def forward(self, y_pred: Tensor, y_true: Tensor) -> Tensor: # NOTE: A better workaround would be to use loss term `mean(y_pred)` # for this case, however it will be a modified jaccard loss - mask = (y_true.sum(dims) > 0).float() - loss = loss * mask + mask = y_true.sum(dims) > 0 + loss *= mask.float() if self.classes is not None: loss = loss[self.classes] diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 4f6984f18..2d410ceff 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -1,15 +1,15 @@ from functools import partial +from typing import List import numpy as np import torch from catalyst.dl import Callback, RunnerState, MetricCallback, CallbackOrder +from sklearn.metrics import f1_score from torchnet.meter import ConfusionMeter -from typing import List +from pytorch_toolbelt.utils.visualization import render_figure_to_tensor, plot_confusion_matrix from .visualization import get_tensorboard_logger from ..torch_utils import to_numpy -from pytorch_toolbelt.utils.visualization import render_figure_to_tensor, plot_confusion_matrix -from sklearn.metrics import f1_score, confusion_matrix __all__ = [ "pixel_accuracy", @@ -81,6 +81,7 @@ def __init__( class_names: List[str] = None, num_classes: int = None, ignore_index=None, + activation_fn=partial(torch.argmax, dim=1), ): """ :param input_key: input key to use for precision calculation; @@ -92,20 +93,21 @@ def __init__( super().__init__(CallbackOrder.Metric) self.prefix = prefix self.class_names = class_names - self.num_classes = num_classes \ - if class_names is None \ - else len(class_names) + self.num_classes = num_classes if class_names is None else len(class_names) self.output_key = output_key self.input_key = input_key self.ignore_index = ignore_index self.confusion_matrix = None + self.activation_fn = activation_fn def on_loader_start(self, state): self.confusion_matrix = ConfusionMeter(self.num_classes) def on_batch_end(self, state: RunnerState): - outputs = state.output[self.output_key].detach().argmax(dim=1).cpu() - targets = state.input[self.input_key].detach().cpu() + outputs: torch.Tensor = state.output[self.output_key].detach().cpu() + outputs: torch.Tensor = self.activation_fn(outputs) + + targets: torch.Tensor = state.input[self.input_key].detach().cpu() # Flatten outputs = outputs.view(-1) @@ -116,10 +118,10 @@ def on_batch_end(self, state: RunnerState): outputs = outputs[mask] targets = targets[mask] + targets = targets.type_as(outputs) self.confusion_matrix.add(predicted=outputs, target=targets) def on_loader_end(self, state): - if self.class_names is None: class_names = [str(i) for i in range(self.num_classes)] else: From f2e437751cce5bcc8f0b830787cb9c6353870512 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 9 Dec 2019 12:23:43 +0200 Subject: [PATCH 50/79] Add support of ABN block --- pytorch_toolbelt/modules/decoders/unet.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index e580568ca..0773a8081 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -12,13 +12,17 @@ class UNetDecoder(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0.): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, abn_block=ABN, dropout=0.): super().__init__() if not isinstance(decoder_features, list): decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] - self.center = UnetCentralBlock(in_dec_filters=feature_maps[-1], out_filters=decoder_features[-1]) + self.center = UnetCentralBlock( + in_dec_filters=feature_maps[-1], + out_filters=decoder_features[-1], + abn_block=abn_block + ) blocks = [] for block_index, in_enc_features in enumerate(feature_maps[:-1]): @@ -27,6 +31,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels in_dec_filters=decoder_features[block_index + 1], in_enc_filters=in_enc_features, out_filters=decoder_features[block_index], + abn_block=abn_block ) ) From 264f99211de1e195e2f41e9bbd73374a9cba60cf Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 9 Dec 2019 12:28:30 +0200 Subject: [PATCH 51/79] Add support of changing number of channels --- pytorch_toolbelt/modules/encoders/efficientnet.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/encoders/efficientnet.py b/pytorch_toolbelt/modules/encoders/efficientnet.py index 6767063a4..07331d930 100644 --- a/pytorch_toolbelt/modules/encoders/efficientnet.py +++ b/pytorch_toolbelt/modules/encoders/efficientnet.py @@ -9,7 +9,7 @@ efficient_net_b7, ) -from .common import EncoderModule, _take +from .common import EncoderModule, _take, make_n_channel_input __all__ = [ "EfficientNetEncoder", @@ -57,6 +57,10 @@ def forward(self, x): # Return only features that were requested return _take(output_features, self._layers) + def change_input_channels(self, input_channels: int, mode="auto"): + self.stem.conv = make_n_channel_input(self.stem.conv, input_channels, mode) + return self + class EfficientNetB0Encoder(EfficientNetEncoder): def __init__(self, layers=None, **kwargs): From 2f7a277a38b9d595521efdf6295dd7f6a8d00302 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Mon, 9 Dec 2019 12:48:20 +0200 Subject: [PATCH 52/79] Add support of abn block in UNets --- pytorch_toolbelt/modules/decoders/unet_v2.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index cb746205f..cae96a25e 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -97,7 +97,7 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> Tuple[torch.Tensor, Lis class UNetDecoderV2(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0.): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0., abn_block=ABN): super().__init__() if not isinstance(decoder_features, list): @@ -107,11 +107,13 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels for block_index, in_enc_features in enumerate(feature_maps[:-1]): blocks.append( UnetDecoderBlockV2( - decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels, post_dropout_rate=dropout + decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels, + abn_block=abn_block, + post_dropout_rate=dropout ) ) - self.center = UnetCentralBlockV2(feature_maps[-1], decoder_features[-1], mask_channels) + self.center = UnetCentralBlockV2(feature_maps[-1], decoder_features[-1], mask_channels, abn_block=abn_block) self.blocks = nn.ModuleList(blocks) self.output_filters = decoder_features From cdae8e3732da05afb2198f5feaea3477a4638090 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Mon, 9 Dec 2019 16:44:54 +0200 Subject: [PATCH 53/79] Improve focal loss (support normalized focal loss) --- pytorch_toolbelt/inference/tta.py | 3 +- pytorch_toolbelt/losses/focal.py | 45 +++-- pytorch_toolbelt/losses/functional.py | 18 +- pytorch_toolbelt/zoo/segmentation.py | 250 ++++++++++++++++++++++++++ tests/test_decoders.py | 30 ++++ 5 files changed, 319 insertions(+), 27 deletions(-) create mode 100644 pytorch_toolbelt/zoo/segmentation.py diff --git a/pytorch_toolbelt/inference/tta.py b/pytorch_toolbelt/inference/tta.py index 84a3c4b7d..f78a4bbc7 100644 --- a/pytorch_toolbelt/inference/tta.py +++ b/pytorch_toolbelt/inference/tta.py @@ -204,7 +204,8 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor: output += F.torch_transpose(x) one_over_8 = float(1.0 / 8.0) - return output * one_over_8 + output *= one_over_8 + return output class TTAWrapper(nn.Module): diff --git a/pytorch_toolbelt/losses/focal.py b/pytorch_toolbelt/losses/focal.py index 2e3b4d965..f8da7d11f 100644 --- a/pytorch_toolbelt/losses/focal.py +++ b/pytorch_toolbelt/losses/focal.py @@ -8,25 +8,27 @@ class BinaryFocalLoss(_Loss): - def __init__(self, alpha=0.5, gamma=2, ignore_index=None, reduction="mean", reduced=False, threshold=0.5): + def __init__( + self, alpha=None, gamma=2, ignore_index=None, reduction="mean", normalized=False, reduced_threshold=None + ): """ - :param alpha: - :param gamma: - :param ignore_index: + :param alpha: Prior probability of having positive value in target. + :param gamma: Power factor for dampening weight (focal strenght). + :param ignore_index: If not None, targets may contain values to be ignored. Target values equal to ignore_index will be ignored from loss computation. :param reduced: :param threshold: """ super().__init__() - self.alpha = alpha - self.gamma = gamma self.ignore_index = ignore_index - if reduced: - self.focal_loss = partial( - focal_loss_with_logits, alpha=None, gamma=gamma, threshold=threshold, reduction=reduction - ) - else: - self.focal_loss = partial(focal_loss_with_logits, alpha=alpha, gamma=gamma, reduction=reduction) + self.focal_loss_fn = partial( + focal_loss_with_logits, + alpha=alpha, + gamma=gamma, + reduced_threshold=reduced_threshold, + reduction=reduction, + normalized=normalized, + ) def forward(self, label_input, label_target): """Compute focal loss for binary classification problem. @@ -40,23 +42,32 @@ def forward(self, label_input, label_target): label_input = label_input[not_ignored] label_target = label_target[not_ignored] - loss = self.focal_loss(label_input, label_target) + loss = self.focal_loss_fn(label_input, label_target) return loss class FocalLoss(_Loss): - def __init__(self, alpha=0.5, gamma=2, ignore_index=None): + def __init__( + self, alpha=None, gamma=2, ignore_index=None, reduction="mean", normalized=False, reduced_threshold=None + ): """ Focal loss for multi-class problem. :param alpha: :param gamma: :param ignore_index: If not None, targets with given index are ignored + :param reduced_threshold: A threshold factor for computing reduced focal loss """ super().__init__() - self.alpha = alpha - self.gamma = gamma self.ignore_index = ignore_index + self.focal_loss_fn = partial( + focal_loss_with_logits, + alpha=alpha, + gamma=gamma, + reduced_threshold=reduced_threshold, + reduction=reduction, + normalized=normalized, + ) def forward(self, label_input, label_target): num_classes = label_input.size(1) @@ -74,5 +85,5 @@ def forward(self, label_input, label_target): cls_label_target = cls_label_target[not_ignored] cls_label_input = cls_label_input[not_ignored] - loss += focal_loss_with_logits(cls_label_input, cls_label_target, gamma=self.gamma, alpha=self.alpha) + loss += self.focal_loss_fn(cls_label_input, cls_label_target) return loss diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index f21dd8d7c..43beec60f 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -14,7 +14,7 @@ def focal_loss_with_logits( alpha: Optional[float] = 0.25, reduction="mean", normalized=False, - threshold: Optional[float] = None, + reduced_threshold: Optional[float] = None, ) -> torch.Tensor: """Compute binary focal loss between target and output logits. @@ -31,24 +31,24 @@ def focal_loss_with_logits( specifying either of those two args will override :attr:`reduction`. 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). - threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). + reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). References:: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py """ target = target.type(input.type()) - logpt = -F.binary_cross_entropy_with_logits(input, target, reduction="none") - pt = torch.exp(logpt) + logpt = F.binary_cross_entropy_with_logits(input, target, reduction="none") + pt = torch.exp(-logpt) # compute the loss - if threshold is None: + if reduced_threshold is None: focal_term = (1 - pt).pow(gamma) else: - focal_term = ((1.0 - pt) / threshold).pow(gamma) - focal_term[pt < threshold] = 1 + focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) + focal_term[pt < reduced_threshold] = 1 - loss = -focal_term * logpt + loss = focal_term * logpt if alpha is not None: loss = loss * (alpha * target + (1 - alpha) * (1 - target)) @@ -73,7 +73,7 @@ def focal_loss_with_logits( # TODO: Mark as deprecated and emit warning def reduced_focal_loss(input: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"): - return focal_loss_with_logits(input, target, alpha=None, gamma=gamma, reduction=reduction, threshold=threshold) + return focal_loss_with_logits(input, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold) def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: diff --git a/pytorch_toolbelt/zoo/segmentation.py b/pytorch_toolbelt/zoo/segmentation.py new file mode 100644 index 000000000..6af13eff2 --- /dev/null +++ b/pytorch_toolbelt/zoo/segmentation.py @@ -0,0 +1,250 @@ +from ..modules import ABN +from ..modules import encoders as E +from ..modules import decoders as D +from torch import nn, Tensor +from torch.nn import functional as F + + +__all__ = [ + "FPNSumSegmentationModel", + "FPNCatSegmentationModel", + "DeeplabV3SegmentationModel", + "resnet34_fpncat128", + "seresnext50_fpncat128", + "seresnext101_fpncat256", + "seresnext101_fpnsum256", + "seresnext101_deeplab256", + "efficientnetb4_fpncat128", + "OUTPUT_MASK_KEY", + "OUTPUT_MASK_4_KEY", + "OUTPUT_MASK_8_KEY", + "OUTPUT_MASK_16_KEY", + "OUTPUT_MASK_32_KEY", +] + +OUTPUT_MASK_KEY = "mask" +OUTPUT_MASK_4_KEY = "mask_4" +OUTPUT_MASK_8_KEY = "mask_8" +OUTPUT_MASK_16_KEY = "mask_16" +OUTPUT_MASK_32_KEY = "mask_32" + + +class FPNSumSegmentationModel(nn.Module): + def __init__( + self, + encoder: E.EncoderModule, + num_classes: int, + dropout=0.25, + abn_block=ABN, + fpn_channels=256, + return_full_size_mask=True, + use_deep_supervision=False, + ): + """ + Create a segmentation model of encoder-decoder architecture were encoder is arbitrary architecture + capable of providing 4 feature maps of 4,8,16,32 stride and decoder is FPN with concatenation. + + Args: + encoder: Encoder model + num_classes: Number of channels in final mask + dropout: Dropout rate to apply before computing final output mask + abn_block: Activated batch-norm block used in decoder + fpn_channels: Number of FPN channels computed for each encoder's feature map + return_full_size_mask: If True, returns mask of same size as input image; + otherwise returns mask that is 4 times smaller than original image + use_deep_supervision: If True, model also predicts mask of strides 4, 8, 16, 32 from intermediate layers + to enforce model learn mask representation at each level of encoder's feature maps + """ + super().__init__() + self.encoder = encoder + + self.decoder = D.FPNSumDecoder( + feature_maps=encoder.output_filters, + output_channels=num_classes, + dsv_channels=num_classes if use_deep_supervision else None, + fpn_channels=fpn_channels, + abn_block=abn_block, + dropout=dropout, + ) + self.deep_supervision = use_deep_supervision + self.full_size_mask = return_full_size_mask + + def forward(self, x): + enc_features = self.encoder(x) + output = self.decoder(enc_features) + + if self.deep_supervision: + mask, dsv = output + else: + mask = output + + if self.full_size_mask: + mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False) + + output = {OUTPUT_MASK_KEY: mask} + + if self.deep_supervision: + output[OUTPUT_MASK_4_KEY] = dsv[3] + output[OUTPUT_MASK_8_KEY] = dsv[2] + output[OUTPUT_MASK_16_KEY] = dsv[1] + output[OUTPUT_MASK_32_KEY] = dsv[0] + + return output + + +class FPNCatSegmentationModel(nn.Module): + def __init__( + self, + encoder: E.EncoderModule, + num_classes: int, + dropout=0.0, + abn_block=ABN, + fpn_channels=256, + return_full_size_mask=True, + use_deep_supervision=False, + ): + """ + Create a segmentation model of encoder-decoder architecture were encoder is arbitrary architecture + capable of providing 4 feature maps of 4,8,16,32 stride and decoder is FPN with summation. + + Args: + encoder: Encoder model + num_classes: Number of channels in final mask + dropout: Dropout rate to apply before computing final output mask + abn_block: Activated batch-norm block used in decoder + fpn_channels: Number of FPN channels computed for each encoder's feature map + return_full_size_mask: If True, returns mask of same size as input image; + otherwise returns mask that is 4 times smaller than original image + use_deep_supervision: If True, model also predicts mask of strides 4, 8, 16, 32 from intermediate layers + to enforce model learn mask representation at each level of encoder's feature maps + """ + super().__init__() + self.encoder = encoder + + self.decoder = D.FPNCatDecoder( + feature_maps=encoder.output_filters, + output_channels=num_classes, + dsv_channels=num_classes if use_deep_supervision else None, + fpn_channels=fpn_channels, + abn_block=abn_block, + dropout=dropout, + ) + + self.deep_supervision = use_deep_supervision + self.full_size_mask = return_full_size_mask + + def forward(self, x: Tensor): + features = self.encoder(x) + output = self.decoder(features) + + if self.deep_supervision: + mask, dsv = output + else: + mask = output + + if self.full_size_mask: + mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False) + + output = {OUTPUT_MASK_KEY: mask} + + if self.deep_supervision: + output[OUTPUT_MASK_4_KEY] = dsv[3] + output[OUTPUT_MASK_8_KEY] = dsv[2] + output[OUTPUT_MASK_16_KEY] = dsv[1] + output[OUTPUT_MASK_32_KEY] = dsv[0] + + return output + + +class DeeplabV3SegmentationModel(nn.Module): + def __init__( + self, + encoder: E.EncoderModule, + num_classes: int, + dropout=0.25, + abn_block=ABN, + high_level_bottleneck=256, + low_level_bottleneck=32, + return_full_size_mask=True, + ): + super().__init__() + self.encoder = encoder + + self.decoder = D.DeeplabV3Decoder( + feature_maps=encoder.output_filters, + output_stride=encoder.output_strides[-1], + num_classes=num_classes, + high_level_bottleneck=high_level_bottleneck, + low_level_bottleneck=low_level_bottleneck, + abn_block=abn_block, + dropout=dropout, + ) + + self.return_full_size_mask = return_full_size_mask + + def forward(self, x): + enc_features = self.encoder(x) + + # Decode mask + mask, dsv = self.decoder(enc_features) + + if self.return_full_size_mask: + mask = F.interpolate(mask, size=x.size()[2:], mode="bilinear", align_corners=False) + + output = {OUTPUT_MASK_KEY: mask, OUTPUT_MASK_32_KEY: dsv} + + return output + + +# resnet34-backbone models + + +def resnet34_fpncat128(input_channels=3, num_classes=1, dropout=0.0, pretrained=None): + encoder = E.Resnet34Encoder(pretrained=pretrained) + if input_channels != 3: + encoder.change_input_channels(input_channels) + return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout) + + +# seresnext50-backbone models + + +def seresnext50_fpncat128(input_channels=3, num_classes=1, dropout=0.0, pretrained=None): + encoder = E.SEResNeXt50Encoder(pretrained=pretrained) + if input_channels != 3: + encoder.change_input_channels(input_channels) + return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout) + + +# seresnext101-backbone models + + +def seresnext101_fpncat256(input_channels=3, num_classes=1, dropout=0.0, pretrained=None): + encoder = E.SEResNeXt101Encoder(pretrained=pretrained) + if input_channels != 3: + encoder.change_input_channels(input_channels) + return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=256, dropout=dropout) + + +def seresnext101_fpnsum256(input_channels=3, num_classes=1, dropout=0.0, pretrained=None): + encoder = E.SEResNeXt101Encoder(pretrained=pretrained) + if input_channels != 3: + encoder.change_input_channels(input_channels) + return FPNSumSegmentationModel(encoder, num_classes=num_classes, fpn_channels=256, dropout=dropout) + + +def seresnext101_deeplab256(input_channels=3, num_classes=1, dropout=0.0): + encoder = E.SEResNeXt101Encoder() + if input_channels != 3: + encoder.change_input_channels(input_channels) + return DeeplabV3SegmentationModel(encoder, num_classes=num_classes, high_level_bottleneck=256, dropout=dropout) + + +# efficientnet-backbone models + + +def efficientnetb4_fpncat128(input_channels=3, num_classes=1, dropout=0.0, pretrained=None): + encoder = E.EfficientNetB4Encoder(abn_params={"activation": "swish"}, pretrained=pretrained) + if input_channels != 3: + encoder.change_input_channels(input_channels) + return FPNCatSegmentationModel(encoder, num_classes=num_classes, fpn_channels=128, dropout=dropout) diff --git a/tests/test_decoders.py b/tests/test_decoders.py index 8c5b115cb..df1206ec7 100644 --- a/tests/test_decoders.py +++ b/tests/test_decoders.py @@ -16,12 +16,27 @@ def test_fpn_sum(): net = FPNSumDecoder(channels, 5).eval() + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] + output = net(input) + + print(output.size(), output.mean(), output.std()) + print(count_parameters(net)) + + +@torch.no_grad() +def test_fpn_sum_with_dsv(): + channels = [256, 512, 1024, 2048] + sizes = [64, 32, 16, 8] + + net = FPNSumDecoder(channels, output_channels=5, dsv_channels=5).eval() + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] output, dsv_masks = net(input) print(output.size(), output.mean(), output.std()) for dsv in dsv_masks: print(dsv.size(), dsv.mean(), dsv.std()) + assert dsv.size(1) == 5 print(count_parameters(net)) @@ -32,10 +47,25 @@ def test_fpn_cat(): net = FPNCatDecoder(channels, 5).eval() + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] + output = net(input) + + print(output.size(), output.mean(), output.std()) + print(count_parameters(net)) + + +@torch.no_grad() +def test_fpn_cat_with_dsv(): + channels = [256, 512, 1024, 2048] + sizes = [64, 32, 16, 8] + + net = FPNCatDecoder(channels, output_channels=5, dsv_channels=5).eval() + input = [torch.randn(4, ch, sz, sz) for sz, ch in zip(sizes, channels)] output, dsv_masks = net(input) print(output.size(), output.mean(), output.std()) for dsv in dsv_masks: print(dsv.size(), dsv.mean(), dsv.std()) + assert dsv.size(1) == 5 print(count_parameters(net)) From c589757df5a7f6e5c41546b3ddd11cd24f3ad0e9 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 10 Dec 2019 23:59:17 +0200 Subject: [PATCH 54/79] Prevent small normalization focal term --- pytorch_toolbelt/losses/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 43beec60f..642554184 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -54,7 +54,7 @@ def focal_loss_with_logits( loss = loss * (alpha * target + (1 - alpha) * (1 - target)) if normalized: - norm_factor = focal_term.sum() + norm_factor = focal_term.sum() + 1e-5 loss = loss / norm_factor if reduction == "mean": From 8539b305fe5847b1fd253873587bd76467a20b2a Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Wed, 11 Dec 2019 17:33:29 +0200 Subject: [PATCH 55/79] Change interpolation in Unet --- pytorch_toolbelt/modules/decoders/fpn_sum.py | 8 +++---- pytorch_toolbelt/modules/unet.py | 25 +++++++++----------- setup.py | 2 +- 3 files changed, 16 insertions(+), 19 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index 5d67f5a53..bfa29bab3 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -67,9 +67,9 @@ def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: x = torch.cat( [ x, - F.interpolate(p2, size=x_size, mode="bilinear", align_corners=False), - F.interpolate(p4, size=x_size, mode="bilinear", align_corners=False), - F.interpolate(p8, size=x_size, mode="bilinear", align_corners=False), + F.interpolate(p2, size=x_size, mode="nearest"), + F.interpolate(p4, size=x_size, mode="nearest"), + F.interpolate(p8, size=x_size, mode="nearest"), ], dim=1, ) @@ -124,7 +124,7 @@ def __init__( self.dsv = None def forward(self, decoder_fm: Tensor, encoder_fm: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: - decoder_fm = F.interpolate(decoder_fm, size=encoder_fm.size()[2:], mode="bilinear", align_corners=False) + decoder_fm = F.interpolate(decoder_fm, size=encoder_fm.size()[2:], mode="nearest") encoder_fm = self.skip(encoder_fm) x = decoder_fm + encoder_fm diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py index 728f9b3cf..0f87fe5ac 100644 --- a/pytorch_toolbelt/modules/unet.py +++ b/pytorch_toolbelt/modules/unet.py @@ -45,12 +45,11 @@ class UnetDecoderBlock(nn.Module): def __init__( self, - in_dec_filters, - in_enc_filters, - out_filters, + in_dec_filters: int, + in_enc_filters: int, + out_filters: int, abn_block=ABN, - pre_dropout_rate=0.0, - post_dropout_rate=0.0, + dropout_rate=0.0, scale_factor=None, scale_mode="nearest", align_corners=None, @@ -62,17 +61,16 @@ def __init__( self.scale_mode = scale_mode self.align_corners = align_corners - self.pre_drop = nn.Dropout2d(pre_dropout_rate, inplace=True) - self.conv1 = nn.Conv2d( in_dec_filters + in_enc_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs ) self.abn1 = abn_block(out_filters) + + self.drop = nn.Dropout2d(dropout_rate, inplace=False) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) self.abn2 = abn_block(out_filters) - self.post_drop = nn.Dropout2d(post_dropout_rate, inplace=False) - def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: if self.scale_factor is not None: x = F.interpolate( @@ -82,15 +80,14 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: lat_size = enc.size()[2:] x = F.interpolate(x, size=lat_size, mode=self.scale_mode, align_corners=self.align_corners) - x = torch.cat([x, enc], 1) - - x = self.pre_drop(x) + x = torch.cat([x, enc], dim=1) x = self.conv1(x) x = self.abn1(x) + x = self.drop(x) + x = self.conv2(x) x = self.abn2(x) - x = self.post_drop(x) - return x + return x \ No newline at end of file diff --git a/setup.py b/setup.py index c2ae21ef2..6622cc6a8 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def load_readme(): def get_test_requirements(): - requirements = ["pytest", "catalyst>=19.6.4"] + requirements = ["pytest", "catalyst>=19.6.4", "black-19.3b0"] if sys.version_info < (3, 3): requirements.append("mock") return requirements From d42971fea89a152905bb73d89aaad05fb9f247a4 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 11 Dec 2019 22:37:06 +0200 Subject: [PATCH 56/79] Improve Unet --- pytorch_toolbelt/modules/decoders/unet.py | 19 +++++++++++------ pytorch_toolbelt/modules/unet.py | 26 +++++++++++------------ pytorch_toolbelt/utils/torch_utils.py | 10 +++++++-- 3 files changed, 33 insertions(+), 22 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index 0773a8081..b19b06812 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -11,17 +11,21 @@ __all__ = ["UNetDecoder"] +def conv1x1(input, output): + return nn.Conv2d(input, output, kernel_size=1) + + class UNetDecoder(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, abn_block=ABN, dropout=0.): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, abn_block=ABN, dropout=0.0, final_block=conv1x1): super().__init__() if not isinstance(decoder_features, list): decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] + else: + assert len(decoder_features) == len(feature_maps) self.center = UnetCentralBlock( - in_dec_filters=feature_maps[-1], - out_filters=decoder_features[-1], - abn_block=abn_block + in_dec_filters=feature_maps[-1], out_filters=decoder_features[-1], abn_block=abn_block ) blocks = [] @@ -31,7 +35,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels in_dec_filters=decoder_features[block_index + 1], in_enc_filters=in_enc_features, out_filters=decoder_features[block_index], - abn_block=abn_block + abn_block=abn_block, ) ) @@ -39,13 +43,14 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels self.output_filters = decoder_features self.final_drop = nn.Dropout2d(dropout) - self.final = nn.Conv2d(decoder_features[0], mask_channels, kernel_size=1) + self.final = final_block(decoder_features[0], mask_channels) def forward(self, feature_maps: List[torch.Tensor]) -> torch.Tensor: output = self.center(feature_maps[-1]) for decoder_block, encoder_output in zip(reversed(self.blocks), reversed(feature_maps[:-1])): - output = decoder_block(output, encoder_output) + x = decoder_block(output, encoder_output) + output = x output = self.final_drop(output) output = self.final(output) diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py index 0f87fe5ac..9cac1fdc5 100644 --- a/pytorch_toolbelt/modules/unet.py +++ b/pytorch_toolbelt/modules/unet.py @@ -1,18 +1,18 @@ import torch from torch import nn import torch.nn.functional as F - +from typing import Optional from .activated_batch_norm import ABN __all__ = ["UnetEncoderBlock", "UnetCentralBlock", "UnetDecoderBlock"] class UnetEncoderBlock(nn.Module): - def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): + def __init__(self, in_dec_filters: int, out_filters: int, abn_block=ABN): super().__init__() - self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False, **kwargs) + self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False) self.abn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, stride=1, bias=False, **kwargs) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, stride=1, bias=False) self.abn2 = abn_block(out_filters) def forward(self, x): @@ -24,11 +24,11 @@ def forward(self, x): class UnetCentralBlock(nn.Module): - def __init__(self, in_dec_filters, out_filters, abn_block=ABN, **kwargs): + def __init__(self, in_dec_filters: int, out_filters: int, abn_block=ABN): super().__init__() - self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False, **kwargs) + self.conv1 = nn.Conv2d(in_dec_filters, out_filters, kernel_size=3, padding=1, stride=2, bias=False) self.abn1 = abn_block(out_filters) - self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False) self.abn2 = abn_block(out_filters) def forward(self, x): @@ -53,7 +53,6 @@ def __init__( scale_factor=None, scale_mode="nearest", align_corners=None, - **kwargs, ): super(UnetDecoderBlock, self).__init__() @@ -62,16 +61,16 @@ def __init__( self.align_corners = align_corners self.conv1 = nn.Conv2d( - in_dec_filters + in_enc_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs + in_dec_filters + in_enc_filters, out_filters, kernel_size=3, padding=1, bias=False ) self.abn1 = abn_block(out_filters) self.drop = nn.Dropout2d(dropout_rate, inplace=False) - self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False, **kwargs) + self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=3, padding=1, bias=False) self.abn2 = abn_block(out_filters) - def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, enc: Optional[torch.Tensor] = None) -> torch.Tensor: if self.scale_factor is not None: x = F.interpolate( x, scale_factor=self.scale_factor, mode=self.scale_mode, align_corners=self.align_corners @@ -80,7 +79,8 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: lat_size = enc.size()[2:] x = F.interpolate(x, size=lat_size, mode=self.scale_mode, align_corners=self.align_corners) - x = torch.cat([x, enc], dim=1) + if enc is not None: + x = torch.cat([x, enc], dim=1) x = self.conv1(x) x = self.abn1(x) @@ -90,4 +90,4 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> torch.Tensor: x = self.conv2(x) x = self.abn2(x) - return x \ No newline at end of file + return x diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index c0950d870..f611d8b30 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -48,7 +48,7 @@ def logit(x: torch.Tensor, eps=1e-5): return torch.log(x / (1.0 - x)) -def count_parameters(model: nn.Module) -> Tuple[int, int]: +def count_parameters(model: nn.Module) -> dict: """ Count number of total and trainable parameters of a model :param model: A model @@ -56,7 +56,13 @@ def count_parameters(model: nn.Module) -> Tuple[int, int]: """ total = sum(p.numel() for p in model.parameters()) trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) - return total, trainable + parameters = {"total": total, "trainable": trainable} + + for key in ["encoder", "decoder"]: + if hasattr(model, key): + parameters[key] = sum(p.numel() for p in model.__getattr__(key).parameters()) + + return parameters def to_numpy(x) -> np.ndarray: From f739698c27c104848c7a17c5cd3b5e4c49237bad Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 11 Dec 2019 22:50:11 +0200 Subject: [PATCH 57/79] Add assertion message --- pytorch_toolbelt/modules/decoders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index b19b06812..92a4564c3 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -22,7 +22,7 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels if not isinstance(decoder_features, list): decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] else: - assert len(decoder_features) == len(feature_maps) + assert len(decoder_features) == len(feature_maps), f"Incorrect number of decoder features: {decoder_features}, {feature_maps}" self.center = UnetCentralBlock( in_dec_filters=feature_maps[-1], out_filters=decoder_features[-1], abn_block=abn_block From 4be370a9d384f1d6be305813a5d52e2d3b66514d Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 11 Dec 2019 23:14:52 +0200 Subject: [PATCH 58/79] Reformat --- pytorch_toolbelt/modules/decoders/deeplab.py | 26 +++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index c985dff59..dd1afebbf 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -66,14 +66,16 @@ def forward(self, x): class DeeplabV3Decoder(DecoderModule): - def __init__(self, - feature_maps: List[int], - num_classes: int, - output_stride=32, - high_level_bottleneck=256, - low_level_bottleneck=32, - dropout=0.5, - abn_block=ABN): + def __init__( + self, + feature_maps: List[int], + num_classes: int, + output_stride=32, + high_level_bottleneck=256, + low_level_bottleneck=32, + dropout=0.5, + abn_block=ABN, + ): super(DeeplabV3Decoder, self).__init__() self.aspp = ASPP(feature_maps[-1], output_stride, high_level_bottleneck, dropout=dropout, abn_block=abn_block) @@ -82,7 +84,13 @@ def __init__(self, self.abn1 = abn_block(low_level_bottleneck) self.last_conv = nn.Sequential( - nn.Conv2d(high_level_bottleneck + low_level_bottleneck, high_level_bottleneck, kernel_size=3, padding=1, bias=False), + nn.Conv2d( + high_level_bottleneck + low_level_bottleneck, + high_level_bottleneck, + kernel_size=3, + padding=1, + bias=False, + ), abn_block(high_level_bottleneck), nn.Dropout(dropout), nn.Conv2d(high_level_bottleneck, high_level_bottleneck, kernel_size=3, padding=1, bias=False), From efb097eda509549b45de99e64c448c206f289755 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Wed, 11 Dec 2019 23:16:10 +0200 Subject: [PATCH 59/79] Reformat --- pytorch_toolbelt/modules/decoders/deeplab.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index dd1afebbf..16339dfcd 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -109,12 +109,12 @@ def forward(self, feature_maps): high_level_features = feature_maps[-1] high_level_features = self.aspp(high_level_features) - dsv = self.dsv(high_level_features) + mask_dsv = self.dsv(high_level_features) high_level_features = F.interpolate( high_level_features, size=low_level_feat.size()[2:], mode="bilinear", align_corners=False ) high_level_features = torch.cat([high_level_features, low_level_feat], dim=1) - high_level_features = self.last_conv(high_level_features) + mask = self.last_conv(high_level_features) - return high_level_features, dsv + return mask, mask_dsv From 2511f2971879ce91a6f993dd7d3687d5023fd9a1 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 15 Dec 2019 16:38:14 +0200 Subject: [PATCH 60/79] Add support of setting interpolation for FPN --- pytorch_toolbelt/modules/decoders/fpn.py | 4 ++-- pytorch_toolbelt/modules/decoders/fpn_cat.py | 6 +++++- pytorch_toolbelt/modules/encoders/common.py | 1 + pytorch_toolbelt/modules/encoders/inception.py | 2 +- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/fpn.py b/pytorch_toolbelt/modules/decoders/fpn.py index 1a58e4b01..aa15d21d1 100644 --- a/pytorch_toolbelt/modules/decoders/fpn.py +++ b/pytorch_toolbelt/modules/decoders/fpn.py @@ -1,5 +1,5 @@ from torch import nn -from typing import List +from typing import List, Optional from .common import DecoderModule from ..fpn import FPNBottleneckBlock, UpsampleAdd, FPNPredictionBlock @@ -15,7 +15,7 @@ def __init__( fpn_features=128, prediction_features=128, mode="bilinear", - align_corners=False, + align_corners: Optional[bool] = False, upsample_scale=None, ): """ diff --git a/pytorch_toolbelt/modules/decoders/fpn_cat.py b/pytorch_toolbelt/modules/decoders/fpn_cat.py index 3fe5b89b2..05e2a58a9 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_cat.py +++ b/pytorch_toolbelt/modules/decoders/fpn_cat.py @@ -49,6 +49,8 @@ def __init__( upsample_add=UpsampleAdd, prediction_block=FPNCatDecoderBlock, final_block=partial(nn.Conv2d, kernel_size=1), + interpolation_mode: str = "bilinear", + align_corners: Optional[bool] = False, ): """ @@ -71,9 +73,11 @@ def __init__( prediction_block=prediction_block, fpn_features=fpn_channels, prediction_features=fpn_channels, + mode=interpolation_mode, + align_corners=align_corners, ) - self.fuse = FPNFuse() + self.fuse = FPNFuse(mode=interpolation_mode, align_corners=align_corners) self.dropout = nn.Dropout2d(dropout, inplace=True) # dsv blocks are for deep supervision diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index 4654dc350..79badf387 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -18,6 +18,7 @@ def _take(elements, indexes): def make_n_channel_input(conv: nn.Conv2d, in_channels: int, mode="auto"): + assert isinstance(conv, nn.Conv2d) if conv.in_channels == in_channels: warnings.warn("make_n_channel_input call is spurious") return conv diff --git a/pytorch_toolbelt/modules/encoders/inception.py b/pytorch_toolbelt/modules/encoders/inception.py index dbd842394..bf3028183 100644 --- a/pytorch_toolbelt/modules/encoders/inception.py +++ b/pytorch_toolbelt/modules/encoders/inception.py @@ -39,5 +39,5 @@ def encoder_layers(self): return [self.layer0, self.layer1, self.layer2, self.layer3, self.layer4] def change_input_channels(self, input_channels: int, mode="auto"): - self.layer0[0] = make_n_channel_input(self.layer0[0], input_channels, mode) + self.layer0[0].conv = make_n_channel_input(self.layer0[0].conv, input_channels, mode) return self From c591dfd14846ded27b063e7d52f1c0e6f789139a Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 19 Dec 2019 16:05:47 +0100 Subject: [PATCH 61/79] Fix implementation of MobilenetV2 --- pytorch_toolbelt/modules/activations.py | 25 ++++++++++++-- .../modules/backbone/mobilenet.py | 33 ++++++++----------- tests/test_activations.py | 8 ++--- tests/test_encoders.py | 2 ++ 4 files changed, 43 insertions(+), 25 deletions(-) diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index 530b2926a..99ce1ddc0 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -23,7 +23,8 @@ "HardSigmoid", "HardSwish", "Swish", - "get_activation_module", + "instantiate_activation_block", + "get_activation_block", "sanitize_activation_name", ] @@ -134,7 +135,27 @@ def forward(self, x): return hard_swish(x, inplace=self.inplace) -def get_activation_module(activation_name: str, **kwargs) -> nn.Module: +def get_activation_block(activation_name: str): + ACTIVATIONS = { + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "leaky_relu": nn.LeakyReLU, + "elu": nn.ELU, + "selu": nn.SELU, + "celu": nn.CELU, + "glu": nn.GLU, + "prelu": nn.PReLU, + "swish": Swish, + "mish": Mish, + "hard_sigmoid": HardSigmoid, + "hard_swish": HardSwish, + "none": Identity, + } + + return ACTIVATIONS[activation_name.lower()] + + +def instantiate_activation_block(activation_name: str, **kwargs) -> nn.Module: ACTIVATIONS = { "relu": nn.ReLU, "relu6": nn.ReLU6, diff --git a/pytorch_toolbelt/modules/backbone/mobilenet.py b/pytorch_toolbelt/modules/backbone/mobilenet.py index b7006fe93..b0ca1a4a1 100644 --- a/pytorch_toolbelt/modules/backbone/mobilenet.py +++ b/pytorch_toolbelt/modules/backbone/mobilenet.py @@ -4,19 +4,19 @@ import torch.nn as nn -from ..activations import get_activation_module +from ..activations import get_activation_block -def conv_bn(inp, oup, stride, activation: nn.Module): - return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), activation(inplace=True)) +def conv_bn(inp, oup, stride, activation): + return nn.Sequential(nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup), activation()) -def conv_1x1_bn(inp, oup, activation: nn.Module): - return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), activation(inplace=True)) +def conv_1x1_bn(inp, oup, activation): + return nn.Sequential(nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), activation()) class InvertedResidual(nn.Module): - def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module): + def __init__(self, inp, oup, stride, expand_ratio, activation): super(InvertedResidual, self).__init__() self.stride = stride assert stride in [1, 2] @@ -29,7 +29,7 @@ def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module): # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), - activation(inplace=True), + activation(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), @@ -39,11 +39,11 @@ def __init__(self, inp, oup, stride, expand_ratio, activation: nn.Module): # pw nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False), nn.BatchNorm2d(hidden_dim), - activation(inplace=True), + activation(), # dw nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False), nn.BatchNorm2d(hidden_dim), - activation(inplace=True), + activation(), # pw-linear nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup), @@ -60,7 +60,7 @@ class MobileNetV2(nn.Module): def __init__(self, n_class=1000, input_size=224, width_mult=1.0, activation="relu6"): super(MobileNetV2, self).__init__() - act = get_activation_module(activation) + activation_block = get_activation_block(activation) block = InvertedResidual input_channel = 32 @@ -80,7 +80,7 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.0, activation="rel assert input_size % 32 == 0 input_channel = int(input_channel * width_mult) self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel - self.layer0 = conv_bn(3, input_channel, 2, act) + self.layer0 = conv_bn(3, input_channel, 2, activation_block) # building inverted residual blocks for layer_index, (t, c, n, s) in enumerate(interverted_residual_setting): @@ -89,16 +89,16 @@ def __init__(self, n_class=1000, input_size=224, width_mult=1.0, activation="rel blocks = [] for i in range(n): if i == 0: - blocks.append(block(input_channel, output_channel, s, expand_ratio=t, activation=act)) + blocks.append(block(input_channel, output_channel, s, expand_ratio=t, activation=activation_block)) else: - blocks.append(block(input_channel, output_channel, 1, expand_ratio=t, activation=act)) + blocks.append(block(input_channel, output_channel, 1, expand_ratio=t, activation=activation_block)) input_channel = output_channel self.add_module(f"layer{layer_index + 1}", nn.Sequential(*blocks)) # building last several layers - self.final_layer = conv_1x1_bn(input_channel, self.last_channel, activation=act) + self.final_layer = conv_1x1_bn(input_channel, self.last_channel, activation=activation_block) # building classifier self.classifier = nn.Sequential(nn.Dropout(0.2), nn.Linear(self.last_channel, n_class)) @@ -134,8 +134,3 @@ def _initialize_weights(self): n = m.weight.size(1) m.weight.data.normal_(0, 0.01) m.bias.data.zero_() - - -def test(): - model = MobileNetV2().eval() - print(model) diff --git a/tests/test_activations.py b/tests/test_activations.py index fafde46a2..880657f87 100644 --- a/tests/test_activations.py +++ b/tests/test_activations.py @@ -1,7 +1,7 @@ import torch import pytest -from pytorch_toolbelt.modules.activations import get_activation_module +from pytorch_toolbelt.modules.activations import instantiate_activation_block skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="Cuda is not available") @@ -11,7 +11,7 @@ ["none", "relu", "relu6", "leaky_relu", "elu", "selu", "celu", "mish", "swish", "hard_sigmoid", "hard_swish"], ) def test_activations(activation_name): - act = get_activation_module(activation_name) + act = instantiate_activation_block(activation_name) x = torch.randn(128).float() y = act(x) assert y.dtype == torch.float32 @@ -23,12 +23,12 @@ def test_activations(activation_name): ) @skip_if_no_cuda def test_activations_cuda(activation_name): - act = get_activation_module(activation_name) + act = instantiate_activation_block(activation_name) x = torch.randn(128).float().cuda() y = act(x) assert y.dtype == torch.float32 - act = get_activation_module(activation_name) + act = instantiate_activation_block(activation_name) x = torch.randn(128).half().cuda() y = act(x) assert y.dtype == torch.float16 diff --git a/tests/test_encoders.py b/tests/test_encoders.py index 7f28860a7..450ee0b29 100644 --- a/tests/test_encoders.py +++ b/tests/test_encoders.py @@ -11,6 +11,8 @@ @pytest.mark.parametrize( ["encoder", "encoder_params"], [ + [E.MobilenetV2Encoder, {}], + [E.MobilenetV3Encoder, {}], [E.Resnet34Encoder, {"pretrained": False}], [E.Resnet50Encoder, {"pretrained": False}], [E.SEResNeXt50Encoder, {"pretrained": False, "layers": [0, 1, 2, 3, 4]}], From 69a8e86aa860e0f5c4e06974daea90fbfe57fc09 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 27 Dec 2019 22:12:26 +0100 Subject: [PATCH 62/79] Update black dependencies --- .appveyor.yml | 2 -- black.toml | 2 +- setup.cfg | 2 +- setup.py | 2 +- 4 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.appveyor.yml b/.appveyor.yml index df7403430..0ee55d4f6 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -6,8 +6,6 @@ cache: environment: matrix: - - PYTHON: 'C:\Python27-x64' - - PYTHON: 'C:\Python35-x64' - PYTHON: 'C:\Python36-x64' - PYTHON: 'C:\Python37-x64' diff --git a/black.toml b/black.toml index d6600cc5b..06d00f2f6 100644 --- a/black.toml +++ b/black.toml @@ -7,7 +7,7 @@ [tool.black] line-length = 119 -target-version = ['py35', 'py36', 'py37', 'py38'] +target-version = ['py36', 'py37', 'py38'] include = '\.pyi?$' exclude = ''' /( diff --git a/setup.cfg b/setup.cfg index 518be03d4..4617f37c1 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,4 +1,4 @@ [flake8] -max-line-length = 179 +max-line-length = 119 exclude =.git,__pycache__,docs/source/conf.py,build,dist ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,D413,W504,E127,E203,W503 diff --git a/setup.py b/setup.py index 6622cc6a8..73d2479c7 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ def load_readme(): def get_test_requirements(): - requirements = ["pytest", "catalyst>=19.6.4", "black-19.3b0"] + requirements = ["pytest", "catalyst>=19.6.4", "black==19.3b0"] if sys.version_info < (3, 3): requirements.append("mock") return requirements From 038b93d00b10a89b6a89e07ba24cf7fb8f76b27a Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Sat, 28 Dec 2019 23:09:30 +0200 Subject: [PATCH 63/79] Simplify swish activation module --- pytorch_toolbelt/modules/activations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index 99ce1ddc0..d9dd53a6e 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -123,7 +123,7 @@ def forward(self, x): class Swish(nn.Module): def forward(self, input_tensor): - return SwishFunction.apply(input_tensor) + return swish(input_tensor) class HardSwish(nn.Module): From d17e748b7a429b181e0f1c621b1e3099ad15e094 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 2 Jan 2020 11:30:43 +0200 Subject: [PATCH 64/79] Black reformatting --- pytorch_toolbelt/inference/tta.py | 3 +-- pytorch_toolbelt/losses/functional.py | 4 +++- pytorch_toolbelt/modules/decoders/unet.py | 14 ++++++++++++-- pytorch_toolbelt/modules/decoders/unet_v2.py | 9 ++++++--- pytorch_toolbelt/modules/ocnet.py | 7 ++++++- pytorch_toolbelt/modules/unet.py | 4 +--- 6 files changed, 29 insertions(+), 12 deletions(-) diff --git a/pytorch_toolbelt/inference/tta.py b/pytorch_toolbelt/inference/tta.py index f78a4bbc7..6ce096cf1 100644 --- a/pytorch_toolbelt/inference/tta.py +++ b/pytorch_toolbelt/inference/tta.py @@ -188,8 +188,7 @@ def d4_image2mask(model: nn.Module, image: Tensor) -> Tensor: output = model(image) for aug, deaug in zip( - [F.torch_rot90, F.torch_rot180, F.torch_rot270], - [F.torch_rot270, F.torch_rot180, F.torch_rot90] + [F.torch_rot90, F.torch_rot180, F.torch_rot270], [F.torch_rot270, F.torch_rot180, F.torch_rot90] ): x = deaug(model(aug(image))) output += x diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 642554184..8c4b81ac4 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -73,7 +73,9 @@ def focal_loss_with_logits( # TODO: Mark as deprecated and emit warning def reduced_focal_loss(input: torch.Tensor, target: torch.Tensor, threshold=0.5, gamma=2.0, reduction="mean"): - return focal_loss_with_logits(input, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold) + return focal_loss_with_logits( + input, target, alpha=None, gamma=gamma, reduction=reduction, reduced_threshold=threshold + ) def soft_jaccard_score(y_pred: torch.Tensor, y_true: torch.Tensor, smooth=0.0, eps=1e-7, dims=None) -> torch.Tensor: diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index 92a4564c3..85853c1fa 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -16,13 +16,23 @@ def conv1x1(input, output): class UNetDecoder(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, abn_block=ABN, dropout=0.0, final_block=conv1x1): + def __init__( + self, + feature_maps: List[int], + decoder_features: int, + mask_channels: int, + abn_block=ABN, + dropout=0.0, + final_block=conv1x1, + ): super().__init__() if not isinstance(decoder_features, list): decoder_features = [decoder_features * (2 ** i) for i in range(len(feature_maps))] else: - assert len(decoder_features) == len(feature_maps), f"Incorrect number of decoder features: {decoder_features}, {feature_maps}" + assert len(decoder_features) == len( + feature_maps + ), f"Incorrect number of decoder features: {decoder_features}, {feature_maps}" self.center = UnetCentralBlock( in_dec_filters=feature_maps[-1], out_filters=decoder_features[-1], abn_block=abn_block diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index cae96a25e..8b8d92153 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -97,7 +97,7 @@ def forward(self, x: torch.Tensor, enc: torch.Tensor) -> Tuple[torch.Tensor, Lis class UNetDecoderV2(DecoderModule): - def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0., abn_block=ABN): + def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels: int, dropout=0.0, abn_block=ABN): super().__init__() if not isinstance(decoder_features, list): @@ -107,9 +107,12 @@ def __init__(self, feature_maps: List[int], decoder_features: int, mask_channels for block_index, in_enc_features in enumerate(feature_maps[:-1]): blocks.append( UnetDecoderBlockV2( - decoder_features[block_index + 1], in_enc_features, decoder_features[block_index], mask_channels, + decoder_features[block_index + 1], + in_enc_features, + decoder_features[block_index], + mask_channels, abn_block=abn_block, - post_dropout_rate=dropout + post_dropout_rate=dropout, ) ) diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index c1cab6603..97f343bc4 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -243,7 +243,12 @@ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, self.out_channels = in_channels self.f_key = nn.Sequential( nn.Conv2d( - in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, bias=False + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False, ), abn_block(self.key_channels), ) diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py index 9cac1fdc5..3908fc36b 100644 --- a/pytorch_toolbelt/modules/unet.py +++ b/pytorch_toolbelt/modules/unet.py @@ -60,9 +60,7 @@ def __init__( self.scale_mode = scale_mode self.align_corners = align_corners - self.conv1 = nn.Conv2d( - in_dec_filters + in_enc_filters, out_filters, kernel_size=3, padding=1, bias=False - ) + self.conv1 = nn.Conv2d(in_dec_filters + in_enc_filters, out_filters, kernel_size=3, padding=1, bias=False) self.abn1 = abn_block(out_filters) self.drop = nn.Dropout2d(dropout_rate, inplace=False) From d47040b318d04c344aad4173732560be4f0cf150 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 2 Jan 2020 11:32:18 +0200 Subject: [PATCH 65/79] Black reformatting --- pytorch_toolbelt/modules/decoders/unet.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_toolbelt/modules/decoders/unet.py b/pytorch_toolbelt/modules/decoders/unet.py index 85853c1fa..73fc750e3 100644 --- a/pytorch_toolbelt/modules/decoders/unet.py +++ b/pytorch_toolbelt/modules/decoders/unet.py @@ -1,11 +1,10 @@ from typing import List import torch -import torch.nn.functional as F from torch import nn -from ..activated_batch_norm import ABN from .common import DecoderModule +from ..activated_batch_norm import ABN from ..unet import UnetCentralBlock, UnetDecoderBlock __all__ = ["UNetDecoder"] From 20eb21d9217e48fa646ce2931b2c48b519dec01a Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 2 Jan 2020 11:34:16 +0200 Subject: [PATCH 66/79] Black reformatting --- pytorch_toolbelt/modules/decoders/unet_v2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index 8b8d92153..006eb120f 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, List +from typing import Tuple, List import torch import torch.nn.functional as F From 3f830cdc9d2f6613124f8ad8d7f2927dd4cd7966 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 2 Jan 2020 11:38:37 +0200 Subject: [PATCH 67/79] Optimize imports --- pytorch_toolbelt/inference/tiles.py | 4 ++-- pytorch_toolbelt/losses/__init__.py | 6 +++--- pytorch_toolbelt/losses/soft_bce.py | 6 +++--- pytorch_toolbelt/modules/activated_batch_norm.py | 3 ++- pytorch_toolbelt/modules/activations.py | 3 +-- pytorch_toolbelt/modules/backbone/efficient_net.py | 2 +- pytorch_toolbelt/modules/backbone/wider_resnet.py | 4 ++-- pytorch_toolbelt/modules/decoders/common.py | 2 -- pytorch_toolbelt/modules/decoders/deeplab.py | 1 + pytorch_toolbelt/modules/decoders/fpn.py | 3 ++- pytorch_toolbelt/modules/decoders/fpn_sum.py | 9 ++++----- pytorch_toolbelt/modules/decoders/hrnet.py | 2 +- pytorch_toolbelt/modules/decoders/unet_v2.py | 2 +- pytorch_toolbelt/modules/encoders/common.py | 3 +-- pytorch_toolbelt/modules/encoders/efficientnet.py | 3 +-- pytorch_toolbelt/modules/encoders/hrnet.py | 4 ++-- pytorch_toolbelt/modules/encoders/seresnet.py | 4 ++-- pytorch_toolbelt/modules/encoders/wide_resnet.py | 3 +-- pytorch_toolbelt/modules/ocnet.py | 3 ++- pytorch_toolbelt/modules/unet.py | 6 ++++-- pytorch_toolbelt/optimization/lr_schedules.py | 2 +- pytorch_toolbelt/utils/catalyst/__init__.py | 4 ++-- pytorch_toolbelt/utils/catalyst/metrics.py | 2 +- pytorch_toolbelt/utils/catalyst/opl.py | 2 +- pytorch_toolbelt/utils/catalyst/utils.py | 0 pytorch_toolbelt/utils/catalyst/visualization.py | 5 ++--- pytorch_toolbelt/utils/catalyst_utils.py | 1 - pytorch_toolbelt/utils/dataset_utils.py | 7 ++++--- pytorch_toolbelt/utils/torch_utils.py | 1 - pytorch_toolbelt/utils/visualization.py | 3 ++- pytorch_toolbelt/zoo/segmentation.py | 6 +++--- 31 files changed, 52 insertions(+), 54 deletions(-) delete mode 100644 pytorch_toolbelt/utils/catalyst/utils.py diff --git a/pytorch_toolbelt/inference/tiles.py b/pytorch_toolbelt/inference/tiles.py index a4d307473..ea86bfb6e 100644 --- a/pytorch_toolbelt/inference/tiles.py +++ b/pytorch_toolbelt/inference/tiles.py @@ -1,11 +1,11 @@ """Implementation of tile-based inference allowing to predict huge images that does not fit into GPU memory entirely in a sliding-window fashion and merging prediction mask back to full-resolution. """ +import math from typing import List -import numpy as np import cv2 -import math +import numpy as np import torch diff --git a/pytorch_toolbelt/losses/__init__.py b/pytorch_toolbelt/losses/__init__.py index 47a28d4bc..2b85ac17b 100644 --- a/pytorch_toolbelt/losses/__init__.py +++ b/pytorch_toolbelt/losses/__init__.py @@ -1,10 +1,10 @@ from __future__ import absolute_import +from .dice import * from .focal import * from .jaccard import * -from .dice import * -from .lovasz import * from .joint_loss import * -from .wing_loss import * +from .lovasz import * from .soft_bce import * from .soft_ce import * +from .wing_loss import * diff --git a/pytorch_toolbelt/losses/soft_bce.py b/pytorch_toolbelt/losses/soft_bce.py index a1b9650bd..6840e3040 100644 --- a/pytorch_toolbelt/losses/soft_bce.py +++ b/pytorch_toolbelt/losses/soft_bce.py @@ -1,8 +1,8 @@ -import torch -from torch import nn -import torch.nn.functional as F from typing import Optional +import torch.nn.functional as F +from torch import nn + __all__ = ["BCELoss", "SoftBCELoss"] diff --git a/pytorch_toolbelt/modules/activated_batch_norm.py b/pytorch_toolbelt/modules/activated_batch_norm.py index d6c16f8f9..3f62bc0e5 100644 --- a/pytorch_toolbelt/modules/activated_batch_norm.py +++ b/pytorch_toolbelt/modules/activated_batch_norm.py @@ -1,8 +1,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch.nn.parameter import Parameter from torch.nn import init +from torch.nn.parameter import Parameter + from .activations import ( ACT_LEAKY_RELU, ACT_HARD_SWISH, diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index d9dd53a6e..3856b6ab5 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -1,8 +1,7 @@ -from functools import partial - import torch from torch import nn from torch.nn import functional as F + from .identity import Identity __all__ = [ diff --git a/pytorch_toolbelt/modules/backbone/efficient_net.py b/pytorch_toolbelt/modules/backbone/efficient_net.py index c09a92971..5f0041756 100644 --- a/pytorch_toolbelt/modules/backbone/efficient_net.py +++ b/pytorch_toolbelt/modules/backbone/efficient_net.py @@ -6,7 +6,7 @@ import torch from torch import nn from torch.nn import functional as F -from torch.nn.init import kaiming_normal_, kaiming_uniform_ +from torch.nn.init import kaiming_uniform_ from ..activated_batch_norm import ABN from ..activated_group_norm import AGN diff --git a/pytorch_toolbelt/modules/backbone/wider_resnet.py b/pytorch_toolbelt/modules/backbone/wider_resnet.py index 393d34ed1..68677baba 100644 --- a/pytorch_toolbelt/modules/backbone/wider_resnet.py +++ b/pytorch_toolbelt/modules/backbone/wider_resnet.py @@ -1,10 +1,10 @@ from collections import OrderedDict from functools import partial -import torch +from torch import nn + from ..activated_batch_norm import ABN from ..pooling import GlobalAvgPool2d -from torch import nn class IdentityResidualBlock(nn.Module): diff --git a/pytorch_toolbelt/modules/decoders/common.py b/pytorch_toolbelt/modules/decoders/common.py index 45319f9e6..d42e61134 100644 --- a/pytorch_toolbelt/modules/decoders/common.py +++ b/pytorch_toolbelt/modules/decoders/common.py @@ -2,8 +2,6 @@ __all__ = ["DecoderModule", "SegmentationDecoderModule"] -from typing import List - class DecoderModule(nn.Module): def __init__(self): diff --git a/pytorch_toolbelt/modules/decoders/deeplab.py b/pytorch_toolbelt/modules/decoders/deeplab.py index 16339dfcd..7c9e29185 100644 --- a/pytorch_toolbelt/modules/decoders/deeplab.py +++ b/pytorch_toolbelt/modules/decoders/deeplab.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from .common import DecoderModule from ..activated_batch_norm import ABN diff --git a/pytorch_toolbelt/modules/decoders/fpn.py b/pytorch_toolbelt/modules/decoders/fpn.py index aa15d21d1..a8b562e15 100644 --- a/pytorch_toolbelt/modules/decoders/fpn.py +++ b/pytorch_toolbelt/modules/decoders/fpn.py @@ -1,6 +1,7 @@ -from torch import nn from typing import List, Optional +from torch import nn + from .common import DecoderModule from ..fpn import FPNBottleneckBlock, UpsampleAdd, FPNPredictionBlock diff --git a/pytorch_toolbelt/modules/decoders/fpn_sum.py b/pytorch_toolbelt/modules/decoders/fpn_sum.py index bfa29bab3..0eef6bc57 100644 --- a/pytorch_toolbelt/modules/decoders/fpn_sum.py +++ b/pytorch_toolbelt/modules/decoders/fpn_sum.py @@ -3,13 +3,12 @@ from typing import List, Tuple, Optional, Union import torch +import torch.nn.functional as F +from torch import Tensor, nn + +from .common import SegmentationDecoderModule from ..activated_batch_norm import ABN from ..identity import Identity -from .common import SegmentationDecoderModule - - -from torch import Tensor, nn -import torch.nn.functional as F __all__ = ["FPNSumDecoder", "FPNSumDecoderBlock", "FPNSumCenterBlock"] diff --git a/pytorch_toolbelt/modules/decoders/hrnet.py b/pytorch_toolbelt/modules/decoders/hrnet.py index e36331841..e2b3e077d 100644 --- a/pytorch_toolbelt/modules/decoders/hrnet.py +++ b/pytorch_toolbelt/modules/decoders/hrnet.py @@ -1,7 +1,7 @@ from collections import OrderedDict +from typing import List from torch import nn, Tensor -from typing import List from .common import DecoderModule diff --git a/pytorch_toolbelt/modules/decoders/unet_v2.py b/pytorch_toolbelt/modules/decoders/unet_v2.py index 006eb120f..3547ddb26 100644 --- a/pytorch_toolbelt/modules/decoders/unet_v2.py +++ b/pytorch_toolbelt/modules/decoders/unet_v2.py @@ -4,8 +4,8 @@ import torch.nn.functional as F from torch import nn -from ..activated_batch_norm import ABN from .common import DecoderModule +from ..activated_batch_norm import ABN __all__ = ["UNetDecoderV2", "UnetCentralBlockV2", "UnetDecoderBlockV2"] diff --git a/pytorch_toolbelt/modules/encoders/common.py b/pytorch_toolbelt/modules/encoders/common.py index 79badf387..646dca43e 100644 --- a/pytorch_toolbelt/modules/encoders/common.py +++ b/pytorch_toolbelt/modules/encoders/common.py @@ -3,13 +3,12 @@ Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. """ import math +import warnings from typing import List import torch from torch import nn -import warnings - __all__ = ["EncoderModule", "_take", "make_n_channel_input"] diff --git a/pytorch_toolbelt/modules/encoders/efficientnet.py b/pytorch_toolbelt/modules/encoders/efficientnet.py index 07331d930..a2a2ab5a6 100644 --- a/pytorch_toolbelt/modules/encoders/efficientnet.py +++ b/pytorch_toolbelt/modules/encoders/efficientnet.py @@ -1,3 +1,4 @@ +from .common import EncoderModule, _take, make_n_channel_input from ..backbone.efficient_net import ( efficient_net_b0, efficient_net_b6, @@ -9,8 +10,6 @@ efficient_net_b7, ) -from .common import EncoderModule, _take, make_n_channel_input - __all__ = [ "EfficientNetEncoder", "EfficientNetB0Encoder", diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index d946d4014..49c8e4b0a 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -1,9 +1,9 @@ from collections import OrderedDict +from typing import List import torch -from torch import nn import torch.nn.functional as F -from typing import List +from torch import nn from .common import EncoderModule, make_n_channel_input diff --git a/pytorch_toolbelt/modules/encoders/seresnet.py b/pytorch_toolbelt/modules/encoders/seresnet.py index a71ddeb85..084dc6653 100644 --- a/pytorch_toolbelt/modules/encoders/seresnet.py +++ b/pytorch_toolbelt/modules/encoders/seresnet.py @@ -2,11 +2,11 @@ Encodes listed here provides easy way to swap backbone of classification/segmentation/detection model. """ -from torch import Tensor from typing import List -from .common import EncoderModule, _take, make_n_channel_input +from torch import Tensor +from .common import EncoderModule, _take, make_n_channel_input from ..backbone.senet import ( SENet, se_resnext50_32x4d, diff --git a/pytorch_toolbelt/modules/encoders/wide_resnet.py b/pytorch_toolbelt/modules/encoders/wide_resnet.py index a77c48032..129e18a86 100644 --- a/pytorch_toolbelt/modules/encoders/wide_resnet.py +++ b/pytorch_toolbelt/modules/encoders/wide_resnet.py @@ -1,10 +1,9 @@ from typing import List +from .common import EncoderModule, _take, make_n_channel_input from ..activated_batch_norm import ABN from ..backbone.wider_resnet import WiderResNet, WiderResNetA2 -from .common import EncoderModule, _take, make_n_channel_input - __all__ = [ "WiderResnetEncoder", "WiderResnet16A2Encoder", diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index 97f343bc4..8db4e2dd3 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -1,8 +1,9 @@ # Credit: https://github.com/PkuRainBow/OCNet.pytorch/blob/master/oc_module/asp_oc_block.py import torch +import torch.nn.functional as F from torch import nn + from .activated_batch_norm import ABN -import torch.nn.functional as F __all__ = ["ObjectContextBlock", "ASPObjectContextBlock", "PyramidObjectContextBlock"] diff --git a/pytorch_toolbelt/modules/unet.py b/pytorch_toolbelt/modules/unet.py index 3908fc36b..a978d3231 100644 --- a/pytorch_toolbelt/modules/unet.py +++ b/pytorch_toolbelt/modules/unet.py @@ -1,7 +1,9 @@ +from typing import Optional + import torch -from torch import nn import torch.nn.functional as F -from typing import Optional +from torch import nn + from .activated_batch_norm import ABN __all__ = ["UnetEncoderBlock", "UnetCentralBlock", "UnetDecoderBlock"] diff --git a/pytorch_toolbelt/optimization/lr_schedules.py b/pytorch_toolbelt/optimization/lr_schedules.py index a841187a9..5f2ae2881 100644 --- a/pytorch_toolbelt/optimization/lr_schedules.py +++ b/pytorch_toolbelt/optimization/lr_schedules.py @@ -1,7 +1,7 @@ import math + import numpy as np from torch import nn - from torch.optim.lr_scheduler import _LRScheduler, LambdaLR from torch.optim.optimizer import Optimizer diff --git a/pytorch_toolbelt/utils/catalyst/__init__.py b/pytorch_toolbelt/utils/catalyst/__init__.py index 7ac9ec9af..1d3f92f37 100644 --- a/pytorch_toolbelt/utils/catalyst/__init__.py +++ b/pytorch_toolbelt/utils/catalyst/__init__.py @@ -1,6 +1,6 @@ from __future__ import absolute_import -from .metrics import * -from .visualization import * from .criterions import * +from .metrics import * from .opl import * +from .visualization import * diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 2d410ceff..811c1dc2f 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -7,9 +7,9 @@ from sklearn.metrics import f1_score from torchnet.meter import ConfusionMeter -from pytorch_toolbelt.utils.visualization import render_figure_to_tensor, plot_confusion_matrix from .visualization import get_tensorboard_logger from ..torch_utils import to_numpy +from ..visualization import render_figure_to_tensor, plot_confusion_matrix __all__ = [ "pixel_accuracy", diff --git a/pytorch_toolbelt/utils/catalyst/opl.py b/pytorch_toolbelt/utils/catalyst/opl.py index 31ae0d7ea..920eb0598 100644 --- a/pytorch_toolbelt/utils/catalyst/opl.py +++ b/pytorch_toolbelt/utils/catalyst/opl.py @@ -1,5 +1,5 @@ -from catalyst.dl import Callback, CallbackOrder, RunnerState import numpy as np +from catalyst.dl import Callback, CallbackOrder, RunnerState from ..torch_utils import to_numpy diff --git a/pytorch_toolbelt/utils/catalyst/utils.py b/pytorch_toolbelt/utils/catalyst/utils.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index 26c6afced..9c668590d 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -6,13 +6,12 @@ import numpy as np import torch import torch.nn.functional as F - from catalyst.dl import Callback, RunnerState, CallbackOrder from catalyst.dl.callbacks import TensorboardLogger from catalyst.utils.tensorboard import SummaryWriter -from pytorch_toolbelt.utils.torch_utils import rgb_image_from_tensor, to_numpy -from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image +from ..torch_utils import rgb_image_from_tensor, to_numpy +from ..torch_utils import tensor_from_rgb_image __all__ = [ "get_tensorboard_logger", diff --git a/pytorch_toolbelt/utils/catalyst_utils.py b/pytorch_toolbelt/utils/catalyst_utils.py index b94f700e9..4433cc0ca 100644 --- a/pytorch_toolbelt/utils/catalyst_utils.py +++ b/pytorch_toolbelt/utils/catalyst_utils.py @@ -1,4 +1,3 @@ import warnings -from pytorch_toolbelt.utils.catalyst import * warnings.warn("Please use 'from pytorch_toolbelt.utils.catalyst import *' instead") diff --git a/pytorch_toolbelt/utils/dataset_utils.py b/pytorch_toolbelt/utils/dataset_utils.py index ee849ab96..c56fe2072 100644 --- a/pytorch_toolbelt/utils/dataset_utils.py +++ b/pytorch_toolbelt/utils/dataset_utils.py @@ -1,10 +1,11 @@ from typing import Callable, List -from pytorch_toolbelt.inference.tiles import ImageSlicer -from pytorch_toolbelt.utils.fs import id_from_fname -from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image, tensor_from_mask_image from torch.utils.data import Dataset, ConcatDataset +from .fs import id_from_fname +from .torch_utils import tensor_from_rgb_image, tensor_from_mask_image +from ..inference.tiles import ImageSlicer + class ImageMaskDataset(Dataset): def __init__( diff --git a/pytorch_toolbelt/utils/torch_utils.py b/pytorch_toolbelt/utils/torch_utils.py index f611d8b30..60fcb9fe0 100644 --- a/pytorch_toolbelt/utils/torch_utils.py +++ b/pytorch_toolbelt/utils/torch_utils.py @@ -3,7 +3,6 @@ """ import collections import warnings -from typing import Tuple import numpy as np import torch diff --git a/pytorch_toolbelt/utils/visualization.py b/pytorch_toolbelt/utils/visualization.py index e1c45e42e..4742a1afd 100644 --- a/pytorch_toolbelt/utils/visualization.py +++ b/pytorch_toolbelt/utils/visualization.py @@ -1,9 +1,10 @@ from __future__ import absolute_import import itertools + import numpy as np -from pytorch_toolbelt.utils.torch_utils import tensor_from_rgb_image +from .torch_utils import tensor_from_rgb_image def plot_confusion_matrix( diff --git a/pytorch_toolbelt/zoo/segmentation.py b/pytorch_toolbelt/zoo/segmentation.py index 6af13eff2..56aa1154e 100644 --- a/pytorch_toolbelt/zoo/segmentation.py +++ b/pytorch_toolbelt/zoo/segmentation.py @@ -1,9 +1,9 @@ -from ..modules import ABN -from ..modules import encoders as E -from ..modules import decoders as D from torch import nn, Tensor from torch.nn import functional as F +from ..modules import ABN +from ..modules import decoders as D +from ..modules import encoders as E __all__ = [ "FPNSumSegmentationModel", From 2e71f3c91c95ec3d13fc42719e350c2216253332 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Thu, 2 Jan 2020 12:57:34 +0200 Subject: [PATCH 68/79] Added naive swish implementation that is tracing-friendly --- .../modules/activated_batch_norm.py | 4 +- pytorch_toolbelt/modules/activations.py | 52 +++++++++---------- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/pytorch_toolbelt/modules/activated_batch_norm.py b/pytorch_toolbelt/modules/activated_batch_norm.py index 3f62bc0e5..c9e71316d 100644 --- a/pytorch_toolbelt/modules/activated_batch_norm.py +++ b/pytorch_toolbelt/modules/activated_batch_norm.py @@ -19,7 +19,7 @@ hard_swish, mish, swish, -) + ACT_SWISH_NAIVE, swish_naive) __all__ = ["ABN"] @@ -150,6 +150,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return F.selu(x, inplace=True) elif self.activation == ACT_SWISH: return swish(x) + elif self.activation == ACT_SWISH_NAIVE: + return swish_naive(x) elif self.activation == ACT_MISH: return mish(x) elif self.activation == ACT_HARD_SWISH: diff --git a/pytorch_toolbelt/modules/activations.py b/pytorch_toolbelt/modules/activations.py index 3856b6ab5..2a378ab6a 100644 --- a/pytorch_toolbelt/modules/activations.py +++ b/pytorch_toolbelt/modules/activations.py @@ -38,6 +38,7 @@ ACT_MISH = "mish" ACT_HARD_SWISH = "hard_swish" ACT_HARD_SIGMOID = "hard_sigmoid" +ACT_SWISH_NAIVE = "swish_naive" class SwishFunction(torch.autograd.Function): @@ -103,6 +104,10 @@ def swish(x): return SwishFunction.apply(x) +def swish_naive(x): + return x * x.sigmoid() + + def hard_sigmoid(x, inplace=False): return F.relu6(x + 3, inplace) / 6 @@ -120,6 +125,11 @@ def forward(self, x): return hard_sigmoid(x, inplace=self.inplace) +class SwishNaive(nn.Module): + def forward(self, input_tensor): + return swish_naive(input_tensor) + + class Swish(nn.Module): def forward(self, input_tensor): return swish(input_tensor) @@ -136,49 +146,35 @@ def forward(self, x): def get_activation_block(activation_name: str): ACTIVATIONS = { - "relu": nn.ReLU, - "relu6": nn.ReLU6, - "leaky_relu": nn.LeakyReLU, - "elu": nn.ELU, - "selu": nn.SELU, + ACT_RELU: nn.ReLU, + ACT_RELU6: nn.ReLU6, + ACT_LEAKY_RELU: nn.LeakyReLU, + ACT_ELU: nn.ELU, + ACT_SELU: nn.SELU, "celu": nn.CELU, "glu": nn.GLU, "prelu": nn.PReLU, - "swish": Swish, - "mish": Mish, - "hard_sigmoid": HardSigmoid, - "hard_swish": HardSwish, - "none": Identity, + ACT_SWISH: Swish, + ACT_SWISH_NAIVE: SwishNaive, + ACT_MISH: Mish, + ACT_HARD_SIGMOID: HardSigmoid, + ACT_HARD_SWISH: HardSwish, + ACT_NONE: Identity, } return ACTIVATIONS[activation_name.lower()] def instantiate_activation_block(activation_name: str, **kwargs) -> nn.Module: - ACTIVATIONS = { - "relu": nn.ReLU, - "relu6": nn.ReLU6, - "leaky_relu": nn.LeakyReLU, - "elu": nn.ELU, - "selu": nn.SELU, - "celu": nn.CELU, - "glu": nn.GLU, - "prelu": nn.PReLU, - "swish": Swish, - "mish": Mish, - "hard_sigmoid": HardSigmoid, - "hard_swish": HardSwish, - "none": Identity, - } - - return ACTIVATIONS[activation_name.lower()](**kwargs) + block = get_activation_block(activation_name) + return block(**kwargs) def sanitize_activation_name(activation_name: str) -> str: """ Return reasonable activation name for initialization in `kaiming_uniform_` for hipster activations """ - if activation_name in {ACT_MISH, ACT_SWISH}: + if activation_name in {ACT_MISH, ACT_SWISH, ACT_SWISH_NAIVE}: return ACT_LEAKY_RELU return activation_name From a8917c0f821423bb66a279bd287220e1d62fd0f7 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 3 Jan 2020 10:30:14 +0200 Subject: [PATCH 69/79] Add Pillow and torchnet dependencies --- setup.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 73d2479c7..c2b93d8db 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ AUTHOR = "Eugene Khvedchenya" REQUIRES_PYTHON = ">=3.6.0" -DEPENDENCIES = ["torch>=1.1", "torchvision>=0.3", "opencv-python>=4.0"] +DEPENDENCIES = ["torch>=1.1", "torchvision>=0.3", "opencv-python>=4.1", "Pillow>=6.2", "torchnet>=0.0.5.1"] EXCLUDE_FROM_PACKAGES = ["contrib", "docs", "tests", "examples"] CURDIR = os.path.abspath(os.path.dirname(__file__)) @@ -66,7 +66,18 @@ def get_test_requirements(): python_requires=REQUIRES_PYTHON, extras_require={"tests": get_test_requirements()}, include_package_data=True, - keywords=["PyTorch", "Kaggle", "Deep Learning", "Machine Learning", "ResNet", "VGG", "ResNext", "Unet", "Focal"], + keywords=[ + "PyTorch", + "Kaggle", + "Deep Learning", + "Machine Learning", + "ResNet", + "VGG", + "ResNext", + "Unet", + "Focal", + "FPN", + ], scripts=[], license="License :: OSI Approved :: MIT License", classifiers=[ From 2c884736cc13f5e777421e87361fd10b7aa643dd Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 3 Jan 2020 12:30:31 +0200 Subject: [PATCH 70/79] Relax torchnet dependency version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index c2b93d8db..7b4ac85a1 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ AUTHOR = "Eugene Khvedchenya" REQUIRES_PYTHON = ">=3.6.0" -DEPENDENCIES = ["torch>=1.1", "torchvision>=0.3", "opencv-python>=4.1", "Pillow>=6.2", "torchnet>=0.0.5.1"] +DEPENDENCIES = ["torch>=1.1", "torchvision>=0.3", "opencv-python>=4.1", "Pillow>=6.2", "torchnet>=0.0.4"] EXCLUDE_FROM_PACKAGES = ["contrib", "docs", "tests", "examples"] CURDIR = os.path.abspath(os.path.dirname(__file__)) From 4077a3b5d670b3c728b04e30aa7e88f1abec9259 Mon Sep 17 00:00:00 2001 From: RG Date: Sat, 4 Jan 2020 03:02:03 +0100 Subject: [PATCH 71/79] performance of ImageSlicer weight=pyramid --- pytorch_toolbelt/inference/tiles.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pytorch_toolbelt/inference/tiles.py b/pytorch_toolbelt/inference/tiles.py index ea86bfb6e..d9ff05841 100644 --- a/pytorch_toolbelt/inference/tiles.py +++ b/pytorch_toolbelt/inference/tiles.py @@ -28,14 +28,18 @@ def compute_pyramid_patch_weight_loss(width, height) -> np.ndarray: Dc = np.zeros((width, height)) De = np.zeros((width, height)) - for i in range(width): - for j in range(height): - Dc[i, j] = np.sqrt(np.square(i - xc + 0.5) + np.square(j - yc + 0.5)) - De_l = np.sqrt(np.square(i - xl + 0.5) + np.square(j - j + 0.5)) - De_r = np.sqrt(np.square(i - xr + 0.5) + np.square(j - j + 0.5)) - De_b = np.sqrt(np.square(i - i + 0.5) + np.square(j - yb + 0.5)) - De_t = np.sqrt(np.square(i - i + 0.5) + np.square(j - yt + 0.5)) - De[i, j] = np.min([De_l, De_r, De_b, De_t]) + Dcx = np.square(np.arange(width) - xc + 0.5) + Dcy = np.square(np.arange(height) - yc + 0.5) + Dc = np.sqrt(Dcx[np.newaxis].transpose() + Dcy) + + De_l = np.square(np.arange(width) - xl + 0.5) + np.square(0.5) + De_r = np.square(np.arange(width) - xr + 0.5) + np.square(0.5) + De_b = np.square(0.5) + np.square(np.arange(height) - yb + 0.5) + De_t = np.square(0.5) + np.square(np.arange(height) - yt + 0.5) + + De_x = np.sqrt(np.minimum(De_l, De_r)) + De_y = np.sqrt(np.minimum(De_b, De_t)) + De = np.minimum(De_x[np.newaxis].transpose(), De_y) alpha = (width * height) / np.sum(np.divide(De, np.add(Dc, De))) W = alpha * np.divide(De, np.add(Dc, De)) From d9cea7018a4d96e2d1672e1706169b3923e1f52e Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 5 Jan 2020 20:12:10 +0100 Subject: [PATCH 72/79] Set Pillow version >=6, <7 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 7b4ac85a1..0218f7617 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,7 @@ AUTHOR = "Eugene Khvedchenya" REQUIRES_PYTHON = ">=3.6.0" -DEPENDENCIES = ["torch>=1.1", "torchvision>=0.3", "opencv-python>=4.1", "Pillow>=6.2", "torchnet>=0.0.4"] +DEPENDENCIES = ["torch>=1.1", "torchvision>=0.3", "opencv-python>=4.1", "Pillow>=6.0,<7.0", "torchnet>=0.0.4"] EXCLUDE_FROM_PACKAGES = ["contrib", "docs", "tests", "examples"] CURDIR = os.path.abspath(os.path.dirname(__file__)) From 50f24e21bd862ed5f5512db09c7225ece12c864c Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 5 Jan 2020 20:35:00 +0100 Subject: [PATCH 73/79] Remove upsampling to produce final feature map of stride 4. Instead, return feature maps of stride 4/8/16/32 --- pytorch_toolbelt/modules/encoders/hrnet.py | 42 ++++++---------------- 1 file changed, 10 insertions(+), 32 deletions(-) diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index 49c8e4b0a..34f13a478 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch import nn -from .common import EncoderModule, make_n_channel_input +from .common import EncoderModule, make_n_channel_input, _take __all__ = ["HRNetV2Encoder48", "HRNetV2Encoder18", "HRNetV2Encoder34"] @@ -227,18 +227,17 @@ def forward(self, x): class HRNetEncoderBase(EncoderModule): def __init__(self, input_channels=3, width=48, layers: List[int] = None): if layers is None: - # By default return only last feature map - layers = [4] + layers = [1, 2, 3, 4] channels = [ 64, - 256, - width * 2 + width, - width * 4 + width * 2 + width, - width * 8 + width * 4 + width * 2 + width, + width, + width * 2, + width * 4, + width * 8, ] - strides = [4, 4, 4, 4, 4] + strides = [4, 4, 8, 16, 32] super().__init__(channels=channels, strides=strides, layers=layers) @@ -392,14 +391,8 @@ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True): return nn.Sequential(*modules), num_inchannels def forward(self, x): - outputs = [] - x = self.layer0(x) - if 0 in self._layers: - outputs.append(x) - - x = self.layer1(x) - if 1 in self._layers: - outputs.append(x) + layer0 = self.layer0(x) + x = self.layer1(layer0) x_list = [] for i in range(self.stage2_cfg["NUM_BRANCHES"]): @@ -408,9 +401,6 @@ def forward(self, x): else: x_list.append(x) y_list = self.stage2(x_list) - if 2 in self._layers: - x = self.resize_and_concatenate_input(y_list) - outputs.append(x) x_list = [] for i in range(self.stage3_cfg["NUM_BRANCHES"]): @@ -419,9 +409,6 @@ def forward(self, x): else: x_list.append(y_list[i]) y_list = self.stage3(x_list) - if 3 in self._layers: - x = self.resize_and_concatenate_input(y_list) - outputs.append(x) x_list = [] for i in range(self.stage4_cfg["NUM_BRANCHES"]): @@ -430,19 +417,10 @@ def forward(self, x): else: x_list.append(y_list[i]) y_list = self.stage4(x_list) - if 4 in self._layers: - x = self.resize_and_concatenate_input(y_list) - outputs.append(x) + outputs = _take([layer0] + y_list, self._layers) return outputs - @staticmethod - def resize_and_concatenate_input(x: List[torch.Tensor]) -> torch.Tensor: - x0_h, x0_w = x[0].size(2), x[0].size(3) - x = [x[0]] + [F.interpolate(xi, size=(x0_h, x0_w), mode="bilinear", align_corners=False) for xi in x[1:]] - x = torch.cat(x, dim=1) - return x - def change_input_channels(self, input_channels: int, mode="auto"): self.layer0.conv1 = make_n_channel_input(self.layer0.conv1, input_channels, mode) From e15d020fb1e80a4f79077d728ba92929331f7a02 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 5 Jan 2020 21:22:10 +0100 Subject: [PATCH 74/79] Black8 reformat --- pytorch_toolbelt/modules/activated_batch_norm.py | 4 +++- pytorch_toolbelt/modules/encoders/hrnet.py | 8 +------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/pytorch_toolbelt/modules/activated_batch_norm.py b/pytorch_toolbelt/modules/activated_batch_norm.py index c9e71316d..d55c338e6 100644 --- a/pytorch_toolbelt/modules/activated_batch_norm.py +++ b/pytorch_toolbelt/modules/activated_batch_norm.py @@ -19,7 +19,9 @@ hard_swish, mish, swish, - ACT_SWISH_NAIVE, swish_naive) + ACT_SWISH_NAIVE, + swish_naive, +) __all__ = ["ABN"] diff --git a/pytorch_toolbelt/modules/encoders/hrnet.py b/pytorch_toolbelt/modules/encoders/hrnet.py index 34f13a478..9e04a2dd3 100644 --- a/pytorch_toolbelt/modules/encoders/hrnet.py +++ b/pytorch_toolbelt/modules/encoders/hrnet.py @@ -229,13 +229,7 @@ def __init__(self, input_channels=3, width=48, layers: List[int] = None): if layers is None: layers = [1, 2, 3, 4] - channels = [ - 64, - width, - width * 2, - width * 4, - width * 8, - ] + channels = [64, width, width * 2, width * 4, width * 8] strides = [4, 4, 8, 16, 32] From 2d0bf7e00f544cf92aab4e59b35ea422ea0cfab8 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Sun, 5 Jan 2020 21:25:13 +0100 Subject: [PATCH 75/79] Black reformat --- pytorch_toolbelt/losses/dice.py | 3 ++- pytorch_toolbelt/losses/focal.py | 3 ++- pytorch_toolbelt/losses/jaccard.py | 3 ++- pytorch_toolbelt/modules/ocnet.py | 2 +- pytorch_toolbelt/modules/scse.py | 3 ++- setup.cfg | 2 +- 6 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pytorch_toolbelt/losses/dice.py b/pytorch_toolbelt/losses/dice.py index 9ff362bd0..2d06819e4 100644 --- a/pytorch_toolbelt/losses/dice.py +++ b/pytorch_toolbelt/losses/dice.py @@ -25,7 +25,8 @@ def __init__(self, mode: str, classes: List[int] = None, log_loss=False, from_lo """ :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} - :param classes: Optional list of classes that contribute in loss computation; By default, all channels are included. + :param classes: Optional list of classes that contribute in loss computation; + By default, all channels are included. :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` :param from_logits: If True assumes input is raw logits :param smooth: diff --git a/pytorch_toolbelt/losses/focal.py b/pytorch_toolbelt/losses/focal.py index f8da7d11f..1ebc49eb1 100644 --- a/pytorch_toolbelt/losses/focal.py +++ b/pytorch_toolbelt/losses/focal.py @@ -15,7 +15,8 @@ def __init__( :param alpha: Prior probability of having positive value in target. :param gamma: Power factor for dampening weight (focal strenght). - :param ignore_index: If not None, targets may contain values to be ignored. Target values equal to ignore_index will be ignored from loss computation. + :param ignore_index: If not None, targets may contain values to be ignored. + Target values equal to ignore_index will be ignored from loss computation. :param reduced: :param threshold: """ diff --git a/pytorch_toolbelt/losses/jaccard.py b/pytorch_toolbelt/losses/jaccard.py index eeaed2986..cdce18aff 100644 --- a/pytorch_toolbelt/losses/jaccard.py +++ b/pytorch_toolbelt/losses/jaccard.py @@ -25,7 +25,8 @@ def __init__(self, mode: str, classes: List[int] = None, log_loss=False, from_lo """ :param mode: Metric mode {'binary', 'multiclass', 'multilabel'} - :param classes: Optional list of classes that contribute in loss computation; By default, all channels are included. + :param classes: Optional list of classes that contribute in loss computation; + By default, all channels are included. :param log_loss: If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard` :param from_logits: If True assumes input is raw logits :param smooth: diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index 8db4e2dd3..e219ea8a7 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -240,7 +240,7 @@ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, self.out_channels = out_channels self.key_channels = key_channels self.value_channels = value_channels - if out_channels == None: + if out_channels is None: self.out_channels = in_channels self.f_key = nn.Sequential( nn.Conv2d( diff --git a/pytorch_toolbelt/modules/scse.py b/pytorch_toolbelt/modules/scse.py index 742e11fd5..f83f6130c 100644 --- a/pytorch_toolbelt/modules/scse.py +++ b/pytorch_toolbelt/modules/scse.py @@ -1,4 +1,5 @@ -"""Implementation of the CoordConv modules from "Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks" +"""Implementation of the CoordConv modules from +"Concurrent Spatial and Channel Squeeze & Excitation in Fully Convolutional Networks" Original paper: https://arxiv.org/abs/1803.02579 """ diff --git a/setup.cfg b/setup.cfg index 4617f37c1..12a522bd9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,4 +1,4 @@ [flake8] max-line-length = 119 exclude =.git,__pycache__,docs/source/conf.py,build,dist -ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,D413,W504,E127,E203,W503 +ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,D413,W504,E127,E203,W503,E501 From 374916371caa2a4d940ca53c64b1fea72983246e Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 9 Jan 2020 15:12:14 +0200 Subject: [PATCH 76/79] Fixes case when data contains non-tensor types --- pytorch_toolbelt/utils/catalyst/visualization.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytorch_toolbelt/utils/catalyst/visualization.py b/pytorch_toolbelt/utils/catalyst/visualization.py index 9c668590d..e69db15c3 100644 --- a/pytorch_toolbelt/utils/catalyst/visualization.py +++ b/pytorch_toolbelt/utils/catalyst/visualization.py @@ -82,9 +82,7 @@ def to_cpu(self, data): return data.detach().cpu() if isinstance(data, (list, tuple)): return [self.to_cpu(value) for value in data] - if isinstance(data, str): - return data - raise ValueError("Unsupported type", type(data)) + return data def on_loader_start(self, state): self.best_score = None From f2d88aa68aa02fdc40906ac882da236ecacb116e Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Thu, 9 Jan 2020 15:12:33 +0200 Subject: [PATCH 77/79] Bugfix of edge case when all elements ignored --- pytorch_toolbelt/utils/catalyst/metrics.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_toolbelt/utils/catalyst/metrics.py b/pytorch_toolbelt/utils/catalyst/metrics.py index 811c1dc2f..fa0a30fee 100644 --- a/pytorch_toolbelt/utils/catalyst/metrics.py +++ b/pytorch_toolbelt/utils/catalyst/metrics.py @@ -118,8 +118,9 @@ def on_batch_end(self, state: RunnerState): outputs = outputs[mask] targets = targets[mask] - targets = targets.type_as(outputs) - self.confusion_matrix.add(predicted=outputs, target=targets) + if len(targets): + targets = targets.type_as(outputs) + self.confusion_matrix.add(predicted=outputs, target=targets) def on_loader_end(self, state): if self.class_names is None: From 225cc289b0e6aae535cf6d4e562ac90a949ff6e7 Mon Sep 17 00:00:00 2001 From: Ievgen Khvedchenia Date: Fri, 17 Jan 2020 17:07:10 +0200 Subject: [PATCH 78/79] Fix missing batchnorm call for FPNBottleneckBlockBN --- pytorch_toolbelt/modules/fpn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_toolbelt/modules/fpn.py b/pytorch_toolbelt/modules/fpn.py index 86eea6a24..3a440ec02 100644 --- a/pytorch_toolbelt/modules/fpn.py +++ b/pytorch_toolbelt/modules/fpn.py @@ -37,6 +37,7 @@ def __init__(self, input_channels, output_channels): def forward(self, x): x = self.conv(x) + x = self.bn(x) return x From 80e10634c20e8a06fb1309daa10d1948c97647d7 Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Fri, 17 Jan 2020 21:14:05 +0100 Subject: [PATCH 79/79] Remove unused variable --- pytorch_toolbelt/modules/ocnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/modules/ocnet.py b/pytorch_toolbelt/modules/ocnet.py index e219ea8a7..c55c6687f 100644 --- a/pytorch_toolbelt/modules/ocnet.py +++ b/pytorch_toolbelt/modules/ocnet.py @@ -264,7 +264,7 @@ def __init__(self, in_channels, key_channels, value_channels, out_channels=None, nn.init.constant(self.W.bias, 0) def forward(self, x): - batch_size, c, h, w = x.size(0), x.size(1), x.size(2), x.size(3) + batch_size, _, h, w = x.size(0), x.size(1), x.size(2), x.size(3) local_x = [] local_y = []