Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Commit

Permalink
R(2+1)D unit
Browse files Browse the repository at this point in the history
Summary:
- Implement R(2+1)D unit (https://arxiv.org/abs/1711.11248).
-  Compared with vanilla R3D model,  in R(2+1)D we replace a 3D conv with a 2D spatial conv and a 1D temporal conv, while the total No. of model parameters are kept the same.
- We use R(2+1)D unit in `ResNeXt3DStem` and `BasicTransformation`.

Differential Revision: D19084922

fbshipit-source-id: 826111323746a268fb2a3bd890b39e08478bab7b
  • Loading branch information
stephenyan1231 authored and facebook-github-bot committed Dec 16, 2019
1 parent 1489190 commit 3ed3b5a
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 29 deletions.
72 changes: 72 additions & 0 deletions classy_vision/models/r2plus1_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import logging

import torch.nn as nn


def r2plus1_unit(
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
dim_mid=None,
):
"""
Implementation of `R(2+1)D unit <https://arxiv.org/abs/1711.11248>`_.
Decompose one 3D conv into one 2D spatial conv and one 1D temporal conv.
Choose the middle dimensionality so that the total No. of parameters
in 2D spatial conv and 1D temporal conv is unchanged.
Args:
dim_in (int): the channel dimensions of the input.
dim_out (int): the channel dimension of the output.
temporal_stride (int): the temporal stride of the bottleneck.
spatial_stride (int): the spatial_stride of the bottleneck.
groups (int): number of groups for the convolution.
inplace_relu (bool): calculate the relu on the original input
without allocating new memory.
bn_eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
dim_mid (Optional[int]): If not None, use the provided channel dimension
for the output of the 2D spatial conv. If None, compute the output
channel dimension of the 2D spatial conv so that the total No. of
model parameters remains unchanged.
"""
if dim_mid is None:
dim_mid = int(dim_out * dim_in * 3 * 3 * 3 / (dim_in * 3 * 3 + dim_out * 3))
logging.info(
"dim_in: %d, dim_out: %d. Set dim_mid to %d" % (dim_in, dim_out, dim_mid)
)
# 1x3x3 group conv, BN, ReLU
conv_middle = nn.Conv3d(
dim_in,
dim_mid,
[1, 3, 3], # kernel
stride=[1, spatial_stride, spatial_stride],
padding=[0, 1, 1],
groups=groups,
bias=False,
)
conv_middle_bn = nn.BatchNorm3d(dim_mid, eps=bn_eps, momentum=bn_mmt)
conv_middle_relu = nn.ReLU(inplace=inplace_relu)
# 3x1x1 group conv
conv = nn.Conv3d(
dim_mid,
dim_out,
[3, 1, 1], # kernel
stride=[temporal_stride, 1, 1],
padding=[1, 0, 0],
groups=groups,
bias=False,
)
return nn.Sequential(conv_middle, conv_middle_bn, conv_middle_relu, conv)
10 changes: 10 additions & 0 deletions classy_vision/models/resnext3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
stem_temporal_kernel,
stem_spatial_kernel,
stem_maxpool,
use_r2plus1,
):
"""
ResNeXt3DBase implements everything in ResNeXt3D model except the
Expand All @@ -55,6 +56,7 @@ def __init__(
input_planes,
stem_planes,
stem_maxpool,
use_r2plus1,
)

@classmethod
Expand Down Expand Up @@ -118,6 +120,8 @@ def _parse_config(cls, config):
"width_per_group": config.get("width_per_group", 64),
}
)
# Default setting for both model stem and model stages
ret_config.update({"use_r2plus1": config.get("use_r2plus1", False)})
# Default setting for model parameter initialization
ret_config.update(
{
Expand Down Expand Up @@ -301,6 +305,7 @@ class ResNeXt3D(ResNeXt3DBase):
The model consists of one stem, a number of stages, and one or multiple
heads that are attached to different blocks in the stage.
"""

def __init__(
self,
input_key,
Expand All @@ -323,6 +328,7 @@ def __init__(
num_groups,
width_per_group,
zero_init_residual_transform,
use_r2plus1,
):
"""
Args:
Expand Down Expand Up @@ -362,6 +368,8 @@ def __init__(
operation, which could be either BatchNorm3D in post-activated
transformation or Conv3D in pre-activated transformation, in the
residual transformation is initialized to zero
use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D
spatial conv and one 1D temporal conv
"""
super(ResNeXt3D, self).__init__(
input_key,
Expand All @@ -374,6 +382,7 @@ def __init__(
stem_temporal_kernel,
stem_spatial_kernel,
stem_maxpool,
use_r2plus1,
)

num_stages = len(num_blocks)
Expand Down Expand Up @@ -401,6 +410,7 @@ def __init__(
block_callback=self.build_attachable_block,
disable_pre_activation=(s == 0),
final_stage=(s == (num_stages - 1)),
use_r2plus1=use_r2plus1,
)
stages.append(stage)

Expand Down
81 changes: 61 additions & 20 deletions classy_vision/models/resnext3d_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import torch.nn as nn

from .r2plus1_util import r2plus1_unit


class BasicTransformation(nn.Module):
"""
Expand All @@ -22,6 +24,7 @@ def __init__(
inplace_relu=True,
bn_eps=1e-5,
bn_mmt=0.1,
use_r2plus1=False,
**kwargs
):
"""
Expand All @@ -36,31 +39,65 @@ def __init__(
bn_eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D
spatial conv and one 1D temporal conv
"""
super(BasicTransformation, self).__init__()

# 3x3x3 group conv, BN, ReLU.
branch2a = nn.Conv3d(
dim_in,
dim_out,
[3, 3, 3], # kernel
stride=[temporal_stride, spatial_stride, spatial_stride],
padding=[1, 1, 1],
groups=groups,
bias=False,
)
if not use_r2plus1:
# 3x3x3 group conv, BN, ReLU.
branch2a = nn.Conv3d(
dim_in,
dim_out,
[3, 3, 3], # kernel
stride=[temporal_stride, spatial_stride, spatial_stride],
padding=[1, 1, 1],
groups=groups,
bias=False,
)
else:
# Implementation of R(2+1)D operation <https://arxiv.org/abs/1711.11248>.
# decompose the original 3D conv into one 2D spatial conv and one
# 1D temporal conv
branch2a = r2plus1_unit(
dim_in,
dim_out,
temporal_stride,
spatial_stride,
groups,
inplace_relu,
bn_eps,
bn_mmt,
)
branch2a_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)
branch2a_relu = nn.ReLU(inplace=inplace_relu)
# 3x3x3 group conv, BN, ReLU.
branch2b = nn.Conv3d(
dim_out,
dim_out,
[3, 3, 3], # kernel
stride=[1, 1, 1],
padding=[1, 1, 1],
groups=groups,
bias=False,
)

if not use_r2plus1:
# 3x3x3 group conv, BN, ReLU.
branch2b = nn.Conv3d(
dim_out,
dim_out,
[3, 3, 3], # kernel
stride=[1, 1, 1],
padding=[1, 1, 1],
groups=groups,
bias=False,
)
else:
# Implementation of R(2+1)D operation <https://arxiv.org/abs/1711.11248>.
# decompose the original 3D conv into one 1x3x3 group conv and one
# 3x1x1 group conv
branch2b = r2plus1_unit(
dim_out,
dim_out,
1, # temporal_stride
1, # spatial_stride
groups,
inplace_relu,
bn_eps,
bn_mmt,
)

branch2b_bn = nn.BatchNorm3d(dim_out, eps=bn_eps, momentum=bn_mmt)
branch2b_bn.final_transform_op = True

Expand Down Expand Up @@ -398,6 +435,7 @@ def __init__(
bn_eps=1e-5,
bn_mmt=0.1,
disable_pre_activation=False,
use_r2plus1=False,
):
"""
ResBlock class constructs redisual blocks. More details can be found in:
Expand All @@ -422,6 +460,8 @@ def __init__(
ResNeXt like networks.
disable_pre_activation (bool): If true, disable the preactivation,
which includes BatchNorm3D and ReLU.
use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D
spatial conv and one 1D temporal conv
"""
super(ResBlock, self).__init__()

Expand Down Expand Up @@ -453,6 +493,7 @@ def __init__(
temporal_kernel_size=temporal_kernel_size,
temporal_conv_1x1=temporal_conv_1x1,
disable_pre_activation=disable_pre_activation,
use_r2plus1=use_r2plus1,
)
self.relu = nn.ReLU(inplace_relu)

Expand Down
4 changes: 4 additions & 0 deletions classy_vision/models/resnext3d_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(
bn_mmt=0.1,
disable_pre_activation=False,
final_stage=False,
use_r2plus1=False,
):
"""
The `__init__` method of any subclass should also contain these arguments.
Expand Down Expand Up @@ -132,6 +133,8 @@ def __init__(
disable_pre_activation (bool): If true, disable the preactivation,
which includes BatchNorm3D and ReLU.
final_stage (bool): If true, this is the last stage in the model.
use_r2plus1 (bool): If true, decompose the original 3D conv into one 2D
spatial conv and one 1D temporal conv
"""
super(ResStage, self).__init__(
stage_idx,
Expand Down Expand Up @@ -169,6 +172,7 @@ def __init__(
bn_eps=bn_eps,
bn_mmt=bn_mmt,
disable_pre_activation=block_disable_pre_activation,
use_r2plus1=use_r2plus1,
)
block_name = self._block_name(p, stage_idx, i)
if block_callback:
Expand Down
Loading

0 comments on commit 3ed3b5a

Please sign in to comment.