From d4a1be891815971e523aa6fd0fa35f8e89731af5 Mon Sep 17 00:00:00 2001 From: Chaosfan Date: Mon, 9 Jan 2023 06:24:44 +0800 Subject: [PATCH] [Op][Manip] collapse_sum_like, collapse_sum_to (#87) * collapse_sum_like & collapse_sum_to * address comments * test coverage * fix some tests * collapse_sum_to tgt shape var test * format * format * reformat * reformat * type: ignore * handle shape var cases and add regression tests --- python/tvm/relax/op/manipulate.py | 53 +++ python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/tensor/manipulate.cc | 130 +++++++ src/relax/op/tensor/manipulate.h | 21 ++ tests/python/relax/test_op_manipulate.py | 323 ++++++++++++++++++ .../test_tvmscript_parser_op_manipulate.py | 33 ++ 6 files changed, 564 insertions(+) diff --git a/python/tvm/relax/op/manipulate.py b/python/tvm/relax/op/manipulate.py index 01426e4129..6facae2ff9 100644 --- a/python/tvm/relax/op/manipulate.py +++ b/python/tvm/relax/op/manipulate.py @@ -220,3 +220,56 @@ def squeeze(x: Expr, axis: Optional[Union[int, List[int]]] = None) -> Expr: if isinstance(axis, int): axis = [axis] return _ffi_api.squeeze(x, axis) # type: ignore + + +def collapse_sum_like(data: Expr, collapse_target: Expr) -> Expr: + """Return a summation of data to the shape of collapse_target. + + For details, please see relax.op.collapse_sum_to. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + collapse_target : relax.Expr + The tensor whose shape is the shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor after summation. + """ + return _ffi_api.collapse_sum_like(data, collapse_target) # type: ignore + + +def collapse_sum_to(data: Expr, shape: Union[Tuple[PrimExprLike], Expr]) -> Expr: + """Return a summation of data to the given shape. + + collapse_sum_to is intended as the backward operator of tvm.relax.op.broadcast_to and + other broadcast operators in the automatic differentiation process. + + We expect that data is the result of broadcasting some tensor of the given shape in some + broadcast operation. Thus the given `shape` and `data.shape` must follow broadcast rules. + + During computation, all axes of `data.shape` and `shape` are checked from right to left. + For an axis, if it follows these rules, `data` will be summed over this axis: + - the axis exists in `data.shape` but not in `shape`, or + - the axis exists in `data.shape` and equals to 1 in `shape`. + + Parameters + ---------- + data : relax.Expr + The input tensor. + + shape : Union[Tuple[PrimExprLike], relax.Expr] + The shape to collapse to. + + Returns + ------- + result : relax.Expr + The result tensor of the given shape after summation. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.collapse_sum_to(data, shape) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 3064b3fa6d..43ec508545 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -38,6 +38,8 @@ null_value, call_tir, call_builtin, + collapse_sum_like, + collapse_sum_to, concat, cos, divide, @@ -400,6 +402,8 @@ def tuple(*fields: List[Expr]) -> Expr: "call_tir", "call_builtin", "cos", + "collapse_sum_like", + "collapse_sum_to", "null_value", "concat", "const", diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 7045bd5fd8..654050a6ff 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -799,5 +799,135 @@ TVM_REGISTER_OP("relax.squeeze") .add_argument("x", "Tensor", "The input tensor.") .set_attr("FInferStructInfo", InferStructInfoSqueeze); +void CheckCollapseShape(const Call& call, const BlockBuilder& ctx, + const Array& data_shape, const Array& target_shape) { + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + + int data_ndim = data_shape.size(); + int target_ndim = target_shape.size(); + + int data_ax = data_ndim - 1; + int target_ax = target_ndim - 1; + for (; data_ax >= 0; --data_ax) { + if (target_ax < 0) { + continue; + } + const PrimExpr& dim0 = data_shape[data_ax]; + const PrimExpr& dim1 = target_shape[target_ax]; + const auto* int_dim0 = dim0.as(); + const auto* int_dim1 = dim1.as(); + + if (analyzer->CanProveEqual(dim0, dim1) || (int_dim1 != nullptr && int_dim1->value == 1)) { + --target_ax; + } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { + ctx->ReportFatal(Diagnostic::Error(call) + << "In " << call->op << ", the data shape at dim " << data_ax << " is " + << dim0 << " and the target shape at dim " << target_ax << " is " << dim1 + << ", which do not match the rule of collapse sum."); + } else { + // Todo(relax-team): At this moment, enforcing MatchCast is fine. But we may need to revisit + // this requirement to reduce the workload of importers and better support dynamic shapes. + ctx->ReportFatal(Diagnostic::Error(call) + << call->op + << " fails to match the axes because of unknown dim or symbolic" + " shape. In this position the dim of data shape is " + << dim0 << " while the dim of target shape is " << dim1 + << ". If it is symbolic, consider use MatchCast first."); + } + } +} + +/* relax.collapse_sum_like */ +Expr collapse_sum_like(Expr data, Expr collapse_target) { + static const Op& op = Op::Get("relax.collapse_sum_like"); + return Call(op, {std::move(data), std::move(collapse_target)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_like").set_body_typed(collapse_sum_like); + +StructInfo InferStructInfoCollapseSumLike(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo collapse_target_sinfo = input_sinfo[1]; + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + Optional> collapse_target_shape_value; + if (collapse_target_sinfo->shape.defined()) { + collapse_target_shape_value = + GetStructInfoAs(collapse_target_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && collapse_target_shape_value.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), collapse_target_shape_value.value()); + } + + if (collapse_target_sinfo->shape.defined()) { + return TensorStructInfo(collapse_target_sinfo->shape.value(), output_dtype); + } else { + return TensorStructInfo(output_dtype, collapse_target_sinfo->ndim); + } +} + +TVM_REGISTER_OP("relax.collapse_sum_like") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("collapse_target", "Tensor", + "The tensor whose shape is the shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumLike); + +/* relax.collapse_sum_to */ +Expr collapse_sum_to(Expr data, Expr shape) { + static const Op& op = Op::Get("relax.collapse_sum_to"); + return Call(op, {std::move(data), std::move(shape)}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.collapse_sum_to").set_body_typed(collapse_sum_to); + +StructInfo InferStructInfoCollapseSumTo(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "CollapseSumTo should have 2 arguments"); + } + + const auto* data_sinfo = GetStructInfoAs(call->args[0]); + const auto* shape_sinfo = GetStructInfoAs(call->args[1]); + + if (data_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "CollapseSumTo requires the input data to be a Tensor. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "CollapseSumTo requires the input shape to be a Shape. However, the given one is " + << call->args[1]->struct_info_->GetTypeKey()); + } + + DataType output_dtype = data_sinfo->dtype; + + Optional> data_shape_value; + if (data_sinfo->shape.defined()) { + data_shape_value = GetStructInfoAs(data_sinfo->shape.value())->values; + } + + if (data_shape_value.defined() && shape_sinfo->values.defined()) { + CheckCollapseShape(call, ctx, data_shape_value.value(), shape_sinfo->values.value()); + } + + return TensorStructInfo(/*shape=*/call->args[1], output_dtype); +} + +TVM_REGISTER_OP("relax.collapse_sum_to") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("shape", "Shape", "The shape to collapse to.") + .set_attr("FInferStructInfo", InferStructInfoCollapseSumTo); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/manipulate.h b/src/relax/op/tensor/manipulate.h index 7e8a511e9e..bffa0071b2 100644 --- a/src/relax/op/tensor/manipulate.h +++ b/src/relax/op/tensor/manipulate.h @@ -102,6 +102,27 @@ Expr split(Expr x, ObjectRef indices_or_sections, int axis); */ Expr squeeze(Expr x, Optional> axis); +/*! + * \brief Return a summation of data to the shape of collapse_target. + * For details, please see the operator `relax.collapse_sum_to`. + * \param data The input tensor. + * \param collapse_target The tensor whose shape is the shape to collapse to. + * \return The result tensor after summation. + */ +Expr collapse_sum_like(Expr data, Expr collapse_target); + +/*! + * \brief Return a summation of data to the given shape. + * collapse_sum_to is intended as the backward operator of broadcast_to and + * other broadcast operators in the automatic differentiation process. + * We expect that data is the result of broadcasting some tensor of the given shape in some + * broadcast operation. Thus the given shape and data.shape must follow broadcast rules. + * \param data The input tensor. + * \param shape The shape to collapse to. + * \return The result tensor of the given shape after summation. + */ +Expr collapse_sum_to(Expr data, Expr shape); + } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 32bc50bf95..c952a5395d 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -2220,5 +2220,328 @@ def test_broadcast_to_infer_struct_info_wrong_input_type(): bb.normalize(relax.op.broadcast_to(x1, stgt)) +def test_collapse_sum_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((3, 4), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=2)) + y2 = relax.Var("y", R.Tensor("float32")) + y3 = relax.Var("y", R.Tensor((3, 4))) + y4 = relax.Var("y", R.Tensor(ndim=2)) + y5 = relax.Var("y", R.Tensor((1, 4))) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y2), relax.TensorStructInfo(dtype="float32", ndim=-1) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y3), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x2, y0), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x2, y4), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x4, y1), relax.TensorStructInfo(dtype="", ndim=2) + ) + _check_inference( + bb, relax.op.collapse_sum_like(x5, y3), relax.TensorStructInfo((3, 4), dtype="") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x0, y5), relax.TensorStructInfo((1, 4), "float32") + ) + + +def test_collapse_sum_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) + y0 = relax.Var("y", R.Tensor((4, a), "float32")) + x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + y1 = relax.Var("x", R.Tensor((1, a + b), "float32")) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((4, a), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((1, a + b), "float32") + ) + + +def test_collapse_sum_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + s3 = relax.Var("s", relax.ShapeStructInfo((3, 4))) + s4 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s5 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s5, "float32")) + + _check_inference(bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo(s3, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo(s4, "float32")) + _check_inference(bb, relax.op.collapse_sum_like(x2, y2), relax.TensorStructInfo(s5, "float32")) + + +def test_collapse_sum_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + y0 = relax.Var("y", R.Tensor((3, 4), "float16")) + y1 = relax.Var("y", R.Tensor((3, 4), "int8")) + + _check_inference( + bb, relax.op.collapse_sum_like(x0, y0), relax.TensorStructInfo((3, 4), "float16") + ) + _check_inference(bb, relax.op.collapse_sum_like(x1, y1), relax.TensorStructInfo((3, 4), "int8")) + + +def test_collapse_sum_like_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x0, x1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x2, x0)) + + +def test_collapse_sum_like_check_shape_failure(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y0 = relax.Var("y", R.Tensor((3, 6, 5), "float32")) + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x1 = relax.Var("z", R.Tensor((3, a, 5), "float32")) + y1 = relax.Var("w", R.Tensor((3, b, 5), "float32")) + + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + s1 = relax.Var("s", relax.ShapeStructInfo((3, 6, 5))) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + + s2 = relax.Var("s", relax.ShapeStructInfo((3, a, 5))) + s3 = relax.Var("s", relax.ShapeStructInfo((3, b, 5))) + x3 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x0, y0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x1, y1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x2, y2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_like(x3, y3)) + + +def test_collapse_sum_to_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x2, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x3, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x4, (3, 4)), relax.TensorStructInfo((3, 4), "")) + _check_inference(bb, relax.op.collapse_sum_to(x5, (3, 4)), relax.TensorStructInfo((3, 4), "")) + + +def test_collapse_sum_to_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x0 = relax.Var("x", R.Tensor((3, 4, a), "float32")) + x1 = relax.Var("x", R.Tensor((3, 4, b + a), "float32")) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (4, a)), relax.TensorStructInfo((4, a), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (1, a + b)), relax.TensorStructInfo((1, a + b), "float32") + ) + + +def test_collapse_sum_to_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "float32") + ) + + +def test_collapse_sum_to_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, (3, 4)), relax.TensorStructInfo((3, 4), "float16") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, (3, 4)), relax.TensorStructInfo((3, 4), "int8") + ) + + +def test_collapse_sum_to_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + x1 = relax.Var("x", relax.ShapeStructInfo((4, 5))) + x2 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, x0)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, x2)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x1, x1)) + + +def test_collapse_sum_to_check_shape_failure(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + x1 = relax.Var("x", R.Tensor((3, a, 5), "float32")) + + s0 = relax.Var("s", relax.ShapeStructInfo((3, 4, 5))) + x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + + s1 = relax.Var("s", relax.ShapeStructInfo((3, a, 5))) + x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x0, (4, 4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x1, (3, b, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x2, (4, 4, 5))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.collapse_sum_to(x3, (3, b, 5))) + + +def test_collapse_sum_to_struct_info_tgt_shape_var(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + d = tir.Var("d", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((3, a, b))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((3, a, b), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + stgt0 = relax.Var("stgt", relax.ShapeStructInfo((a, b))) + stgt1 = relax.Var("stgt", relax.ShapeStructInfo(ndim=2)) + stgt2 = relax.Var("stgt", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt0), relax.TensorStructInfo(stgt0, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt0), relax.TensorStructInfo(stgt0, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt1), relax.TensorStructInfo(stgt1, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt1), relax.TensorStructInfo(stgt1, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x0, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x1, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference(bb, relax.op.collapse_sum_to(x2, stgt2), relax.TensorStructInfo(stgt2, "")) + _check_inference( + bb, relax.op.collapse_sum_to(x3, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x4, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + _check_inference( + bb, relax.op.collapse_sum_to(x5, stgt2), relax.TensorStructInfo(stgt2, "float32") + ) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py b/tests/python/relax/test_tvmscript_parser_op_manipulate.py index 68dcc0c820..7527462b7d 100644 --- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py +++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py @@ -273,5 +273,38 @@ def foo(x: R.Tensor((2, 1, 3, 1, 1, 4), "float32")) -> R.Tensor((2, 3, 1, 4), "f _check(foo, bb.get()["foo"]) +def test_collapse_sum_like(): + @R.function + def foo( + x: R.Tensor((3, 4, 5), "float32"), y: R.Tensor((4, 5), "float32") + ) -> R.Tensor((4, 5), "float32"): + gv: R.Tensor((4, 5), "float32") = R.collapse_sum_like(x, y) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(relax.op.collapse_sum_like(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_collapse_sum_to(): + @R.function + def foo(x: R.Tensor((3, 4, 5), "float32")) -> R.Tensor((4, 5), "float32"): + gv: R.Tensor((4, 5), "float32") = R.collapse_sum_to(x, (4, 5)) + return gv + + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.collapse_sum_to(x, (4, 5))) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main()