From c2cfa6c2edad9fa68c755953bfa295d98b783970 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 9 Jun 2019 16:43:33 -0700 Subject: [PATCH 1/3] Fix MxNet RNN without giving states as input --- python/tvm/relay/frontend/mxnet.py | 48 ++++++++++++++++++--- tests/python/frontend/mxnet/test_forward.py | 48 +++++++++++++-------- 2 files changed, 72 insertions(+), 24 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 0975a33450c8..f48b0f74f266 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -93,6 +93,15 @@ def impl(inputs, attrs): return impl +def _mx_zeros(inputs, attrs): + assert len(inputs) == 0 + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + if 0 in shape: + return None + return _op.zeros(shape=shape, dtype=dtype) + + def _mx_conv2d(inputs, attrs): kernel_size = attrs.get_int_tuple("kernel") if len(kernel_size) != 2: @@ -753,9 +762,30 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_data = inputs[0] concat_weight = inputs[1] - concat_states = inputs[2:] - seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0]) + init_states = inputs[2:] + + data_shape = ir_pass.infer_type(seq_data).checked_type.shape + seq_len = int(data_shape[0]) assert len(concat_weight) == num_layers * 4 + output_states = True + for idx, state in enumerate(init_states[:]): + if isinstance(state, tuple) and isinstance(state[1], dict): + nid, node = state + attrs = StrAttrsDict(node.get("attrs", {})) + op_name = node["op"] + assert op_name == "_zeros" + shape = attrs.get_int_tuple("shape") + dtype = attrs.get_str("dtype", "float32") + init_layout = attrs.get_str("__layout__") + print(shape) + new_shape = list(shape) + for i, dim in enumerate(shape): + if dim == 0: + axis = layout.find(init_layout[i]) + assert axis >= 0 + new_shape[i] = int(data_shape[axis]) + init_states[idx] = _op.zeros(new_shape, dtype) + output_states = False weights = [] bias = [] @@ -767,7 +797,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): for j in range(2): w.append(concat_weight[i*2 + j].args[0]) b.append(concat_weight[num_layers*2 + i*2 + j].args[0]) - for state in concat_states: + for state in init_states: s.append(_op.take(state, _expr.const(i, "int32"), axis=0)) weights.append(w) bias.append(b) @@ -788,8 +818,9 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): seq_output.append(out) outputs = [_op.stack(seq_output, axis=0)] - for i in range(num_states): - outputs.append(_op.stack([s[i] for s in states], axis=0)) + if output_states: + for i in range(num_states): + outputs.append(_op.stack([s[i] for s in states], axis=0)) return outputs @@ -880,7 +911,6 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "argmin" : _arg_reduce(_op.argmin), # init ops "_ones" : _init_op(_op.ones), - "_zeros" : _init_op(_op.zeros), # softmax "softmax" : _softmax_op(_op.nn.softmax), "log_softmax" : _softmax_op(_op.nn.log_softmax), @@ -894,6 +924,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): "UpSampling" : _upsampling, "add_n" : _elemwise_sum, # MXNet specific implementations + "_zeros" : _mx_zeros, "FullyConnected": _mx_fully_connected, "Activation" : _mx_activations, "Convolution" : _mx_conv2d, @@ -1001,7 +1032,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)] elif op_name in _convert_map: res = _convert_map[op_name](children, attrs) - if isinstance(res, (_expr.TupleWrapper, tuple, list)): + if res is None: + # defer conversion, used in RNN state initialization + res = [(nid, node)] + elif isinstance(res, (_expr.TupleWrapper, tuple, list)): pass elif isinstance(res, _expr.Expr): res = [res] diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 7569257830af..8d7c15bb0be5 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -536,7 +536,7 @@ def test_forward_bilinear_resize(): verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10)) def test_forward_rnn_layer(): - def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): + def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True): if mode == "rnn": layer = gluon.rnn.RNN(hidden_size, num_layers) elif mode == "gru": @@ -545,23 +545,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): layer = gluon.rnn.LSTM(hidden_size, num_layers) num_states = 2 if mode == "lstm" else 1 layer.initialize() + layer.hybridize() dtype = "float32" + batch = 1 data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype) - states_np = [] - states_mx = [] - shape_dict = {'data0': data_np.shape} - inputs = {'data0': data_np} - for i in range(num_states): - s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) - states_np.append(s) - states_mx.append(mx.nd.array(s)) - shape_dict['data%s' % (i+1)] = s.shape - inputs['data%s' % (i+1)] = s + data_mx = mx.nd.array(data_np) + + if init_states: + shape_dict = {'data0': data_np.shape} + inputs = {'data0': data_np} + states_np = [] + states_mx = [] + for i in range(num_states): + s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype) + states_np.append(s) + states_mx.append(mx.nd.array(s)) + shape_dict['data%s' % (i+1)] = s.shape + inputs['data%s' % (i+1)] = s + mx_out, mx_states = layer(data_mx, states_mx) + mx_res = [mx_out] + mx_states + else: + shape_dict = {'data': data_np.shape} + inputs = {'data': data_np} + mx_res = layer(data_mx) - layer.hybridize() - mx_out, mx_states = layer(mx.nd.array(data_np), states_mx) - mx_res = [mx_out] + mx_states mx_sym = layer._cached_graph[1] mx_params = {} for name, param in layer.collect_params().items(): @@ -574,14 +582,20 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1): for kind in ["graph"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(new_sym)(**inputs, **params) - assert len(op_res) == len(mx_res) - for i, val in enumerate(op_res): - tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3) + if init_states: + assert len(op_res) == len(mx_res) + for i, val in enumerate(op_res): + tvm.testing.assert_allclose( + val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3) + else: + tvm.testing.assert_allclose( + op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3) for mode in ["rnn", "gru", "lstm"]: verify(mode, 64, 10, 64, 1) verify(mode, 64, 10, 64, 2) verify(mode, 64, 10, 32, 2) + verify(mode, 64, 10, 64, 2, init_states=False) def test_forward_Crop(): def verify(xshape, yshape, offset=None): From aa1d1e0109fe80b544c6efbced99f14ea759a318 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Sun, 9 Jun 2019 19:47:54 -0700 Subject: [PATCH 2/3] lint --- python/tvm/relay/frontend/mxnet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index f48b0f74f266..dcde515ef3d7 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -769,8 +769,8 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): assert len(concat_weight) == num_layers * 4 output_states = True for idx, state in enumerate(init_states[:]): - if isinstance(state, tuple) and isinstance(state[1], dict): - nid, node = state + if isinstance(state, dict): + node = state attrs = StrAttrsDict(node.get("attrs", {})) op_name = node["op"] assert op_name == "_zeros" @@ -1034,7 +1034,7 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info): res = _convert_map[op_name](children, attrs) if res is None: # defer conversion, used in RNN state initialization - res = [(nid, node)] + res = [node] elif isinstance(res, (_expr.TupleWrapper, tuple, list)): pass elif isinstance(res, _expr.Expr): From da2579ba244a7252fae7d13c8125e283f716e9a5 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 10 Jun 2019 13:49:38 -0700 Subject: [PATCH 3/3] remove print --- python/tvm/relay/frontend/mxnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index dcde515ef3d7..239f19f486dc 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -773,11 +773,11 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): node = state attrs = StrAttrsDict(node.get("attrs", {})) op_name = node["op"] + # by default, RNN layer uses zeros to initialize states assert op_name == "_zeros" shape = attrs.get_int_tuple("shape") dtype = attrs.get_str("dtype", "float32") init_layout = attrs.get_str("__layout__") - print(shape) new_shape = list(shape) for i, dim in enumerate(shape): if dim == 0: