Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend] Fix MxNet RNN without providing state initialization as input #3326

Merged
merged 3 commits into from
Jun 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 41 additions & 7 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def impl(inputs, attrs):
return impl


def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
if 0 in shape:
return None
return _op.zeros(shape=shape, dtype=dtype)


def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
Expand Down Expand Up @@ -753,9 +762,30 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):

seq_data = inputs[0]
concat_weight = inputs[1]
concat_states = inputs[2:]
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
init_states = inputs[2:]

data_shape = ir_pass.infer_type(seq_data).checked_type.shape
seq_len = int(data_shape[0])
assert len(concat_weight) == num_layers * 4
output_states = True
for idx, state in enumerate(init_states[:]):
if isinstance(state, dict):
node = state
attrs = StrAttrsDict(node.get("attrs", {}))
op_name = node["op"]
# by default, RNN layer uses zeros to initialize states
assert op_name == "_zeros"
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
init_layout = attrs.get_str("__layout__")
new_shape = list(shape)
for i, dim in enumerate(shape):
if dim == 0:
axis = layout.find(init_layout[i])
assert axis >= 0
new_shape[i] = int(data_shape[axis])
init_states[idx] = _op.zeros(new_shape, dtype)
output_states = False

weights = []
bias = []
Expand All @@ -767,7 +797,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in concat_states:
for state in init_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w)
bias.append(b)
Expand All @@ -788,8 +818,9 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
seq_output.append(out)

outputs = [_op.stack(seq_output, axis=0)]
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
if output_states:
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs


Expand Down Expand Up @@ -880,7 +911,6 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"argmin" : _arg_reduce(_op.argmin),
# init ops
"_ones" : _init_op(_op.ones),
"_zeros" : _init_op(_op.zeros),
# softmax
"softmax" : _softmax_op(_op.nn.softmax),
"log_softmax" : _softmax_op(_op.nn.log_softmax),
Expand All @@ -894,6 +924,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"UpSampling" : _upsampling,
"add_n" : _elemwise_sum,
# MXNet specific implementations
"_zeros" : _mx_zeros,
"FullyConnected": _mx_fully_connected,
"Activation" : _mx_activations,
"Convolution" : _mx_conv2d,
Expand Down Expand Up @@ -1001,7 +1032,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map:
res = _convert_map[op_name](children, attrs)
if isinstance(res, (_expr.TupleWrapper, tuple, list)):
if res is None:
# defer conversion, used in RNN state initialization
res = [node]
elif isinstance(res, (_expr.TupleWrapper, tuple, list)):
pass
elif isinstance(res, _expr.Expr):
res = [res]
Expand Down
48 changes: 31 additions & 17 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))

def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True):
if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers)
elif mode == "gru":
Expand All @@ -545,23 +545,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
layer = gluon.rnn.LSTM(hidden_size, num_layers)
num_states = 2 if mode == "lstm" else 1
layer.initialize()
layer.hybridize()

dtype = "float32"
batch = 1
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
states_np = []
states_mx = []
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
data_mx = mx.nd.array(data_np)

if init_states:
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
states_np = []
states_mx = []
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
mx_out, mx_states = layer(data_mx, states_mx)
mx_res = [mx_out] + mx_states
else:
shape_dict = {'data': data_np.shape}
inputs = {'data': data_np}
mx_res = layer(data_mx)

layer.hybridize()
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
mx_res = [mx_out] + mx_states
mx_sym = layer._cached_graph[1]
mx_params = {}
for name, param in layer.collect_params().items():
Expand All @@ -574,14 +582,20 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
for kind in ["graph"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(**inputs, **params)
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
if init_states:
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(
val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
else:
tvm.testing.assert_allclose(
op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)

for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)
verify(mode, 64, 10, 64, 2, init_states=False)

def test_forward_Crop():
def verify(xshape, yshape, offset=None):
Expand Down