diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index 82272b0249bd..e1f38c61eb1f 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -149,6 +149,7 @@ This level support backpropagation of broadcast operators. It is temporary. tvm.relay.layout_transform tvm.relay.device_copy tvm.relay.annotation.on_device + tvm.relay.reverse_reshape Level 1 Definitions @@ -257,4 +258,5 @@ Level 10 Definitions .. autofunction:: tvm.relay.slice_like .. autofunction:: tvm.relay.layout_transform .. autofunction:: tvm.relay.device_copy -.. autofunction:: tvm.relay.annotation.on_device \ No newline at end of file +.. autofunction:: tvm.relay.annotation.on_device +.. autofunction:: tvm.relay.reverse_reshape diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 36ccd3fb6c89..fa27a4d437d2 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -63,9 +63,13 @@ struct TransposeAttrs : public tvm::AttrsNode { /*! \brief Attributes used in reshape operators */ struct ReshapeAttrs : public tvm::AttrsNode { Array newshape; + bool reverse; TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") { TVM_ATTR_FIELD(newshape) .describe("The new shape. Should be compatible with the original shape."); + TVM_ATTR_FIELD(reverse) + .describe("Infer the special values from right to left if true") + .set_default(false); } }; // struct ReshapeAttrs diff --git a/python/tvm/relay/frontend/nnvm_common.py b/python/tvm/relay/frontend/nnvm_common.py index d75b0cac65bc..3838c3d4aa3b 100644 --- a/python/tvm/relay/frontend/nnvm_common.py +++ b/python/tvm/relay/frontend/nnvm_common.py @@ -22,9 +22,10 @@ def impl(inputs, _, _dtype='float32'): def _reshape(inputs, attrs): - if attrs.get_bool("reverse", False): - raise RuntimeError("reshape do not support option reverse") shape = attrs.get_int_tuple("shape") + reverse = attrs.get_bool("reverse", False) + if reverse: + return _op.reverse_reshape(inputs[0], newshape=shape) return _op.reshape(inputs[0], newshape=shape) diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 085a8ceed5d1..abf0b5317b48 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -26,6 +26,7 @@ _reg.register_schedule("take", schedule_injective) _reg.register_schedule("transpose", schedule_injective) _reg.register_schedule("where", schedule_broadcast) +_reg.register_schedule("_contrib_reverse_reshape", schedule_injective) # layout_transform _reg.register_schedule("layout_transform", schedule_injective) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 454d335692c6..78efd3cfd4d9 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -141,7 +141,7 @@ def reshape(data, newshape): Example:: - - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape =(1,2,3,4) + - data.shape = (2,3,4), newshape = (-4,1,2,-2), result.shape = (1,2,3,4) - data.shape = (2,3,4), newshape = (2,-4,-1,3,-2), result.shape = (2,1,3,4) Parameters @@ -449,3 +449,34 @@ def layout_transform(data, src_layout, dst_layout): The transformed tensor. """ return _make.layout_transform(data, src_layout, dst_layout) + + +def reverse_reshape(data, newshape): + """Reshapes the input array where the special values are inferred from + right to left. + + Example:: + + The special values have the same semantics as :py:class:`tvm.relay.reshape`. + The difference is that special values are inferred from right to left. It + can be explained in the example below:: + + - data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5) + - data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + newshape : Union[int, Tuple[int], List[int]] + The new shape. Should be compatible with the original shape. + + Returns + ------- + result : relay.Expr + The reshaped result. + """ + if isinstance(newshape, int): + newshape = [newshape] + return _make._contrib_reverse_reshape(data, list(newshape)) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index f4d625195a2f..b38999d1b1b7 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -382,20 +382,29 @@ bool ReshapeRel(const Array& types, } const auto* param = attrs.as(); + Array data_shape; + Array newshape; + if (param->reverse) { + data_shape.assign(data->shape.rbegin(), data->shape.rend()); + newshape.assign(param->newshape.rbegin(), param->newshape.rend()); + } else { + data_shape = data->shape; + newshape = param->newshape; + } Array oshape; size_t src_idx = 0; int infer_idx = -1; - for (size_t i = 0; i < param->newshape.size(); ++i) { - int svalue = param->newshape[i]->value; + for (size_t i = 0; i < newshape.size(); ++i) { + int svalue = newshape[i]->value; // special flag handling for shape inference. if (svalue > 0) { - oshape.push_back(param->newshape[i]); + oshape.push_back(newshape[i]); ++src_idx; } else if (svalue == 0) { // keep same - CHECK_LT(src_idx, data->shape.size()); - oshape.push_back(data->shape[src_idx++]); + CHECK_LT(src_idx, data_shape.size()); + oshape.push_back(data_shape[src_idx++]); } else if (svalue == -1) { // inference based on rest CHECK_LT(infer_idx, 0) @@ -405,42 +414,51 @@ bool ReshapeRel(const Array& types, ++src_idx; } else if (svalue == -2) { // copy all remaining dims from source - while (src_idx < data->shape.size()) { - oshape.push_back(data->shape[src_idx++]); + while (src_idx < data_shape.size()) { + oshape.push_back(data_shape[src_idx++]); } } else if (svalue == -3) { // merge two dims from source - CHECK_LT(src_idx + 1, data->shape.size()); - IndexExpr d1 = data->shape[src_idx++]; - IndexExpr d2 = data->shape[src_idx++]; + CHECK_LT(src_idx + 1, data_shape.size()); + IndexExpr d1 = data_shape[src_idx++]; + IndexExpr d2 = data_shape[src_idx++]; oshape.push_back(d1 * d2); } else if (svalue == -4) { // split the source dim s into two dims // read the left dim and then the right dim (either can be -1) - CHECK_LT(i + 2, param->newshape.size()); - CHECK_LT(src_idx, data->shape.size()); - IndexExpr d0 = data->shape[src_idx++]; - Integer d1 = param->newshape[++i]; - Integer d2 = param->newshape[++i]; + CHECK_LT(i + 2, newshape.size()); + CHECK_LT(src_idx, data_shape.size()); + IndexExpr d0 = data_shape[src_idx++]; + Integer d1 = newshape[++i]; + Integer d2 = newshape[++i]; if (d1->value == -1) { CHECK(d2->value != -1) << "Split dims cannot both be -1."; oshape.push_back(d0 / d2); oshape.push_back(d2); } else { - CHECK_EQ(d2->value, -1); oshape.push_back(d1); - oshape.push_back(d0 / d1); + if (d2->value == -1) { + oshape.push_back(d0 / d1); + } else { + oshape.push_back(d2); + } } } } if (infer_idx >= 0) { IndexExpr new_size = arith::ComputeReduce(oshape, 1); - IndexExpr old_size = arith::ComputeReduce(data->shape, 1); + IndexExpr old_size = arith::ComputeReduce(data_shape, 1); oshape.Set(infer_idx, old_size / new_size); } - reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + + if (param->reverse) { + reporter->Assign(types[1], TensorTypeNode::make( + Array(oshape.rbegin(), oshape.rend()), data->dtype)); + } else { + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + } return true; } @@ -457,6 +475,7 @@ Expr MakeReshape(Expr data, Array newshape) { auto attrs = make_node(); attrs->newshape = std::move(newshape); + attrs->reverse = false; static const Op& op = Op::Get("reshape"); return CallNode::make(op, {data}, Attrs(attrs), {}); } @@ -1699,5 +1718,43 @@ the input array by output[n, c, h, w, C] = data[n, C*16+c, h, w] .set_support_level(5) .set_attr("FTVMCompute", LayoutTransformCompute); + +/* relay._contrib_reverse_reshape */ +Expr MakeReverseReshape(Expr data, + Array newshape) { + auto attrs = make_node(); + attrs->newshape = std::move(newshape); + attrs->reverse = true; + static const Op& op = Op::Get("_contrib_reverse_reshape"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_API("relay.op._make._contrib_reverse_reshape") +.set_body([](const TVMArgs& args, TVMRetValue* rv) { + runtime::detail::unpack_call(MakeReverseReshape, args, rv); +}); + +RELAY_REGISTER_OP("_contrib_reverse_reshape") +.describe(R"code(Reshapes the input array where the special values are inferred from +right to left. + +Example:: + +The special values have the same semantics as reshape. The difference is that +special values are inferred from right to left. It can be explained in the +example below:: + +- data.shape = (10,5,4), newshape = (-1,0), reshape results in (40,5) +- data.shape = (10,5,4), newshape = (-1,0), reverse_reshape results in (40,5) + +)code" TVM_ADD_FILELINE) +.set_num_inputs(1) +.set_attrs_type_key("relay.attrs.ReshapeAttrs") +.add_argument("data", "Tensor", "The input tensor.") +.set_support_level(10) +.add_type_rel("Reshape", ReshapeRel) +.set_attr("FTVMCompute", ReshapeCompute) +.set_attr("TOpPattern", kInjective); + } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 2c0ed73a7535..a6e169e23a6c 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -121,8 +121,31 @@ def test_slice_like(): axes=(2, 3), output=(1, 3, 112, 112)) +def test_reverse_reshape(): + def verify_reverse_reshape(shape, newshape, oshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.reverse_reshape(x, newshape=newshape) + zz = relay.ir_pass.infer_type(z) + print(zz.checked_type) + assert "newshape=" in z.astext() + assert zz.checked_type == relay.ty.TensorType(oshape, "float32") + + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = np.reshape(x_data, oshape) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x_data) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + verify_reverse_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) + verify_reverse_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) + verify_reverse_reshape((2, 3, 4), (0, -1), (3, 8)) + verify_reverse_reshape((2, 3, 4), (-1, 0), (6, 4)) + verify_reverse_reshape((2, 3, 4), (0, -3), (2, 12)) if __name__ == "__main__": test_collapse_sum_like() test_broadcast_to_like() test_slice_like() + test_reverse_reshape() diff --git a/tests/python/relay/test_op_level3.py b/tests/python/relay/test_op_level3.py index e66a73cf84cb..550637023d43 100644 --- a/tests/python/relay/test_op_level3.py +++ b/tests/python/relay/test_op_level3.py @@ -152,25 +152,36 @@ def test_reshape_infer_type(): (n, t, 2000), "float32") def test_reshape(): - def verify_reshape(shape, oshape): - x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") - ref_res = np.reshape(x_data, oshape) - + def verify_reshape(shape, newshape, oshape): x = relay.var("x", relay.TensorType(shape, "float32")) - z = relay.reshape(x, newshape=ref_res.shape) + z = relay.reshape(x, newshape=newshape) zz = relay.ir_pass.infer_type(z) assert "newshape=" in z.astext() assert zz.checked_type == relay.ty.TensorType(oshape, "float32") func = relay.Function([x], z) - + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + ref_res = np.reshape(x_data, oshape) for target, ctx in ctx_list(): for kind in ["graph", "debug"]: intrp = relay.create_executor(kind, ctx=ctx, target=target) op_res = intrp.evaluate(func)(x_data) tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) - verify_reshape((2, 3, 4), (8, 3)) - verify_reshape((4, 7), (2, 7, 2)) + verify_reshape((2, 3, 4), (8, 3), (8, 3)) + verify_reshape((4, 7), (2, 7, 2), (2, 7, 2)) + verify_reshape((2, 3, 4), (4, 0, 2), (4, 3, 2)) + verify_reshape((2, 3, 4), (2, 0, 0), (2, 3, 4)) + verify_reshape((2, 3, 4), (0, -1), (2, 12)) + verify_reshape((2, 3, 4), (-1, 0), (8, 3)) + verify_reshape((2, 3, 4), (2, -2), (2, 3, 4)) + verify_reshape((2, 3, 4), (-2, 1, 1), (2, 3, 4, 1, 1)) + verify_reshape((2, 3, 4), (-3, 4), (6, 4)) + verify_reshape((2, 3, 4, 5), (-3, -3), (6, 20)) + verify_reshape((2, 3, 4), (0, -3), (2, 12)) + verify_reshape((2, 3, 4), (-3, -2), (6, 4)) + verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4)) + verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4)) + def test_reshape_like_infer_type(): # concrete shape