Skip to content

Commit

Permalink
[Relay][Frontend] Add slice axis op in mxnet converter (#2706)
Browse files Browse the repository at this point in the history
* Add slice axis op in mxnet converter

* Fix lint
  • Loading branch information
icemelon authored Mar 4, 2019
1 parent c8373ec commit 2fa3a67
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
29 changes: 29 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,34 @@ def _mx_slice(inputs, attrs):
return _op.strided_slice(inputs[0], **new_attrs)


def _mx_slice_axis(inputs, attrs):
assert len(inputs) == 1
shape = ir_pass.infer_type(inputs[0]).checked_type.shape
axis = attrs.get_int("axis")
ax_beg = attrs.get_int("begin")
ax_end = attrs.get_str("end")
if ax_end == "None":
ax_end = int(shape[axis])
else:
ax_end = int(ax_end)
if ax_beg < 0:
ax_beg += int(shape[axis])
if ax_end < 0:
ax_end += int(shape[axis])
assert ax_beg >= 0 and ax_beg < int(shape[axis])
assert ax_end > ax_beg and ax_end <= int(shape[axis])
begin = []
end = []
for i, dim in enumerate(shape):
if i != axis:
begin.append(0)
end.append(dim)
else:
begin.append(ax_beg)
end.append(ax_end)
return _op.strided_slice(inputs[0], begin, end)


def _mx_split(inputs, attrs):
axis = attrs.get_int("axis", 1)
new_attrs = {}
Expand Down Expand Up @@ -423,6 +451,7 @@ def _mx_roi_align(inputs, attrs):
"BatchNorm_v1" : _mx_batch_norm,
"LRN" : _mx_lrn,
"slice" : _mx_slice,
"slice_axis" : _mx_slice_axis,
"SliceChannel" : _mx_split,
"split" : _mx_split,
"expand_dims" : _mx_expand_dims,
Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/mxnet/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,23 @@ def test_forward_scalar_ops():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


def test_forward_slice_axis():
def verify(shape, axis, begin, end):
data_np = np.random.uniform(size=shape).astype("float32")
ref_res = mx.nd.slice_axis(mx.nd.array(data_np), axis, begin, end)
mx_sym = mx.sym.slice_axis(mx.sym.var("data"), axis, begin, end)
new_sym, _ = relay.frontend.from_mxnet(mx_sym, {"data": shape})
for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(new_sym)(data_np)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())
verify((3, 4), 0, 1, 2)
verify((3, 4), 0, 1, None)
verify((3, 4), 1, 0, 2)
verify((3, 4), 1, -3, -1)


if __name__ == '__main__':
test_forward_mlp()
test_forward_vgg()
Expand All @@ -363,3 +380,4 @@ def test_forward_scalar_ops():
test_forward_broadcast_ops()
test_forward_elemwise_ops()
test_forward_scalar_ops()
test_forward_slice_axis()

0 comments on commit 2fa3a67

Please sign in to comment.