From 8999c32799b10021413adb5b180a91d930674668 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 12 Jul 2018 20:25:29 +0000 Subject: [PATCH 1/5] fix a bug. --- src/imperative/imperative_utils.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 2331d7be155c..6daf96e60d0b 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -963,13 +963,13 @@ inline void CreateEngineOpSeg( seg_execs.push_back(exec); auto& seg = (*opr_segs)[nid]; - if (is_async) { - seg = EngineOprSeg{false, nid + 1}; - seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); + if (!valid) { + seg = EngineOprSeg{false, nid + 1, nullptr}; seg_execs.clear(); seg_start = nid + 1; - } else if (!valid) { - seg = EngineOprSeg{false, nid + 1, nullptr}; + } else if (is_async) { + seg = EngineOprSeg{false, nid + 1}; + seg.opr.reset(CreateEngineOp(default_ctx, seg_execs)); seg_execs.clear(); seg_start = nid + 1; } From f8e3280e79559b25984191c4338d189a9a030db4 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Thu, 12 Jul 2018 21:06:28 +0000 Subject: [PATCH 2/5] add tests. --- tests/python/unittest/test_gluon_rnn.py | 42 ++++++++++++++----------- 1 file changed, 24 insertions(+), 18 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index 6167f660d2c1..d9e1f58a88dd 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -66,25 +66,31 @@ def check_contrib_rnn(cell_type, num_states): res1.backward() trainer.step(batch_size) - layer = TestRNNLayer(cell_type, hidden_size) - layer.initialize(ctx=mx.cpu(0)) - layer.hybridize() - res2 = layer(rnn_data, states) - params2 = layer.collect_params() - for key, val in orig_params1.items(): - params2[key].set_data(val.data()) - - trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) - with mx.autograd.record(): + configs = [ + {}, + {'static_alloc': True}, + {'static_alloc': True, 'static_shape': True} ] + for config in configs: + layer = TestRNNLayer(cell_type, hidden_size) + layer.initialize(ctx=mx.cpu(0)) + layer.hybridize(**config) res2 = layer(rnn_data, states) - assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) - res2.backward() - trainer.step(batch_size) - - for key, val in params1.items(): - weight1 = val.data() - weight2 = params2[key].data() - assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), rtol=0.001, atol=0.0001) + params2 = layer.collect_params() + for key, val in orig_params1.items(): + params2[key].set_data(copy.deepcopy(val.data())) + + trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res2 = layer(rnn_data, states) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + res2.backward() + trainer.step(batch_size) + + for key, val in params1.items(): + weight1 = val.data() + weight2 = params2[key].data() + assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), + rtol=0.001, atol=0.0001) def test_contrib_rnn(): From 6204c5197e164a5fa7b3cf81b8bf87aae9933ef3 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 18 Jul 2018 17:28:30 -0700 Subject: [PATCH 3/5] use default context. --- tests/python/unittest/test_gluon_rnn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index d9e1f58a88dd..cef886b95bba 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -55,7 +55,7 @@ def check_contrib_rnn(cell_type, num_states): state_shape = (batch_size, hidden_size) states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] layer = TestRNNLayer(cell_type, hidden_size) - layer.initialize(ctx=mx.cpu(0)) + layer.initialize(ctx=default_context()) res1 = layer(rnn_data, states) params1 = layer.collect_params() orig_params1 = copy.deepcopy(params1) @@ -72,7 +72,7 @@ def check_contrib_rnn(cell_type, num_states): {'static_alloc': True, 'static_shape': True} ] for config in configs: layer = TestRNNLayer(cell_type, hidden_size) - layer.initialize(ctx=mx.cpu(0)) + layer.initialize(ctx=default_context()) layer.hybridize(**config) res2 = layer(rnn_data, states) params2 = layer.collect_params() From 2b3e6e093639ffb4df86e65454179d0dee376c10 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Wed, 18 Jul 2018 17:30:45 -0700 Subject: [PATCH 4/5] move all tests to test_contrib_control_flow.py --- .../unittest/test_contrib_control_flow.py | 532 ++++++++++++++++++ tests/python/unittest/test_gluon_rnn.py | 61 -- tests/python/unittest/test_operator.py | 471 ---------------- 3 files changed, 532 insertions(+), 532 deletions(-) diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 1cc5b21ac86c..0d85c41c0a0c 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -975,6 +975,538 @@ def _func(*states): assert_almost_equal(x, y, rtol=1e-4, atol=1e-4) +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell_type, hidden_size, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell_type(hidden_size, prefix='rnn_') + + def hybrid_forward(self, F, inputs, states): + out, states = F.contrib.foreach(self.cell, inputs, states) + return out + +def check_contrib_rnn(cell_type, num_states): + batch_size = 10 + hidden_size = 100 + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50)) + state_shape = (batch_size, hidden_size) + states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] + layer = TestRNNLayer(cell_type, hidden_size) + layer.initialize(ctx=default_context()) + res1 = layer(rnn_data, states) + params1 = layer.collect_params() + orig_params1 = copy.deepcopy(params1) + + trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res1 = layer(rnn_data, states) + res1.backward() + trainer.step(batch_size) + + configs = [ + {}, + {'static_alloc': True}, + {'static_alloc': True, 'static_shape': True} ] + for config in configs: + layer = TestRNNLayer(cell_type, hidden_size) + layer.initialize(ctx=default_context()) + layer.hybridize(**config) + res2 = layer(rnn_data, states) + params2 = layer.collect_params() + for key, val in orig_params1.items(): + params2[key].set_data(copy.deepcopy(val.data())) + + trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res2 = layer(rnn_data, states) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + res2.backward() + trainer.step(batch_size) + + for key, val in params1.items(): + weight1 = val.data() + weight2 = params2[key].data() + assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), + rtol=0.001, atol=0.0001) + + +def test_contrib_rnn(): + cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), + (gluon.rnn.GRUCell, 1)] + for cell_type, num_states in cell_types: + check_contrib_rnn(cell_type, num_states) + + +@with_seed() +def test_foreach(): + v3 = mx.sym.var("v0") + v4 = mx.sym.var("v1") + v5 = mx.sym.var("v2") + v6 = mx.sym.var("v3") + v7 = mx.sym.var("v4") + v8 = mx.sym.var("v5") + + def verify_foreach(step, in_syms, state_syms, free_syms, + in_arrs, init_states, frees, out_grads, is_train=True, + free_vars_func=None, num_iters=1): + step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms) + res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms) + out = _as_list(res) + num_outputs = len(out) + for i in range(num_outputs): + out[i] = out[i] * 2 + out.extend(states) + out = mx.sym.Group(out) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + arr_grads = [] + arg_dict = {} + arg_grad_dict = {} + i = 0 + for arr in _as_list(in_arrs): + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + for arr in init_states: + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + for arr in frees: + arr_grad = mx.nd.empty(arr.shape) + arr_grads.append(arr_grad) + arg_dict['v'+str(i)] = arr + arg_grad_dict['v'+str(i)] = arr_grad + i = i + 1 + + if is_train: + e = out.bind(ctx=default_context(), args=arg_dict, args_grad=arg_grad_dict) + else: + e = out.bind(ctx=default_context(), args=arg_dict) + # the inputs to forward and backward are the same so forward and backward + # should always return the same outputs. + for i in range(num_iters): + e.forward(is_train=is_train) + if (is_train): + # backward + tmp_grads = out_grads[0][:] + tmp_grads.extend(out_grads[1]) + e.backward(tmp_grads) + + # Below we use imperative to reimplement foreach and compute its gradients. + res = [] + for i in range(len(_as_list(out_grads[0]))): + res.append([]) + for arr in _as_list(in_arrs): + arr.attach_grad() + for arr in init_states: + arr.attach_grad() + for arr in frees: + arr.attach_grad() + with mx.autograd.record(): + frees_imp = frees if free_vars_func is None else free_vars_func(frees) + step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp) + states = [mx.nd.expand_dims(s, 0) for s in init_states] + res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states) + + res2 = _as_list(res) + for i in range(len(res2)): + res2[i] = res2[i] * 2 + outs = [] + outs[:] = res2[:] + if isinstance(states, list): + outs.extend(states) + states = [mx.nd.expand_dims(s, 0) for s in states] + res2.extend(states) + else: + outs.append(states) + states = mx.nd.expand_dims(states, 0) + res2.append(states) + if is_train: + res = mx.nd.concat(*res2, dim=0) + + tmp_grads = out_grads[0][:] + tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]] + tmp_grads.extend(tmp_grads1) + if is_train: + res.backward(mx.nd.concat(*tmp_grads, dim=0)) + for i in range(len(outs)): + assert e.outputs[i].shape == outs[i].shape + assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(), + rtol=0.001, atol=0.0001) + if (is_train): + all_ins = _as_list(in_arrs)[:] + all_ins.extend(init_states) + all_ins.extend(frees) + size = min(len(all_ins), len(e.grad_arrays)) + for i in range(size): + assert_almost_equal(all_ins[i].grad.asnumpy(), + e.grad_arrays[i].asnumpy(), + rtol=0.001, atol=0.0001) + + # Test cases: + # * graph inputs are stored in different orders. + # This is to test if foreach finds the data arrays and weight arrays + # in the right location. + # * the number of iterations: odd or even. + # * multiple inputs and multiple outputs. + # * inference. + def step1(in1, states, free): + out = in1 * 2 + states[0] + free[0] + return (out, [out]) + frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1] + arrs = mx.nd.arange(6).reshape(shape=(3, 2)) + states = [mx.nd.arange(2)] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, + lambda frees : [frees[0] + frees[1]]) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, + lambda frees : [frees[0] + frees[1]]) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, + lambda frees : [frees[0] + frees[1]], 5) + verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, + lambda frees : [frees[0] + frees[1]], 5) + + # Test the even number of iterations. + frees = [mx.nd.random.uniform(shape=(2))] + arrs = mx.nd.random.uniform(shape=(2, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) + # Test the odd number of iterations + arrs = mx.nd.random.uniform(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Reorder the input and state in the subgraph inputs. + def step2(in1, states, free): + out = states[0] + in1 * 2 + free[0] + return (out, [out]) + # Test the even number of iterations. + arrs = mx.nd.random.uniform(shape=(2, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) + # Test the odd number of iterations. + arrs = mx.nd.random.uniform(shape=(3, 2)) + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + def step3(in1, states, free): + out = in1[0] + in1[1] * 2 + states[0] + states[1] * 2 + free[0] + return ([out, out], [out * 2, out * 3]) + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) + verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # The order of subgraph inputs doesn't match the operator inputs + def step4(in1, states, free): + out = in1[1] * 2 + states[0] + free[0] + states[1] * 2 + in1[0] + return ([out, out * 2], [out * 2, out * 3]) + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) + verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # The data inputs and states have different shapes. + def step5(in1, states, free): + if isinstance(in1[0], mx.nd.NDArray): + out1 = mx.nd.broadcast_add(states[0] + free[1], in1[1] * 2) + out2 = mx.nd.broadcast_add(in1[0], free[0] + states[1] * 2) + else: + out1 = mx.sym.broadcast_add(states[0] + free[1], in1[1] * 2) + out2 = mx.sym.broadcast_add(in1[0], free[0] + states[1] * 2) + return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) + frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2, 2))] + arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2, 2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step5, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # The data inputs and states have different shapes and data types. + def step6(in1, states, free): + if isinstance(in1[0], mx.nd.NDArray): + out1 = mx.nd.broadcast_add(states[0] + mx.nd.cast(free[1], 'float32'), + mx.nd.cast(in1[1], 'float32') * 2) + out2 = mx.nd.broadcast_add(in1[0], + free[0] + mx.nd.cast(states[1], 'float32') * 2) + else: + out1 = mx.sym.broadcast_add(states[0] + mx.sym.cast(free[1], 'float32'), + mx.sym.cast(in1[1], 'float32') * 2) + out2 = mx.sym.broadcast_add(in1[0], + free[0] + mx.sym.cast(states[1], 'float32') * 2) + return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) + frees = [mx.nd.random.uniform(shape=(2)), + mx.nd.cast(mx.nd.random.uniform(shape=(2, 2)), 'float64')] + arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), + mx.nd.cast(mx.nd.random.uniform(shape=(3, 2)), dtype='float16')] + states = [mx.nd.random.uniform(shape=(2, 2)), + mx.nd.cast(mx.nd.random.uniform(shape=(2)), dtype='int32')] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step6, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) + + # Test multiple inputs and outputs. + # some of the inputs are used twice. + def step7(in1, states, free): + out1 = states[0] + in1[0] + free[1] + in1[1] * 2 + free[0] + out2 = in1[0] + free[0] + states[1] * 2 + in1[1] + return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) + frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] + verify_foreach(step7, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) + + # Test the case that the output is the input. + arrs = mx.nd.random.uniform(shape=(3, 2)) + states = [mx.nd.arange(2)] + frees = [mx.nd.random.uniform(shape=(2))] + out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], + [mx.nd.random.uniform(-10, 10, states[0].shape)]] + def step8(in1, states, free): + return (in1, [states[0] * free[0]]) + verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads, False) + def step9(in1, states, free): + return (in1 * free[0], states) + verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test the case that not all inputs are used. + def step10(in1, states, free): + return (in1, states) + verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads, False) + def step11(in1, states, free): + return (in1, free) + try: + verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads, False) + except AssertionError: + print("the states have to be used") + def step12(in1, states, free): + return (in1, [states[0] + 1, states[0] + 2]) + states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] + frees = [] + try: + verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads) + verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads, False) + except AssertionError: + print("the states have to be used") + + # test without free variables. + def step13(in1, states, free): + return (in1, states) + states = [mx.nd.random.uniform(shape=(2))] + verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads) + verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads, False) + + # test when there isn't output data or output states. + def step14(in1, states, free): + return (in1 + free[0], []) + frees = [mx.nd.random.uniform(shape=(2))] + verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads) + verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads, False) + def step15(in1, states, free): + return ([], [in1 * states[0] * free[0]]) + out_grads = [[], [mx.nd.random.uniform(-10, 10, states[0].shape)]] + verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads) + verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads, False) + + # Test the case of iterating on a 1D data array. + def step16(in1, states, free): + return ([in1[0] * states[0]], [states[0] * 2]) + arrs = [mx.nd.arange(3)] + states = [mx.nd.random.uniform(shape=(1))] + out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], + [mx.nd.random.uniform(-10, 10, (1))]] + verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads) + verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads, False) + def step17(in1, states, free): + return ([in1[1] * in1[0] * states[0]], [states[0] * 2]) + arrs = [mx.nd.random.uniform(shape=(3, 1)), mx.nd.arange(3)] + states = [mx.nd.random.uniform(shape=(1))] + out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], + [mx.nd.random.uniform(-10, 10, (1))]] + verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads) + verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads, False) + + +@with_seed() +def test_foreach_nested(): + # Test nested foreach. + def step_in(in1, states): + out = in1 * 2 + states[0] + return (out, [out]) + + def step_sym(in1, states): + out1 = mx.sym.contrib.foreach(step_in, in1, states) + out = mx.sym.broadcast_add(out1[0], states[0]) + return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))]) + def step_nd(in1, states): + out1 = mx.nd.contrib.foreach(step_in, in1, states) + out = mx.nd.broadcast_add(out1[0], states[0]) + return (out, [mx.nd.squeeze(mx.nd.slice(out, begin=(0, 0), end=(1, 2)))]) + + data_sym = mx.sym.var("v1") + state_sym = mx.sym.var("v2") + out, states = mx.sym.contrib.foreach(step_sym, data_sym, [state_sym]) + assert isinstance(states, list) + assert len(states) == 1 + out = mx.sym.broadcast_add(out, states[0]) + + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + + data = mx.nd.arange(8).reshape((2, 2, 2)) + state = mx.nd.arange(2) + data_grad = mx.nd.empty(data.shape) + state_grad = mx.nd.empty(state.shape) + e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state}, + args_grad={'v1':data_grad, 'v2':state_grad}) + e.forward(is_train=True) + out_grads = [] + for out in e.outputs: + out_grads.append(mx.nd.random.uniform(shape=out.shape)) + e.backward(out_grads) + + data.attach_grad() + state.attach_grad() + with mx.autograd.record(): + out, states = mx.nd.contrib.foreach(step_nd, data, [state]) + assert isinstance(states, list) + assert len(states) == 1 + res = mx.nd.broadcast_add(out, states[0]) + assert_almost_equal(res.asnumpy(), e.outputs[0].asnumpy(), rtol=0.001, atol=0.0001) + + res.backward(out_grads[0]) + assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy()) + assert_almost_equal(state.grad.asnumpy(), state_grad.asnumpy()) + + +def check_foreach_rnn(cell_type, num_states): + data = mx.sym.var("data") + params = mx.rnn.RNNParams() + hidden_dim = 4 + input_dim = 5 + seq_len = 2 + batch_size = 2 + + # This tests foreach with accumulation sum. + def step(in1, states): + rnn = cell_type(hidden_dim, prefix='', params=params) + next_h, states = rnn(in1, states) + return (next_h, states) + + def sym_group(out): + if (isinstance(out[0], mx.sym.Symbol)): + ret = [out[0]] + else: + ret = out[0] + ret.extend(out[1]) + return mx.sym.Group(ret) + + rnn = cell_type(hidden_dim, prefix='', params=params) + if num_states == 2: + init_states = [mx.sym.var("h"), mx.sym.var("c")] + else: + init_states = [mx.sym.var("h")] + out = mx.sym.contrib.foreach(step, data, init_states) + out = sym_group(out) + arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=(seq_len, batch_size, input_dim), + h=(batch_size, hidden_dim)) + rnn_inputs = out.list_inputs() + + # Inputs + args1 = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + args2 = copy.deepcopy(args1) + # gradients for the backward of the foreach symbol + args_grad1 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + # gradients for the backward of the unrolled symbol. + args_grad2 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + + # Symbol of running LSTM with foreach. + out = mx.sym.contrib.foreach(step, data, init_states) + out = sym_group(out) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) + + # Symbol of running unrolled LSTM. + lstm = cell_type(hidden_dim, prefix='') + unroll_outs = [] + states = init_states + for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): + h, states = lstm(inputs, states) + unroll_outs.append(mx.sym.expand_dims(h, axis=0)) + unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) + unroll_outs.extend(states) + out = mx.sym.Group(unroll_outs) + js_1 = out.tojson() + out = mx.sym.load_json(js_1) + js_2 = out.tojson() + assert js_1 == js_2 + e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) + + for i in range(5): + out_grads = [] + for arr in e1.outputs: + out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) + + args = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} + + e1.forward(is_train=True, **args) + outputs1 = e1.outputs + e1.backward(out_grads) + + e2.forward(is_train=True, **args) + outputs2 = e2.outputs + e2.backward(out_grads) + + for i in range(len(outputs2)): + assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), + rtol=0.001, atol=0.0001) + input_names = out.list_inputs() + for i in range(len(e1.grad_arrays)): + name = input_names[i] + assert_almost_equal(args_grad1[name].asnumpy(), args_grad2[name].asnumpy(), + rtol=0.001, atol=0.0001) + + +@with_seed() +def test_foreach_rnn(): + cell_types = [(mx.rnn.LSTMCell, 2), (mx.rnn.RNNCell, 1), (mx.rnn.GRUCell, 1)] + for cell_type, num_states in cell_types: + check_foreach_rnn(cell_type, num_states) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index cef886b95bba..a9a2904e1e13 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -39,67 +39,6 @@ def test_rnn(): assert outs == [(10, 100), (10, 100), (10, 100)] -class TestRNNLayer(gluon.HybridBlock): - def __init__(self, cell_type, hidden_size, prefix=None, params=None): - super(TestRNNLayer, self).__init__(prefix=prefix, params=params) - self.cell = cell_type(hidden_size, prefix='rnn_') - - def hybrid_forward(self, F, inputs, states): - out, states = F.contrib.foreach(self.cell, inputs, states) - return out - -def check_contrib_rnn(cell_type, num_states): - batch_size = 10 - hidden_size = 100 - rnn_data = mx.nd.normal(loc=0, scale=1, shape=(5, batch_size, 50)) - state_shape = (batch_size, hidden_size) - states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] - layer = TestRNNLayer(cell_type, hidden_size) - layer.initialize(ctx=default_context()) - res1 = layer(rnn_data, states) - params1 = layer.collect_params() - orig_params1 = copy.deepcopy(params1) - - trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) - with mx.autograd.record(): - res1 = layer(rnn_data, states) - res1.backward() - trainer.step(batch_size) - - configs = [ - {}, - {'static_alloc': True}, - {'static_alloc': True, 'static_shape': True} ] - for config in configs: - layer = TestRNNLayer(cell_type, hidden_size) - layer.initialize(ctx=default_context()) - layer.hybridize(**config) - res2 = layer(rnn_data, states) - params2 = layer.collect_params() - for key, val in orig_params1.items(): - params2[key].set_data(copy.deepcopy(val.data())) - - trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) - with mx.autograd.record(): - res2 = layer(rnn_data, states) - assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) - res2.backward() - trainer.step(batch_size) - - for key, val in params1.items(): - weight1 = val.data() - weight2 = params2[key].data() - assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), - rtol=0.001, atol=0.0001) - - -def test_contrib_rnn(): - cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), - (gluon.rnn.GRUCell, 1)] - for cell_type, num_states in cell_types: - check_contrib_rnn(cell_type, num_states) - - def test_lstm(): cell = gluon.rnn.LSTMCell(100, prefix='rnn_') inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 0592e5a93291..59311fccd1db 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -6016,477 +6016,6 @@ def test_float16_min_max(): assert np.finfo('float16').max == mx.nd.max(a).asscalar() -@with_seed() -def test_foreach(): - v3 = mx.sym.var("v0") - v4 = mx.sym.var("v1") - v5 = mx.sym.var("v2") - v6 = mx.sym.var("v3") - v7 = mx.sym.var("v4") - v8 = mx.sym.var("v5") - - def verify_foreach(step, in_syms, state_syms, free_syms, - in_arrs, init_states, frees, out_grads, is_train=True, - free_vars_func=None, num_iters=1): - step_sym = lambda in_syms, state_syms : step(in_syms, state_syms, free_syms) - res, states = mx.sym.contrib.foreach(step_sym, in_syms, state_syms) - out = _as_list(res) - num_outputs = len(out) - for i in range(num_outputs): - out[i] = out[i] * 2 - out.extend(states) - out = mx.sym.Group(out) - js_1 = out.tojson() - out = mx.sym.load_json(js_1) - js_2 = out.tojson() - assert js_1 == js_2 - arr_grads = [] - arg_dict = {} - arg_grad_dict = {} - i = 0 - for arr in _as_list(in_arrs): - arr_grad = mx.nd.empty(arr.shape) - arr_grads.append(arr_grad) - arg_dict['v'+str(i)] = arr - arg_grad_dict['v'+str(i)] = arr_grad - i = i + 1 - for arr in init_states: - arr_grad = mx.nd.empty(arr.shape) - arr_grads.append(arr_grad) - arg_dict['v'+str(i)] = arr - arg_grad_dict['v'+str(i)] = arr_grad - i = i + 1 - for arr in frees: - arr_grad = mx.nd.empty(arr.shape) - arr_grads.append(arr_grad) - arg_dict['v'+str(i)] = arr - arg_grad_dict['v'+str(i)] = arr_grad - i = i + 1 - - if is_train: - e = out.bind(ctx=default_context(), args=arg_dict, args_grad=arg_grad_dict) - else: - e = out.bind(ctx=default_context(), args=arg_dict) - # the inputs to forward and backward are the same so forward and backward - # should always return the same outputs. - for i in range(num_iters): - e.forward(is_train=is_train) - if (is_train): - # backward - tmp_grads = out_grads[0][:] - tmp_grads.extend(out_grads[1]) - e.backward(tmp_grads) - - # Below we use imperative to reimplement foreach and compute its gradients. - res = [] - for i in range(len(_as_list(out_grads[0]))): - res.append([]) - for arr in _as_list(in_arrs): - arr.attach_grad() - for arr in init_states: - arr.attach_grad() - for arr in frees: - arr.attach_grad() - with mx.autograd.record(): - frees_imp = frees if free_vars_func is None else free_vars_func(frees) - step_imp = lambda in_arrs, state_arrs : step(in_arrs, state_arrs, frees_imp) - states = [mx.nd.expand_dims(s, 0) for s in init_states] - res, states = mx.nd.contrib.foreach(step_imp, in_arrs, init_states) - - res2 = _as_list(res) - for i in range(len(res2)): - res2[i] = res2[i] * 2 - outs = [] - outs[:] = res2[:] - if isinstance(states, list): - outs.extend(states) - states = [mx.nd.expand_dims(s, 0) for s in states] - res2.extend(states) - else: - outs.append(states) - states = mx.nd.expand_dims(states, 0) - res2.append(states) - if is_train: - res = mx.nd.concat(*res2, dim=0) - - tmp_grads = out_grads[0][:] - tmp_grads1 = [mx.nd.expand_dims(grad, 0) for grad in out_grads[1]] - tmp_grads.extend(tmp_grads1) - if is_train: - res.backward(mx.nd.concat(*tmp_grads, dim=0)) - for i in range(len(outs)): - assert e.outputs[i].shape == outs[i].shape - assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(), - rtol=0.001, atol=0.0001) - if (is_train): - all_ins = _as_list(in_arrs)[:] - all_ins.extend(init_states) - all_ins.extend(frees) - size = min(len(all_ins), len(e.grad_arrays)) - for i in range(size): - assert_almost_equal(all_ins[i].grad.asnumpy(), - e.grad_arrays[i].asnumpy(), - rtol=0.001, atol=0.0001) - - # Test cases: - # * graph inputs are stored in different orders. - # This is to test if foreach finds the data arrays and weight arrays - # in the right location. - # * the number of iterations: odd or even. - # * multiple inputs and multiple outputs. - # * inference. - def step1(in1, states, free): - out = in1 * 2 + states[0] + free[0] - return (out, [out]) - frees1 = [mx.nd.arange(2), mx.nd.arange(2) + 1] - arrs = mx.nd.arange(6).reshape(shape=(3, 2)) - states = [mx.nd.arange(2)] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape)]] - verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, - lambda frees : [frees[0] + frees[1]]) - verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, - lambda frees : [frees[0] + frees[1]]) - verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, True, - lambda frees : [frees[0] + frees[1]], 5) - verify_foreach(step1, v3, [v4], [v5 + v6], arrs, states, frees1, out_grads, False, - lambda frees : [frees[0] + frees[1]], 5) - - # Test the even number of iterations. - frees = [mx.nd.random.uniform(shape=(2))] - arrs = mx.nd.random.uniform(shape=(2, 2)) - out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape)]] - verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) - # Test the odd number of iterations - arrs = mx.nd.random.uniform(shape=(3, 2)) - out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape)]] - verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step1, v3, [v4], [v5], arrs, states, frees, out_grads, False) - - # Reorder the input and state in the subgraph inputs. - def step2(in1, states, free): - out = states[0] + in1 * 2 + free[0] - return (out, [out]) - # Test the even number of iterations. - arrs = mx.nd.random.uniform(shape=(2, 2)) - out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape)]] - verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) - # Test the odd number of iterations. - arrs = mx.nd.random.uniform(shape=(3, 2)) - out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape)]] - verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step2, v3, [v4], [v5], arrs, states, frees, out_grads, False) - - # Test multiple inputs and outputs. - def step3(in1, states, free): - out = in1[0] + in1[1] * 2 + states[0] + states[1] * 2 + free[0] - return ([out, out], [out * 2, out * 3]) - arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] - states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] - verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) - verify_foreach(step3, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) - - # Test multiple inputs and outputs. - # The order of subgraph inputs doesn't match the operator inputs - def step4(in1, states, free): - out = in1[1] * 2 + states[0] + free[0] + states[1] * 2 + in1[0] - return ([out, out * 2], [out * 2, out * 3]) - arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] - states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[1].shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] - verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads) - verify_foreach(step4, [v3, v4], [v5, v6], [v7], arrs, states, frees, out_grads, False) - - # Test multiple inputs and outputs. - # The data inputs and states have different shapes. - def step5(in1, states, free): - if isinstance(in1[0], mx.nd.NDArray): - out1 = mx.nd.broadcast_add(states[0] + free[1], in1[1] * 2) - out2 = mx.nd.broadcast_add(in1[0], free[0] + states[1] * 2) - else: - out1 = mx.sym.broadcast_add(states[0] + free[1], in1[1] * 2) - out2 = mx.sym.broadcast_add(in1[0], free[0] + states[1] * 2) - return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) - frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2, 2))] - arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), mx.nd.random.uniform(shape=(3, 2))] - states = [mx.nd.random.uniform(shape=(2, 2)), mx.nd.random.uniform(shape=(2))] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] - verify_foreach(step5, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) - - # Test multiple inputs and outputs. - # The data inputs and states have different shapes and data types. - def step6(in1, states, free): - if isinstance(in1[0], mx.nd.NDArray): - out1 = mx.nd.broadcast_add(states[0] + mx.nd.cast(free[1], 'float32'), - mx.nd.cast(in1[1], 'float32') * 2) - out2 = mx.nd.broadcast_add(in1[0], - free[0] + mx.nd.cast(states[1], 'float32') * 2) - else: - out1 = mx.sym.broadcast_add(states[0] + mx.sym.cast(free[1], 'float32'), - mx.sym.cast(in1[1], 'float32') * 2) - out2 = mx.sym.broadcast_add(in1[0], - free[0] + mx.sym.cast(states[1], 'float32') * 2) - return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) - frees = [mx.nd.random.uniform(shape=(2)), - mx.nd.cast(mx.nd.random.uniform(shape=(2, 2)), 'float64')] - arrs = [mx.nd.random.uniform(shape=(3, 2, 2)), - mx.nd.cast(mx.nd.random.uniform(shape=(3, 2)), dtype='float16')] - states = [mx.nd.random.uniform(shape=(2, 2)), - mx.nd.cast(mx.nd.random.uniform(shape=(2)), dtype='int32')] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] - verify_foreach(step6, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) - - # Test multiple inputs and outputs. - # some of the inputs are used twice. - def step7(in1, states, free): - out1 = states[0] + in1[0] + free[1] + in1[1] * 2 + free[0] - out2 = in1[0] + free[0] + states[1] * 2 + in1[1] - return ([out1, out2 * 2], [states[0] * 2, states[1] * 3]) - frees = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] - arrs = [mx.nd.random.uniform(shape=(3, 2)), mx.nd.random.uniform(shape=(3, 2))] - states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs[0].shape), mx.nd.random.uniform(-10, 10, arrs[0].shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape), mx.nd.random.uniform(-10, 10, states[1].shape)]] - verify_foreach(step7, [v3, v4], [v5, v6], [v7, v8], arrs, states, frees, out_grads, False) - - # Test the case that the output is the input. - arrs = mx.nd.random.uniform(shape=(3, 2)) - states = [mx.nd.arange(2)] - frees = [mx.nd.random.uniform(shape=(2))] - out_grads = [[mx.nd.random.uniform(-10, 10, arrs.shape)], - [mx.nd.random.uniform(-10, 10, states[0].shape)]] - def step8(in1, states, free): - return (in1, [states[0] * free[0]]) - verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step8, v3, [v4], [v5], arrs, states, frees, out_grads, False) - def step9(in1, states, free): - return (in1 * free[0], states) - verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step9, v3, [v4], [v5], arrs, states, frees, out_grads, False) - - # Test the case that not all inputs are used. - def step10(in1, states, free): - return (in1, states) - verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step10, v3, [v4], [v5], arrs, states, frees, out_grads, False) - def step11(in1, states, free): - return (in1, free) - try: - verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step11, v3, [v4], [v5], arrs, states, frees, out_grads, False) - except AssertionError: - print("the states have to be used") - def step12(in1, states, free): - return (in1, [states[0] + 1, states[0] + 2]) - states = [mx.nd.random.uniform(shape=(2)), mx.nd.random.uniform(shape=(2))] - frees = [] - try: - verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads) - verify_foreach(step12, v3, [v4, v5], [], arrs, states, frees, out_grads, False) - except AssertionError: - print("the states have to be used") - - # test without free variables. - def step13(in1, states, free): - return (in1, states) - states = [mx.nd.random.uniform(shape=(2))] - verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads) - verify_foreach(step13, v3, [v4], [], arrs, states, [], out_grads, False) - - # test when there isn't output data or output states. - def step14(in1, states, free): - return (in1 + free[0], []) - frees = [mx.nd.random.uniform(shape=(2))] - verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads) - verify_foreach(step14, v3, [], [v4], arrs, [], frees, out_grads, False) - def step15(in1, states, free): - return ([], [in1 * states[0] * free[0]]) - out_grads = [[], [mx.nd.random.uniform(-10, 10, states[0].shape)]] - verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads) - verify_foreach(step15, v3, [v4], [v5], arrs, states, frees, out_grads, False) - - # Test the case of iterating on a 1D data array. - def step16(in1, states, free): - return ([in1[0] * states[0]], [states[0] * 2]) - arrs = [mx.nd.arange(3)] - states = [mx.nd.random.uniform(shape=(1))] - out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], - [mx.nd.random.uniform(-10, 10, (1))]] - verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads) - verify_foreach(step16, [v3], [v4], [], arrs, states, [], out_grads, False) - def step17(in1, states, free): - return ([in1[1] * in1[0] * states[0]], [states[0] * 2]) - arrs = [mx.nd.random.uniform(shape=(3, 1)), mx.nd.arange(3)] - states = [mx.nd.random.uniform(shape=(1))] - out_grads = [[mx.nd.random.uniform(-10, 10, (3, 1))], - [mx.nd.random.uniform(-10, 10, (1))]] - verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads) - verify_foreach(step17, [v3, v4], [v5], [], arrs, states, [], out_grads, False) - - -@with_seed() -def test_foreach_nested(): - # Test nested foreach. - def step_in(in1, states): - out = in1 * 2 + states[0] - return (out, [out]) - - def step_sym(in1, states): - out1 = mx.sym.contrib.foreach(step_in, in1, states) - out = mx.sym.broadcast_add(out1[0], states[0]) - return (out, [mx.sym.squeeze(mx.sym.slice(out, begin=(0, 0), end=(1, 2)))]) - def step_nd(in1, states): - out1 = mx.nd.contrib.foreach(step_in, in1, states) - out = mx.nd.broadcast_add(out1[0], states[0]) - return (out, [mx.nd.squeeze(mx.nd.slice(out, begin=(0, 0), end=(1, 2)))]) - - data_sym = mx.sym.var("v1") - state_sym = mx.sym.var("v2") - out, states = mx.sym.contrib.foreach(step_sym, data_sym, [state_sym]) - assert isinstance(states, list) - assert len(states) == 1 - out = mx.sym.broadcast_add(out, states[0]) - - js_1 = out.tojson() - out = mx.sym.load_json(js_1) - js_2 = out.tojson() - assert js_1 == js_2 - - data = mx.nd.arange(8).reshape((2, 2, 2)) - state = mx.nd.arange(2) - data_grad = mx.nd.empty(data.shape) - state_grad = mx.nd.empty(state.shape) - e = out.bind(ctx=default_context(), args={'v1':data, 'v2':state}, - args_grad={'v1':data_grad, 'v2':state_grad}) - e.forward(is_train=True) - out_grads = [] - for out in e.outputs: - out_grads.append(mx.nd.random.uniform(shape=out.shape)) - e.backward(out_grads) - - data.attach_grad() - state.attach_grad() - with mx.autograd.record(): - out, states = mx.nd.contrib.foreach(step_nd, data, [state]) - assert isinstance(states, list) - assert len(states) == 1 - res = mx.nd.broadcast_add(out, states[0]) - assert_almost_equal(res.asnumpy(), e.outputs[0].asnumpy(), rtol=0.001, atol=0.0001) - - res.backward(out_grads[0]) - assert_almost_equal(data.grad.asnumpy(), data_grad.asnumpy()) - assert_almost_equal(state.grad.asnumpy(), state_grad.asnumpy()) - - -def check_foreach_rnn(cell_type, num_states): - data = mx.sym.var("data") - params = mx.rnn.RNNParams() - hidden_dim = 4 - input_dim = 5 - seq_len = 2 - batch_size = 2 - - # This tests foreach with accumulation sum. - def step(in1, states): - rnn = cell_type(hidden_dim, prefix='', params=params) - next_h, states = rnn(in1, states) - return (next_h, states) - - def sym_group(out): - if (isinstance(out[0], mx.sym.Symbol)): - ret = [out[0]] - else: - ret = out[0] - ret.extend(out[1]) - return mx.sym.Group(ret) - - rnn = cell_type(hidden_dim, prefix='', params=params) - if num_states == 2: - init_states = [mx.sym.var("h"), mx.sym.var("c")] - else: - init_states = [mx.sym.var("h")] - out = mx.sym.contrib.foreach(step, data, init_states) - out = sym_group(out) - arg_shapes, out_shapes, aux_shapes = out.infer_shape(data=(seq_len, batch_size, input_dim), - h=(batch_size, hidden_dim)) - rnn_inputs = out.list_inputs() - - # Inputs - args1 = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} - args2 = copy.deepcopy(args1) - # gradients for the backward of the foreach symbol - args_grad1 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} - # gradients for the backward of the unrolled symbol. - args_grad2 = {name:mx.nd.empty(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} - - # Symbol of running LSTM with foreach. - out = mx.sym.contrib.foreach(step, data, init_states) - out = sym_group(out) - js_1 = out.tojson() - out = mx.sym.load_json(js_1) - js_2 = out.tojson() - assert js_1 == js_2 - e1 = out.bind(ctx=default_context(), args=args1, args_grad=args_grad1) - - # Symbol of running unrolled LSTM. - lstm = cell_type(hidden_dim, prefix='') - unroll_outs = [] - states = init_states - for inputs in mx.sym.split(data, num_outputs=seq_len, axis=0, squeeze_axis=True): - h, states = lstm(inputs, states) - unroll_outs.append(mx.sym.expand_dims(h, axis=0)) - unroll_outs = _as_list(mx.sym.concat(*unroll_outs, dim=0)) - unroll_outs.extend(states) - out = mx.sym.Group(unroll_outs) - js_1 = out.tojson() - out = mx.sym.load_json(js_1) - js_2 = out.tojson() - assert js_1 == js_2 - e2 = out.bind(ctx=default_context(), args=args2, args_grad=args_grad2) - - for i in range(5): - out_grads = [] - for arr in e1.outputs: - out_grads.append(mx.nd.random.uniform(-10, 10, arr.shape)) - - args = {name:mx.nd.random.uniform(shape=arg_shapes[i]) for i, name in enumerate(rnn_inputs)} - - e1.forward(is_train=True, **args) - outputs1 = e1.outputs - e1.backward(out_grads) - - e2.forward(is_train=True, **args) - outputs2 = e2.outputs - e2.backward(out_grads) - - for i in range(len(outputs2)): - assert_almost_equal(outputs1[i].asnumpy(), outputs2[i].asnumpy(), - rtol=0.001, atol=0.0001) - input_names = out.list_inputs() - for i in range(len(e1.grad_arrays)): - name = input_names[i] - assert_almost_equal(args_grad1[name].asnumpy(), args_grad2[name].asnumpy(), - rtol=0.001, atol=0.0001) - - -@with_seed() -def test_foreach_rnn(): - cell_types = [(mx.rnn.LSTMCell, 2), (mx.rnn.RNNCell, 1), (mx.rnn.GRUCell, 1)] - for cell_type, num_states in cell_types: - check_foreach_rnn(cell_type, num_states) - - @with_seed() def test_squeeze_op(): def check_squeeze_op(shape, axis=None): From 0cc2c7b39406744dc7a3832f7ba9cd771a311e4e Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Fri, 20 Jul 2018 01:04:04 -0700 Subject: [PATCH 5/5] fix test. --- tests/python/unittest/test_contrib_control_flow.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/unittest/test_contrib_control_flow.py b/tests/python/unittest/test_contrib_control_flow.py index 0d85c41c0a0c..83eebecbf670 100644 --- a/tests/python/unittest/test_contrib_control_flow.py +++ b/tests/python/unittest/test_contrib_control_flow.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import copy import numpy as np import mxnet as mx from mxnet import gluon @@ -1029,6 +1030,7 @@ def check_contrib_rnn(cell_type, num_states): rtol=0.001, atol=0.0001) +@with_seed() def test_contrib_rnn(): cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), (gluon.rnn.GRUCell, 1)]