diff --git a/classy_vision/models/r2plus1_util.py b/classy_vision/models/r2plus1_util.py new file mode 100644 index 0000000000..561c5973cc --- /dev/null +++ b/classy_vision/models/r2plus1_util.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch.nn as nn + + +def r2plus1_unit( + dim_in, + dim_out, + temporal_stride, + spatial_stride, + groups, + inplace_relu, + bn_eps, + bn_mmt, + dim_mid=None, +): + """ + Implementation of `R(2+1)D unit `_. + Decompose one 3D conv into one 2D spatial conv and one 1D temporal conv. + Choose the middle dimensionality so that the total No. of parameters + in 2D spatial conv and 1D temporal conv is unchanged. + + Args: + dim_in (int): the channel dimensions of the input. + dim_out (int): the channel dimension of the output. + temporal_stride (int): the temporal stride of the bottleneck. + spatial_stride (int): the spatial_stride of the bottleneck. + groups (int): number of groups for the convolution. + inplace_relu (bool): calculate the relu on the original input + without allocating new memory. + bn_eps (float): epsilon for batch norm. + bn_mmt (float): momentum for batch norm. Noted that BN momentum in + PyTorch = 1 - BN momentum in Caffe2. + dim_mid (Optional[int]): If not None, use the provided channel dimension + for the output of the 2D spatial conv. If None, compute the output + channel dimension of the 2D spatial conv so that the total No. of + model parameters remains unchanged. + """ + if dim_mid is None: + dim_mid = int(dim_out * dim_in * 3 * 3 * 3 / (dim_in * 3 * 3 + dim_out * 3)) + logging.info( + "dim_in: %d, dim_out: %d. Set dim_mid to %d" % (dim_in, dim_out, dim_mid) + ) + # 1x3x3 group conv, BN, ReLU + conv_middle = nn.Conv3d( + dim_in, + dim_mid, + [1, 3, 3], # kernel + stride=[1, spatial_stride, spatial_stride], + padding=[0, 1, 1], + groups=groups, + bias=False, + ) + conv_middle_bn = nn.BatchNorm3d(dim_mid, eps=bn_eps, momentum=bn_mmt) + conv_middle_relu = nn.ReLU(inplace=inplace_relu) + # 3x1x1 group conv + conv = nn.Conv3d( + dim_mid, + dim_out, + [3, 1, 1], # kernel + stride=[temporal_stride, 1, 1], + padding=[1, 0, 0], + groups=groups, + bias=False, + ) + return nn.Sequential(conv_middle, conv_middle_bn, conv_middle_relu, conv) diff --git a/classy_vision/models/resnext3d.py b/classy_vision/models/resnext3d.py index 4292754a1f..f77301f5c7 100644 --- a/classy_vision/models/resnext3d.py +++ b/classy_vision/models/resnext3d.py @@ -35,6 +35,7 @@ def __init__( stem_temporal_kernel, stem_spatial_kernel, stem_maxpool, + use_r2plus1, ): """ ResNeXt3DBase implements everything in ResNeXt3D model except the @@ -55,6 +56,7 @@ def __init__( input_planes, stem_planes, stem_maxpool, + use_r2plus1, ) @classmethod @@ -118,6 +120,8 @@ def _parse_config(cls, config): "width_per_group": config.get("width_per_group", 64), } ) + # Default setting for both model stem and model stages + ret_config.update({"use_r2plus1": config.get("use_r2plus1", False)}) # Default setting for model parameter initialization ret_config.update( { @@ -301,6 +305,7 @@ class ResNeXt3D(ResNeXt3DBase): The model consists of one stem, a number of stages, and one or multiple heads that are attached to different blocks in the stage. """ + def __init__( self, input_key, @@ -323,6 +328,7 @@ def __init__( num_groups, width_per_group, zero_init_residual_transform, + use_r2plus1, ): """ Args: @@ -362,6 +368,8 @@ def __init__( operation, which could be either BatchNorm3D in post-activated transformation or Conv3D in pre-activated transformation, in the residual transformation is initialized to zero + use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D + spatial conv and one 1D temporal conv """ super(ResNeXt3D, self).__init__( input_key, @@ -374,6 +382,7 @@ def __init__( stem_temporal_kernel, stem_spatial_kernel, stem_maxpool, + use_r2plus1, ) num_stages = len(num_blocks) @@ -401,6 +410,7 @@ def __init__( block_callback=self.build_attachable_block, disable_pre_activation=(s == 0), final_stage=(s == (num_stages - 1)), + use_r2plus1=use_r2plus1, ) stages.append(stage) diff --git a/classy_vision/models/resnext3d_block.py b/classy_vision/models/resnext3d_block.py index 2147507e76..409409ee85 100644 --- a/classy_vision/models/resnext3d_block.py +++ b/classy_vision/models/resnext3d_block.py @@ -6,6 +6,8 @@ import torch.nn as nn +from .r2plus1_util import r2plus1_unit + class BasicTransformation(nn.Module): """ @@ -22,6 +24,7 @@ def __init__( inplace_relu=True, bn_eps=1e-5, bn_mmt=0.1, + use_r2plus1=False, **kwargs ): """ @@ -36,31 +39,65 @@ def __init__( bn_eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. + use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D + spatial conv and one 1D temporal conv """ super(BasicTransformation, self).__init__() - # 3x3x3 group conv, BN, ReLU. - branch2a = nn.Conv3d( - dim_in, - dim_out, - [3, 3, 3], # kernel - stride=[temporal_stride, spatial_stride, spatial_stride], - padding=[1, 1, 1], - groups=groups, - bias=False, - ) + if not use_r2plus1: + # 3x3x3 group conv, BN, ReLU. + branch2a = nn.Conv3d( + dim_in, + dim_out, + [3, 3, 3], # kernel + stride=[temporal_stride, spatial_stride, spatial_stride], + padding=[1, 1, 1], + groups=groups, + bias=False, + ) + else: + # Implementation of R(2+1)D operation . + # decompose the original 3D conv into one 2D spatial conv and one + # 1D temporal conv + branch2a = r2plus1_unit( + dim_in, + dim_out, + temporal_stride, + spatial_stride, + groups, + inplace_relu, + bn_eps, + bn_mmt, + ) branch2a_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt) branch2a_relu = nn.ReLU(inplace=inplace_relu) - # 3x3x3 group conv, BN, ReLU. - branch2b = nn.Conv3d( - dim_out, - dim_out, - [3, 3, 3], # kernel - stride=[1, 1, 1], - padding=[1, 1, 1], - groups=groups, - bias=False, - ) + + if not use_r2plus1: + # 3x3x3 group conv, BN, ReLU. + branch2b = nn.Conv3d( + dim_out, + dim_out, + [3, 3, 3], # kernel + stride=[1, 1, 1], + padding=[1, 1, 1], + groups=groups, + bias=False, + ) + else: + # Implementation of R(2+1)D operation . + # decompose the original 3D conv into one 1x3x3 group conv and one + # 3x1x1 group conv + branch2b = r2plus1_unit( + dim_out, + dim_out, + 1, # temporal_stride + 1, # spatial_stride + groups, + inplace_relu, + bn_eps, + bn_mmt, + ) + branch2b_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt) branch2b_bn.final_transform_op = True @@ -398,6 +435,7 @@ def __init__( bn_eps=1e-5, bn_mmt=0.1, disable_pre_activation=False, + use_r2plus1=False, ): """ ResBlock class constructs redisual blocks. More details can be found in: @@ -422,6 +460,8 @@ def __init__( ResNeXt like networks. disable_pre_activation (bool): If true, disable the preactivation, which includes BatchNorm3D and ReLU. + use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D + spatial conv and one 1D temporal conv """ super(ResBlock, self).__init__() @@ -453,6 +493,7 @@ def __init__( temporal_kernel_size=temporal_kernel_size, temporal_conv_1x1=temporal_conv_1x1, disable_pre_activation=disable_pre_activation, + use_r2plus1=use_r2plus1, ) self.relu = nn.ReLU(inplace_relu) diff --git a/classy_vision/models/resnext3d_stage.py b/classy_vision/models/resnext3d_stage.py index 1597c04219..590214246c 100644 --- a/classy_vision/models/resnext3d_stage.py +++ b/classy_vision/models/resnext3d_stage.py @@ -97,6 +97,7 @@ def __init__( bn_mmt=0.1, disable_pre_activation=False, final_stage=False, + use_r2plus1=False, ): """ The `__init__` method of any subclass should also contain these arguments. @@ -132,6 +133,8 @@ def __init__( disable_pre_activation (bool): If true, disable the preactivation, which includes BatchNorm3D and ReLU. final_stage (bool): If true, this is the last stage in the model. + use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D + spatial conv and one 1D temporal conv """ super(ResStage, self).__init__( stage_idx, @@ -169,6 +172,7 @@ def __init__( bn_eps=bn_eps, bn_mmt=bn_mmt, disable_pre_activation=block_disable_pre_activation, + use_r2plus1=use_r2plus1, ) block_name = self._block_name(p, stage_idx, i) if block_callback: diff --git a/classy_vision/models/resnext3d_stem.py b/classy_vision/models/resnext3d_stem.py index 880fa8bb48..9f741c9fae 100644 --- a/classy_vision/models/resnext3d_stem.py +++ b/classy_vision/models/resnext3d_stem.py @@ -6,6 +6,8 @@ import torch.nn as nn +from .r2plus1_util import r2plus1_unit + class ResNeXt3DStemSinglePathway(nn.Module): """ @@ -25,6 +27,7 @@ def __init__( inplace_relu=True, bn_eps=1e-5, bn_mmt=0.1, + use_r2plus1=False, ): """ The `__init__` method of any subclass should also contain these arguments. @@ -49,6 +52,8 @@ def __init__( bn_eps (float): epsilon for batch norm. bn_mmt (float): momentum for batch norm. Noted that BN momentum in PyTorch = 1 - BN momentum in Caffe2. + use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D + spatial conv and one 1D temporal conv """ super(ResNeXt3DStemSinglePathway, self).__init__() self.kernel = kernel @@ -58,19 +63,36 @@ def __init__( self.bn_eps = bn_eps self.bn_mmt = bn_mmt self.maxpool = maxpool + self.use_r2plus1 = use_r2plus1 # Construct the stem layer. self._construct_stem(dim_in, dim_out) def _construct_stem(self, dim_in, dim_out): - self.conv = nn.Conv3d( - dim_in, - dim_out, - self.kernel, - stride=self.stride, - padding=self.padding, - bias=False, - ) + if not self.use_r2plus1: + self.conv = nn.Conv3d( + dim_in, + dim_out, + self.kernel, + stride=self.stride, + padding=self.padding, + bias=False, + ) + else: + assert ( + self.stride[1] == self.stride[2] + ), "Only support identical height stride and width stride" + self.conv = r2plus1_unit( + dim_in, + dim_out, + self.stride[0], # temporal_stride + self.stride[1], # spatial_stride + 1, # groups + self.inplace_relu, + self.bn_eps, + self.bn_mmt, + dim_mid=45, + ) self.bn = nn.BatchNorm3d(dim_out, eps=self.bn_eps, momentum=self.bn_mmt) self.relu = nn.ReLU(self.inplace_relu) if self.maxpool: @@ -104,6 +126,7 @@ def __init__( bn_eps=1e-5, bn_mmt=0.1, maxpool=(True,), + use_r2plus1=(False,), ): """ The `__init__` method of any subclass should also contain these @@ -131,6 +154,9 @@ def __init__( maxpool (iterable): At training time, when crop size is 224 x 224, do max pooling. When crop size is 112 x 112, skip max pooling. Default value is a (True,) + use_r2plus1 (iterable): If true for one pathway, decompose the original + 3D conv into one 2D spatial conv and one 1D temporal conv in that + pathway """ super(ResNeXt3DStemMultiPathway, self).__init__() @@ -146,6 +172,7 @@ def __init__( self.bn_eps = bn_eps self.bn_mmt = bn_mmt self.maxpool = maxpool + self.use_r2plus1 = use_r2plus1 # Construct the stem layer. self._construct_stem(dim_in, dim_out) @@ -168,6 +195,7 @@ def _construct_stem(self, dim_in, dim_out): bn_eps=self.bn_eps, bn_mmt=self.bn_mmt, maxpool=self.maxpool[p], + use_r2plus1=self.use_r2plus1[p], ) stem_name = self._stem_name(p) self.add_module(stem_name, stem) @@ -188,7 +216,13 @@ def forward(self, x): class ResNeXt3DStem(nn.Module): def __init__( - self, temporal_kernel, spatial_kernel, input_planes, stem_planes, maxpool + self, + temporal_kernel, + spatial_kernel, + input_planes, + stem_planes, + maxpool, + use_r2plus1, ): super(ResNeXt3DStem, self).__init__() self.stem = ResNeXt3DStemMultiPathway( @@ -200,6 +234,7 @@ def __init__( [temporal_kernel // 2, spatial_kernel // 2, spatial_kernel // 2] ], # padding maxpool=[maxpool], + use_r2plus1=[use_r2plus1], ) def forward(self, x):