diff --git a/README.md b/README.md index 3d19058b..f812199e 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,7 @@ pip install flowvision==0.1.0
  • InceptionV3
  • ResNet
  • ResNeXt
  • +
  • SENet
  • DenseNet
  • ShuffleNetV2
  • MobileNetV2
  • diff --git a/docs/source/flowvision.models.rst b/docs/source/flowvision.models.rst index 72faed70..65211a9f 100644 --- a/docs/source/flowvision.models.rst +++ b/docs/source/flowvision.models.rst @@ -16,6 +16,7 @@ architectures for image classification: - `InceptionV3`_ - `ResNet`_ - `ResNeXt`_ +- `SENet`_ - `DenseNet`_ - `ShuffleNetV2`_ - `MobileNetV2`_ @@ -50,6 +51,7 @@ architectures for image classification: .. _MobileNetV2: https://arxiv.org/abs/1801.04381 .. _MobileNetV3: https://arxiv.org/abs/1905.02244 .. _ResNeXt: https://arxiv.org/abs/1611.05431 +.. _SENet: https://arxiv.org/abs/1709.01507 .. _Res2Net: https://arxiv.org/abs/1904.01169 .. _ReXNet: https://arxiv.org/abs/2007.00992 .. _MNASNet: https://arxiv.org/abs/1807.11626 @@ -220,6 +222,18 @@ ReXNet rexnet_lite_2_0, +SENet +-------- +.. automodule:: flowvision.models + :members: + senet154, + se_resnet50, + se_resnet101, + se_resnet152, + se_resnext50_32x4d, + se_resnext101_32x4d, + + ViT ------ .. automodule:: flowvision.models diff --git a/flowvision/models/__init__.py b/flowvision/models/__init__.py index f1331872..969eb492 100644 --- a/flowvision/models/__init__.py +++ b/flowvision/models/__init__.py @@ -24,6 +24,7 @@ from .efficientnet import * from .vision_transformer import * from .convnext import * +from .senet import * from . import neural_style_transfer from . import detection diff --git a/flowvision/models/senet.py b/flowvision/models/senet.py new file mode 100644 index 00000000..ebf42174 --- /dev/null +++ b/flowvision/models/senet.py @@ -0,0 +1,467 @@ +""" +Modified from https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py +""" + +from collections import OrderedDict +import math + +import oneflow as flow +import oneflow.nn as nn +import oneflow.nn.functional as F + +from flowvision.layers import trunc_normal_, DropPath, SEModule +from .registry import ModelCreator +from .utils import load_state_dict_from_url + + +model_urls = { + "senet154": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/SENet/senet154.zip", + "se_resnet50": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/SENet/se_resnet50.zip", + "se_resnet101": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/SENet/se_resnet101.zip", + "se_resnet152": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/SENet/se_resnet152.zip", + "se_resnext50_32x4d": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/SENet/se_resnext50_32x4d.zip", + "se_resnext101_32x4d": "https://oneflow-public.oss-cn-beijing.aliyuncs.com/model_zoo/flowvision/classification/SENet/se_resnext101_32x4d.zip", +} + + +class Bottleneck(nn.Module): + """ + Base class for bottlenecks that implements `forward()` method. + """ + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out = self.se_module(out) + residual + out = self.relu(out) + + return out + + +class SEBottleneck(Bottleneck): + """ + Bottleneck for SENet154. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes * 2) + self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3, + stride=stride, padding=1, groups=groups, + bias=False) + self.bn2 = nn.BatchNorm2d(planes * 4) + self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNetBottleneck(Bottleneck): + """ + ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe + implementation and uses `stride=stride` in `conv1` and not in `conv2` + (the latter is used in the torchvision implementation of ResNet). + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None): + super(SEResNetBottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False, + stride=stride) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, + groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SEResNeXtBottleneck(Bottleneck): + """ + ResNeXt bottleneck type C with a Squeeze-and-Excitation module. + """ + expansion = 4 + + def __init__(self, inplanes, planes, groups, reduction, stride=1, + downsample=None, base_width=4): + super(SEResNeXtBottleneck, self).__init__() + width = math.floor(planes * (base_width / 64)) * groups + self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False, + stride=1) + self.bn1 = nn.BatchNorm2d(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, groups=groups, bias=False) + self.bn2 = nn.BatchNorm2d(width) + self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.se_module = SEModule(planes * 4, reduction=reduction) + self.downsample = downsample + self.stride = stride + + +class SENet(nn.Module): + + def __init__(self, block, layers, groups, reduction, dropout_p=0.2, + inplanes=128, input_3x3=True, downsample_kernel_size=3, + downsample_padding=1, num_classes=1000): + """ + Parameters + ---------- + block (nn.Module): Bottleneck class. + - For SENet154: SEBottleneck + - For SE-ResNet models: SEResNetBottleneck + - For SE-ResNeXt models: SEResNeXtBottleneck + layers (list of ints): Number of residual blocks for 4 layers of the + network (layer1...layer4). + groups (int): Number of groups for the 3x3 convolution in each + bottleneck block. + - For SENet154: 64 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 32 + reduction (int): Reduction ratio for Squeeze-and-Excitation modules. + - For all models: 16 + dropout_p (float or None): Drop probability for the Dropout layer. + If `None` the Dropout layer is not used. + - For SENet154: 0.2 + - For SE-ResNet models: None + - For SE-ResNeXt models: None + inplanes (int): Number of input channels for layer1. + - For SENet154: 128 + - For SE-ResNet models: 64 + - For SE-ResNeXt models: 64 + input_3x3 (bool): If `True`, use three 3x3 convolutions instead of + a single 7x7 convolution in layer0. + - For SENet154: True + - For SE-ResNet models: False + - For SE-ResNeXt models: False + downsample_kernel_size (int): Kernel size for downsampling convolutions + in layer2, layer3 and layer4. + - For SENet154: 3 + - For SE-ResNet models: 1 + - For SE-ResNeXt models: 1 + downsample_padding (int): Padding for downsampling convolutions in + layer2, layer3 and layer4. + - For SENet154: 1 + - For SE-ResNet models: 0 + - For SE-ResNeXt models: 0 + num_classes (int): Number of outputs in `last_linear` layer. + - For all models: 1000 + """ + super(SENet, self).__init__() + self.inplanes = inplanes + if input_3x3: + layer0_modules = [ + ('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1, + bias=False)), + ('bn1', nn.BatchNorm2d(64)), + ('relu1', nn.ReLU(inplace=True)), + ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, + bias=False)), + ('bn2', nn.BatchNorm2d(64)), + ('relu2', nn.ReLU(inplace=True)), + ('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1, + bias=False)), + ('bn3', nn.BatchNorm2d(inplanes)), + ('relu3', nn.ReLU(inplace=True)), + ] + else: + layer0_modules = [ + ('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2, + padding=3, bias=False)), + ('bn1', nn.BatchNorm2d(inplanes)), + ('relu1', nn.ReLU(inplace=True)), + ] + # To preserve compatibility with Caffe weights `ceil_mode=True` + # is used instead of `padding=1`. + layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2, + ceil_mode=True))) + self.layer0 = nn.Sequential(OrderedDict(layer0_modules)) + self.layer1 = self._make_layer( + block, + planes=64, + blocks=layers[0], + groups=groups, + reduction=reduction, + downsample_kernel_size=1, + downsample_padding=0 + ) + self.layer2 = self._make_layer( + block, + planes=128, + blocks=layers[1], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer3 = self._make_layer( + block, + planes=256, + blocks=layers[2], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.layer4 = self._make_layer( + block, + planes=512, + blocks=layers[3], + stride=2, + groups=groups, + reduction=reduction, + downsample_kernel_size=downsample_kernel_size, + downsample_padding=downsample_padding + ) + self.avg_pool = nn.AvgPool2d(7, stride=1) + self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None + self.last_linear = nn.Linear(512 * block.expansion, num_classes) + + def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, + downsample_kernel_size=1, downsample_padding=0): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=downsample_kernel_size, stride=stride, + padding=downsample_padding, bias=False), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, groups, reduction, stride, + downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, groups, reduction)) + + return nn.Sequential(*layers) + + def features(self, x): + x = self.layer0(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + return x + + def logits(self, x): + x = self.avg_pool(x) + if self.dropout is not None: + x = self.dropout(x) + x = x.view(x.size(0), -1) + x = self.last_linear(x) + return x + + def forward(self, x): + x = self.features(x) + x = self.logits(x) + return x + + +def _create_se_resnet(arch, pretrained=False, progress=True, **model_kwargs): + model = SENet(**model_kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +@ModelCreator.register_model +def senet154(pretrained=False, progress=True, **kwargs): + """ + Constructs the SENet-154 model trained on ImageNet2012. + + .. note:: + seneSENet-154t154 model from `Squeeze-and-Excitation Networks `_. + The required input size of the model is 224x224. + + Args: + pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` + progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True`` + + For example: + + .. code-block:: python + + >>> import flowvision + >>> senet154 = flowvision.models.senet154(pretrained=False, progress=True) + + """ + model_kwargs = dict(block=SEBottleneck, layers=[3, 8, 36, 3], groups=64, + reduction=16, dropout_p=0.2, num_classes=1000, **kwargs) + return _create_se_resnet( + "senet154", pretrained=pretrained, progress=progress, **model_kwargs + ) + + +@ModelCreator.register_model +def se_resnet50(pretrained=False, progress=True, **kwargs): + """ + Constructs the SE-ResNet50 model trained on ImageNet2012. + + .. note:: + SE-ResNet50 model from `Squeeze-and-Excitation Networks `_. + The required input size of the model is 224x224. + + Args: + pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` + progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True`` + + For example: + + .. code-block:: python + + >>> import flowvision + >>> se_resnet50 = flowvision.models.se_resnet50(pretrained=False, progress=True) + + """ + model_kwargs = dict(block=SEResNetBottleneck, layers=[3, 4, 6, 3], groups=1, + reduction=16, dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, num_classes=1000, **kwargs) + return _create_se_resnet( + "se_resnet50", pretrained=pretrained, progress=progress, **model_kwargs + ) + + +@ModelCreator.register_model +def se_resnet101(pretrained=False, progress=True, **kwargs): + """ + Constructs the SE-ResNet101 model trained on ImageNet2012. + + .. note:: + SE-ResNet101 model from `Squeeze-and-Excitation Networks `_. + The required input size of the model is 224x224. + + Args: + pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` + progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True`` + + For example: + + .. code-block:: python + + >>> import flowvision + >>> se_resnet101 = flowvision.models.se_resnet101(pretrained=False, progress=True) + + """ + model_kwargs = dict(block=SEResNetBottleneck, layers=[3, 4, 23, 3], groups=1, + reduction=16, dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, num_classes=1000, **kwargs) + return _create_se_resnet( + "se_resnet101", pretrained=pretrained, progress=progress, **model_kwargs + ) + + +@ModelCreator.register_model +def se_resnet152(pretrained=False, progress=True, **kwargs): + """ + Constructs the SE-ResNet152 model trained on ImageNet2012. + + .. note:: + SE-ResNet152 model `Squeeze-and-Excitation Networks `_. + The required input size of the model is 224x224. + + Args: + pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` + progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True`` + + For example: + + .. code-block:: python + + >>> import flowvision + >>> se_resnet152 = flowvision.models.se_resnet152(pretrained=False, progress=True) + + """ + model_kwargs = dict(block=SEResNetBottleneck, layers=[3, 8, 36, 3], groups=1, + reduction=16, dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, num_classes=1000, **kwargs) + return _create_se_resnet( + "se_resnet152", pretrained=pretrained, progress=progress, **model_kwargs + ) + + +@ModelCreator.register_model +def se_resnext50_32x4d(pretrained=False, progress=True, **kwargs): + """ + Constructs the SE-ResNeXt50-32x4d model trained on ImageNet2012. + + .. note:: + SE-ResNeXt50-32x4d model from `Squeeze-and-Excitation Networks `_. + The required input size of the model is 224x224. + + Args: + pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` + progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True`` + + For example: + + .. code-block:: python + + >>> import flowvision + >>> se_resnext50_32x4d = flowvision.models.se_resnext50_32x4d(pretrained=False, progress=True) + + """ + model_kwargs = dict(block=SEResNeXtBottleneck, layers=[3, 4, 6, 3], groups=32, + reduction=16, dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, num_classes=1000, **kwargs) + return _create_se_resnet( + "se_resnext50_32x4d", pretrained=pretrained, progress=progress, **model_kwargs + ) + + +@ModelCreator.register_model +def se_resnext101_32x4d(pretrained=False, progress=True, **kwargs): + """ + Constructs the SE-ResNeXt101-32x4d model trained on ImageNet2012. + + .. note:: + SE-ResNeXt101-32x4d model from `Squeeze-and-Excitation Networks `_. + The required input size of the model is 224x224. + + Args: + pretrained (bool): Whether to download the pre-trained model on ImageNet. Default: ``False`` + progress (bool): If True, displays a progress bar of the download to stderrt. Default: ``True`` + + For example: + + .. code-block:: python + + >>> import flowvision + >>> se_resnext101_32x4d = flowvision.models.se_resnext101_32x4d(pretrained=False, progress=True) + + """ + model_kwargs = dict(block=SEResNeXtBottleneck, layers=[3, 4, 23, 3], groups=32, + reduction=16, dropout_p=None, inplanes=64, input_3x3=False, + downsample_kernel_size=1, downsample_padding=0, num_classes=1000, **kwargs) + return _create_se_resnet( + "se_resnext101_32x4d", pretrained=pretrained, progress=progress, **model_kwargs + ) diff --git a/results/results_imagenet.md b/results/results_imagenet.md index 234f7931..9a329356 100644 --- a/results/results_imagenet.md +++ b/results/results_imagenet.md @@ -33,6 +33,12 @@ | resnext101_32x8d | 79.308 | 20.692 | 94.530 | 5.470 | 25.0M | 224 | 0.875 | bilinear | | wide_resnet50_2 | 78.480 | 21.520 | 94.084 | 5.916 | 68.9M | 224 | 0.875 | bilinear | | wide_resnet101_2 | 78.842 | 21.158 | 94.280 | 5.780 | 126.9M | 224 | 0.875 | bilinear | +| senet154 | 81.324 | 18.676 | 95.508 | 4.492 | 115.1M | 224 | 0.875 | bilinear | +| se_resent50 | 77.642 | 22.358 | 93.748 | 6.252 | 28.1M | 224 | 0.875 | bilinear | +| se_resnet101 | 78.390 | 21.610 | 94.254 | 5.746 | 49.3M | 224 | 0.875 | bilinear | +| se_resnet152 | 78.666 | 21.334 | 94.380 | 5.620 | 66.8M | 224 | 0.875 | bilinear | +| se_resnext50_32x4d | 79.080 | 20.920 | 94.432 | 5.568 | 27.6M | 224 | 0.875 | bilinear | +| se_resnext101_32x4d | 80.236 | 19.764 | 95.034 | 4.966 | 49.0M | 224 | 0.875 | bilinear | | res2net50_14w_8s | 78.152 | 21.848 | 93.842 | 6.158 | 25.1M | 224 | 0.875 | bilinear | | res2net50_26w_4s | 77.946 | 22.054 | 93.852 | 6.148 | 25.7M | 224 | 0.875 | bilinear | | res2net50_26w_6s | 78.574 | 21.426 | 94.126 | 5.874 | 37.1M | 224 | 0.875 | bilinear |