Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY]reshape_like op #1951

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ This level enables additional math and transform operators.
tvm.relay.ones
tvm.relay.ones_like
tvm.relay.reshape
tvm.relay.reshape_like
tvm.relay.copy
tvm.relay.transpose
tvm.relay.floor
Expand Down Expand Up @@ -167,6 +168,7 @@ Level 3 Definitions
.. autofunction:: tvm.relay.abs
.. autofunction:: tvm.relay.negative
.. autofunction:: tvm.relay.reshape
.. autofunction:: tvm.relay.reshape_like
.. autofunction:: tvm.relay.copy
.. autofunction:: tvm.relay.transpose
.. autofunction:: tvm.relay.take
Expand Down
3 changes: 3 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ class TensorTypeNode : public BaseTensorTypeNode {
v->Visit("span", &span);
}

/*! \brief Return product of elements in the shape */
TVM_DLL IndexExpr Size() const;

TVM_DLL static TensorType make(Array<IndexExpr> shape, DataType dtype);

/*! \brief Construct an scalar containing elements of dtype. */
Expand Down
23 changes: 23 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,29 @@ def reshape(data, newshape):
return _make.reshape(data, list(newshape))


def reshape_like(data, shape_like):
"""Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.

Parameters
----------
data : relay.Expr
The input data to the operator.

shape_like : tuple of int
The new shape. Should be compatible with the original shape.

Returns
-------
ret : relay.Expr
The computed result.
"""
return _make.reshape_like(data, shape_like)


def take(data, indices, axis=None):
"""Take elements from an array along an axis.

Expand Down
8 changes: 8 additions & 0 deletions src/relay/ir/type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) {
return TensorTypeNode::make({}, dtype);
}

IndexExpr TensorTypeNode::Size() const {
IndexExpr size = make_const(Int(64), 1);
for (IndexExpr i : shape) {
size *= i;
}
return size;
}

TVM_REGISTER_NODE_TYPE(TensorTypeNode);

TVM_REGISTER_API("relay._make.TensorType")
Expand Down
55 changes: 55 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,61 @@ Example::
.set_support_level(3)
.add_type_rel("Reshape", ReshapeRel);


/*!
* \brief ReshapeLikeRel Output type and shape relation evaluation function.
* \param num_inputs Number of input types in the args.
* \param attrs The additional attributes of the operator.
* \param reporter The reporter to report solution to.
* \return false if This relation cannot be resolved. true if this relation has been resolved.
*/
bool ReshapeLikeRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
const auto* data = types[0].as<TensorTypeNode>();
if (data == nullptr) {
return false;
}
const auto* reshape_like = types[1].as<TensorTypeNode>();
if (reshape_like == nullptr) {
return false;
}
CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size()))
<< "Reshape inputs size should be compatible";
reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype));
return true;
}


Expr MakeReshapeLike(Expr data,
Expr shape_like) {
static const Op& op = Op::Get("reshape_like");
return CallNode::make(op, {data, shape_like}, Attrs(), {});
}


TVM_REGISTER_API("relay.op._make.reshape_like")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 2>(MakeReshapeLike, args, rv);
});


RELAY_REGISTER_OP("reshape_like")
.describe(R"code(Reshapes the input array by the size of another array.
For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes
the input array into an output array with the same shape as the second input array.
.. note::
Sizes for both array should be compatible.
)code" TVM_ADD_FILELINE)
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("shape_like", "Tensor", "Shape tensor.")
.set_support_level(3)
.add_type_rel("ReshapeLike", ReshapeLikeRel);


// Take
TVM_REGISTER_NODE_TYPE(TakeAttrs);

Expand Down
17 changes: 17 additions & 0 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,22 @@ def test_reshape_infer_type():
(n, t, 2000), "float32")


def test_reshape_like():
# concrete shape
x = relay.var("x", relay.TensorType((1, 2, 3), "float32"))
y = relay.var("y", relay.TensorType((1,6), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 6), "float32")

# symbolic shape
n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w")
x = relay.var("x", relay.TensorType((n, c, h, w), "float32"))
y = relay.var("y", relay.TensorType((1, 8, 8), "float32"))
z = relay.reshape_like(x, y)
zz = relay.ir_pass.infer_type(z)
assert zz.checked_type == relay.TensorType((1, 8, 8), "float32")


def test_take_infer_type():
def verify_take(dshape, indices_shape, oshape, axis=None):
Expand Down Expand Up @@ -145,6 +161,7 @@ def test_infer_type_leaky_relu():
test_clip_type()
test_transpose_infer_type()
test_reshape_infer_type()
test_reshape_like()
test_take_infer_type()
test_full()
test_full_like()
Expand Down