Skip to content

Commit 5f5e8e8

Browse files
authored
[Refactoring] Add Caffe2Xavier Initializer (#902)
* [Refactoring] Add Caffe2Xavier Initializer * fix lint
1 parent 933b052 commit 5f5e8e8

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

mmcv/cnn/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
build_upsample_layer, conv_ws_2d, is_norm)
1414
# yapf: enable
1515
from .resnet import ResNet, make_res_layer
16-
from .utils import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
17-
PretrainedInit, UniformInit, XavierInit,
16+
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
17+
NormalInit, PretrainedInit, UniformInit, XavierInit,
1818
bias_init_with_prob, caffe2_xavier_init, constant_init,
1919
fuse_conv_bn, get_model_complexity_info, initialize,
2020
kaiming_init, normal_init, uniform_init, xavier_init)
@@ -33,5 +33,6 @@
3333
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
3434
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
3535
'MaxPool3d', 'Conv3d', 'initialize', 'INITIALIZERS', 'ConstantInit',
36-
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit'
36+
'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
37+
'Caffe2XavierInit'
3738
]

mmcv/cnn/utils/__init__.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Copyright (c) Open-MMLab. All rights reserved.
22
from .flops_counter import get_model_complexity_info
33
from .fuse_conv_bn import fuse_conv_bn
4-
from .weight_init import (INITIALIZERS, ConstantInit, KaimingInit, NormalInit,
5-
PretrainedInit, UniformInit, XavierInit,
6-
bias_init_with_prob, caffe2_xavier_init,
4+
from .weight_init import (INITIALIZERS, Caffe2XavierInit, ConstantInit,
5+
KaimingInit, NormalInit, PretrainedInit, UniformInit,
6+
XavierInit, bias_init_with_prob, caffe2_xavier_init,
77
constant_init, initialize, kaiming_init, normal_init,
88
uniform_init, xavier_init)
99

@@ -12,5 +12,5 @@
1212
'constant_init', 'kaiming_init', 'normal_init', 'uniform_init',
1313
'xavier_init', 'fuse_conv_bn', 'initialize', 'INITIALIZERS',
1414
'ConstantInit', 'XavierInit', 'NormalInit', 'UniformInit', 'KaimingInit',
15-
'PretrainedInit'
15+
'PretrainedInit', 'Caffe2XavierInit'
1616
]

mmcv/cnn/utils/weight_init.py

+16
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,22 @@ def init(m):
298298
module.apply(init)
299299

300300

301+
@INITIALIZERS.register_module(name='Caffe2Xavier')
302+
class Caffe2XavierInit(KaimingInit):
303+
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
304+
# Acknowledgment to FAIR's internal code
305+
def __init__(self, **kwargs):
306+
super().__init__(
307+
a=1,
308+
mode='fan_in',
309+
nonlinearity='leaky_relu',
310+
distribution='uniform',
311+
**kwargs)
312+
313+
def __call__(self, module):
314+
super().__call__(module)
315+
316+
301317
@INITIALIZERS.register_module(name='Pretrained')
302318
class PretrainedInit(object):
303319
"""Initialize module by loading a pretrained model.

tests/test_cnn/test_weight_init.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import torch
77
from torch import nn
88

9-
from mmcv.cnn import (ConstantInit, KaimingInit, NormalInit, PretrainedInit,
10-
UniformInit, XavierInit, bias_init_with_prob,
11-
caffe2_xavier_init, constant_init, initialize,
12-
kaiming_init, normal_init, uniform_init, xavier_init)
9+
from mmcv.cnn import (Caffe2XavierInit, ConstantInit, KaimingInit, NormalInit,
10+
PretrainedInit, UniformInit, XavierInit,
11+
bias_init_with_prob, caffe2_xavier_init, constant_init,
12+
initialize, kaiming_init, normal_init, uniform_init,
13+
xavier_init)
1314

1415

1516
def test_constant_init():
@@ -219,6 +220,15 @@ def test_kaiminginit():
219220
assert torch.equal(model[2].bias, torch.full(model[2].bias.shape, 10.))
220221

221222

223+
def test_caffe2xavierinit():
224+
"""test Caffe2XavierInit."""
225+
model = nn.Sequential(nn.Conv2d(3, 1, 3), nn.ReLU(), nn.Linear(1, 2))
226+
func = Caffe2XavierInit(bias=0.1, layer='Conv2d')
227+
func(model)
228+
assert torch.equal(model[0].bias, torch.full(model[0].bias.shape, 0.1))
229+
assert not torch.equal(model[2].bias, torch.full(model[2].bias.shape, 0.1))
230+
231+
222232
class FooModule(nn.Module):
223233

224234
def __init__(self):

0 commit comments

Comments
 (0)