diff --git a/include/tvm/relax/attrs/index.h b/include/tvm/relax/attrs/index.h new file mode 100644 index 000000000000..c95395a80376 --- /dev/null +++ b/include/tvm/relax/attrs/index.h @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/index.h + * \brief Attributes for indexing operators. + */ +#ifndef TVM_RELAX_ATTRS_INDEX_H_ +#define TVM_RELAX_ATTRS_INDEX_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in take operator */ +struct TakeAttrs : public tvm::AttrsNode { + Optional axis; + + TVM_DECLARE_ATTRS(TakeAttrs, "relax.attrs.TakeAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis over which to select values."); + } +}; // struct TakeAttrs + +/*! \brief Attributes used in strided_slice operator */ +struct StridedSliceAttrs : public tvm::AttrsNode { + Array axes; + Array begin; + Array end; + Optional> strides; + + TVM_DECLARE_ATTRS(StridedSliceAttrs, "relax.attrs.StridedSliceAttrs") { + TVM_ATTR_FIELD(axes).describe("Axes along which slicing is applied."); + TVM_ATTR_FIELD(begin).describe("The indices to begin with in the slicing, inclusive."); + TVM_ATTR_FIELD(end).describe("The indices indicating end of the slice, exclusive."); + TVM_ATTR_FIELD(strides).describe( + "Specifies the stride values, it can be negative in that case, the input tensor will be " + "reversed in that particular axis. If not specified, it by default is an list of ones of " + "the same length as `axes`."); + } +}; // struct StridedSliceAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_INDEX_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 9a131cdf957f..3393a5dcae67 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,6 +20,8 @@ # Operators from .base import * from .binary import * +from .index import * from .manipulate import * +from .op_attrs import * from . import builtin from . import memory diff --git a/python/tvm/relax/op/index.py b/python/tvm/relax/op/index.py new file mode 100644 index 000000000000..2a7afa5ba0f9 --- /dev/null +++ b/python/tvm/relax/op/index.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Indexing operators.""" +from typing import List, Optional, Union + +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr + +PrimExprLike = Union[int, PrimExpr] + + +def take(x: Expr, indices: Expr, axis: Optional[int] = None) -> Expr: + """Take elements from a tensor along an axis. + + Parameters + ---------- + x : relax.Expr + The source tensor. + + indices : relax.Expr + The indices of the values to extract. + It is required to be a one-dimensional tensor which has integer dtype. + + axis : Optional[int] + The axis over which to select values. + If it is none, the input tensor is required to be one-dimensional. + + Returns + ------- + ret : relax.Expr + The taken result. + """ + return _ffi_api.take(x, indices, axis) # type: ignore + + +def strided_slice( + x: Expr, + axes: List[int], + begin: List[PrimExprLike], + end: List[PrimExprLike], + strides: Optional[List[PrimExprLike]] = None, +) -> Expr: + """Strided slice of a tensor. + + Parameters + ---------- + x : relax.Expr + The source tensor to be sliced. + + axes : List[int] + Axes along which slicing is applied. + + begin : List[PrimExprLike] + The indices to begin with in the slicing, inclusive. + + end : List[PrimExprLike] + The indices indicating end of the slice, exclusive. + + strides : Optional[List[PrimExprLike]] + Specifies the stride values, it can be negative in that case, + the input tensor will be reversed in that particular axis. + If not specified, it by default is an list of ones of the same length as `axes`. + + Returns + ------- + ret : relax.Expr + The sliced result. + + Note + ---- + strided_slice require the input `begin`, `end` and `strides` to have the + same length as `axes`. + """ + return _ffi_api.strided_slice(x, axes, begin, end, strides) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py new file mode 100644 index 000000000000..44cb2cf3a5b4 --- /dev/null +++ b/python/tvm/relax/op/op_attrs.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""The attributes node used for Relax operators""" +from tvm.ir import Attrs +import tvm._ffi + + +@tvm._ffi.register_object("relax.attrs.TakeAttrs") +class TakeAttrs(Attrs): + """Attributes used in take operator""" + + +@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs") +class StridedSliceAttrs(Attrs): + """Attributes used in strided_slice operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 0e6595cb4514..75a00ea04985 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -42,6 +42,8 @@ print, reshape, shape_of, + strided_slice, + take, ) from tvm.relax.struct_info import StructInfo from tvm.relax.utils import args_converter @@ -427,5 +429,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "shape", "shape_of", "str", + "strided_slice", + "take", "tuple", ] diff --git a/src/relax/op/tensor/index.cc b/src/relax/op/tensor/index.cc new file mode 100644 index 000000000000..246abef9084b --- /dev/null +++ b/src/relax/op/tensor/index.cc @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index.cc + * \brief indexing operators. + */ + +#include "index.h" + +#include +#include + +namespace tvm { +namespace relax { + +/* relax.take */ +TVM_REGISTER_NODE_TYPE(TakeAttrs); + +Expr take(Expr x, Expr indices, Optional axis) { + ObjectPtr attrs = make_object(); + attrs->axis = std::move(axis); + + static const Op& op = Op::Get("relax.take"); + return Call(op, {std::move(x), std::move(indices)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.take").set_body_typed(take); + +StructInfo InferStructInfoTake(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo indices_sinfo = input_sinfo[1]; + if (indices_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op requires the input indices to be 1-dimensional tensor. However, " + "the given indices ndim is " + << indices_sinfo->ndim); + } else if (!indices_sinfo->IsUnknownDtype() && + !(indices_sinfo->dtype.is_int() || indices_sinfo->dtype.is_uint())) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op requires the input indices to have integer dtype. However, the " + "given indices dtype is " + << indices_sinfo->dtype); + } + + const auto* attrs = call->attrs.as(); + if (!attrs->axis.defined() && data_sinfo->ndim != 1) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Take op expects the input data to be 1-dimensional tensor when the axis " + "is not specified. However, the given data tensor has ndim " + << data_sinfo->ndim); + } + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + int axis = attrs->axis.defined() + ? NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis.value()->value) + : 0; + const auto* data_shape = data_sinfo->shape.as(); + const auto* indices_shape = indices_sinfo->shape.as(); + if (data_shape == nullptr || indices_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + + Array output_shape = data_shape->values; + output_shape.Set(axis, indices_shape->values[0]); + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.take") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The source tensor.") + .add_argument("indices", "Tensor", "The indices of the values to extract.") + .set_attr("FInferStructInfo", InferStructInfoTake); + +/* relax.strided_slice */ +TVM_REGISTER_NODE_TYPE(StridedSliceAttrs); + +Expr strided_slice(Expr x, // + Array axes, // + Array begin, // + Array end, // + Optional> strides) { + int n_axis = axes.size(); + CHECK_EQ(static_cast(begin.size()), n_axis) + << "StridedSlice requires the number of begin indices to equal the number of axes."; + CHECK_EQ(static_cast(end.size()), n_axis) + << "StridedSlice requires the number of end indices to equal the number of axes."; + if (strides.defined()) { + CHECK_EQ(static_cast(strides.value().size()), n_axis) + << "StridedSlice requires the number of strides to equal the number of axes."; + } + + // Todo(relax-team): We are going to support dynamic strided slice, where + // begin/end/stride can be not static at compile time. Therefore, begin/end/stride + // should not be part of StridedSliceAttrs, as we only allow static values to + // reside in attributes. However, using ShapeExpr to represent these + // arrays is not conceptually right, because they are not describing a + // concrete shape. The proper way to support dynamic strided slice is to use + // Tuple of PrimValue to represent begin/end/stride. Since at this moment + // we have no support for PrimValue, we store begin/end/stride as attribute + // fields as a workaround. + // Will switch to Tuple of PrimValue after introducing PrimValue. + auto f_convert_to_int64 = [](const PrimExpr& value) { + if (value->IsInstance()) { + return cast(DataType::Int(64), value); + } + CHECK(value.dtype() == DataType::Int(64)) << "strided_slice expects the input begin/end/stride " + "values to be all int64. However, the given " + << value << " has dtype " << value->dtype; + return value; + }; + + ObjectPtr attrs = make_object(); + attrs->axes = std::move(axes); + attrs->begin = begin.Map(f_convert_to_int64); + attrs->end = end.Map(f_convert_to_int64); + attrs->strides = strides.defined() ? strides.value().Map(f_convert_to_int64) : strides; + + static const Op& op = Op::Get("relax.strided_slice"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.strided_slice").set_body_typed(strided_slice); + +StructInfo InferStructInfoStridedSlice(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->axes.empty()) { + return data_sinfo; + } + + if (data_sinfo->IsUnknownNdim()) { + return TensorStructInfo(data_sinfo->dtype, kUnknownNDim); + } + + std::vector axes = NormalizeAxes(call, ctx, data_sinfo->ndim, attrs->axes); + const auto* data_shape = data_sinfo->shape.as(); + if (data_shape == nullptr) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + + int n_axis = axes.size(); + Array strides = attrs->strides.defined() + ? attrs->strides.value() + : Array(n_axis, IntImm(DataType::Int(64), 1)); + std::vector int_strides; + int_strides.reserve(n_axis); + // Only do output shape inference when all the begin/end/stride values are integers. + for (int i = 0; i < n_axis; ++i) { + const auto* int_begin = attrs->begin[i].as(); + const auto* int_end = attrs->end[i].as(); + const auto* int_stride = strides[i].as(); + if (!int_begin || !int_end || !int_stride) { + return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim); + } + int_strides.push_back(int_stride->value); + } + + Array output_shape = data_shape->values; + for (int i = 0; i < n_axis; ++i) { + PrimExpr len = int_strides[i] < 0 ? ceildiv(attrs->begin[i] - attrs->end[i], -int_strides[i]) + : ceildiv(attrs->end[i] - attrs->begin[i], int_strides[i]); + output_shape.Set(axes[i], len); + } + return TensorStructInfo(ShapeExpr(output_shape), data_sinfo->dtype); +} + +TVM_REGISTER_OP("relax.strided_slice") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The source tensor to be sliced.") + .set_attr("FInferStructInfo", InferStructInfoStridedSlice); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/index.h b/src/relax/op/tensor/index.h new file mode 100644 index 000000000000..6944493a0fd6 --- /dev/null +++ b/src/relax/op/tensor/index.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file index.h + * \brief The functions to make Relax tensor indexing operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_INDEX_H_ +#define TVM_RELAX_OP_TENSOR_INDEX_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Take elements from a tensor along an axis. + * \param x The source tensor. + * \param indices The indices of the values to extract. + * It is required to be a one-dimensional tensor which has integer dtype. + * \param axis The axis over which to select values. + * If it is `NullOpt`, the input tensor is required to be one-dimensional. + * \return The taken result. + */ +Expr take(Expr x, Expr indices, Optional axis); + +/*! + * \brief Strided slice of a tensor. + * \param x The source tensor to be sliced. + * \param axes Axes along which slicing is applied. + * \param begin The indices to begin with in the slicing, inclusive. + * \param end The indices indicating end of the slice, exclusive. + * \param strides Specifies the stride values, it can be negative in that case, + * the input tensor will be reversed in that particular axis. + * If it is `NullOpt`, it by default is an list of ones of the same length as `axes`. + * \return The sliced result + */ +Expr strided_slice(Expr x, // + Array axes, // + Array begin, // + Array end, // + Optional> strides); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_INDEX_H_ diff --git a/tests/python/relax/test_op_index.py b/tests/python/relax/test_op_index.py new file mode 100644 index 000000000000..77a04b1a1aab --- /dev/null +++ b/tests/python/relax/test_op_index.py @@ -0,0 +1,593 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + idx = relax.Var("idx", R.Tensor((2,), "float32")) + assert relax.op.take(x, idx, axis=1).op == Op.get("relax.take") + assert relax.op.strided_slice(x, axes=[0], begin=[0], end=[2]).op == Op.get( + "relax.strided_slice" + ) + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_take_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((4, 10))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((10,), "float32")) + y1 = relax.Var("y", R.Tensor("float32", ndim=1)) + y2 = relax.Var("y", R.Tensor((10,))) + y3 = relax.Var("y", R.Tensor(ndim=1)) + idx0 = relax.Var("idx", R.Tensor((6,), "int64")) + idx1 = relax.Var("idx", R.Tensor("int64", ndim=1)) + idx2 = relax.Var("idx", R.Tensor((6,))) + idx3 = relax.Var("idx", R.Tensor(ndim=1)) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x0, idx0, axis=-1), relax.TensorStructInfo((4, 6), "float32") + ) + _check_inference( + bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx0, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx0, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx1, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx1, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float32")) + _check_inference( + bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx2, axis=1), relax.TensorStructInfo((4, 6), dtype="")) + _check_inference(bb, relax.op.take(x4, idx2, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx2, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.take(x0, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx3, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx3, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x3, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x4, idx3, axis=1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.take(x5, idx3, axis=1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx0), relax.TensorStructInfo((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx0), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx1), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx1), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx2), relax.TensorStructInfo((6,), "float32")) + _check_inference(bb, relax.op.take(y1, idx2), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx2), relax.TensorStructInfo((6,), dtype="")) + _check_inference(bb, relax.op.take(y3, idx2), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y0, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y1, idx3), relax.TensorStructInfo(dtype="float32", ndim=1)) + _check_inference(bb, relax.op.take(y2, idx3), relax.TensorStructInfo(dtype="", ndim=1)) + _check_inference(bb, relax.op.take(y3, idx3), relax.TensorStructInfo(dtype="", ndim=1)) + + +def test_take_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + i = tir.Var("i", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + y0 = relax.Var("y", R.Tensor((n,), "float32")) + y1 = relax.Var("y", R.Tensor((n,))) + idx0 = relax.Var("idx", R.Tensor((i,), "int64")) + idx1 = relax.Var( + "idx", + R.Tensor( + (i,), + ), + ) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((m, i), dtype="")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((m, i), "float32")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((m, i), dtype="")) + _check_inference(bb, relax.op.take(y0, idx0), relax.TensorStructInfo((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx0), relax.TensorStructInfo((i,), dtype="")) + _check_inference(bb, relax.op.take(y0, idx1), relax.TensorStructInfo((i,), "float32")) + _check_inference(bb, relax.op.take(y1, idx1), relax.TensorStructInfo((i,), dtype="")) + + +def test_take_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + sx0 = relax.Var("sx", relax.ShapeStructInfo((4, 10))) + sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=2)) + sx2 = relax.Var("sx", relax.ShapeStructInfo()) + sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6,))) + sidx1 = relax.Var("sidx", relax.ShapeStructInfo(ndim=1)) + x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32")) + x3 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) + idx1 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx2 = relax.Var("idx", R.Tensor((6,), "int64")) + + _check_inference( + bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo(dtype="float32")) + _check_inference( + bb, relax.op.take(x3, idx0, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.take(x3, idx1, axis=1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_take_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float16")) + x1 = relax.Var("x", R.Tensor((4, 10), "int16")) + x2 = relax.Var("x", R.Tensor((4, 10), "int32")) + idx0 = relax.Var("idx", R.Tensor((6,), "int32")) + idx1 = relax.Var("idx", R.Tensor((6,), "int8")) + idx2 = relax.Var("idx", R.Tensor((6,), "uint32")) + + _check_inference(bb, relax.op.take(x0, idx0, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx0, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx0, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx1, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx1, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx1, axis=1), relax.TensorStructInfo((4, 6), "int32")) + _check_inference(bb, relax.op.take(x0, idx2, axis=1), relax.TensorStructInfo((4, 6), "float16")) + _check_inference(bb, relax.op.take(x1, idx2, axis=1), relax.TensorStructInfo((4, 6), "int16")) + _check_inference(bb, relax.op.take(x2, idx2, axis=1), relax.TensorStructInfo((4, 6), "int32")) + + +def test_take_infer_struct_info_indices_not_one_dimensional(): + bb = relax.BlockBuilder() + sidx0 = relax.Var("sidx", relax.ShapeStructInfo((6, 6))) + sidx1 = relax.Var("sidx", relax.ShapeStructInfo(())) + sidx2 = relax.Var("sidx", relax.ShapeStructInfo(ndim=2)) + sidx3 = relax.Var("sidx", relax.ShapeStructInfo(ndim=0)) + sidx4 = relax.Var("sidx", relax.ShapeStructInfo()) + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", R.Tensor((6, 6), "int64")) + idx1 = relax.Var("idx", R.Tensor((), "int64")) + idx2 = relax.Var("idx", R.Tensor("int64", ndim=2)) + idx3 = relax.Var("idx", R.Tensor("int64", ndim=0)) + idx4 = relax.Var("idx", R.Tensor("int64")) + idx5 = relax.Var("idx", relax.TensorStructInfo(sidx0, "int64")) + idx6 = relax.Var("idx", relax.TensorStructInfo(sidx1, "int64")) + idx7 = relax.Var("idx", relax.TensorStructInfo(sidx2, "int64")) + idx8 = relax.Var("idx", relax.TensorStructInfo(sidx3, "int64")) + idx9 = relax.Var("idx", relax.TensorStructInfo(sidx4, "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx2, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx3, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx4, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx5, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx6, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx7, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx8, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx9, axis=1)) + + +def test_take_infer_struct_info_indices_not_integer_dtype(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", R.Tensor((6, 6), "float32")) + idx1 = relax.Var("idx", R.Tensor((6, 6), "float64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx0, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx1, axis=1)) + + +def test_take_infer_struct_info_multi_dimensional_without_axis(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((4, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + idx0 = relax.Var("idx", R.Tensor((6,), "int64")) + idx1 = relax.Var("idx", R.Tensor("int64", ndim=1)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x2, idx0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x2, idx1)) + + +def test_take_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((4, 10), "float32")) + idx = relax.Var("idx", R.Tensor((6,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx, axis=-3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x, idx, axis=2)) + + +def test_take_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((4, 10))) + x1 = relax.Var("x", R.Tensor((4, 10), "float32")) + idx0 = relax.Var("idx", relax.ShapeStructInfo((6,))) + idx1 = relax.Var("idx", R.Tensor((6,), "int64")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x0, idx1, axis=1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.take(x1, idx0, axis=1)) + + +def test_strided_slice_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=4)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((8, 9, 10, 10))) + x4 = relax.Var("x", R.Tensor(ndim=4)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo((4, 9, 10, 3), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x1, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="float32", ndim=4), + ) + _check_inference( + bb, + relax.op.strided_slice( + x2, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice( + x3, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo((4, 9, 10, 3), dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice( + x4, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype="", ndim=4), + ) + _check_inference( + bb, + relax.op.strided_slice( + x5, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ), + relax.TensorStructInfo(dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice( + x0, axes=[-1, -3, -4], begin=[8, 0, 1], end=[0, 9, 8], strides=[-3, 1, 2] + ), + relax.TensorStructInfo((4, 9, 10, 3), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[1, 2], begin=[1, 0], end=[8, 9]), + relax.TensorStructInfo((8, 7, 9, 10), "float32"), + ) + + +def test_strided_slice_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[1], end=[3]), + relax.TensorStructInfo((2, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[1], end=[8], strides=[3]), + relax.TensorStructInfo((3, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[1], end=[3]), + relax.TensorStructInfo((2, n), dtype=""), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[1], end=[8], strides=[3]), + relax.TensorStructInfo((3, n), dtype=""), + ) + + +def test_strided_slice_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((8, 10))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, dtype="")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, dtype="")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, dtype="")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x3, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x4, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype="", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x5, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo(dtype=""), + ) + + +def test_strided_slice_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((8, 9), "float16")) + x1 = relax.Var("x", R.Tensor((8, 9), "int32")) + x2 = relax.Var("x", R.Tensor((8, 9), "int64")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "float16"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "int32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[0], begin=[0], end=[8]), + relax.TensorStructInfo((8, 9), "int64"), + ) + + +def test_strided_slice_infer_struct_info_symbolic_begin_end_strides(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[a], end=[8]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[a]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x, axes=[0], begin=[0], end=[8], strides=[a]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + + +def test_strided_slice_infer_struct_info_no_axis(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + s0 = relax.Var("s", relax.ShapeStructInfo((m, n))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor(dtype="float32", ndim=2)) + x2 = relax.Var("x", R.Tensor(dtype="float32")) + x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference( + bb, + relax.op.strided_slice(x0, axes=[], begin=[], end=[]), + relax.TensorStructInfo((m, n), "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x1, axes=[], begin=[], end=[]), + relax.TensorStructInfo(dtype="float32", ndim=2), + ) + _check_inference( + bb, + relax.op.strided_slice(x2, axes=[], begin=[], end=[]), + relax.TensorStructInfo(dtype="float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x3, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s0, "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x4, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s1, "float32"), + ) + _check_inference( + bb, + relax.op.strided_slice(x5, axes=[], begin=[], end=[]), + relax.TensorStructInfo(s2, "float32"), + ) + + +def test_strided_slice_begin_end_strides_int64(): + x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + strided_slice = relax.op.strided_slice( + x, axes=[0, 1, 3], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ) + + assert strided_slice.attrs.begin[0].dtype == "int64" + assert strided_slice.attrs.begin[1].dtype == "int64" + assert strided_slice.attrs.begin[2].dtype == "int64" + assert strided_slice.attrs.end[0].dtype == "int64" + assert strided_slice.attrs.end[1].dtype == "int64" + assert strided_slice.attrs.end[2].dtype == "int64" + assert strided_slice.attrs.strides[0].dtype == "int64" + assert strided_slice.attrs.strides[1].dtype == "int64" + assert strided_slice.attrs.strides[2].dtype == "int64" + + +def test_strided_slice_inconsistent_axes_begin_end_strides_length(): + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[], end=[9]) + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[0], end=[]) + with pytest.raises(TVMError): + relax.op.strided_slice(x, axes=[1], begin=[0], end=[9], strides=[]) + + +def test_strided_slice_infer_struct_info_repetitive_axes(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[0, 0], begin=[0, 0], end=[8, 8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[0, -2], begin=[0, 0], end=[8, 8])) + + +def test_strided_slice_infer_struct_info_axis_out_of_range(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[2], begin=[0], end=[8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x, axes=[-3], begin=[0], end=[8])) + + +def test_strided_slice_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((8, 9))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((8, 9), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x0, axes=[0], begin=[0], end=[8])) + with pytest.raises(TVMError): + bb.normalize(relax.op.strided_slice(x1, axes=[0], begin=[0], end=[8])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_index.py b/tests/python/relax/test_tvmscript_parser_op_index.py new file mode 100644 index 000000000000..b271d1a7f3bc --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_index.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_take(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), "float32"), indices: R.Tensor((3,), "int64") + ) -> R.Tensor((2, 3, 3), "float32"): + gv: R.Tensor((2, 3, 3), "float32") = R.take(x, indices, axis=2) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + indices = relax.Var("indices", R.Tensor((3,), "int64")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, indices]): + gv = bb.emit(relax.op.take(x, indices, axis=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_strided_slice(): + @R.function + def foo(x: R.Tensor((8, 9, 10, 10), "float32")) -> R.Tensor((4, 9, 10, 3), "float32"): + gv: R.Tensor((4, 9, 10, 3), "float32") = R.strided_slice( + x, + axes=[0, 1, -1], + begin=[1, 0, 8], + end=[8, 9, 0], + strides=[2, 1, -3], + ) + return gv + + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((8, 9, 10, 10), "float32")) + with bb.function("foo", [x]): + gv = bb.emit( + relax.op.strided_slice( + x, axes=[0, 1, -1], begin=[1, 0, 8], end=[8, 9, 0], strides=[2, 1, -3] + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()