Skip to content

Commit

Permalink
[Relay][Frontend] Add Crop op converter (apache#3241)
Browse files Browse the repository at this point in the history
* Add Crop op converter

* lint

* x
  • Loading branch information
icemelon authored and Wei Chen committed Jun 26, 2019
1 parent 4b8bae3 commit a6fc910
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nnvm/python/nnvm/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _crop_like(inputs, attrs):
raise tvm.error.OpAttributeUnimplemented(
'Center crop is not supported in operator crop_like.')
if len(inputs) < 2:
raise RuntimeError("Only support crop_like pattern.")
raise tvm.error.OpAttributeUnimplemented("Only support crop_like pattern.")
new_attrs["axis"] = [2, 3]
return get_nnvm_op('slice_like')(inputs[0], inputs[1], **new_attrs)

Expand Down
32 changes: 30 additions & 2 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _mx_conv2d_transpose(inputs, attrs):
new_attrs["groups"] = attrs.get_int("num_group", 1)
new_attrs["data_layout"] = data_layout
new_attrs["kernel_layout"] = kernel_layout
use_bias = not attrs.get_bool("no_bias", False)
use_bias = not attrs.get_bool("no_bias", True)
res = _op.nn.conv2d_transpose(inputs[0], inputs[1], **new_attrs)

if use_bias:
Expand Down Expand Up @@ -277,6 +277,28 @@ def _mx_slice_axis(inputs, attrs):
return _op.strided_slice(inputs[0], begin, end)


def _mx_crop_like(inputs, attrs):
if len(inputs) < 2:
raise tvm.error.OpAttributeUnimplemented(
"Only support crop_like pattern for operator Crop.")
if attrs.get_bool("center_crop", False):
raise tvm.error.OpAttributeUnimplemented(
"Center crop is not supported in operator Crop.")
if attrs.get_int_tuple("h_w", (0, 0)) != (0, 0):
raise tvm.error.OpAttributeUnimplemented(
"Doesn't support h_w in operator Crop.")
offset = attrs.get_int_tuple("offset", (0, 0))
new_attrs = {}
if offset == (0, 0):
new_attrs["axes"] = (2, 3)
return _op.slice_like(*inputs, **new_attrs)
like_shape = ir_pass.infer_type(inputs[1]).checked_type.shape
new_attrs['begin'] = [0, 0, offset[0], offset[1]]
new_attrs['end'] = [like_shape[0], like_shape[1], offset[0]+like_shape[2],
offset[1]+like_shape[3]]
return _op.strided_slice(inputs[0], **new_attrs)


def _mx_split(inputs, attrs):
axis = attrs.get_int("axis", 1)
new_attrs = {}
Expand All @@ -300,6 +322,10 @@ def _mx_softmax_output(inputs, attrs):
return _op.nn.softmax(inputs[0])


def _mx_linear_regression_output(inputs, _):
return inputs[0]


def _mx_concat(inputs, attrs):
axis = attrs.get_int("dim", 1)
return _op.concatenate(tuple(inputs), axis=axis)
Expand Down Expand Up @@ -890,6 +916,7 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
"argsort" : _mx_argsort,
"SoftmaxOutput" : _mx_softmax_output,
"SoftmaxActivation" : _mx_softmax_activation,
"LinearRegressionOutput" : _mx_linear_regression_output,
"smooth_l1" : _mx_smooth_l1,
# vision
"_contrib_BilinearResize2D" : _mx_resize,
Expand All @@ -905,11 +932,12 @@ def _lstm_cell(data, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias):
# NLP
"RNN" : _mx_rnn_layer,
"_rnn_param_concat" : _mx_rnn_param_concat,
# Depricated:
"Crop" : _mx_crop_like,
# List of missing operators that are present in NNVMv1
# TODO(tvm-tvm): support all operators.
#
# "broadcast_to",
# "Crop" : _crop_like,
}

# set identity list
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,31 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
verify(mode, 64, 10, 64, 2)
verify(mode, 64, 10, 32, 2)

def test_forward_Crop():
def verify(xshape, yshape, offset=None):
x_data = np.random.uniform(size=xshape).astype("float32")
y_data = np.random.uniform(size=yshape).astype("float32")
if offset is None:
mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"))
ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data))
else:
mx_sym = mx.sym.Crop(mx.sym.var("x"), mx.sym.var("y"), offset=offset)
ref_res = mx.nd.Crop(mx.nd.array(x_data), mx.nd.array(y_data), offset=offset)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"x": xshape, "y": yshape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
if offset is None or offset == (0, 0):
op_res = intrp.evaluate(new_sym)(x_data, y_data)
else:
op_res = intrp.evaluate(new_sym)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((1, 3, 40, 40), (1, 3, 20, 20))
verify((1, 3, 40, 40), (1, 3, 20, 20), (0, 0))
verify((1, 3, 40, 40), (1, 3, 20, 20), (10, 10))
verify((5, 32, 40, 40), (5, 32, 25, 25))
verify((5, 32, 40, 40), (5, 32, 25, 25), (5, 5))


if __name__ == '__main__':
test_forward_mlp()
Expand Down Expand Up @@ -624,3 +649,4 @@ def verify(mode, input_size, seq_len, hidden_size, num_layers, batch=1):
test_forward_gather_nd()
test_forward_bilinear_resize()
test_forward_rnn_layer()
test_forward_Crop()

0 comments on commit a6fc910

Please sign in to comment.