Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
Add Squeeze and Excitation to ResNeXt models (#426)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #426

Added a `SqueezeAndExcitation` layer to a new sub-package, `models.common` (open to other suggestions, I didn't want to have a `generic.py` or `util.py` as that is too vague and broad).

Plugged in the layer to `ResNeXt` models.

Differential Revision: D20283172

fbshipit-source-id: 21d5183a61d7aa13fca094afe95ecb0aa18f1632
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Mar 10, 2020
1 parent 636740b commit ea6a56b
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 59 deletions.
2 changes: 2 additions & 0 deletions classy_vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def build_model(config):
from .resnet import ResNet # isort:skip
from .resnext import ResNeXt # isort:skip
from .resnext3d import ResNeXt3D # isort:skip
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # isort:skip


__all__ = [
Expand All @@ -107,4 +108,5 @@ def build_model(config):
"ResNet",
"ResNeXt",
"ResNeXt3D",
"SqueezeAndExcitationLayer",
]
147 changes: 89 additions & 58 deletions classy_vision/models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Implementation of ResNeXt (https://arxiv.org/pdf/1611.05431.pdf)
"""

import copy
import math
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -16,6 +17,7 @@

from . import register_model
from .classy_model import ClassyModel
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer


# global setting for in-place ReLU:
Expand Down Expand Up @@ -55,6 +57,8 @@ def __init__(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# assertions on inputs:
Expand All @@ -79,6 +83,12 @@ def __init__(
nn.BatchNorm2d(out_planes),
)

self.se = (
SqueezeAndExcitationLayer(out_planes, reduction_ratio=se_reduction_ratio)
if use_se
else None
)

def forward(self, x):

# if required, perform downsampling along shortcut connection:
Expand All @@ -92,6 +102,10 @@ def forward(self, x):

if self.final_bn_relu:
out = self.bn(out)

if self.se is not None:
out = self.se(out)

# add residual connection, perform rely + batchnorm, and return result:
out += residual
if self.final_bn_relu:
Expand All @@ -101,7 +115,7 @@ def forward(self, x):

class BasicLayer(GenericLayer):
"""
ResNeXt bottleneck layer with `in_planes` input planes and `out_planes`
ResNeXt layer with `in_planes` input planes and `out_planes`
output planes.
"""

Expand All @@ -113,6 +127,8 @@ def __init__(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# assertions on inputs:
Expand All @@ -128,13 +144,15 @@ def __init__(
)

# call constructor of generic layer:
super(BasicLayer, self).__init__(
super().__init__(
convolutional_block,
in_planes,
out_planes,
stride=stride,
reduction=reduction,
final_bn_relu=final_bn_relu,
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
)


Expand All @@ -152,6 +170,8 @@ def __init__(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# assertions on inputs:
Expand Down Expand Up @@ -185,6 +205,8 @@ def __init__(
stride=stride,
reduction=reduction,
final_bn_relu=final_bn_relu,
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
)


Expand Down Expand Up @@ -236,14 +258,20 @@ def __init__(
basic_layer: bool = False,
final_bn_relu: bool = True,
bn_weight_decay: Optional[bool] = False,
use_se: bool = False,
se_reduction_ratio: int = 16,
):
"""
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
Set ``small_input`` to `True` for 32x32 sized image inputs.
Set ``final_bn_relu`` to `False` to exclude the final batchnorm and
ReLU layers. These settings are useful when training Siamese networks.
Args:
small_input: set to `True` for 32x32 sized image inputs.
final_bn_relu: set to `False` to exclude the final batchnorm and
ReLU layers. These settings are useful when training Siamese
networks.
use_se: Enable squeeze and excitation
se_reduction_ratio: The reduction ratio to apply in the excitation
stage. Only used if `use_se` is `True`.
"""
super().__init__()

Expand All @@ -263,6 +291,7 @@ def __init__(
and is_pos_int(base_width_and_cardinality[0])
and is_pos_int(base_width_and_cardinality[1])
)
assert isinstance(use_se, bool), "use_se has to be a boolean"

# Chooses whether to apply weight decay to batch norm
# parameters. This improves results in some situations,
Expand Down Expand Up @@ -295,6 +324,8 @@ def __init__(
mid_planes_and_cardinality=mid_planes_and_cardinality,
reduction=reduction,
final_bn_relu=final_bn_relu or (idx != (len(out_planes) - 1)),
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
)
blocks.append(nn.Sequential(*new_block))
self.blocks = nn.Sequential(*blocks)
Expand Down Expand Up @@ -337,6 +368,8 @@ def _make_resolution_block(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# add the desired number of residual blocks:
Expand All @@ -352,6 +385,8 @@ def _make_resolution_block(
mid_planes_and_cardinality=mid_planes_and_cardinality,
reduction=reduction,
final_bn_relu=final_bn_relu or (idx != (num_blocks - 1)),
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
),
)
)
Expand Down Expand Up @@ -379,6 +414,8 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
"final_bn_relu": config.get("final_bn_relu", True),
"zero_init_bn_residuals": config.get("zero_init_bn_residuals", False),
"bn_weight_decay": config.get("bn_weight_decay", False),
"use_se": config.get("use_se", False),
"se_reduction_ratio": config.get("se_reduction_ratio", 16),
}
return cls(**config)

Expand Down Expand Up @@ -421,65 +458,68 @@ def model_depth(self):
return sum(self.num_blocks)


class _ResNeXt(ResNeXt):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
config = copy.deepcopy(config)
config.pop("name")
return cls(**config)


@register_model("resnet18")
class ResNet18(ResNeXt):
def __init__(self):
class ResNet18(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True
num_blocks=[2, 2, 2, 2],
basic_layer=True,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet34")
class ResNet34(ResNeXt):
def __init__(self):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True
num_blocks=[3, 4, 6, 3],
basic_layer=True,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet50")
class ResNet50(ResNeXt):
def __init__(self):
class ResNet50(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True
num_blocks=[3, 4, 6, 3],
basic_layer=False,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet101")
class ResNet101(ResNeXt):
def __init__(self):
class ResNet101(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True
num_blocks=[3, 4, 23, 3],
basic_layer=False,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet152")
class ResNet152(ResNeXt):
def __init__(self):
class ResNet152(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True
num_blocks=[3, 8, 36, 3],
basic_layer=False,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


# Note, the ResNeXt models all have weight decay enabled for the batch
# norm parameters. We have found empirically that this gives better
Expand All @@ -488,48 +528,39 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
# training on other datasets, we have observed losses in accuracy (for
# example, the dataset used in https://arxiv.org/abs/1805.00932).
@register_model("resnext50_32x4d")
class ResNeXt50(ResNeXt):
def __init__(self):
class ResNeXt50(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 6, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext101_32x4d")
class ResNeXt101(ResNeXt):
def __init__(self):
class ResNeXt101(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 23, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext152_32x4d")
class ResNeXt152(ResNeXt):
def __init__(self):
class ResNeXt152(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 8, 36, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()
29 changes: 29 additions & 0 deletions classy_vision/models/squeeze_and_excitation_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch.nn as nn


class SqueezeAndExcitationLayer(nn.Module):
"""Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf"""

def __init__(self, in_planes, reduction_ratio=16):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
reduced_planes = in_planes // reduction_ratio
self.excitation = nn.Sequential(
nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True),
nn.ReLU(),
nn.Conv2d(reduced_planes, in_planes, kernel_size=1, stride=1, bias=True),
nn.Sigmoid(),
)

def forward(self, x):
x_squeezed = self.avgpool(x)
x_excited = self.excitation(x_squeezed)
x_scaled = x * x_excited
return x_scaled
Loading

0 comments on commit ea6a56b

Please sign in to comment.