From 79f6c6dbe3b78f1e17820c8c097825f4461d168d Mon Sep 17 00:00:00 2001 From: Joao Gomes Date: Fri, 18 Feb 2022 14:48:56 +0000 Subject: [PATCH] Rename ConvNormActivation to Conv2dNormActivation --- docs/source/ops.rst | 2 +- torchvision/models/convnext.py | 4 +-- torchvision/models/detection/ssdlite.py | 10 +++---- torchvision/models/efficientnet.py | 12 ++++---- torchvision/models/mobilenetv2.py | 14 ++++----- torchvision/models/mobilenetv3.py | 12 ++++---- torchvision/models/optical_flow/raft.py | 30 +++++++++---------- .../models/quantization/mobilenetv2.py | 4 +-- .../models/quantization/mobilenetv3.py | 4 +-- torchvision/models/regnet.py | 12 ++++---- torchvision/models/vision_transformer.py | 4 +-- torchvision/ops/__init__.py | 3 +- torchvision/ops/misc.py | 13 ++++++-- 13 files changed, 67 insertions(+), 57 deletions(-) diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 2a960474205..5b86e3443df 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -45,5 +45,5 @@ Operators FeaturePyramidNetwork StochasticDepth FrozenBatchNorm2d - ConvNormActivation + Conv2dNormActivation SqueezeExcitation diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 9067b6876fd..3a0dcdb31cd 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -6,7 +6,7 @@ from torch.nn import functional as F from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..ops.stochastic_depth import StochasticDepth from ..utils import _log_api_usage_once @@ -127,7 +127,7 @@ def __init__( # Stem firstconv_output_channels = block_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=4, diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 47ecf59f1e2..1ee59e069ea 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -7,7 +7,7 @@ from torch import nn, Tensor from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from ...ops.misc import Conv2dNormActivation from ...utils import _log_api_usage_once from .. import mobilenet from . import _utils as det_utils @@ -29,7 +29,7 @@ def _prediction_block( ) -> nn.Sequential: return nn.Sequential( # 3x3 depthwise with stride 1 and padding 1 - ConvNormActivation( + Conv2dNormActivation( in_channels, in_channels, kernel_size=kernel_size, @@ -47,11 +47,11 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., intermediate_channels = out_channels // 2 return nn.Sequential( # 1x1 projection to half output channels - ConvNormActivation( + Conv2dNormActivation( in_channels, intermediate_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation ), # 3x3 depthwise with stride 2 and padding 1 - ConvNormActivation( + Conv2dNormActivation( intermediate_channels, intermediate_channels, kernel_size=3, @@ -61,7 +61,7 @@ def _extra_block(in_channels: int, out_channels: int, norm_layer: Callable[..., activation_layer=activation, ), # 1x1 projetion to output channels - ConvNormActivation( + Conv2dNormActivation( intermediate_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=activation ), ) diff --git a/torchvision/models/efficientnet.py b/torchvision/models/efficientnet.py index f7eba46cb39..f245d00cffe 100644 --- a/torchvision/models/efficientnet.py +++ b/torchvision/models/efficientnet.py @@ -8,7 +8,7 @@ from torchvision.ops import StochasticDepth from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -104,7 +104,7 @@ def __init__( expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) if expanded_channels != cnf.input_channels: layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.input_channels, expanded_channels, kernel_size=1, @@ -115,7 +115,7 @@ def __init__( # depthwise layers.append( - ConvNormActivation( + Conv2dNormActivation( expanded_channels, expanded_channels, kernel_size=cnf.kernel, @@ -132,7 +132,7 @@ def __init__( # project layers.append( - ConvNormActivation( + Conv2dNormActivation( expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None ) ) @@ -193,7 +193,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU ) ) @@ -224,7 +224,7 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 4 * lastconv_input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index e24c5962d7e..930f68d13e9 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -6,7 +6,7 @@ from torch import nn from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -20,11 +20,11 @@ # necessary for backwards compatibility -class _DeprecatedConvBNAct(ConvNormActivation): +class _DeprecatedConvBNAct(Conv2dNormActivation): def __init__(self, *args, **kwargs): warnings.warn( "The ConvBNReLU/ConvBNActivation classes are deprecated since 0.12 and will be removed in 0.14. " - "Use torchvision.ops.misc.ConvNormActivation instead.", + "Use torchvision.ops.misc.Conv2dNormActivation instead.", FutureWarning, ) if kwargs.get("norm_layer", None) is None: @@ -56,12 +56,12 @@ def __init__( if expand_ratio != 1: # pw layers.append( - ConvNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) + Conv2dNormActivation(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6) ) layers.extend( [ # dw - ConvNormActivation( + Conv2dNormActivation( hidden_dim, hidden_dim, stride=stride, @@ -144,7 +144,7 @@ def __init__( input_channel = _make_divisible(input_channel * width_mult, round_nearest) self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) features: List[nn.Module] = [ - ConvNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) + Conv2dNormActivation(3, input_channel, stride=2, norm_layer=norm_layer, activation_layer=nn.ReLU6) ] # building inverted residual blocks for t, c, n, s in inverted_residual_setting: @@ -155,7 +155,7 @@ def __init__( input_channel = output_channel # building last several layers features.append( - ConvNormActivation( + Conv2dNormActivation( input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.ReLU6 ) ) diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 711888b7c8b..530467d6d53 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -6,7 +6,7 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation as SElayer +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation as SElayer from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -83,7 +83,7 @@ def __init__( # expand if cnf.expanded_channels != cnf.input_channels: layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.input_channels, cnf.expanded_channels, kernel_size=1, @@ -95,7 +95,7 @@ def __init__( # depthwise stride = 1 if cnf.dilation > 1 else cnf.stride layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, @@ -112,7 +112,7 @@ def __init__( # project layers.append( - ConvNormActivation( + Conv2dNormActivation( cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None ) ) @@ -172,7 +172,7 @@ def __init__( # building first layer firstconv_output_channels = inverted_residual_setting[0].input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( 3, firstconv_output_channels, kernel_size=3, @@ -190,7 +190,7 @@ def __init__( lastconv_input_channels = inverted_residual_setting[-1].out_channels lastconv_output_channels = 6 * lastconv_input_channels layers.append( - ConvNormActivation( + Conv2dNormActivation( lastconv_input_channels, lastconv_output_channels, kernel_size=1, diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 18aa25df625..03ade47247a 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -6,7 +6,7 @@ from torch import Tensor from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn.modules.instancenorm import InstanceNorm2d -from torchvision.ops import ConvNormActivation +from torchvision.ops import Conv2dNormActivation from ..._internally_replaced_utils import load_state_dict_from_url from ...utils import _log_api_usage_once @@ -38,17 +38,17 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): # and frozen for the rest of the training process (i.e. set as eval()). The bias term is thus still useful # for the rest of the datasets. Technically, we could remove the bias for other norm layers like Instance norm # because these aren't frozen, but we don't bother (also, we woudn't be able to load the original weights). - self.convnormrelu1 = ConvNormActivation( + self.convnormrelu1 = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True ) - self.convnormrelu2 = ConvNormActivation( + self.convnormrelu2 = Conv2dNormActivation( out_channels, out_channels, norm_layer=norm_layer, kernel_size=3, bias=True ) if stride == 1: self.downsample = nn.Identity() else: - self.downsample = ConvNormActivation( + self.downsample = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, @@ -77,13 +77,13 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): super().__init__() # See note in ResidualBlock for the reason behind bias=True - self.convnormrelu1 = ConvNormActivation( + self.convnormrelu1 = Conv2dNormActivation( in_channels, out_channels // 4, norm_layer=norm_layer, kernel_size=1, bias=True ) - self.convnormrelu2 = ConvNormActivation( + self.convnormrelu2 = Conv2dNormActivation( out_channels // 4, out_channels // 4, norm_layer=norm_layer, kernel_size=3, stride=stride, bias=True ) - self.convnormrelu3 = ConvNormActivation( + self.convnormrelu3 = Conv2dNormActivation( out_channels // 4, out_channels, norm_layer=norm_layer, kernel_size=1, bias=True ) self.relu = nn.ReLU(inplace=True) @@ -91,7 +91,7 @@ def __init__(self, in_channels, out_channels, *, norm_layer, stride=1): if stride == 1: self.downsample = nn.Identity() else: - self.downsample = ConvNormActivation( + self.downsample = Conv2dNormActivation( in_channels, out_channels, norm_layer=norm_layer, @@ -124,7 +124,7 @@ def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_l assert len(layers) == 5 # See note in ResidualBlock for the reason behind bias=True - self.convnormrelu = ConvNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True) + self.convnormrelu = Conv2dNormActivation(3, layers[0], norm_layer=norm_layer, kernel_size=7, stride=2, bias=True) self.layer1 = self._make_2_blocks(block, layers[0], layers[1], norm_layer=norm_layer, first_stride=1) self.layer2 = self._make_2_blocks(block, layers[1], layers[2], norm_layer=norm_layer, first_stride=2) @@ -170,17 +170,17 @@ def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128 assert len(flow_layers) == 2 assert len(corr_layers) in (1, 2) - self.convcorr1 = ConvNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) + self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) if len(corr_layers) == 2: - self.convcorr2 = ConvNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) + self.convcorr2 = Conv2dNormActivation(corr_layers[0], corr_layers[1], norm_layer=None, kernel_size=3) else: self.convcorr2 = nn.Identity() - self.convflow1 = ConvNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) - self.convflow2 = ConvNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) + self.convflow1 = Conv2dNormActivation(2, flow_layers[0], norm_layer=None, kernel_size=7) + self.convflow2 = Conv2dNormActivation(flow_layers[0], flow_layers[1], norm_layer=None, kernel_size=3) # out_channels - 2 because we cat the flow (2 channels) at the end - self.conv = ConvNormActivation( + self.conv = Conv2dNormActivation( corr_layers[-1] + flow_layers[-1], out_channels - 2, norm_layer=None, kernel_size=3 ) @@ -301,7 +301,7 @@ class MaskPredictor(nn.Module): def __init__(self, *, in_channels, hidden_size, multiplier=0.25): super().__init__() - self.convrelu = ConvNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) + self.convrelu = Conv2dNormActivation(in_channels, hidden_size, norm_layer=None, kernel_size=3) # 8 * 8 * 9 because the predicted flow is downsampled by 8, from the downsampling of the initial FeatureEncoder # and we interpolate with all 9 surrounding neighbors. See paper and appendix B. self.conv = nn.Conv2d(hidden_size, 8 * 8 * 9, 1, padding=0) diff --git a/torchvision/models/quantization/mobilenetv2.py b/torchvision/models/quantization/mobilenetv2.py index b1e38f2cbbb..8cd9f16d13e 100644 --- a/torchvision/models/quantization/mobilenetv2.py +++ b/torchvision/models/quantization/mobilenetv2.py @@ -6,7 +6,7 @@ from torchvision.models.mobilenetv2 import InvertedResidual, MobileNetV2, model_urls from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation +from ...ops.misc import Conv2dNormActivation from .utils import _fuse_modules, _replace_relu, quantize_model @@ -54,7 +54,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self, is_qat: Optional[bool] = None) -> None: for m in self.modules(): - if type(m) is ConvNormActivation: + if type(m) is Conv2dNormActivation: _fuse_modules(m, ["0", "1", "2"], is_qat, inplace=True) if type(m) is QuantizableInvertedResidual: m.fuse_model(is_qat) diff --git a/torchvision/models/quantization/mobilenetv3.py b/torchvision/models/quantization/mobilenetv3.py index 2f58cd96ace..4d7e2f7baad 100644 --- a/torchvision/models/quantization/mobilenetv3.py +++ b/torchvision/models/quantization/mobilenetv3.py @@ -5,7 +5,7 @@ from torch.ao.quantization import QuantStub, DeQuantStub from ..._internally_replaced_utils import load_state_dict_from_url -from ...ops.misc import ConvNormActivation, SqueezeExcitation +from ...ops.misc import Conv2dNormActivation, SqueezeExcitation from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, model_urls, _mobilenet_v3_conf from .utils import _fuse_modules, _replace_relu @@ -103,7 +103,7 @@ def forward(self, x: Tensor) -> Tensor: def fuse_model(self, is_qat: Optional[bool] = None) -> None: for m in self.modules(): - if type(m) is ConvNormActivation: + if type(m) is Conv2dNormActivation: modules_to_fuse = ["0", "1"] if len(m) == 3 and type(m[2]) is nn.ReLU: modules_to_fuse.append("2") diff --git a/torchvision/models/regnet.py b/torchvision/models/regnet.py index 3f393c8e82d..74abd20b237 100644 --- a/torchvision/models/regnet.py +++ b/torchvision/models/regnet.py @@ -12,7 +12,7 @@ from torch import nn, Tensor from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation, SqueezeExcitation +from ..ops.misc import Conv2dNormActivation, SqueezeExcitation from ..utils import _log_api_usage_once from ._utils import _make_divisible @@ -55,7 +55,7 @@ } -class SimpleStemIN(ConvNormActivation): +class SimpleStemIN(Conv2dNormActivation): """Simple stem for ImageNet: 3x3, BN, ReLU.""" def __init__( @@ -88,10 +88,10 @@ def __init__( w_b = int(round(width_out * bottleneck_multiplier)) g = w_b // group_width - layers["a"] = ConvNormActivation( + layers["a"] = Conv2dNormActivation( width_in, w_b, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=activation_layer ) - layers["b"] = ConvNormActivation( + layers["b"] = Conv2dNormActivation( w_b, w_b, kernel_size=3, stride=stride, groups=g, norm_layer=norm_layer, activation_layer=activation_layer ) @@ -105,7 +105,7 @@ def __init__( activation=activation_layer, ) - layers["c"] = ConvNormActivation( + layers["c"] = Conv2dNormActivation( w_b, width_out, kernel_size=1, stride=1, norm_layer=norm_layer, activation_layer=None ) super().__init__(layers) @@ -131,7 +131,7 @@ def __init__( self.proj = None should_proj = (width_in != width_out) or (stride != 1) if should_proj: - self.proj = ConvNormActivation( + self.proj = Conv2dNormActivation( width_in, width_out, kernel_size=1, stride=stride, norm_layer=norm_layer, activation_layer=None ) self.f = BottleneckTransform( diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index b36658e34d8..29f756ccbe5 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -7,7 +7,7 @@ import torch.nn as nn from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation +from ..ops.misc import Conv2dNormActivation from ..utils import _log_api_usage_once __all__ = [ @@ -163,7 +163,7 @@ def __init__( for i, conv_stem_layer_config in enumerate(conv_stem_configs): seq_proj.add_module( f"conv_bn_relu_{i}", - ConvNormActivation( + Conv2dNormActivation( in_channels=prev_channels, out_channels=conv_stem_layer_config.out_channels, kernel_size=conv_stem_layer_config.kernel_size, diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index 8ba10080c1f..712127efb08 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -14,7 +14,7 @@ from .feature_pyramid_network import FeaturePyramidNetwork from .focal_loss import sigmoid_focal_loss from .giou_loss import generalized_box_iou_loss -from .misc import FrozenBatchNorm2d, ConvNormActivation, SqueezeExcitation +from .misc import FrozenBatchNorm2d, ConvNormActivation, Conv2dNormActivation, SqueezeExcitation from .poolers import MultiScaleRoIAlign from .ps_roi_align import ps_roi_align, PSRoIAlign from .ps_roi_pool import ps_roi_pool, PSRoIPool @@ -52,6 +52,7 @@ "StochasticDepth", "FrozenBatchNorm2d", "ConvNormActivation", + "Conv2dNormActivation", "SqueezeExcitation", "generalized_box_iou_loss", ] diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 268962e204c..3f120d65524 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -1,5 +1,5 @@ from typing import Callable, List, Optional - +import warnings import torch from torch import Tensor @@ -65,7 +65,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps})" -class ConvNormActivation(torch.nn.Sequential): +class Conv2dNormActivation(torch.nn.Sequential): """ Configurable block used for Convolution-Normalzation-Activation blocks. @@ -124,6 +124,15 @@ def __init__( self.out_channels = out_channels +class ConvNormActivation(Conv2dNormActivation): + def __init__(self, *args, **kwargs): + warnings.warn( + "The ConvNormActivation class are deprecated since 0.13 and will be removed in 0.15. " + "Use torchvision.ops.misc.ConvNormActivation instead.", + FutureWarning, + ) + super().__init__(*args, **kwargs) + class SqueezeExcitation(torch.nn.Module): """ This block implements the Squeeze-and-Excitation block from https://arxiv.org/abs/1709.01507 (see Fig. 1).