Skip to content

Commit

Permalink
fix update params (apache#14218)
Browse files Browse the repository at this point in the history
  • Loading branch information
roywei authored and haohuw committed Jun 23, 2019
1 parent d6fc8ef commit a33aae2
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,10 @@ def _update_params(param_arrays, grad_arrays, updater, num_device,
w, g = p
updates[k].append((index*num_device+k, g, w))
for dev_updates in updates:
i, w, g = zip(*dev_updates)
updater(i, w, g)
# update params if param_arrays and grad_arrays are not empty
if dev_updates:
i, w, g = zip(*dev_updates)
updater(i, w, g)


def _multiple_callbacks(callbacks, *args, **kwargs):
Expand Down
14 changes: 14 additions & 0 deletions tests/python/unittest/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,20 @@ def sym_gen(_):
assert(mod._curr_module._exec_group.execs[0].grad_dict['a'].asscalar() == 2 * batch_size)


def test_module_update_no_pragram():
# test module to do update on layers without params
data_shape = (10, 10)
data = mx.sym.Variable('data')
out = mx.sym.Dropout(data, 0.5)
mod = mx.mod.Module(out)
mod.bind(data_shapes=[('data', data_shape)])
mod.init_params()
mod.init_optimizer()
data_batch = mx.io.DataBatch([nd.ones(data_shape)])
mod.forward_backward(data_batch)
mod.update()
assert(mod.get_outputs()[0].shape == data_shape)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit a33aae2

Please sign in to comment.