Skip to content

Commit

Permalink
[Relay][Frontend] Fix MxNet RNN without providing state initializatio…
Browse files Browse the repository at this point in the history
…n as input (apache#3326)
  • Loading branch information
icemelon authored and wweic committed Jun 27, 2019
1 parent 6a26307 commit de9a8fa
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 24 deletions.
48 changes: 41 additions & 7 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ def impl(inputs, attrs):
return impl


def _mx_zeros(inputs, attrs):
assert len(inputs) == 0
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
if 0 in shape:
return None
return _op.zeros(shape=shape, dtype=dtype)


def _mx_conv2d(inputs, attrs):
kernel_size = attrs.get_int_tuple("kernel")
if len(kernel_size) != 2:
Expand Down Expand Up @@ -754,9 +763,30 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):

seq_data = inputs[0]
concat_weight = inputs[1]
concat_states = inputs[2:]
seq_len = int(ir_pass.infer_type(seq_data).checked_type.shape[0])
init_states = inputs[2:]

data_shape = ir_pass.infer_type(seq_data).checked_type.shape
seq_len = int(data_shape[0])
assert len(concat_weight) == num_layers * 4
output_states = True
for idx, state in enumerate(init_states[:]):
if isinstance(state, dict):
node = state
attrs = StrAttrsDict(node.get("attrs", {}))
op_name = node["op"]
# by default, RNN layer uses zeros to initialize states
assert op_name == "_zeros"
shape = attrs.get_int_tuple("shape")
dtype = attrs.get_str("dtype", "float32")
init_layout = attrs.get_str("__layout__")
new_shape = list(shape)
for i, dim in enumerate(shape):
if dim == 0:
axis = layout.find(init_layout[i])
assert axis >= 0
new_shape[i] = int(data_shape[axis])
init_states[idx] = _op.zeros(new_shape, dtype)
output_states = False

weights = []
bias = []
Expand All @@ -768,7 +798,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
for j in range(2):
w.append(concat_weight[i*2 + j].args[0])
b.append(concat_weight[num_layers*2 + i*2 + j].args[0])
for state in concat_states:
for state in init_states:
s.append(_op.take(state, _expr.const(i, "int32"), axis=0))
weights.append(w)
bias.append(b)
Expand All @@ -789,8 +819,9 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
seq_output.append(out)

outputs = [_op.stack(seq_output, axis=0)]
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
if output_states:
for i in range(num_states):
outputs.append(_op.stack([s[i] for s in states], axis=0))
return outputs


Expand Down Expand Up @@ -881,7 +912,6 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"argmin" : _arg_reduce(_op.argmin),
# init ops
"_ones" : _init_op(_op.ones),
"_zeros" : _init_op(_op.zeros),
# softmax
"softmax" : _softmax_op(_op.nn.softmax),
"log_softmax" : _softmax_op(_op.nn.log_softmax),
Expand All @@ -895,6 +925,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"UpSampling" : _upsampling,
"add_n" : _elemwise_sum,
# MXNet specific implementations
"_zeros" : _mx_zeros,
"FullyConnected": _mx_fully_connected,
"Activation" : _mx_activations,
"Convolution" : _mx_conv2d,
Expand Down Expand Up @@ -1002,7 +1033,10 @@ def _from_mxnet_impl(symbol, shape_dict, dtype_info):
node_map[nid] = [_expr.var(node_name, shape=shape, dtype=dtype)]
elif op_name in _convert_map:
res = _convert_map[op_name](children, attrs)
if isinstance(res, (_expr.TupleWrapper, tuple, list)):
if res is None:
# defer conversion, used in RNN state initialization
res = [node]
elif isinstance(res, (_expr.TupleWrapper, tuple, list)):
pass
elif isinstance(res, _expr.Expr):
res = [res]
Expand Down
48 changes: 31 additions & 17 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ def test_forward_bilinear_resize():
verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))

def test_forward_rnn_layer():
def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
def verify(mode, input_size, seq_len, hidden_size, num_layers, init_states=True):
if mode == "rnn":
layer = gluon.rnn.RNN(hidden_size, num_layers)
elif mode == "gru":
Expand All @@ -545,23 +545,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
layer = gluon.rnn.LSTM(hidden_size, num_layers)
num_states = 2 if mode == "lstm" else 1
layer.initialize()
layer.hybridize()

dtype = "float32"
batch = 1
data_np = np.random.uniform(size=(seq_len, batch, input_size)).astype(dtype)
states_np = []
states_mx = []
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
data_mx = mx.nd.array(data_np)

if init_states:
shape_dict = {'data0': data_np.shape}
inputs = {'data0': data_np}
states_np = []
states_mx = []
for i in range(num_states):
s = np.random.uniform(size=(num_layers, batch, hidden_size)).astype(dtype)
states_np.append(s)
states_mx.append(mx.nd.array(s))
shape_dict['data%s' % (i+1)] = s.shape
inputs['data%s' % (i+1)] = s
mx_out, mx_states = layer(data_mx, states_mx)
mx_res = [mx_out] + mx_states
else:
shape_dict = {'data': data_np.shape}
inputs = {'data': data_np}
mx_res = layer(data_mx)

layer.hybridize()
mx_out, mx_states = layer(mx.nd.array(data_np), states_mx)
mx_res = [mx_out] + mx_states
mx_sym = layer._cached_graph[1]
mx_params = {}
for name, param in layer.collect_params().items():
Expand All @@ -574,14 +582,20 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
for kind in ["graph"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(**inputs, **params)
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
if init_states:
assert len(op_res) == len(mx_res)
for i, val in enumerate(op_res):
tvm.testing.assert_allclose(
val.asnumpy(), mx_res[i].asnumpy(), rtol=1e-3)
else:
tvm.testing.assert_allclose(
op_res.asnumpy(), mx_res.asnumpy(), rtol=1e-3)

for mode in ["rnn", "gru", "lstm"]:
verify(mode, 64, 10, 64, 1)
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)
verify(mode, 64, 10, 64, 2, init_states=False)

def test_forward_Crop():
def verify(xshape, yshape, offset=None):
Expand Down

0 comments on commit de9a8fa

Please sign in to comment.