Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Add Squeeze and Excitation to ResNeXt models #426

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions classy_vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def build_model(config):
from .resnet import ResNet # isort:skip
from .resnext import ResNeXt # isort:skip
from .resnext3d import ResNeXt3D # isort:skip
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # isort:skip


__all__ = [
Expand All @@ -107,4 +108,5 @@ def build_model(config):
"ResNet",
"ResNeXt",
"ResNeXt3D",
"SqueezeAndExcitationLayer",
]
147 changes: 89 additions & 58 deletions classy_vision/models/resnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Implementation of ResNeXt (https://arxiv.org/pdf/1611.05431.pdf)
"""

import copy
import math
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -16,6 +17,7 @@

from . import register_model
from .classy_model import ClassyModel
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer


# global setting for in-place ReLU:
Expand Down Expand Up @@ -55,6 +57,8 @@ def __init__(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# assertions on inputs:
Expand All @@ -79,6 +83,12 @@ def __init__(
nn.BatchNorm2d(out_planes),
)

self.se = (
SqueezeAndExcitationLayer(out_planes, reduction_ratio=se_reduction_ratio)
if use_se
else None
)

def forward(self, x):

# if required, perform downsampling along shortcut connection:
Expand All @@ -92,6 +102,10 @@ def forward(self, x):

if self.final_bn_relu:
out = self.bn(out)

if self.se is not None:
out = self.se(out)

# add residual connection, perform rely + batchnorm, and return result:
out += residual
if self.final_bn_relu:
Expand All @@ -101,7 +115,7 @@ def forward(self, x):

class BasicLayer(GenericLayer):
"""
ResNeXt bottleneck layer with `in_planes` input planes and `out_planes`
ResNeXt layer with `in_planes` input planes and `out_planes`
output planes.
"""

Expand All @@ -113,6 +127,8 @@ def __init__(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# assertions on inputs:
Expand All @@ -128,13 +144,15 @@ def __init__(
)

# call constructor of generic layer:
super(BasicLayer, self).__init__(
super().__init__(
convolutional_block,
in_planes,
out_planes,
stride=stride,
reduction=reduction,
final_bn_relu=final_bn_relu,
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
)


Expand All @@ -152,6 +170,8 @@ def __init__(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# assertions on inputs:
Expand Down Expand Up @@ -185,6 +205,8 @@ def __init__(
stride=stride,
reduction=reduction,
final_bn_relu=final_bn_relu,
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
)


Expand Down Expand Up @@ -236,14 +258,20 @@ def __init__(
basic_layer: bool = False,
final_bn_relu: bool = True,
bn_weight_decay: Optional[bool] = False,
use_se: bool = False,
se_reduction_ratio: int = 16,
):
"""
Implementation of `ResNeXt <https://arxiv.org/pdf/1611.05431.pdf>`_.
Set ``small_input`` to `True` for 32x32 sized image inputs.
Set ``final_bn_relu`` to `False` to exclude the final batchnorm and
ReLU layers. These settings are useful when training Siamese networks.
Args:
small_input: set to `True` for 32x32 sized image inputs.
final_bn_relu: set to `False` to exclude the final batchnorm and
ReLU layers. These settings are useful when training Siamese
networks.
use_se: Enable squeeze and excitation
se_reduction_ratio: The reduction ratio to apply in the excitation
stage. Only used if `use_se` is `True`.
"""
super().__init__()

Expand All @@ -263,6 +291,7 @@ def __init__(
and is_pos_int(base_width_and_cardinality[0])
and is_pos_int(base_width_and_cardinality[1])
)
assert isinstance(use_se, bool), "use_se has to be a boolean"

# Chooses whether to apply weight decay to batch norm
# parameters. This improves results in some situations,
Expand Down Expand Up @@ -295,6 +324,8 @@ def __init__(
mid_planes_and_cardinality=mid_planes_and_cardinality,
reduction=reduction,
final_bn_relu=final_bn_relu or (idx != (len(out_planes) - 1)),
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
)
blocks.append(nn.Sequential(*new_block))
self.blocks = nn.Sequential(*blocks)
Expand Down Expand Up @@ -337,6 +368,8 @@ def _make_resolution_block(
mid_planes_and_cardinality=None,
reduction=4,
final_bn_relu=True,
use_se=False,
se_reduction_ratio=16,
):

# add the desired number of residual blocks:
Expand All @@ -352,6 +385,8 @@ def _make_resolution_block(
mid_planes_and_cardinality=mid_planes_and_cardinality,
reduction=reduction,
final_bn_relu=final_bn_relu or (idx != (num_blocks - 1)),
use_se=use_se,
se_reduction_ratio=se_reduction_ratio,
),
)
)
Expand Down Expand Up @@ -379,6 +414,8 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
"final_bn_relu": config.get("final_bn_relu", True),
"zero_init_bn_residuals": config.get("zero_init_bn_residuals", False),
"bn_weight_decay": config.get("bn_weight_decay", False),
"use_se": config.get("use_se", False),
"se_reduction_ratio": config.get("se_reduction_ratio", 16),
}
return cls(**config)

Expand Down Expand Up @@ -421,65 +458,68 @@ def model_depth(self):
return sum(self.num_blocks)


class _ResNeXt(ResNeXt):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
config = copy.deepcopy(config)
config.pop("name")
return cls(**config)


@register_model("resnet18")
class ResNet18(ResNeXt):
def __init__(self):
class ResNet18(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[2, 2, 2, 2], basic_layer=True, zero_init_bn_residuals=True
num_blocks=[2, 2, 2, 2],
basic_layer=True,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet34")
class ResNet34(ResNeXt):
def __init__(self):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=True, zero_init_bn_residuals=True
num_blocks=[3, 4, 6, 3],
basic_layer=True,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet50")
class ResNet50(ResNeXt):
def __init__(self):
class ResNet50(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 6, 3], basic_layer=False, zero_init_bn_residuals=True
num_blocks=[3, 4, 6, 3],
basic_layer=False,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet101")
class ResNet101(ResNeXt):
def __init__(self):
class ResNet101(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 23, 3], basic_layer=False, zero_init_bn_residuals=True
num_blocks=[3, 4, 23, 3],
basic_layer=False,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnet152")
class ResNet152(ResNeXt):
def __init__(self):
class ResNet152(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 8, 36, 3], basic_layer=False, zero_init_bn_residuals=True
num_blocks=[3, 8, 36, 3],
basic_layer=False,
zero_init_bn_residuals=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


# Note, the ResNeXt models all have weight decay enabled for the batch
# norm parameters. We have found empirically that this gives better
Expand All @@ -488,48 +528,39 @@ def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
# training on other datasets, we have observed losses in accuracy (for
# example, the dataset used in https://arxiv.org/abs/1805.00932).
@register_model("resnext50_32x4d")
class ResNeXt50(ResNeXt):
def __init__(self):
class ResNeXt50(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 6, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext101_32x4d")
class ResNeXt101(ResNeXt):
def __init__(self):
class ResNeXt101(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 4, 23, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()


@register_model("resnext152_32x4d")
class ResNeXt152(ResNeXt):
def __init__(self):
class ResNeXt152(_ResNeXt):
def __init__(self, **kwargs):
super().__init__(
num_blocks=[3, 8, 36, 3],
basic_layer=False,
zero_init_bn_residuals=True,
base_width_and_cardinality=(4, 32),
bn_weight_decay=True,
**kwargs,
)

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "ResNeXt":
return cls()
29 changes: 29 additions & 0 deletions classy_vision/models/squeeze_and_excitation_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch.nn as nn


class SqueezeAndExcitationLayer(nn.Module):
"""Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf"""

def __init__(self, in_planes, reduction_ratio=16):
super().__init__()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
reduced_planes = in_planes // reduction_ratio
self.excitation = nn.Sequential(
nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True),
nn.ReLU(),
nn.Conv2d(reduced_planes, in_planes, kernel_size=1, stride=1, bias=True),
nn.Sigmoid(),
)

def forward(self, x):
x_squeezed = self.avgpool(x)
x_excited = self.excitation(x_squeezed)
x_scaled = x * x_excited
return x_scaled
Loading