diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 38573512691ca..344576fe13b2d 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -25,6 +25,19 @@ from .manipulate import * from .op_attrs import * from .set import * +from .ternary import * +from .unary import * from . import builtin from . import image from . import memory + + +def _register_op_make(): + # pylint: disable=import-outside-toplevel + from . import _ffi_api + from .. import expr + + expr._op_ffi_api = _ffi_api # type: ignore + + +_register_op_make() diff --git a/python/tvm/relax/op/binary.py b/python/tvm/relax/op/binary.py index eee0b6f3366a8..4042f9bbc9aad 100644 --- a/python/tvm/relax/op/binary.py +++ b/python/tvm/relax/op/binary.py @@ -49,6 +49,42 @@ def add(x1: Expr, x2: Expr) -> Expr: return _ffi_api.add(x1, x2) # type: ignore +def divide(x1: Expr, x2: Expr) -> Expr: + """Division with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.divide(x1, x2) # type: ignore + + +def floor_divide(x1: Expr, x2: Expr) -> Expr: + """Floor division with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.floor_divide(x1, x2) # type: ignore + + def multiply(x1: Expr, x2: Expr) -> Expr: """Multiplication with numpy-style broadcasting. @@ -65,3 +101,132 @@ def multiply(x1: Expr, x2: Expr) -> Expr: The computed result. """ return _ffi_api.multiply(x1, x2) # type: ignore + + +def subtract(x1: Expr, x2: Expr) -> Expr: + """Subtraction with numpy-style broadcasting. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.subtract(x1, x2) # type: ignore + + +###################### Comparison operators ###################### + + +def equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs == rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.equal(x1, x2) # type: ignore + + +def greater(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs > rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.greater(x1, x2) # type: ignore + + +def greater_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs >= rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.greater_equal(x1, x2) # type: ignore + + +def less(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs < rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.less(x1, x2) # type: ignore + + +def less_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs <= rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.less_equal(x1, x2) # type: ignore + + +def not_equal(x1: Expr, x2: Expr) -> Expr: + """Broadcasted element-wise test for (lhs != rhs). + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + x2 : relax.Expr + The second input tensor. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.not_equal(x1, x2) # type: ignore diff --git a/python/tvm/relax/op/ternary.py b/python/tvm/relax/op/ternary.py new file mode 100644 index 0000000000000..7c320cc1ca480 --- /dev/null +++ b/python/tvm/relax/op/ternary.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax ternary arithmetic operators.""" +from . import _ffi_api +from ..expr import Expr + + +def ewise_fma(x1: Expr, x2: Expr, x3: Expr) -> Expr: + """Elementwise fused multiply-add operator + Returns elementwise result of :math:`x1 * x2 + x3` + + Parameters + ---------- + x1 : relax.Expr + The left hand operand of the multiplication + + x2 : relax.Expr + The right hand operand of the multiplication + + x3 : relax.Expr + The operand of the addition + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.ewise_fma(x1, x2, x3) # type: ignore diff --git a/python/tvm/relax/op/unary.py b/python/tvm/relax/op/unary.py new file mode 100644 index 0000000000000..866d2a8273d6e --- /dev/null +++ b/python/tvm/relax/op/unary.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=redefined-builtin, invalid-name +"""Relax unary arithmetic operators.""" +from . import _ffi_api +from ..expr import Expr +from ..utils import args_converter + +###################### Arithmetic operators ###################### + + +def abs(x: Expr) -> Expr: + """Compute element-wise absolute value of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.abs(x) # type: ignore + + +def acos(x: Expr) -> Expr: + """Compute element-wise arc cos of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.acos(x) # type: ignore + + +def acosh(x: Expr) -> Expr: + """Compute element-wise arc cosh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.acosh(x) # type: ignore + + +def asin(x: Expr) -> Expr: + """Compute element-wise arc sin of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.asin(x) # type: ignore + + +def asinh(x: Expr) -> Expr: + """Compute element-wise arc sinh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.asinh(x) # type: ignore + + +def atan(x: Expr) -> Expr: + """Compute element-wise arc tan of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.atan(x) # type: ignore + + +def atanh(x: Expr) -> Expr: + """Compute element-wise arc tanh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.atanh(x) # type: ignore + + +def ceil(x: Expr) -> Expr: + """Take ceil of input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.ceil(x) # type: ignore + + +def cos(x: Expr) -> Expr: + """Compute element-wise cos of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.cos(x) # type: ignore + + +def cosh(x: Expr) -> Expr: + """Compute element-wise cosh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.cosh(x) # type: ignore + + +def exp(x: Expr) -> Expr: + """Compute element-wise exp of data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.exp(x) # type: ignore + + +def floor(x: Expr) -> Expr: + """Take floor of input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.floor(x) # type: ignore + + +def log(x: Expr) -> Expr: + """Compute element-wise natural logarithm of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.log(x) # type: ignore + + +def negative(x: Expr) -> Expr: + """Compute element-wise negative of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result + """ + return _ffi_api.negative(x) # type: ignore + + +def round(x: Expr) -> Expr: + """Rounds each element of the input data to nearest integer. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.round(x) # type: ignore + + +def sigmoid(x: Expr) -> Expr: + """Compute element-wise sigmoid of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sigmoid(x) # type: ignore + + +def sign(x: Expr) -> Expr: + """Returns an indication of the sign of a number for each element of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.sign(x) # type: ignore + + +def sin(x: Expr) -> Expr: + """Compute element-wise sin of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sin(x) # type: ignore + + +def sinh(x: Expr) -> Expr: + """Compute element-wise sinh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sinh(x) # type: ignore + + +def square(x: Expr) -> Expr: + """Squares each element of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.square(x) # type: ignore + + +def sqrt(x: Expr) -> Expr: + """Compute element-wise square root of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.sqrt(x) # type: ignore + + +def tan(x: Expr) -> Expr: + """Compute element-wise tan of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.tan(x) # type: ignore + + +def tanh(x: Expr) -> Expr: + """Compute element-wise tanh of the input data. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + + Note + ---- + The input tensor is required to have float dtype + """ + return _ffi_api.tanh(x) # type: ignore + + +@args_converter.auto +def clip(x: Expr, min: Expr, max: Expr) -> Expr: + """Clips tensor values to a specified min and max. + + Parameters + ---------- + x : relax.Expr + The input data + + min : relax.Expr + The minimum value + + max : relax.Expr + The maximum value + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.clip(x, min, max) # type: ignore + + +###################### Check operators ###################### + + +def isfinite(x: Expr) -> Expr: + """Check if input value is finite. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isfinite(x) # type: ignore + + +def isinf(x: Expr) -> Expr: + """Check if input value is infinite. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isinf(x) # type: ignore + + +def isnan(x: Expr) -> Expr: + """Check if input value is Nan. + + Parameters + ---------- + x : relax.Expr + The input data + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.isnan(x) # type: ignore diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 22b85f6f402f0..a5cb574a06f00 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -29,23 +29,60 @@ ############################### Operators ############################### from tvm.relax.op import ( + abs, + acos, + acosh, + asin, + asinh, + atan, + atanh, add, assert_op, astype, builtin, call_builtin_with_ctx, call_tir, + ceil, + clip, + cos, + cosh, + divide, + equal, + ewise_fma, + exp, + floor, + floor_divide, + greater, + greater_equal, image, invoke_closure, + isfinite, + isinf, + isnan, + less, + less_equal, + log, make_closure, memory, multiply, + negative, + not_equal, null_value, print, reshape, + round, shape_of, + sigmoid, + sign, + sin, + sinh, + square, + sqrt, strided_slice, + subtract, take, + tan, + tanh, unique, ) from tvm.relax.struct_info import StructInfo @@ -403,6 +440,13 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "If", "Then", "TupleGetItem", + "abs", + "acos", + "acosh", + "asin", + "asinh", + "atan", + "atanh", "add", "arg", "assert_op", @@ -411,31 +455,61 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "call_packed", "call_tir", "call_builtin_with_ctx", + "ceil", + "clip", + "cos", + "cosh", "const", "dataflow", + "divide", "dtype", "emit", "emit_match_cast", + "equal", + "ewise_fma", + "exp", + "floor", + "floor_divide", "func_attr", "func_name", "func_ret_struct_info", "func_ret_value", "function", + "greater", + "greater_equal", "image", "invoke_closure", + "isfinite", + "isinf", + "isnan", + "less", + "less_equal", + "log", "make_closure", "memory", "multiply", + "negative", + "not_equal", "null_value", "output", "prim_value", "print", "reshape", + "round", "shape", "shape_of", + "sigmoid", + "sign", + "sin", + "sinh", + "square", + "sqrt", "str", "strided_slice", + "subtract", "take", + "tan", + "tanh", "tuple", "unique", ] diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index ba167a45bc68a..f478871e218f1 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -62,7 +62,7 @@ StructInfo ReturnShapeStructInfo(const Call& call, const BlockBuilder& ctx) { StructInfo InferStructInfoCallTIR(const Call& call, const BlockBuilder& ctx) { if (call->sinfo_args.size() != 1) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "sinfo_args should have exact 1 output struct info."); } return call->sinfo_args[0]; diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc index 260f71e7bfb6d..c82c325d9ba73 100644 --- a/src/relax/op/op_common.cc +++ b/src/relax/op/op_common.cc @@ -28,7 +28,7 @@ Array GetInputTensorStructInfo(const Call& call, const BlockBu Op op = Downcast(call->op); int n_input = op->arguments.size(); if (static_cast(call->args.size()) != n_input) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " op should have " << n_input << " arguments"); } Array input_tensor_sinfo; @@ -36,7 +36,7 @@ Array GetInputTensorStructInfo(const Call& call, const BlockBu for (int i = 0; i < n_input; ++i) { const auto* sinfo = GetStructInfoAs(call->args[i]); if (sinfo == nullptr) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " requires the input " << op->arguments[i]->name << " to be Tensor. However, the given one is " << call->args[i]->struct_info_->GetTypeKey()); @@ -70,7 +70,7 @@ Optional> InferBinaryBroadcastShape(const Call& call, const Bloc } else if (analyzer->CanProveEqual(dim0, dim1)) { output_shape.push_back(dim0); } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the first input shape at dim " << x1_ndim - i << " is " << dim0 << " and the second input shape at dim " << x2_ndim - i << " is " << dim1 << ", which are not broadcastable."); @@ -96,17 +96,16 @@ std::vector NormalizeAxes(const Call& call, const BlockBuilder& ctx, int nd for (const Integer& axis : axes) { int _axis = axis->value; if (_axis < -ndim || _axis >= ndim) { - ctx->ReportFatal(Diagnostic::Error(call->span) - << "In " << call->op << ", the input axis " << _axis - << " is out of range. The input tensor has " << ndim - << " dimensions, so axis should be in range [" << -ndim << ", " << ndim - << ")."); + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the input axis " << _axis + << " is out of range. The input tensor has " << ndim + << " dimensions, so axis should be in range [" + << -ndim << ", " << ndim << ")."); } else if (_axis < 0) { _axis = ndim + _axis; } if (appeared_dims_set[_axis]) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", the input axes is required to be non-repetitive. However, there are " "multiple given axes referring to axis " diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index c6d335b2a1bd4..29e02946c6d17 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -104,7 +104,7 @@ inline StructInfo InferStructInfoUnary(const Call& call, const BlockBuilder& ctx TensorStructInfo input_sinfo = GetUnaryInputTensorStructInfo(call, ctx); if (require_float_dtype && !input_sinfo->IsUnknownDtype() && !input_sinfo->dtype.is_float()) { ctx->ReportFatal( - Diagnostic::Error(call->span) + Diagnostic::Error(call) << call->op << " requires the input tensor to have float dtype. However, the given input dtype is " << input_sinfo->dtype); @@ -126,11 +126,11 @@ StructInfo ReturnStructInfoFromArg(const Call& call, const BlockBuilder& ctx) { Op op = Downcast(call->op); int n_input = op->arguments.size(); if (static_cast(call->args.size()) != n_input) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " op should have " << n_input << " arguments"); } if (arg_index >= n_input) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << op << " op has only " << n_input << "arguments, but try to get the arg with index " << arg_index); } @@ -151,8 +151,6 @@ StructInfo InferStructInfoUnaryArith(const Call& call, const BlockBuilder& ctx) call, ctx, [](const TensorStructInfo& input_sinfo) { return input_sinfo->dtype; }); } -/************ Utilities ************/ - /*! * \brief Infer the output datatype for binary arithmetic operators. * \param call The context Call to the operator. @@ -168,7 +166,7 @@ inline DataType InferBinaryArithOpOutDtype(const Call& call, const BlockBuilder& if (x1_sinfo->IsUnknownDtype() || x2_sinfo->IsUnknownDtype()) { return DataType::Void(); } else if (x1_sinfo->dtype != x2_sinfo->dtype) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "Data types " << x1_sinfo->dtype << " and " << x2_sinfo->dtype << " must be equal for binary operators"); } @@ -269,11 +267,10 @@ inline std::pair CheckTensorLayout(const Call tir::Layout _tensor_layout(tensor_layout, DataType::Int(64)); tir::BijectiveLayout tensor2tgt(_tensor_layout, tir::Layout(tgt_layout, DataType::Int(64))); if (!tensor2tgt.defined()) { - ctx->ReportFatal(Diagnostic::Error(call->span) - << call->op << " requires the given " << tensor_name - << " layout to be convertible from " << tgt_layout - << " layout. However, the given layout " << tensor_layout - << " is not convertible."); + ctx->ReportFatal(Diagnostic::Error(call) << call->op << " requires the given " << tensor_name + << " layout to be convertible from " << tgt_layout + << " layout. However, the given layout " + << tensor_layout << " is not convertible."); } return {_tensor_layout, tensor2tgt}; } @@ -291,7 +288,7 @@ inline Optional CheckNdimPerLayoutAndGetShape(const Call& call, const const TensorStructInfo& sinfo, const tir::Layout& layout) { if (!sinfo->IsUnknownNdim() && sinfo->ndim != static_cast(layout.ndim())) { - ctx->ReportFatal(Diagnostic::Error(call->span) + ctx->ReportFatal(Diagnostic::Error(call) << "In " << call->op << ", layout " << layout << " requires the input to be " << layout.ndim() << "-dim tensor. However, the given input has ndim " << sinfo->ndim); diff --git a/src/relax/op/tensor/binary.cc b/src/relax/op/tensor/binary.cc index dd61091f7aaa2..b7a07c5202089 100644 --- a/src/relax/op/tensor/binary.cc +++ b/src/relax/op/tensor/binary.cc @@ -81,7 +81,19 @@ StructInfo InferStructInfoBroadcastCMP(const Call& call, const BlockBuilder& ctx /***************** Arithmetic operators *****************/ RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(add); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(divide); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(floor_divide); RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(multiply); +RELAX_REGISTER_BINARY_BROADCAST_OP_AND_IMPL(subtract); + +/***************** Comparison operators *****************/ + +RELAX_REGISTER_CMP_OP_AND_IMPL(equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(greater); +RELAX_REGISTER_CMP_OP_AND_IMPL(greater_equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(less); +RELAX_REGISTER_CMP_OP_AND_IMPL(less_equal); +RELAX_REGISTER_CMP_OP_AND_IMPL(not_equal); } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/binary.h b/src/relax/op/tensor/binary.h index a7aea576b6858..b565b159bb489 100644 --- a/src/relax/op/tensor/binary.h +++ b/src/relax/op/tensor/binary.h @@ -61,9 +61,38 @@ namespace relax { /*! \brief Addition with numpy-style broadcasting. */ Expr add(Expr x1, Expr x2); +/*! \brief Division with numpy-style broadcasting. */ +Expr divide(Expr x1, Expr x2); + +/*! \brief Floor division with numpy-style broadcasting. */ +Expr floor_divide(Expr x1, Expr x2); + /*! \brief Multiplication with numpy-style broadcasting. */ Expr multiply(Expr x1, Expr x2); +/*! \brief Subtraction with numpy-style broadcasting. */ +Expr subtract(Expr x1, Expr x2); + +/***************** Comparison operators *****************/ + +/*! \brief Broadcasted element-wise test for (lhs == rhs). */ +Expr equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs > rhs). */ +Expr greater(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs >= rhs). */ +Expr greter_equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs < rhs). */ +Expr less(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs <= rhs). */ +Expr less_equal(Expr x1, Expr x2); + +/*! \brief Broadcasted element-wise test for (lhs != rhs). */ +Expr not_equal(Expr x1, Expr x2); + } // namespace relax } // namespace tvm diff --git a/src/relax/op/tensor/ternary.cc b/src/relax/op/tensor/ternary.cc new file mode 100644 index 0000000000000..8820c07afd253 --- /dev/null +++ b/src/relax/op/tensor/ternary.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.cc + * \brief ternary operators. + */ + +#include "ternary.h" + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoEwiseFMA(const Call& call, const BlockBuilder& ctx) { + Array input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo t1 = input_sinfo[0]; + TensorStructInfo t2 = input_sinfo[1]; + TensorStructInfo t3 = input_sinfo[2]; + + int ndim = kUnknownNDim; + if (!t1->IsUnknownNdim()) { + ndim = t1->ndim; + } + if (!t2->IsUnknownNdim()) { + if (ndim == kUnknownNDim) { + ndim = t2->ndim; + } else if (t2->ndim != ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + } + if (!t3->IsUnknownNdim()) { + if (ndim == kUnknownNDim) { + ndim = t3->ndim; + } else if (t3->ndim != ndim) { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same number of dimensions"); + } + } + + DataType output_dtype; + if (t1->IsUnknownDtype() || t2->IsUnknownDtype() || t3->IsUnknownDtype()) { + output_dtype = DataType::Void(); + } else if (t1->dtype != t2->dtype || t2->dtype != t3->dtype) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Data types " << t1->dtype << ", " << t2->dtype << ", and " << t3->dtype + << " must be equal for EwiseFMA"); + } else { + output_dtype = t1->dtype; + } + + auto* s1 = t1->shape.as(); + auto* s2 = t2->shape.as(); + auto* s3 = t3->shape.as(); + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + if (s1 && s2 && s3) { + Array output_shape; + for (int i = 0; i < ndim; ++i) { + PrimExpr dim1 = s1->values[i]; + PrimExpr dim2 = s2->values[i]; + PrimExpr dim3 = s3->values[i]; + if (analyzer->CanProveEqual(dim1, dim2) && analyzer->CanProveEqual(dim2, dim3)) { + output_shape.push_back(dim1); + } else { + ctx->ReportFatal(Diagnostic::Error(call) + << "The 3 arguments of EwiseFMA must have the same shape"); + } + } + return TensorStructInfo(ShapeExpr(output_shape), output_dtype); + } else if (t1->shape.defined() && t1->shape.same_as(t2->shape) && t1->shape.same_as(t3->shape)) { + return TensorStructInfo(t1->shape.value(), output_dtype); + } + + return TensorStructInfo(output_dtype, ndim); +} + +TVM_REGISTER_OP("relax.ewise_fma") + .set_num_inputs(3) + .add_argument("x1", "Tensor", "The left hand operand of the multiplication") + .add_argument("x2", "Tensor", "The right hand operand of the multiplication") + .add_argument("x3", "Tensor", "The operand of the addition") + .set_attr("FInferStructInfo", InferStructInfoEwiseFMA); + +Expr ewise_fma(Expr x1, Expr x2, Expr x3) { + static const Op& op = Op::Get("relax.ewise_fma"); + return Call(op, {x1, x2, x3}, Attrs(), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.ewise_fma").set_body_typed(ewise_fma); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/ternary.h b/src/relax/op/tensor/ternary.h new file mode 100644 index 0000000000000..ba22c56d9efd9 --- /dev/null +++ b/src/relax/op/tensor/ternary.h @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file ternary.h + * \brief The functions to make Relax ternary operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_TERNARY_H_ +#define TVM_RELAX_OP_TENSOR_TERNARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Elementwise fused multiply-add operator + * Returns elementwise result of `x1 * x2 + x3` + * \param x1 The left hand operand of the multiplication + * \param x2 The right hand operand of the multiplication + * \param x3 The operand of the addition + * \return The computed result. + */ +Expr ewise_fma(Expr x1, Expr x2, Expr x3); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_TERNARY_H_ diff --git a/src/relax/op/tensor/unary.cc b/src/relax/op/tensor/unary.cc new file mode 100644 index 0000000000000..f1117c1826c5a --- /dev/null +++ b/src/relax/op/tensor/unary.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.cc + * \brief Relax unary arithmetic operators. + */ + +#include "unary.h" + +#include + +namespace tvm { +namespace relax { + +StructInfo InferStructInfoUnaryCheck(const Call& call, const BlockBuilder& ctx) { + return InferStructInfoUnary( + call, ctx, [](const TensorStructInfo& input_sinfo) { return DataType::Bool(); }); +} + +/***************** Arithmetic operators *****************/ + +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(abs, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(acos, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(acosh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asin, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(asinh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atan, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(atanh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(ceil, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cos, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(cosh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(exp, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(floor, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(log, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(negative, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(round, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sigmoid, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sign, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sin, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sinh, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(square, /*require_float_dtype=*/false); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(sqrt, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tan, /*require_float_dtype=*/true); +RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(tanh, /*require_float_dtype=*/true); + +// relax.clip +TVM_REGISTER_OP("relax.clip") + .set_num_inputs(3) + .add_argument("x", "Tensor", "The input tensor.") + .add_argument("min", "PrimValue", "The lower-bound of the range to be clipped to") + .add_argument("max", "PrimValue", "The upper-bound of the range to be clipped to") + .set_attr("FInferStructInfo", ReturnStructInfoFromArg<0>); + +Expr clip(Expr x, Expr min, Expr max) { + CHECK(min->IsInstance()) + << "The argument `min` of relax.clip is expected to be a PrimValue, but got" + << min->GetTypeKey(); + CHECK(max->IsInstance()) + << "The argument `max` of relax.clip is expected to be a PrimValue, but got" + << max->GetTypeKey(); + static const Op& op = Op::Get("relax.clip"); + return Call(op, {std::move(x), std::move(min), std::move(max)}); +} + +TVM_REGISTER_GLOBAL("relax.op.clip").set_body_typed(clip); + +/***************** Check operators *****************/ + +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isfinite); +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isinf); +RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(isnan); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/unary.h b/src/relax/op/tensor/unary.h new file mode 100644 index 0000000000000..8f6404c5d9ed2 --- /dev/null +++ b/src/relax/op/tensor/unary.h @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. Sex The NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. Sex The License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file unary.h + * \brief The functions to make Relax unary arithmetic operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_UNARY_H_ +#define TVM_RELAX_OP_TENSOR_UNARY_H_ + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief Quick helper macro to + * - expose a make-function interface which construct the call node. + * - register op to the registry. + * \param OpName The name of operator to register. + * \param RequireFloatDtype A boolean indicating if the input is required to have float dtype. + * (Only for unary arith operators since all check operators don't require float dtype.) + */ +#define RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName) \ + RELAX_UNARY_OP_INTERFACE(OpName, #OpName); \ + RELAX_REGISTER_UNARY_OP(#OpName) + +#define RELAX_REGISTER_UNARY_ARITH_OP_AND_IMPL(OpName, RequireFloatDtype) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryArith) + +#define RELAX_REGISTER_UNARY_CHECK_OP_AND_IMPL(OpName) \ + RELAX_REGISTER_UNARY_OP_AND_IMPL(OpName).set_attr( \ + "FInferStructInfo", InferStructInfoUnaryCheck) // require_float_dtype=false for check op + +/***************** Arithmetic operators *****************/ + +/*! + * \brief Compute element-wise absolute value of the input data. + * \param x The input data. + * \return The computed result. + */ +Expr abs(Expr x); + +/*! \brief Compute element-wise arc cos of the input data. */ +Expr acos(Expr x); + +/*! \brief Compute element-wise arc cosh of the input data. */ +Expr acosh(Expr x); + +/*! \brief Compute element-wise arc sin of the input data. */ +Expr asin(Expr x); + +/*! \brief Compute element-wise arc sinh of the input data. */ +Expr asinh(Expr x); + +/*! \brief Compute element-wise arc tan of the input data. */ +Expr atan(Expr x); + +/*! \brief Compute element-wise arc tanh of the input data. */ +Expr atanh(Expr x); + +/*! \brief Take ceil of input data. */ +Expr ceil(Expr x); + +/*! \brief Compute element-wise cos of the input data. */ +Expr cos(Expr x); + +/*! \brief Compute element-wise cosh of the input data. */ +Expr cosh(Expr x); + +/*! \brief Compute element-wise exp of data. */ +Expr exp(Expr x); + +/*! \brief Take floor of input data. */ +Expr floor(Expr x); + +/*! \brief Compute element-wise natural logarithm of data. */ +Expr log(Expr x); + +/*! \brief Compute element-wise negative value of data. */ +Expr negative(Expr x); + +/*! \brief Rounds each element of the input data to nearest integer. */ +Expr round(Expr x); + +/*! \brief Compute element-wise sigmoid of data. */ +Expr sigmoid(Expr x); + +/*! \brief Returns an indication of the sign of a number for each element of the input data. */ +Expr sign(Expr x); + +/*! \brief Compute element-wise sin of data. */ +Expr sin(Expr x); + +/*! \brief Compute element-wise sinh of data. */ +Expr sinh(Expr x); + +/*! \brief Compute element-wise square root of data. */ +Expr sqrt(Expr x); + +/*! \brief Squares each element of the input data. */ +Expr square(Expr x); + +/*! \brief Compute element-wise tan of data. */ +Expr tan(Expr x); + +/*! \brief Compute element-wise tanh of data. */ +Expr tanh(Expr x); + +/*! \brief Clips tensor values to a specified min and max. */ +Expr clip(Expr x, Expr min, Expr max); + +/***************** Check operators *****************/ + +/*! \brief Check if input value is finite. */ +Expr isfinite(Expr x); + +/*! \brief Check if input value is infinite. */ +Expr isinf(Expr x); + +/*! \brief Check if input value is Nan. */ +Expr isnan(Expr x); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_UNARY_H_ diff --git a/tests/python/relax/test_op_binary.py b/tests/python/relax/test_op_binary.py new file mode 100644 index 0000000000000..a4ae8ce31ac7c --- /dev/null +++ b/tests/python/relax/test_op_binary.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + assert relax.op.add(x, y).op == Op.get("relax.add") + assert relax.op.divide(x, y).op == Op.get("relax.divide") + assert relax.op.floor_divide(x, y).op == Op.get("relax.floor_divide") + assert relax.op.multiply(x, y).op == Op.get("relax.multiply") + assert relax.op.subtract(x, y).op == Op.get("relax.subtract") + + assert relax.op.equal(x, y).op == Op.get("relax.equal") + assert relax.op.greater(x, y).op == Op.get("relax.greater") + assert relax.op.greater_equal(x, y).op == Op.get("relax.greater_equal") + assert relax.op.less(x, y).op == Op.get("relax.less") + assert relax.op.less_equal(x, y).op == Op.get("relax.less_equal") + assert relax.op.not_equal(x, y).op == Op.get("relax.not_equal") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +(binary_arith_op,) = tvm.testing.parameters( + (relax.op.add,), + (relax.op.divide,), + (relax.op.floor_divide,), + (relax.op.multiply,), + (relax.op.subtract,), +) + + +def test_binary_arith_infer_struct_info(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((1, 3), "float32")) + x2 = relax.Var("x", R.Tensor((3, 2, 3), "float32")) + x3 = relax.Var("x", R.Tensor((3, 1, 3), "float32")) + x4 = relax.Var("x", R.Tensor("float32", ndim=2)) + x5 = relax.Var("x", R.Tensor()) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((4, 3, 2, 1), "float32")) + y2 = relax.Var("y", R.Tensor("float32", ndim=2)) + y3 = relax.Var("y", R.Tensor("float32", ndim=-1)) + + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((4, 3, 2, 3), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y2), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y1), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x5, y0), relax.TensorStructInfo(dtype="", ndim=-1)) + + +(binary_cmp_op,) = tvm.testing.parameters( + (relax.op.equal,), + (relax.op.greater,), + (relax.op.greater_equal,), + (relax.op.less,), + (relax.op.less_equal,), + (relax.op.not_equal,), +) + + +def test_binary_cmp_infer_struct_info(binary_cmp_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((2, 3), "int32")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y0), relax.TensorStructInfo((2, 3), "bool")) + _check_inference(bb, binary_cmp_op(x, y1), relax.TensorStructInfo((2, 3), "bool")) + + +def test_binary_infer_struct_info_shape_symbolic(binary_arith_op: Callable): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k = tir.Var("k", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((1, n), "float32")) + x2 = relax.Var("x", R.Tensor((k, n, m), "float32")) + x3 = relax.Var("x", R.Tensor((3, 1, n), "float32")) + x4 = relax.Var("x", R.Tensor("float32", ndim=2)) + y0 = relax.Var("y", R.Tensor((m, n), "float32")) + y1 = relax.Var("y", R.Tensor((m, n + 2), "float32")) + y2 = relax.Var("y", R.Tensor((4, k, m, 1), "float32")) + y3 = relax.Var("y", R.Tensor("float32", ndim=2)) + y4 = relax.Var("y", R.Tensor("float32", ndim=-1)) + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, binary_arith_op(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x1, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, binary_arith_op(x1, y2), relax.TensorStructInfo((4, k, m, n), "float32")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x2, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x3, y3), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, binary_arith_op(x4, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x4, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x4, y4), relax.TensorStructInfo(dtype="float32", ndim=-1)) + + +def test_binary_infer_struct_info_shape_var(binary_arith_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=4)) + s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + s4 = relax.Var("s4", relax.ShapeStructInfo()) + x = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + y3 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + y4 = relax.Var("y", relax.TensorStructInfo(s4, "float32")) + + _check_inference(bb, binary_arith_op(x, y0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, binary_arith_op(x, y1), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y2), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, binary_arith_op(x, y3), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, binary_arith_op(x, y4), relax.TensorStructInfo(dtype="float32")) + + +def test_binary_arith_infer_struct_info_more_input_dtype(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + y0 = relax.Var("y", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + y2 = relax.Var("y", R.Tensor((2, 3), "int64")) + + _check_inference(bb, binary_arith_op(x0, y0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, binary_arith_op(x1, y1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, binary_arith_op(x2, y2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_binary_infer_struct_info_shape_unequal_const_int(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 4), "float32")) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x0, y0)) + + +def test_binary_arith_infer_struct_info_dtype_mismatch(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "int32")) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x, y)) + + +def test_binary_wrong_input_number(binary_arith_op: Callable): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + binary_arith_op(x, x, x) + with pytest.raises(TypeError): + binary_arith_op(x) + with pytest.raises(TypeError): + binary_arith_op(x, x, x, x) + + +def test_binary_infer_struct_info_wrong_input_type(binary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x0, y)) + with pytest.raises(TVMError): + bb.normalize(binary_arith_op(x1, y)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_ternary.py b/tests/python/relax/test_op_ternary.py new file mode 100644 index 0000000000000..5ea7a01da7011 --- /dev/null +++ b/tests/python/relax/test_op_ternary.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 3), "float32")) + z = relax.Var("z", R.Tensor((2, 3), "float32")) + assert relax.op.ewise_fma(x, y, z).op == Op.get("relax.ewise_fma") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_ewise_fma_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor((2, 3))) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor("float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z1), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference(bb, relax.op.ewise_fma(x1, y0, z0), relax.TensorStructInfo((2, 3), dtype="")) + + +def test_ewise_fma_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + y0 = relax.Var("y", R.Tensor((m, n), "float32")) + y1 = relax.Var("y", R.Tensor(dtype="float32", ndim=2)) + z0 = relax.Var("z", R.Tensor((m, n), "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((m, n), "float32")) + _check_inference( + bb, relax.op.ewise_fma(x0, y1, z0), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_ewise_fma_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s2 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + y = relax.Var("y", relax.TensorStructInfo(s0, "float32")) + z = relax.Var("z", relax.TensorStructInfo(s0, "float32")) + + _check_inference(bb, relax.op.ewise_fma(x0, y, z), relax.TensorStructInfo(s0, "float32")) + _check_inference( + bb, relax.op.ewise_fma(x1, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + _check_inference( + bb, relax.op.ewise_fma(x2, y, z), relax.TensorStructInfo(dtype="float32", ndim=2) + ) + + +def test_ewise_fma_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + y0 = relax.Var("y", R.Tensor((2, 3), "float64")) + z0 = relax.Var("z", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + y1 = relax.Var("y", R.Tensor((2, 3), "int8")) + z1 = relax.Var("z", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + y2 = relax.Var("y", R.Tensor((2, 3), "int64")) + z2 = relax.Var("z", R.Tensor((2, 3), "int64")) + + _check_inference(bb, relax.op.ewise_fma(x0, y0, z0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, relax.op.ewise_fma(x1, y1, z1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, relax.op.ewise_fma(x2, y2, z2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_ewise_fma_infer_struct_info_dtype_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "int32")) + y1 = relax.Var("y", R.Tensor((2, 3), "float32")) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor((2, 3), "int8")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z1)) + + +def test_ewise_fma_infer_struct_info_ndim_mismatch(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", R.Tensor((2, 3), "float32")) + y1 = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + z0 = relax.Var("z", R.Tensor((2, 3), "float32")) + z1 = relax.Var("z", R.Tensor(dtype="float32", ndim=4)) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z1)) + + +def test_ewise_fma_wrong_input_number(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + relax.op.ewise_fma(x) + with pytest.raises(TypeError): + relax.op.ewise_fma(x, x) + with pytest.raises(TypeError): + relax.op.ewise_fma(x, x, x, x) + + +def test_ewise_fma_infer_struct_info_wrong_input_type(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y0 = relax.Var("y", relax.ShapeStructInfo((2, 3))) + y1 = relax.Var("y", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + z = relax.Var("z", R.Tensor((2, 3), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y0, z)) + with pytest.raises(TVMError): + bb.normalize(relax.op.ewise_fma(x, y1, z)) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_op_unary.py b/tests/python/relax/test_op_unary.py new file mode 100644 index 0000000000000..45336661a1aef --- /dev/null +++ b/tests/python/relax/test_op_unary.py @@ -0,0 +1,203 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Callable +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + assert relax.op.abs(x).op == Op.get("relax.abs") + assert relax.op.acos(x).op == Op.get("relax.acos") + assert relax.op.acosh(x).op == Op.get("relax.acosh") + assert relax.op.asin(x).op == Op.get("relax.asin") + assert relax.op.asinh(x).op == Op.get("relax.asinh") + assert relax.op.atan(x).op == Op.get("relax.atan") + assert relax.op.atanh(x).op == Op.get("relax.atanh") + assert relax.op.ceil(x).op == Op.get("relax.ceil") + assert relax.op.cos(x).op == Op.get("relax.cos") + assert relax.op.cosh(x).op == Op.get("relax.cosh") + assert relax.op.exp(x).op == Op.get("relax.exp") + assert relax.op.floor(x).op == Op.get("relax.floor") + assert relax.op.isfinite(x).op == Op.get("relax.isfinite") + assert relax.op.isinf(x).op == Op.get("relax.isinf") + assert relax.op.isnan(x).op == Op.get("relax.isnan") + assert relax.op.log(x).op == Op.get("relax.log") + assert relax.op.negative(x).op == Op.get("relax.negative") + assert relax.op.round(x).op == Op.get("relax.round") + assert relax.op.sigmoid(x).op == Op.get("relax.sigmoid") + assert relax.op.sin(x).op == Op.get("relax.sin") + assert relax.op.sinh(x).op == Op.get("relax.sinh") + assert relax.op.square(x).op == Op.get("relax.square") + assert relax.op.sqrt(x).op == Op.get("relax.sqrt") + assert relax.op.tan(x).op == Op.get("relax.tan") + assert relax.op.tanh(x).op == Op.get("relax.tanh") + assert relax.op.clip(x, 0, 6).op == Op.get("relax.clip") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +unary_arith_op, require_float_dtype = tvm.testing.parameters( + (relax.op.abs, False), + (relax.op.acos, True), + (relax.op.acosh, True), + (relax.op.asin, True), + (relax.op.asinh, True), + (relax.op.atan, True), + (relax.op.atanh, True), + (relax.op.ceil, False), + (relax.op.cos, True), + (relax.op.cosh, True), + (relax.op.exp, True), + (relax.op.floor, False), + (relax.op.log, True), + (relax.op.negative, False), + (relax.op.round, False), + (relax.op.sigmoid, True), + (relax.op.sign, False), + (relax.op.sin, True), + (relax.op.sinh, True), + (relax.op.square, False), + (relax.op.sqrt, True), + (relax.op.tan, True), + (relax.op.tanh, True), +) + + +def test_unary_arith_infer_struct_info(unary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, unary_arith_op(x3), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, unary_arith_op(x4), relax.TensorStructInfo(dtype="")) + + +def test_unary_arith_infer_struct_info_shape_symbolic(unary_arith_op: Callable): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x0 = relax.Var("x", R.Tensor((m, n), "float32")) + x1 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((4, n), "float32")) + + +def test_unary_arith_infer_struct_info_shape_var(unary_arith_op: Callable): + bb = relax.BlockBuilder() + s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2)) + s1 = relax.Var("s", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo(s0, "float32")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo(s1, "float32")) + + +def test_unary_arith_infer_struct_info_more_input_dtype( + unary_arith_op: Callable, require_float_dtype: bool +): + if require_float_dtype: + return + + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float64")) + x1 = relax.Var("x", R.Tensor((2, 3), "int8")) + x2 = relax.Var("x", R.Tensor((2, 3), "int64")) + + _check_inference(bb, unary_arith_op(x0), relax.TensorStructInfo((2, 3), "float64")) + _check_inference(bb, unary_arith_op(x1), relax.TensorStructInfo((2, 3), "int8")) + _check_inference(bb, unary_arith_op(x2), relax.TensorStructInfo((2, 3), "int64")) + + +def test_unary_arith_infer_struct_info_invalid_input_dtype( + unary_arith_op: Callable, require_float_dtype: bool +): + if not require_float_dtype: + return + + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "int8")) + x1 = relax.Var("x", R.Tensor((2, 3), "int64")) + + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x0)) + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x1)) + + +def test_unary_arith_wrong_input_number(unary_arith_op: Callable): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + + with pytest.raises(TypeError): + unary_arith_op(x, x) + with pytest.raises(TypeError): + unary_arith_op(x, x, x) + + +def test_unary_arith_infer_struct_info_wrong_input_type(unary_arith_op: Callable): + bb = relax.BlockBuilder() + x0 = relax.Var("x", relax.ShapeStructInfo((2, 3))) + x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32"))) + + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x0)) + with pytest.raises(TVMError): + bb.normalize(unary_arith_op(x1)) + + +def test_clip_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((2, 3), "float32")) + x1 = relax.Var("x", R.Tensor("float32", ndim=3)) + x2 = relax.Var("x", R.Tensor("float32", ndim=-1)) + x3 = relax.Var("x", R.Tensor((2, 3))) + x4 = relax.Var("x", R.Tensor()) + + _check_inference(bb, relax.op.clip(x0, 0, 6), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.clip(x1, 0, 6), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.clip(x2, 0, 6), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.clip(x3, 0, 6), relax.TensorStructInfo((2, 3), dtype="")) + _check_inference(bb, relax.op.clip(x4, 0, 6), relax.TensorStructInfo(dtype="")) + + # Symbolic + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + x5 = relax.Var("x", R.Tensor((m, n), "float32")) + x6 = relax.Var("x", R.Tensor((4, n), "float32")) + + _check_inference(bb, relax.op.clip(x5, 0, 6), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.clip(x6, 0, 6), relax.TensorStructInfo((4, n), "float32")) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index c9a16fbcacb71..f6d2e4c20e48e 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -31,10 +31,9 @@ def _check( parsed: Union[relax.Function, IRModule], expect: Optional[Union[relax.Function, IRModule]] = None, ): - # TODO(relax-team): enable roundtrip testing when printer is ready - # test = parsed.script(show_meta=True) - # roundtrip_mod = tvm.script.parse(test) - # tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) if expect: tvm.ir.assert_structural_equal(parsed, expect) diff --git a/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py new file mode 100644 index 0000000000000..ffb8576b27dc1 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_arith_cmp.py @@ -0,0 +1,179 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union, Callable + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +(unary_arith_op,) = tvm.testing.parameters( + (relax.op.abs,), + (relax.op.acos,), + (relax.op.acosh,), + (relax.op.asin,), + (relax.op.asinh,), + (relax.op.atan,), + (relax.op.atanh,), + (relax.op.ceil,), + (relax.op.cos,), + (relax.op.cosh,), + (relax.op.exp,), + (relax.op.floor,), + (relax.op.log,), + (relax.op.negative,), + (relax.op.round,), + (relax.op.sigmoid,), + (relax.op.sign,), + (relax.op.sin,), + (relax.op.sinh,), + (relax.op.square,), + (relax.op.sqrt,), + (relax.op.tan,), + (relax.op.tanh,), +) + + +def test_unary_arith(unary_arith_op: Callable): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = unary_arith_op(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(unary_arith_op(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(unary_check_op,) = tvm.testing.parameters( + (relax.op.isfinite,), + (relax.op.isinf,), + (relax.op.isnan,), +) + + +def test_unary_check(unary_check_op: Callable): + @R.function + def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), "bool") = unary_check_op(x) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x]): + gv = bb.emit(unary_check_op(x)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(binary_arith_op,) = tvm.testing.parameters( + (relax.op.add,), + (relax.op.divide,), + (relax.op.floor_divide,), + (relax.op.multiply,), + (relax.op.subtract,), +) + + +def test_binary_arith(binary_arith_op: Callable): + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + ) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = binary_arith_op(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 1), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(binary_arith_op(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +(binary_cmp_op,) = tvm.testing.parameters( + (relax.op.equal,), + (relax.op.greater,), + (relax.op.greater_equal,), + (relax.op.less,), + (relax.op.less_equal,), + (relax.op.not_equal,), +) + + +def test_binary_cmp(binary_cmp_op: Callable): + @R.function + def foo( + x: R.Tensor((2, 3), "float32"), y: R.Tensor((2, 1), "float32") + ) -> R.Tensor((2, 3), "bool"): + gv: R.Tensor((2, 3), "bool") = binary_cmp_op(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((2, 1), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(binary_cmp_op(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_relax_ewise_fma(): + @R.function + def foo( + x: R.Tensor((2, 3, 4), dtype="float32"), + y: R.Tensor((2, 3, 4), dtype="float32"), + z: R.Tensor((2, 3, 4), dtype="float32"), + ) -> R.Tensor((2, 3, 4), dtype="float32"): + gv: R.Tensor((2, 3, 4), dtype="float32") = R.ewise_fma(x, y, z) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + y = relax.Var("y", R.Tensor((2, 3, 4), "float32")) + z = relax.Var("z", R.Tensor((2, 3, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y, z]): + gv = bb.emit(relax.op.ewise_fma(x, y, z)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()