diff --git a/include/tvm/relax/attrs/create.h b/include/tvm/relax/attrs/create.h new file mode 100644 index 000000000000..6af176a42c9d --- /dev/null +++ b/include/tvm/relax/attrs/create.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/create.h + * \brief Attributes for tensor creation operators. + */ +#ifndef TVM_RELAX_ATTRS_CREATE_H_ +#define TVM_RELAX_ATTRS_CREATE_H_ + +#include + +namespace tvm { +namespace relax { + +/*! \brief Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operators */ +struct InitAttrs : public tvm::AttrsNode { + DataType dtype; + + TVM_DECLARE_ATTRS(InitAttrs, "relax.attrs.InitAttrs") { + TVM_ATTR_FIELD(dtype).describe("The data type of the created tensor."); + } +}; // struct InitAttrs + +/*! \brief Attributes used in tril and triu operator */ +struct TriluAttrs : public tvm::AttrsNode { + int k; + + TVM_DECLARE_ATTRS(TriluAttrs, "relax.attrs.TriluAttrs") { + TVM_ATTR_FIELD(k).describe( + "The number of diagonals above or below the main diagonal to exclude or include."); + } +}; // struct TriluAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_CREATE_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 6c6fffc7c65e..97d08c0946a0 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -20,6 +20,7 @@ # Operators from .base import * from .binary import * +from .create import * from .datatype import * from .index import * from .manipulate import * diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py new file mode 100644 index 000000000000..a6643a8633e4 --- /dev/null +++ b/python/tvm/relax/op/create.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Creation operators.""" +from typing import Optional, Tuple, Union + +from tvm import DataType +from tvm.ir.expr import PrimExpr + +from . import _ffi_api +from ..expr import Expr, ShapeExpr + +PrimExprLike = Union[int, PrimExpr] + + +def full( + shape: Union[Tuple[PrimExprLike], Expr], + fill_value: Expr, + dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Fill array with scalar value. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + fill_value : relax.Expr + The value to fill. Must be a scalar tensor. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of fill_value. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.full(shape, fill_value, dtype) # type: ignore + + +def full_like(x: Expr, fill_value: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor such that + - its shape is the same as the input data tensor's shape, + - its value is filled with the input scalar fill value. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + fill_value : relax.Expr + The value to fill. Must be a scalar tensor. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.full_like(x, fill_value, dtype) # type: ignore + + +def ones(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr: + """Construct a tensor of all ones, with the input shape and dtype. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.ones(shape, dtype) # type: ignore + + +def ones_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor with all ones, with shape of the input tensor shape. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.ones_like(x, dtype) # type: ignore + + +def zeros(shape: Union[Tuple[PrimExprLike], Expr], dtype: Union[str, DataType]) -> Expr: + """Construct a tensor of all zeros, with the input shape and dtype. + + Parameters + ---------- + shape : Union[Tuple[PrimExprLike], Expr] + The shape of the created tensor. + + dtype : Union[str, DataType] + The data type of the created tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + if isinstance(shape, (tuple, list)): + shape = ShapeExpr(shape) + return _ffi_api.zeros(shape, dtype) # type: ignore + + +def zeros_like(x: Expr, dtype: Optional[Union[str, DataType]] = None) -> Expr: + """Construct a tensor with all zeros, with shape of the input tensor shape. + + Parameters + ---------- + x : relax.Expr + The input tensor, which provides the shape, and dtype + when the `dtype` field is not specified. + + dtype : Optional[Union[str, DataType]] + The data type of the created tensor. + If dtype is not given, it will by default use the dtype of the input tensor. + + Returns + ------- + result : relax.Expr + The result tensor. + """ + return _ffi_api.zeros_like(x, dtype) # type: ignore + + +def tril(x: Expr, k: int = 0) -> Expr: + """Return the lower triangular part of a matrix or a batch of matrices. + + Parameters + ---------- + x : relax.Expr + The tensor that tril will be applied to. + It is required to have at least two dimensions. + + k : int + The index indicating the diagonal above which to zero elements. + If k = 0, the diagonal is the main diagonal. + If k < 0, the diagonal is below the main diagonal. + If k > 0, the diagonal is above the main diagonal. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + return _ffi_api.tril(x, k) # type: ignore + + +def triu(x: Expr, k: int = 0) -> Expr: + """Return the upper triangular part of a matrix or a batch of matrices. + + Parameters + ---------- + x : relax.Expr + The tensor that triu will be applied to. + It is required to have at least two dimensions. + + k : int + The index indicating the diagonal below which to zero elements. + If k = 0, the diagonal is the main diagonal. + If k < 0, the diagonal is below the main diagonal. + If k > 0, the diagonal is above the main diagonal. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + return _ffi_api.triu(x, k) # type: ignore diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index 68f84b3514a9..ac6714d940d3 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -19,6 +19,16 @@ import tvm._ffi +@tvm._ffi.register_object("relax.attrs.InitAttrs") +class InitAttrs(Attrs): + """Attributes used in full/full_like, ones/ones_like, and zeros/zeros_like operator""" + + +@tvm._ffi.register_object("relax.attrs.TriluAttrs") +class TriluAttrs(Attrs): + """Attributes used in tril and triu operator""" + + @tvm._ffi.register_object("relax.attrs.AstypeAttrs") class AstypeAttrs(Attrs): """Attributes used in astype operator""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 1f0e31428c63..118790372a35 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -52,6 +52,8 @@ exp, floor, floor_divide, + full, + full_like, greater, greater_equal, image, @@ -71,6 +73,8 @@ negative, not_equal, null_value, + ones, + ones_like, print, prod, reshape, @@ -92,7 +96,11 @@ take, tan, tanh, + tril, + triu, unique, + zeros, + zeros_like, nn, ) from tvm.relax.struct_info import StructInfo @@ -480,6 +488,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "exp", "floor", "floor_divide", + "full", + "full_like", "func_attr", "func_name", "func_ret_struct_info", @@ -504,6 +514,8 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "negative", "not_equal", "null_value", + "ones", + "ones_like", "output", "prim_value", "print", @@ -528,8 +540,12 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "take", "tan", "tanh", + "tril", + "triu", "tuple", "variance", "unique", + "zeros", + "zeros_like", "nn", ] diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc new file mode 100644 index 000000000000..e8374d198109 --- /dev/null +++ b/src/relax/op/tensor/create.cc @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file create.cc + * \brief Creation operators. + */ + +#include "create.h" + +#include + +namespace tvm { +namespace relax { + +/* Initialization operators */ +TVM_REGISTER_NODE_TYPE(InitAttrs); + +/* relax.full */ +Expr full(ObjectRef shape, Expr fill_value, DataType dtype) { + Expr shape_in_expr{nullptr}; + if (const auto* expr = shape.as()) { + shape_in_expr = GetRef(expr); + } else if (const auto* _array = shape.as()) { + shape_in_expr = ShapeExpr(GetRef>(_array)); + } else { + LOG(FATAL) << "Full only expects the input shape to be either an Expr or an Array of PrimExpr. " + "However, the given one is " + << shape->GetTypeKey(); + } + + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.full"); + return Call(op, {std::move(shape_in_expr), std::move(fill_value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full").set_body_typed(full); + +StructInfo InferStructInfoFull(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 2) { + ctx->ReportFatal(Diagnostic::Error(call) << "Full op should have 2 arguments"); + } + const auto* shape_sinfo = GetStructInfoAs(call->args[0]); + const auto* fill_value_sinfo = GetStructInfoAs(call->args[1]); + if (shape_sinfo == nullptr) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Full requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + if (fill_value_sinfo == nullptr || fill_value_sinfo->ndim != 0) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Full requires the input fill value to be zero rank Tensor. However, the given one is " + << call->args[1]->struct_info_); + } + + const auto* attrs = call->attrs.as(); + DataType out_dtype = attrs->dtype.is_void() ? fill_value_sinfo->dtype : attrs->dtype; + return TensorStructInfo(/*shape=*/call->args[0], out_dtype); +} + +TVM_REGISTER_OP("relax.full") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .add_argument("fill_value", "Tensor", "The scalar tensor, denoting the value to fill.") + .set_attr("FInferStructInfo", InferStructInfoFull); + +/* relax.full_like */ +Expr full_like(Expr x, Expr fill_value, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.full_like"); + return Call(op, {std::move(x), std::move(fill_value)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.full_like").set_body_typed(full_like); + +StructInfo InferStructInfoFullLike(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo data_sinfo = input_sinfo[0]; + TensorStructInfo fill_value_sinfo = input_sinfo[1]; + if (fill_value_sinfo->ndim != 0) { + ctx->ReportFatal(Diagnostic::Error(call) << "FullLike requires the input fill value to be zero " + "rank Tensor. However, the given one has ndim" + << fill_value_sinfo->ndim); + } + + const auto* attrs = call->attrs.as(); + if (attrs->dtype.is_void()) { + return data_sinfo; + } else { + auto output_sinfo = make_object(*data_sinfo.get()); + output_sinfo->dtype = attrs->dtype; + return TensorStructInfo(output_sinfo); + } +} + +TVM_REGISTER_OP("relax.full_like") + .set_attrs_type() + .set_num_inputs(2) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("fill_value", "Tensor", "The scalar value to fill.") + .set_attr("FInferStructInfo", InferStructInfoFullLike); + +// Structure info inference for ones and zeros +StructInfo InferStructInfoOnesZeros(const Call& call, const BlockBuilder& ctx) { + if (call->args.size() != 1) { + ctx->ReportFatal(Diagnostic::Error(call) << "Ones/Zeros should have 1 argument"); + } + + const auto* shape_sinfo = GetStructInfoAs(call->args[0]); + if (shape_sinfo == nullptr) { + ctx->ReportFatal( + Diagnostic::Error(call) + << "Ones/Zeros requires the input shape to be a Shape. However, the given one is " + << call->args[0]->struct_info_->GetTypeKey()); + } + const auto* attrs = call->attrs.as(); + return TensorStructInfo(/*shape=*/call->args[0], attrs->dtype); +} + +// Structure info inference for ones_like and zeros_like +StructInfo InferStructInfoOnesLikeZerosLike(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + const auto* attrs = call->attrs.as(); + if (attrs->dtype.is_void()) { + return data_sinfo; + } else { + auto output_sinfo = make_object(*data_sinfo.get()); + output_sinfo->dtype = attrs->dtype; + return TensorStructInfo(output_sinfo); + } +} + +/* relax.ones & relax.ones_like */ +Expr ones(Expr shape, DataType dtype) { + CHECK(!dtype.is_void()) << "Ones op expects the input dtype not to be void"; + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.ones"); + return Call(op, {std::move(shape)}, Attrs(attrs), {}); +} + +Expr ones_like(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.ones_like"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ones").set_body_typed(ones); +TVM_REGISTER_GLOBAL("relax.op.ones_like").set_body_typed(ones_like); + +TVM_REGISTER_OP("relax.ones") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesZeros); + +TVM_REGISTER_OP("relax.ones_like") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + +/* relax.zeros & relax.zeros_like */ +Expr zeros(Expr shape, DataType dtype) { + CHECK(!dtype.is_void()) << "Zeros op expects the input dtype not to be void"; + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + + static const Op& op = Op::Get("relax.zeros"); + return Call(op, {std::move(shape)}, Attrs(attrs), {}); +} + +Expr zeros_like(Expr x, DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.zeros_like"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.zeros").set_body_typed(zeros); +TVM_REGISTER_GLOBAL("relax.op.zeros_like").set_body_typed(zeros_like); + +TVM_REGISTER_OP("relax.zeros") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("shape", "Shape", "The shape of the created tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesZeros); + +TVM_REGISTER_OP("relax.zeros_like") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoOnesLikeZerosLike); + +/* relax.tril & relax.triu */ +TVM_REGISTER_NODE_TYPE(TriluAttrs); + +Expr tril(Expr x, int k) { + ObjectPtr attrs = make_object(); + attrs->k = k; + + static const Op& op = Op::Get("relax.tril"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +Expr triu(Expr x, int k) { + ObjectPtr attrs = make_object(); + attrs->k = k; + + static const Op& op = Op::Get("relax.triu"); + return Call(op, {std::move(x)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.tril").set_body_typed(tril); +TVM_REGISTER_GLOBAL("relax.op.triu").set_body_typed(triu); + +StructInfo InferStructInfoTrilTriu(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (!data_sinfo->IsUnknownNdim() && data_sinfo->ndim < 2) { + ctx->ReportFatal(Diagnostic::Error(call) << call->op + << " requires the input tensor to have at least two " + "dimensions. However, the given input has " + << data_sinfo->ndim << " dimension(s)."); + } + return data_sinfo; +} + +TVM_REGISTER_OP("relax.tril") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + +TVM_REGISTER_OP("relax.triu") + .set_attrs_type() + .set_num_inputs(1) + .add_argument("x", "Tensor", "The input tensor.") + .set_attr("FInferStructInfo", InferStructInfoTrilTriu); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h new file mode 100644 index 000000000000..c1ade470b4e8 --- /dev/null +++ b/src/relax/op/tensor/create.h @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file create.h + * \brief The functions to make Relax tensor-creation operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_CREATE_H_ +#define TVM_RELAX_OP_TENSOR_CREATE_H_ + +#include + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Fill array with scalar value. + * \param shape The shape of the created tensor. + * \param fill_value The value to fill. Must be a scalar tensor. + * \param dtype The data type of the created tensor. + * If dtype is not given, it will by default use the dtype of fill_value. + * \return The result tensor. + */ +Expr full(ObjectRef shape, Expr fill_value, DataType dtype); + +/*! + * \brief Construct a tensor such that + * - its shape is the same as the input data tensor's shape, + * - its value is filled with the input scalar fill value. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param fill_value The value to fill. Must be a scalar tensor. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr full_like(Expr x, Expr fill_value, DataType dtype); + +/*! + * \brief Construct a tensor of all ones, with the input shape and dtype. + * \param shape The shape of the created tensor. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr ones(Expr shape, DataType dtype); + +/*! + * \brief Construct a tensor with all ones, with shape of the input tensor shape. + * \param x The input tensor, which provides the shape, and dtype + * when the input dtype is void. + * \param dtype The data type of the created tensor. If it is + * void, the input tensor's dtype will be used. + * \return The result tensor. + */ +Expr ones_like(Expr x, DataType dtype); + +/*! \brief Construct a tensor of all zeros, with the input shape and dtype. */ +Expr zeros(Expr shape, DataType dtype); + +/*! \brief Construct a tensor with all zeros, with shape of the input tensor shape. */ +Expr zeros_like(Expr x, DataType dtype); + +/*! \brief Return the lower triangular part of a matrix or a batch of matrices. */ +Expr tril(Expr x, int k); + +/*! \brief Return the upper triangular part of a matrix or a batch of matrices. */ +Expr triu(Expr x, int k); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_CREATE_H_ diff --git a/tests/python/relax/test_op_create.py b/tests/python/relax/test_op_create.py new file mode 100644 index 000000000000..6dd0a0d15ead --- /dev/null +++ b/tests/python/relax/test_op_create.py @@ -0,0 +1,638 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((3, 4, 5), "float32")) + fill_value = relax.Var("fill_value", R.Tensor((), "float32")) + assert relax.op.full((2, 3), fill_value).op == Op.get("relax.full") + assert relax.op.full_like(x, fill_value).op == Op.get("relax.full_like") + assert relax.op.ones((2, 3), "float32").op == Op.get("relax.ones") + assert relax.op.ones_like(x).op == Op.get("relax.ones_like") + assert relax.op.zeros((2, 3), "float32").op == Op.get("relax.zeros") + assert relax.op.zeros_like(x).op == Op.get("relax.zeros_like") + assert relax.op.tril(x).op == Op.get("relax.tril") + assert relax.op.triu(x).op == Op.get("relax.triu") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_full_infer_struct_info(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=0)) + v2 = relax.Var("v", R.Tensor(())) + v3 = relax.Var("v", R.Tensor(ndim=0)) + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full(s0, v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v0, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v0), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v0, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v0), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v0, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v0), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full(s0, v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v1, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full(s2, v1, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v1), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full(s3, v1, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v1), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.full((2, 3), v2, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference( + bb, relax.op.full(s0, v2, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v2, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v2), relax.TensorStructInfo(s1, dtype="")) + _check_inference(bb, relax.op.full(s2, v2, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference(bb, relax.op.full(s2, v2), relax.TensorStructInfo(s2, dtype="")) + _check_inference(bb, relax.op.full(s3, v2, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v2), relax.TensorStructInfo(s3, dtype="")) + _check_inference( + bb, relax.op.full((2, 3), v3, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full((2, 3), v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference( + bb, relax.op.full(s0, v3, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference(bb, relax.op.full(s0, v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full(s1, v3, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference( + bb, + relax.op.full( + s1, + v3, + ), + relax.TensorStructInfo(s1, dtype=""), + ) + _check_inference(bb, relax.op.full(s2, v3, "float16"), relax.TensorStructInfo(s2, "float16")) + _check_inference( + bb, + relax.op.full( + s2, + v3, + ), + relax.TensorStructInfo(s2, dtype=""), + ) + _check_inference(bb, relax.op.full(s3, v3, "float16"), relax.TensorStructInfo(s3, "float16")) + _check_inference(bb, relax.op.full(s3, v3), relax.TensorStructInfo(s3, dtype="")) + + +def test_full_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + v = relax.Var("v", R.Tensor((), "float32")) + s0 = relax.ShapeExpr((a, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((a, 3))) + + _check_inference( + bb, relax.op.full((a, 3), v, "float16"), relax.TensorStructInfo((a, 3), "float16") + ) + _check_inference(bb, relax.op.full((a, 3), v), relax.TensorStructInfo((a, 3), "float32")) + _check_inference(bb, relax.op.full(s0, v, "float16"), relax.TensorStructInfo((a, 3), "float16")) + _check_inference(bb, relax.op.full(s0, v), relax.TensorStructInfo((a, 3), "float32")) + _check_inference(bb, relax.op.full(s1, v, "float16"), relax.TensorStructInfo(s1, "float16")) + _check_inference(bb, relax.op.full(s1, v), relax.TensorStructInfo(s1, "float32")) + + +def test_full_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(())) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + v0 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v1 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full((2, 3), v1, "float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_full_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float16")) + v1 = relax.Var("v", R.Tensor((), "int8")) + v2 = relax.Var("v", R.Tensor((), "int32")) + + _check_inference( + bb, relax.op.full((2, 3), v0, "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.full((2, 3), v0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference( + bb, relax.op.full((2, 3), v1, "int32"), relax.TensorStructInfo((2, 3), "int32") + ) + _check_inference(bb, relax.op.full((2, 3), v1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2, "int8"), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full((2, 3), v2), relax.TensorStructInfo((2, 3), "int32")) + + +def test_full_infer_struct_info_fill_value_not_scalar_tensor(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + v0 = relax.Var("v", R.Tensor((1,), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=1)) + v2 = relax.Var("v", R.Tensor("float32")) + v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v5)) + + +def test_full_shape_not_tuple(): + m = tir.Var("m", "int64") + v = relax.Var("v", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + relax.op.full(4, v) + with pytest.raises(TVMError): + relax.op.full(m, v) + + +def test_full_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + v0 = relax.Var("v", R.Tensor((), "float32")) + v1 = relax.Var("v", relax.ShapeStructInfo(())) + v2 = relax.Var("v", relax.FuncStructInfo([], R.Tensor((), "float32"))) + s = relax.Var("s", R.Tensor((2, 3))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full(s, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full((2, 3), v2)) + + +def test_full_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + v0 = relax.Var("v", R.Tensor((), "float16")) + v1 = relax.Var("v", R.Tensor("float16", ndim=0)) + v2 = relax.Var("v", R.Tensor(())) + v3 = relax.Var("v", R.Tensor(ndim=0)) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x0, v3), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.full_like(x1, v3), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x2, v3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v2), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x3, v3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.full_like(x4, v0), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v1), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v2), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x4, v3), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.full_like(x5, v0), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v1), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v2), relax.TensorStructInfo(dtype="")) + _check_inference(bb, relax.op.full_like(x5, v3), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.full_like(x0, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x0, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x3, v0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.full_like(x3, v2, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_full_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + v = relax.Var("v", R.Tensor((), "float16")) + + _check_inference(bb, relax.op.full_like(x0, v), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.full_like(x1, v), relax.TensorStructInfo((m, n), dtype="")) + + +def test_full_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x3 = relax.Var("x", R.Tensor((2, 3), "float32")) + sv0 = relax.Var("sv", relax.ShapeStructInfo(())) + sv1 = relax.Var("sv", relax.ShapeStructInfo(ndim=0)) + v0 = relax.Var("v", relax.TensorStructInfo(sv0, "float16")) + v1 = relax.Var("v", relax.TensorStructInfo(sv1, "float16")) + v2 = relax.Var("v", R.Tensor((), "float16")) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x0, v2), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x1, v2), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.full_like(x2, v0), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v1), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x2, v2), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.full_like(x3, v0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.full_like(x3, v1), relax.TensorStructInfo((2, 3), "float32")) + + +def test_full_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + v0 = relax.Var("v", R.Tensor((), "int32")) + v1 = relax.Var("v", R.Tensor((), "float64")) + + _check_inference(bb, relax.op.full_like(x0, v0), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x0, v1), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.full_like(x1, v0), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.full_like(x1, v1), relax.TensorStructInfo((2, 3), "int8")) + + +def test_full_like_infer_struct_info_fill_value_not_scalar_tensor(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + s0 = relax.Var("s", relax.ShapeStructInfo((1,))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + v0 = relax.Var("v", R.Tensor((1,), "float32")) + v1 = relax.Var("v", R.Tensor("float32", ndim=1)) + v2 = relax.Var("v", R.Tensor("float32")) + v3 = relax.Var("v", relax.TensorStructInfo(s0, "float32")) + v4 = relax.Var("v", relax.TensorStructInfo(s1, "float32")) + v5 = relax.Var("v", relax.TensorStructInfo(s2, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x, v5)) + + +def test_full_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((), "float32"))) + x2 = relax.Var("x", R.Tensor((2, 3))) + v0 = relax.Var("v", R.Tensor(())) + v1 = relax.Var("v", relax.ShapeStructInfo(())) + + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x0, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x1, v0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.full_like(x2, v1)) + + +def test_ones_zeros_infer_struct_info(): + bb = relax.BlockBuilder() + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference( + bb, relax.op.ones((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.ones(s2, "float32"), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.ones(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + _check_inference( + bb, relax.op.zeros((2, 3), "float32"), relax.TensorStructInfo((2, 3), "float32") + ) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.zeros(s2, "float32"), relax.TensorStructInfo(s2, "float32")) + _check_inference(bb, relax.op.zeros(s3, "float32"), relax.TensorStructInfo(s3, "float32")) + + +def test_ones_zeros_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + s0 = relax.ShapeExpr((m, n)) + s1 = relax.Var("s", relax.ShapeStructInfo((m, n))) + + _check_inference( + bb, relax.op.ones((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") + ) + _check_inference(bb, relax.op.ones(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.ones(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + _check_inference( + bb, relax.op.zeros((m, n), "float32"), relax.TensorStructInfo((m, n), "float32") + ) + _check_inference(bb, relax.op.zeros(s0, "float32"), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.zeros(s1, "float32"), relax.TensorStructInfo(s1, "float32")) + + +def test_ones_zeros_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + s0 = relax.ShapeExpr((2, 3)) + s1 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s3 = relax.Var("s", relax.ShapeStructInfo()) + + _check_inference(bb, relax.op.ones(s0, "float16"), relax.TensorStructInfo((2, 3), "float16")) + _check_inference(bb, relax.op.ones(s1, "int8"), relax.TensorStructInfo(s1, "int8")) + _check_inference(bb, relax.op.zeros(s2, "int32"), relax.TensorStructInfo(s2, "int32")) + _check_inference(bb, relax.op.zeros(s3, "float64"), relax.TensorStructInfo(s3, "float64")) + + +def test_ones_zeros_shape_not_tuple(): + m = tir.Var("m", "int64") + + with pytest.raises(TVMError): + relax.op.ones(10, "float32") + with pytest.raises(TVMError): + relax.op.zeros(m, "float32") + + +def test_ones_zeros_wrong_dtype(): + with pytest.raises(TypeError): + relax.op.ones((2, 3)) + with pytest.raises(TVMError): + relax.op.ones((2, 3), "") + with pytest.raises(TypeError): + relax.op.zeros((2, 3)) + with pytest.raises(TVMError): + relax.op.zeros((2, 3), "") + + +def test_ones_zeros_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", R.Tensor((2, 3))) + s1 = relax.Var("s", relax.FuncStructInfo([], R.Tensor((2, 3)))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ones(s0, "float32")) + with pytest.raises(TVMError): + bb.normalize(relax.op.zeros(s1, "float32")) + + +def test_ones_like_zeros_like_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=2)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor(ndim=2)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.ones_like(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.zeros_like(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.ones_like(x4), relax.TensorStructInfo(dtype="", ndim=2)) + _check_inference(bb, relax.op.zeros_like(x5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, relax.op.ones_like(x0, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + _check_inference( + bb, relax.op.zeros_like(x3, dtype="float16"), relax.TensorStructInfo((2, 3), "float16") + ) + + +def test_ones_like_zeros_like_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((m, n))) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((m, n), dtype="")) + + +def test_ones_like_zeros_like_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.zeros_like(x2), relax.TensorStructInfo(s2, "float32")) + + +def test_ones_like_zeros_like_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + + _check_inference(bb, relax.op.ones_like(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.zeros_like(x1), relax.TensorStructInfo((2, 3), "int8")) + + +def test_ones_like_zeros_like_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ones_like(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.zeros_like(x1)) + + +def test_tril_triu_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32")) + x3 = relax.Var("x", R.Tensor((2, 3, 4))) + x4 = relax.Var("x", R.Tensor(ndim=3)) + x5 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.tril(x0, k=1), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x0, k=0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((2, 3, 4), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.triu(x3), relax.TensorStructInfo((2, 3, 4), dtype="")) + _check_inference(bb, relax.op.tril(x4), relax.TensorStructInfo(dtype="", ndim=3)) + _check_inference(bb, relax.op.triu(x5), relax.TensorStructInfo(dtype="")) + + +def test_tril_triu_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((a, b, c), "float32")) + x1 = relax.Var("x", R.Tensor((a, b, c))) + + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo((a, b, c), "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo((a, b, c), dtype="")) + + +def test_tril_triu_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2, 3, 4))) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + + _check_inference(bb, relax.op.tril(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, relax.op.triu(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.tril(x2), relax.TensorStructInfo(s2, "float32")) + + +def test_tril_triu_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3, 4), "float16")) + x1 = relax.Var("x", R.Tensor((2, 3, 4), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3, 4), "int32")) + + _check_inference(bb, relax.op.triu(x0), relax.TensorStructInfo((2, 3, 4), "float16")) + _check_inference(bb, relax.op.tril(x1), relax.TensorStructInfo((2, 3, 4), "int8")) + _check_inference(bb, relax.op.triu(x2), relax.TensorStructInfo((2, 3, 4), "int32")) + + +def test_tril_triu_infer_struct_info_less_than_two_ndim(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo((2,))) + s1 = relax.Var("s", relax.ShapeStructInfo(())) + s2 = relax.Var("s", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s", relax.ShapeStructInfo(ndim=0)) + x0 = relax.Var("x", R.Tensor((2,), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + x2 = relax.Var("x", R.Tensor("float32", ndim=1)) + x3 = relax.Var("x", R.Tensor("float32", ndim=0)) + x4 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x5 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x6 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x7 = relax.Var("x", relax.TensorStructInfo(s3, "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x2)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x3)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x4)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x5)) + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x6)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x7)) + + +def test_tril_triu_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 4), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(relax.op.tril(x0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.triu(x1)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_create.py b/tests/python/relax/test_tvmscript_parser_op_create.py new file mode 100644 index 000000000000..6cbc0ebf906a --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_create.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_full(): + @R.function + def foo(v: R.Tensor((), "int32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.full((2, 3), v, dtype="float32") + return gv + + bb = relax.BlockBuilder() + v = relax.Var("v", R.Tensor((), "int32")) + with bb.function("foo", [v]): + gv = bb.emit(relax.op.full((2, 3), v, "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_full_like(): + @R.function + def foo( + x: R.Tensor((2, 3), "float16"), v: R.Tensor((), "float32") + ) -> R.Tensor((2, 3), "float16"): + gv: R.Tensor((2, 3), "float16") = R.full_like(x, v) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float16")) + v = relax.Var("y", R.Tensor((), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, v]): + gv = bb.emit(relax.op.full_like(x, v)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_ones(): + @R.function + def foo(dumb_param: R.Tensor()) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones((2, 3), "float32") + return gv + + bb = relax.BlockBuilder() + dumb_param = relax.Var("dumb_param", R.Tensor()) + with bb.function("foo", [dumb_param]): + gv = bb.emit(relax.op.ones((2, 3), "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_ones_like(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.ones_like(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.ones_like(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_zeros(): + @R.function + def foo(dumb_param: R.Tensor()) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros((2, 3), "float32") + return gv + + bb = relax.BlockBuilder() + dumb_param = relax.Var("dumb_param", R.Tensor()) + with bb.function("foo", [dumb_param]): + gv = bb.emit(relax.op.zeros((2, 3), "float32")) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_zeros_like(): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.zeros_like(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.zeros_like(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_tril(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.tril(x, k=2) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.tril(x, k=2)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_triu(): + @R.function + def foo(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), "float32"): + gv: R.Tensor((2, 3, 4), "float32") = R.triu(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(relax.op.triu(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()