diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index a36f8e6c71cf..9a4cfb280e90 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -76,6 +76,7 @@ This level enables additional math and transform operators. tvm.relay.ones tvm.relay.ones_like tvm.relay.reshape + tvm.relay.reshape_like tvm.relay.copy tvm.relay.transpose tvm.relay.floor @@ -167,6 +168,7 @@ Level 3 Definitions .. autofunction:: tvm.relay.abs .. autofunction:: tvm.relay.negative .. autofunction:: tvm.relay.reshape +.. autofunction:: tvm.relay.reshape_like .. autofunction:: tvm.relay.copy .. autofunction:: tvm.relay.transpose .. autofunction:: tvm.relay.take diff --git a/include/tvm/relay/type.h b/include/tvm/relay/type.h index 2bb9b3070270..6612cfaea88a 100644 --- a/include/tvm/relay/type.h +++ b/include/tvm/relay/type.h @@ -82,6 +82,9 @@ class TensorTypeNode : public BaseTensorTypeNode { v->Visit("span", &span); } + /*! \brief Return product of elements in the shape */ + TVM_DLL IndexExpr Size() const; + TVM_DLL static TensorType make(Array shape, DataType dtype); /*! \brief Construct an scalar containing elements of dtype. */ diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index c2036f509133..cfacce899493 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -141,6 +141,29 @@ def reshape(data, newshape): return _make.reshape(data, list(newshape)) +def reshape_like(data, shape_like): + """Reshapes the input array by the size of another array. + For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes + the input array into an output array with the same shape as the second input array. + .. note:: + Sizes for both array should be compatible. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + shape_like : tuple of int + The new shape. Should be compatible with the original shape. + + Returns + ------- + ret : relay.Expr + The computed result. + """ + return _make.reshape_like(data, shape_like) + + def take(data, indices, axis=None): """Take elements from an array along an axis. diff --git a/src/relay/ir/type.cc b/src/relay/ir/type.cc index 39347adced92..f9e3e5d274eb 100644 --- a/src/relay/ir/type.cc +++ b/src/relay/ir/type.cc @@ -22,6 +22,14 @@ TensorType TensorTypeNode::Scalar(DataType dtype) { return TensorTypeNode::make({}, dtype); } +IndexExpr TensorTypeNode::Size() const { + IndexExpr size = make_const(Int(64), 1); + for (IndexExpr i : shape) { + size *= i; + } + return size; +} + TVM_REGISTER_NODE_TYPE(TensorTypeNode); TVM_REGISTER_API("relay._make.TensorType") diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 61ee2778d0a2..6650288561a1 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -320,6 +320,61 @@ Example:: .set_support_level(3) .add_type_rel("Reshape", ReshapeRel); + +/*! +* \brief ReshapeLikeRel Output type and shape relation evaluation function. +* \param num_inputs Number of input types in the args. +* \param attrs The additional attributes of the operator. +* \param reporter The reporter to report solution to. +* \return false if This relation cannot be resolved. true if this relation has been resolved. +*/ +bool ReshapeLikeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + if (data == nullptr) { + return false; + } + const auto* reshape_like = types[1].as(); + if (reshape_like == nullptr) { + return false; + } + CHECK(reporter->AssertEQ(data->Size(), reshape_like->Size())) + << "Reshape inputs size should be compatible"; + reporter->Assign(types[2], TensorTypeNode::make(reshape_like->shape, data->dtype)); + return true; +} + + +Expr MakeReshapeLike(Expr data, + Expr shape_like) { + static const Op& op = Op::Get("reshape_like"); + return CallNode::make(op, {data, shape_like}, Attrs(), {}); +} + + +TVM_REGISTER_API("relay.op._make.reshape_like") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReshapeLike, args, rv); +}); + + +RELAY_REGISTER_OP("reshape_like") +.describe(R"code(Reshapes the input array by the size of another array. +For an input array with shape ``(d1, d2, ..., dk)``, `reshape_like` operation reshapes +the input array into an output array with the same shape as the second input array. +.. note:: + Sizes for both array should be compatible. +)code" TVM_ADD_FILELINE) +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("shape_like", "Tensor", "Shape tensor.") +.set_support_level(3) +.add_type_rel("ReshapeLike", ReshapeLikeRel); + + // Take TVM_REGISTER_NODE_TYPE(TakeAttrs); diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index d1bff2940457..e2feab9128b7 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -80,6 +80,22 @@ def test_reshape_infer_type(): (n, t, 2000), "float32") +def test_reshape_like(): + # concrete shape + x = relay.var("x", relay.TensorType((1, 2, 3), "float32")) + y = relay.var("y", relay.TensorType((1,6), "float32")) + z = relay.reshape_like(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((1, 6), "float32") + + # symbolic shape + n, c, h, w = tvm.var("n"), 2, 3, tvm.var("w") + x = relay.var("x", relay.TensorType((n, c, h, w), "float32")) + y = relay.var("y", relay.TensorType((1, 8, 8), "float32")) + z = relay.reshape_like(x, y) + zz = relay.ir_pass.infer_type(z) + assert zz.checked_type == relay.TensorType((1, 8, 8), "float32") + def test_take_infer_type(): def verify_take(dshape, indices_shape, oshape, axis=None): @@ -145,6 +161,7 @@ def test_infer_type_leaky_relu(): test_clip_type() test_transpose_infer_type() test_reshape_infer_type() + test_reshape_like() test_take_infer_type() test_full() test_full_like()