diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 573cde982bea..e1f09233bcff 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1717,7 +1717,7 @@ def _impl_v1(cls, bb, inputs, attr, params): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - return bb.emit_te(topi.split, inputs[0], indices, attr.get("axis", 0)) + return relax.op.split(inputs[0], indices, attr.get("axis", 0)) @classmethod def _impl_v13(cls, bb, inputs, attr, params): @@ -1738,7 +1738,7 @@ def _impl_v13(cls, bb, inputs, attr, params): # When splits isnt specified divide evenly over axis. else: indices = attr["tvm_custom"]["num_outputs"] - return bb.emit_te(topi.split, inputs[0], indices, axis=attr.get("axis", 0)) + return relax.op.split(inputs[0], indices, attr.get("axis", 0)) def get_prim_value_list(values): diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py b/python/tvm/relax/transform/legalize_ops/manipulate.py index 55bc2772bcce..c71a41dc1c2d 100644 --- a/python/tvm/relax/transform/legalize_ops/manipulate.py +++ b/python/tvm/relax/transform/legalize_ops/manipulate.py @@ -112,12 +112,13 @@ def _split(bb: BlockBuilder, call: Call) -> Expr: modulo = tvm.arith.Analyzer().simplify( call.args[0].struct_info.shape.values[call.attrs.axis] % indices_or_sections ) - if modulo != 0: - logging.info( - "Split cannot be legalized by TOPI when the axis being split has " - "length that not divisible by the input number of section." - ) - return call + if isinstance(modulo, tir.IntImm): + if modulo != 0: + logging.info( + "Split cannot be legalized by TOPI when the axis being split has " + "length that not divisible by the input number of section." + ) + return call else: indices_or_sections = call.attrs.indices_or_sections return bb.call_te(topi.split, call.args[0], indices_or_sections, call.attrs.axis) diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 452b1f223a80..cb738db363ee 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -864,7 +864,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { auto p_indices = opt_indices.value(); // When there is not index, return the input tensor's struct info. if (p_indices.size() == 0) { - return TupleStructInfo({data_sinfo}); + return data_sinfo; } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { @@ -911,7 +911,7 @@ StructInfo InferStructInfoSplit(const Call& call, const BlockBuilder& ctx) { int n_section = p_n_section->value; // When the number of section is one, return the input tensor's struct info. if (n_section == 1) { - return TupleStructInfo({data_sinfo}); + return data_sinfo; } // Fall back to unknown shape when the input tensor doesn't have ShapeExpr as shape. if (data_shape == nullptr) { diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 23ab6780cf7b..28e762d9a4de 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -2108,62 +2108,62 @@ def test_split_infer_struct_info_single_output(): _check_inference( bb, relax.op.split(x0, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + relax.TensorStructInfo((a, b), "float32"), ) _check_inference( bb, relax.op.split(x1, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + relax.TensorStructInfo(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.split(x2, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + relax.TensorStructInfo(dtype="float32"), ) _check_inference( bb, relax.op.split(x3, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + relax.TensorStructInfo(s0, "float32"), ) _check_inference( bb, relax.op.split(x4, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + relax.TensorStructInfo(s1, "float32"), ) _check_inference( bb, relax.op.split(x5, [], axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + relax.TensorStructInfo(s2, "float32"), ) _check_inference( bb, relax.op.split(x0, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo((a, b), "float32")]), + relax.TensorStructInfo((a, b), "float32"), ) _check_inference( bb, relax.op.split(x1, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32", ndim=2)]), + relax.TensorStructInfo(dtype="float32", ndim=2), ) _check_inference( bb, relax.op.split(x2, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(dtype="float32")]), + relax.TensorStructInfo(dtype="float32"), ) _check_inference( bb, relax.op.split(x3, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s0, "float32")]), + relax.TensorStructInfo(s0, "float32"), ) _check_inference( bb, relax.op.split(x4, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s1, "float32")]), + relax.TensorStructInfo(s1, "float32"), ) _check_inference( bb, relax.op.split(x5, 1, axis=1), - relax.TupleStructInfo([relax.TensorStructInfo(s2, "float32")]), + relax.TensorStructInfo(s2, "float32"), ) @@ -2200,9 +2200,7 @@ def test_split_infer_struct_info(): _check_inference( bb, relax.op.split(x, 1), - R.Tuple( - R.Tensor([16, 4]), - ), + R.Tensor([16, 4]), ) _check_inference( bb,