diff --git a/classy_vision/models/__init__.py b/classy_vision/models/__init__.py index a0137e970d..6ac82d94e2 100644 --- a/classy_vision/models/__init__.py +++ b/classy_vision/models/__init__.py @@ -93,6 +93,7 @@ def build_model(config): from .resnet import ResNet # isort:skip from .resnext import ResNeXt # isort:skip from .resnext3d import ResNeXt3D # isort:skip +from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # isort:skip __all__ = [ @@ -107,4 +108,5 @@ def build_model(config): "ResNet", "ResNeXt", "ResNeXt3D", + "SqueezeAndExcitationLayer", ] diff --git a/classy_vision/models/resnext.py b/classy_vision/models/resnext.py index 6bd525e34b..74e9ef7a32 100644 --- a/classy_vision/models/resnext.py +++ b/classy_vision/models/resnext.py @@ -8,6 +8,7 @@ Implementation of ResNeXt (https://arxiv.org/pdf/1611.05431.pdf) """ +import copy import math from typing import Any, Dict, List, Optional, Tuple, Union @@ -16,6 +17,7 @@ from . import register_model from .classy_model import ClassyModel +from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # global setting for in-place ReLU: @@ -55,6 +57,8 @@ def __init__( mid_planes_and_cardinality=None, reduction=4, final_bn_relu=True, + use_se=False, + se_reduction_ratio=16, ): # assertions on inputs: @@ -79,6 +83,12 @@ def __init__( nn.BatchNorm2d(out_planes), ) + self.se = ( + SqueezeAndExcitationLayer(out_planes, reduction_ratio=se_reduction_ratio) + if use_se + else None + ) + def forward(self, x): # if required, perform downsampling along shortcut connection: @@ -92,6 +102,10 @@ def forward(self, x): if self.final_bn_relu: out = self.bn(out) + + if self.se is not None: + out = self.se(out) + # add residual connection, perform rely + batchnorm, and return result: out += residual if self.final_bn_relu: @@ -101,7 +115,7 @@ def forward(self, x): class BasicLayer(GenericLayer): """ - ResNeXt bottleneck layer with `in_planes` input planes and `out_planes` + ResNeXt layer with `in_planes` input planes and `out_planes` output planes. """ @@ -113,6 +127,8 @@ def __init__( mid_planes_and_cardinality=None, reduction=4, final_bn_relu=True, + use_se=False, + se_reduction_ratio=16, ): # assertions on inputs: @@ -128,13 +144,15 @@ def __init__( ) # call constructor of generic layer: - super(BasicLayer, self).__init__( + super().__init__( convolutional_block, in_planes, out_planes, stride=stride, reduction=reduction, final_bn_relu=final_bn_relu, + use_se=use_se, + se_reduction_ratio=se_reduction_ratio, ) @@ -152,6 +170,8 @@ def __init__( mid_planes_and_cardinality=None, reduction=4, final_bn_relu=True, + use_se=False, + se_reduction_ratio=16, ): # assertions on inputs: @@ -185,6 +205,8 @@ def __init__( stride=stride, reduction=reduction, final_bn_relu=final_bn_relu, + use_se=use_se, + se_reduction_ratio=se_reduction_ratio, ) @@ -236,14 +258,20 @@ def __init__( basic_layer: bool = False, final_bn_relu: bool = True, bn_weight_decay: Optional[bool] = False, + use_se: bool = False, + se_reduction_ratio: int = 16, ): """ Implementation of `ResNeXt `_. - Set ``small_input`` to `True` for 32x32 sized image inputs. - - Set ``final_bn_relu`` to `False` to exclude the final batchnorm and - ReLU layers. These settings are useful when training Siamese networks. + Args: + small_input: set to `True` for 32x32 sized image inputs. + final_bn_relu: set to `False` to exclude the final batchnorm and + ReLU layers. These settings are useful when training Siamese + networks. + use_se: Enable squeeze and excitation + se_reduction_ratio: The reduction ratio to apply in the excitation + stage. Only used if `use_se` is `True`. """ super().__init__() @@ -263,6 +291,7 @@ def __init__( and is_pos_int(base_width_and_cardinality[0]) and is_pos_int(base_width_and_cardinality[1]) ) + assert isinstance(use_se, bool), "use_se has to be a boolean" # Chooses whether to apply weight decay to batch norm # parameters. This improves results in some situations, @@ -295,6 +324,8 @@ def __init__( mid_planes_and_cardinality=mid_planes_and_cardinality, reduction=reduction, final_bn_relu=final_bn_relu or (idx != (len(out_planes) - 1)), + use_se=use_se, + se_reduction_ratio=se_reduction_ratio, ) blocks.append(nn.Sequential(*new_block)) self.blocks = nn.Sequential(*blocks) @@ -337,6 +368,8 @@ def _make_resolution_block( mid_planes_and_cardinality=None, reduction=4, final_bn_relu=True, + use_se=False, + se_reduction_ratio=16, ): # add the desired number of residual blocks: @@ -352,6 +385,8 @@ def _make_resolution_block( mid_planes_and_cardinality=mid_planes_and_cardinality, reduction=reduction, final_bn_relu=final_bn_relu or (idx != (num_blocks - 1)), + use_se=use_se, + se_reduction_ratio=se_reduction_ratio, ), ) ) @@ -379,6 +414,8 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": "final_bn_relu": config.get("final_bn_relu", True), "zero_init_bn_residuals": config.get("zero_init_bn_residuals", False), "bn_weight_decay": config.get("bn_weight_decay", False), + "use_se": config.get("use_se", False), + "se_reduction_ratio": config.get("se_reduction_ratio", 16), } return cls(**config) @@ -421,65 +458,68 @@ def model_depth(self): return sum(self.num_blocks) +class _ResNeXt(ResNeXt): + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": + config = copy.deepcopy(config) + config.pop("name") + return cls(**config) + + @register_model("resnet18") -class ResNet18(ResNeXt): - def __init__(self): +class ResNet18(_ResNeXt): + def __init__(self, **kwargs): super().__init__( - num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True + num_blocks=[2, 2, 2, 2], + basic_layer=True, + zero_init_bn_residuals=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - @register_model("resnet34") class ResNet34(ResNeXt): - def __init__(self): + def __init__(self, **kwargs): super().__init__( - num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True + num_blocks=[3, 4, 6, 3], + basic_layer=True, + zero_init_bn_residuals=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - @register_model("resnet50") -class ResNet50(ResNeXt): - def __init__(self): +class ResNet50(_ResNeXt): + def __init__(self, **kwargs): super().__init__( - num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True + num_blocks=[3, 4, 6, 3], + basic_layer=False, + zero_init_bn_residuals=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - @register_model("resnet101") -class ResNet101(ResNeXt): - def __init__(self): +class ResNet101(_ResNeXt): + def __init__(self, **kwargs): super().__init__( - num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True + num_blocks=[3, 4, 23, 3], + basic_layer=False, + zero_init_bn_residuals=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - @register_model("resnet152") -class ResNet152(ResNeXt): - def __init__(self): +class ResNet152(_ResNeXt): + def __init__(self, **kwargs): super().__init__( - num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True + num_blocks=[3, 8, 36, 3], + basic_layer=False, + zero_init_bn_residuals=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - # Note, the ResNeXt models all have weight decay enabled for the batch # norm parameters. We have found empirically that this gives better @@ -488,48 +528,39 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": # training on other datasets, we have observed losses in accuracy (for # example, the dataset used in https://arxiv.org/abs/1805.00932). @register_model("resnext50_32x4d") -class ResNeXt50(ResNeXt): - def __init__(self): +class ResNeXt50(_ResNeXt): + def __init__(self, **kwargs): super().__init__( num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True, base_width_and_cardinality=(4, 32), bn_weight_decay=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - @register_model("resnext101_32x4d") -class ResNeXt101(ResNeXt): - def __init__(self): +class ResNeXt101(_ResNeXt): + def __init__(self, **kwargs): super().__init__( num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True, base_width_and_cardinality=(4, 32), bn_weight_decay=True, + **kwargs, ) - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() - @register_model("resnext152_32x4d") -class ResNeXt152(ResNeXt): - def __init__(self): +class ResNeXt152(_ResNeXt): + def __init__(self, **kwargs): super().__init__( num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True, base_width_and_cardinality=(4, 32), bn_weight_decay=True, + **kwargs, ) - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "ResNeXt": - return cls() diff --git a/classy_vision/models/squeeze_and_excitation_layer.py b/classy_vision/models/squeeze_and_excitation_layer.py new file mode 100644 index 0000000000..0bff96ace8 --- /dev/null +++ b/classy_vision/models/squeeze_and_excitation_layer.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch.nn as nn + + +class SqueezeAndExcitationLayer(nn.Module): + """Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf""" + + def __init__(self, in_planes, reduction_ratio=16): + super().__init__() + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + reduced_planes = in_planes // reduction_ratio + self.excitation = nn.Sequential( + nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True), + nn.ReLU(), + nn.Conv2d(reduced_planes, in_planes, kernel_size=1, stride=1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, x): + x_squeezed = self.avgpool(x) + x_excited = self.excitation(x_squeezed) + x_scaled = x * x_excited + return x_scaled diff --git a/test/models_resnext_test.py b/test/models_resnext_test.py index c246e977d0..c2c5151b94 100644 --- a/test/models_resnext_test.py +++ b/test/models_resnext_test.py @@ -8,7 +8,7 @@ from test.generic.utils import compare_model_state import torch -from classy_vision.models import build_model +from classy_vision.models import ResNeXt, build_model MODELS = { @@ -51,6 +51,26 @@ } ], }, + "small_resnet_se": { + "name": "resnet", + "num_blocks": [1, 1, 1, 1], + "init_planes": 4, + "reduction": 4, + "small_input": True, + "zero_init_bn_residuals": True, + "basic_layer": True, + "final_bn_relu": True, + "use_se": True, + "heads": [ + { + "name": "fully_connected", + "unique_id": "default_head", + "num_classes": 1000, + "fork_block": "block3-0", + "in_plane": 128, + } + ], + }, } @@ -78,8 +98,17 @@ def _test_model(self, model_config): compare_model_state(self, state, new_state, check_heads=True) + def test_build_preset_model(self): + configs = [{"name": "resnet18"}, {"name": "resnet18", "use_se": True}] + for config in configs: + model = build_model(config) + self.assertIsInstance(model, ResNeXt) + def test_small_resnext(self): self._test_model(MODELS["small_resnext"]) def test_small_resnet(self): self._test_model(MODELS["small_resnet"]) + + def test_small_resnet_se(self): + self._test_model(MODELS["small_resnet_se"])