Skip to content

Commit

Permalink
Relay reshape reshape_like compute and schedule (apache#2159)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and tqchen committed Nov 25, 2018
1 parent 0fff6b8 commit 29928a2
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 3 deletions.
8 changes: 8 additions & 0 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ def expand_dims_compiler(attrs, inputs, output_type, target):
# slice_like
_reg.register_schedule("slice_like", schedule_injective)
_reg.register_pattern("slice_like", OpPattern.INJECTIVE)

# reshape
_reg.register_schedule("reshape", schedule_injective)
_reg.register_pattern("reshape", OpPattern.INJECTIVE)

# reshape_like
_reg.register_schedule("reshape_like", schedule_injective)
_reg.register_pattern("reshape_like", OpPattern.INJECTIVE)
18 changes: 16 additions & 2 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,15 @@ Example::
.set_attrs_type_key("relay.attrs.ReshapeAttrs")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);
.add_type_rel("Reshape", ReshapeRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
const auto* param = attrs.as<ReshapeAttrs>();
CHECK(param != nullptr);
return Array<Tensor>{ topi::reshape(inputs[0], param->newshape) };
});


/*!
Expand Down Expand Up @@ -431,7 +439,13 @@ the input array into an output array with the same shape as the second input arr
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3)
.add_type_rel("ReshapeLike", ReshapeLikeRel);
.add_type_rel("ReshapeLike", ReshapeLikeRel)
.set_attr<FTVMCompute>("FTVMCompute", [](const Attrs& attrs,
const Array<Tensor>& inputs,
const Type& out_type,
const Target& target) {
return Array<Tensor>{ topi::reshape(inputs[0], inputs[1]->shape) };
});


// Take
Expand Down
47 changes: 46 additions & 1 deletion tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,28 @@ def test_reshape_infer_type():
assert yy.checked_type == relay.TensorType(
(n, t, 2000), "float32")

def test_reshape():
def verify_reshape(shape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
ref_res = np.reshape(x_data, oshape)

def test_reshape_like():
x = relay.var("x", relay.TensorType(shape, "float32"))
z = relay.reshape(x, newshape=ref_res.shape)
zz = relay.ir_pass.infer_type(z)
assert "newshape=" in z.astext()
assert zz.checked_type == relay.ty.TensorType(oshape, "float32")

func = relay.Function([x], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))

def test_reshape_like_infer_type():
# concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
y = relay.var("y", relay.TensorType((1,6), "float32"))
Expand All @@ -141,6 +161,29 @@ def test_reshape_like():
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")


def test_reshape_like():
def verify_reshape_like(shape, oshape):
x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32")
y_data = np.random.uniform(low=-1, high=1, size=oshape).astype("float32")
ref_res = np.reshape(x_data, y_data.shape)

x = relay.var("x", relay.TensorType(shape, "float32"))
y = relay.var("x", relay.TensorType(oshape, "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.ty.TensorType(ref_res.shape, "float32")

func = relay.Function([x, y], z)

for target, ctx in ctx_list():
for kind in ["graph", "debug"]:
intrp = relay.create_executor(kind, ctx=ctx, target=target)
op_res = intrp.evaluate(func)(x_data, y_data)
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5)

verify_reshape_like((2, 3, 4), (1, 8, 3))
verify_reshape_like((4, 7), (2, 7, 2))

def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None):
x = relay.var("x", relay.TensorType(dshape, "float32"))
Expand Down Expand Up @@ -318,6 +361,8 @@ def test_infer_type_prelu():
test_clip()
test_transpose_infer_type()
test_reshape_infer_type()
test_reshape()
test_reshape_like_infer_type()
test_reshape_like()
test_take_infer_type()
test_full()
Expand Down

0 comments on commit 29928a2

Please sign in to comment.