Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix BatchNorm backward synchronization (#18644)
Browse files Browse the repository at this point in the history
* Add test for BatchNorm running variables synchronization

* Fix BatchNorm backward synchronization

It fixes issue #18610
  • Loading branch information
anko-intel authored Jul 1, 2020
1 parent 2158106 commit 37bed6e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ then set ``gamma`` to 1 and its gradient to 0.
NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_inputs(8)
.set_num_outputs(3)
.set_attr<nnvm::FMutateInputs>("FMutateInputs", [](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{6, 7}; // moving_mean, moving_var
})
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,32 @@ def transpose(shape):
assert (layer(x).shape==ceil_out_shape)


@with_seed()
@pytest.mark.parametrize('variable', ['running_var', 'running_mean'])
def test_batchnorm_backward_synchronization(variable):
"""
Tests if synchronization of BatchNorm running variables is done correctly.
If not, the test sometimes fails - depending on the timing.
"""
ctx = mx.test_utils.default_context()

for _ in range(20):
layer = nn.BatchNorm()
layer.initialize(ctx=ctx)
for _ in range(3):
data = mx.nd.random.normal(loc=10, scale=2, shape=(1, 3, 10, 10), ctx=ctx)
with mx.autograd.record():
out = layer(data)
out.backward()

# check if each read give the same value
var1 = getattr(layer, variable).data().asnumpy()
for _ in range(10):
var2 = getattr(layer, variable).data().asnumpy()
if (var1 != var2).any():
raise AssertionError("Two consecutive reads of " + variable + " give different results")


@with_seed()
def test_batchnorm():
layer = nn.BatchNorm(in_channels=10)
Expand Down

0 comments on commit 37bed6e

Please sign in to comment.