diff --git a/torchvision/models/video/resnet.py b/torchvision/models/video/resnet.py index a9e59a149c0..40875a8c86f 100644 --- a/torchvision/models/video/resnet.py +++ b/torchvision/models/video/resnet.py @@ -1,3 +1,5 @@ +import warnings + import torch import torch.nn as nn @@ -17,7 +19,6 @@ class Conv3DSimple(nn.Conv3d): def __init__(self, in_planes, out_planes, - midplanes=None, stride=1, padding=1): @@ -39,9 +40,13 @@ class Conv2Plus1D(nn.Sequential): def __init__(self, in_planes, out_planes, - midplanes, stride=1, - padding=1): + padding=1, + midplanes=None): + + if midplanes is None: + midplanes = (in_planes * out_planes * 3 * 3 * 3) // ( + in_planes * 3 * 3 + 3 * out_planes) super(Conv2Plus1D, self).__init__( nn.Conv3d(in_planes, midplanes, kernel_size=(1, 3, 3), stride=(1, stride, stride), padding=(0, padding, padding), @@ -62,7 +67,6 @@ class Conv3DNoTemporal(nn.Conv3d): def __init__(self, in_planes, out_planes, - midplanes=None, stride=1, padding=1): @@ -84,16 +88,15 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): - midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) super(BasicBlock, self).__init__() self.conv1 = nn.Sequential( - conv_builder(inplanes, planes, midplanes, stride), + conv_builder(inplanes, planes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes), + conv_builder(planes, planes), nn.BatchNorm3d(planes) ) self.relu = nn.ReLU(inplace=True) @@ -120,7 +123,6 @@ class Bottleneck(nn.Module): def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): super(Bottleneck, self).__init__() - midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) # 1x1x1 self.conv1 = nn.Sequential( @@ -130,7 +132,7 @@ def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): ) # Second kernel self.conv2 = nn.Sequential( - conv_builder(planes, planes, midplanes, stride), + conv_builder(planes, planes, stride), nn.BatchNorm3d(planes), nn.ReLU(inplace=True) ) @@ -190,6 +192,10 @@ def __init__(self): class VideoResNet(nn.Module): + # Version 2 adds updated BN params, and + # solves midplane computation + _version = 2 + def __init__(self, block, conv_makers, layers, stem, num_classes=400, zero_init_residual=False): @@ -268,9 +274,42 @@ def _initialize_weights(self): elif isinstance(m, nn.BatchNorm3d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) + # following are v2 updates for maximum reproducibility + m.eps = 1e-3 + m.momentum = 0.9 elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) nn.init.constant_(m.bias, 0) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + version = local_metadata.get("version", None) + assert version in [1, 2] + # the new changes only apply to the R2+1D models + if version == 1 and isinstance(self.layer2[0].conv2[0], Conv2Plus1D): + # V1 of the models had midplanes hard coded into the blocks + # and default BN parameters as in Pytorch. + # All other layer configurations were the same. + self.layer2[0].conv2[0] = Conv2Plus1D(128, 128, midplanes=230) + self.layer3[0].conv2[0] = Conv2Plus1D(256, 256, midplanes=460) + self.layer4[0].conv2[0] = Conv2Plus1D(512, 512, midplanes=921) + + for m in self.modules(): + if isinstance(m, nn.BatchNorm3d): + m.eps = 1e-5 + m.momentum = 0.1 + + # The model is now identical to v1, and must be saved as such. + self._version = 1 + warnings.warn( + "This is an updated version of the R(2+1D) model that was " + "updated following discussion in #1265. The performance " + "deviations are minimal, but this might cause some BW compatibility " + "issues, depending on the models.", UserWarning) + + super(VideoResNet, self)._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, + unexpected_keys, error_msgs) def _video_resnet(arch, pretrained=False, progress=True, **kwargs):