diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4e7c0bf324d6..57a6577eaf4a 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -284,6 +284,20 @@ def _one_hot(self, node: fx.Node) -> relax.Var: return self.block_builder.emit(relax.op.one_hot(x, on_value, off_value, num_classes, axis)) + def _hamming_window(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + + window_size = args[0] + periodic = args[1] if len(args) > 1 else True + alpha = args[2] if len(args) > 2 else 0.54 + beta = args[3] if len(args) > 3 else 0.46 + dtype = node.kwargs.get("dtype", "float") + dtype = self._convert_data_type(dtype) + + return self.block_builder.emit( + relax.op.hamming_window(window_size, periodic, alpha, beta, dtype) + ) + def _zeros(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) size = relax.ShapeExpr(args[0] if isinstance(args[0], (list, tuple)) else (args[0],)) @@ -528,6 +542,10 @@ def create_convert_map( "fill_.Scalar": self._inplace_fill, "full.default": self._full, "full_like.default": self._full_like, + "hamming_window.periodic": self._hamming_window, + "hamming_window.periodic_alpha": self._hamming_window, + "hamming_window.periodic_alpha_beta": self._hamming_window, + "hamming_window.default": self._hamming_window, "index_select.default": self._index_select, "lift_fresh_copy.default": self._to_copy, "linspace.default": self._linspace, diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index c4a5d2fd2329..9388831fce31 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -73,6 +73,7 @@ arange, full, full_like, + hamming_window, ones, ones_like, eye, diff --git a/python/tvm/relax/op/create.py b/python/tvm/relax/op/create.py index c61d9521a41d..c8526e95d817 100644 --- a/python/tvm/relax/op/create.py +++ b/python/tvm/relax/op/create.py @@ -283,6 +283,41 @@ def is_int(expr): return _ffi_api.arange(start, end, step, dtype) # type: ignore +def hamming_window(window_size, periodic, alpha, beta, dtype): + """Hamming window function. + + Parameters + ---------- + window_size : PrimExpr + The size of returned window. + + periodic : PrimExpr + If True, returns a window to be used as periodic function. + If False, return a symmetric window. + + alpha : PrimExpr + The co-efficient alpha. + + beta : PrimExpr + The co-efficient beta. + + Returns + ------- + ret : relax.Expr + The result tensor. + """ + if not isinstance(window_size, Expr): + window_size = PrimValue(window_size) + if not isinstance(periodic, Expr): + periodic = PrimValue(periodic) + if not isinstance(alpha, Expr): + alpha = PrimValue(alpha) + if not isinstance(beta, Expr): + beta = PrimValue(beta) + + return _ffi_api.hamming_window(window_size, periodic, alpha, beta, dtype) + + def tril(x: Expr, k: Union[int, PrimExpr, Expr] = 0) -> Expr: """Return the lower triangular part of a matrix or a batch of matrices. diff --git a/python/tvm/relax/transform/legalize_ops/create.py b/python/tvm/relax/transform/legalize_ops/create.py index 8bf85e34dee8..7598cba076a6 100644 --- a/python/tvm/relax/transform/legalize_ops/create.py +++ b/python/tvm/relax/transform/legalize_ops/create.py @@ -114,3 +114,14 @@ def is_const_scalar(x: PrimValue): return const(np.arange(start.value, end.value, step.value, dtype=dtype), dtype=dtype) else: return bb.call_te(topi.arange, start, end, step, dtype) + + +@register_legalize("relax.hamming_window") +def _hamming_window(bb: BlockBuilder, call: Call) -> Expr: + assert len(call.args) == 4 + dtype = call.attrs.dtype + window_size = call.args[0].value + periodic = call.args[1].value + alpha = call.args[2].value + beta = call.args[3].value + return bb.call_te(topi.hamming_window, window_size, periodic, alpha, beta, dtype) diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 92f84ce05cc2..1e48e9ea1ad7 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -99,6 +99,7 @@ grad, greater, greater_equal, + hamming_window, hint_on_device, index_put, image, @@ -786,6 +787,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "grad", "greater", "greater_equal", + "hamming_window", "hexagon", "hint_on_device", "index_put", diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index 951944e618ab..bcb3ff95faf0 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -18,6 +18,9 @@ """Injective transformation operators""" from __future__ import absolute_import as _abs +from math import pi +import numpy as np + import tvm from tvm import te, topi @@ -1106,3 +1109,45 @@ def index_tensor(data, indices): z = topi.index_tensor(x, [row, col]) # shape (2, 3) """ return topi.adv_index(data, indices) + + +def hamming_window(window_size, periodic, alpha, beta, dtype): + """Hamming window function. + + Parameters + ---------- + window_size: tvm.Expr + The size of returned window. + + periodic: tvm.Expr + If True, returns a window to be used as periodic function. + If False, return a symmetric window. + + alpha: tvm.Expr + The co-efficient alpha. + + beta: tvm.Expr + The co-efficient beta. + + Returns + ------- + ret : tvm.te.Tensor + The result tensor. + """ + if window_size == 1: + return topi.const_vector(np.array([1], dtype=dtype)) + + periodic = topi.cast(periodic, "bool") + + if periodic: + window_size += 1 + + index = topi.arange(0, window_size, dtype=dtype) + angular_freq = 2 * pi * index / (window_size - 1) + cos_values = topi.cos(angular_freq) + window = topi.cast(alpha - beta * cos_values, dtype=dtype) + + if periodic: + return topi.strided_slice(window, [0], [window.shape[0] - 1]) + + return window diff --git a/src/relax/op/tensor/create.cc b/src/relax/op/tensor/create.cc index b2355b1af7f0..37fd84e13ed1 100644 --- a/src/relax/op/tensor/create.cc +++ b/src/relax/op/tensor/create.cc @@ -29,6 +29,8 @@ #include #include +#include "tvm/relax/expr.h" + namespace tvm { namespace relax { @@ -363,6 +365,57 @@ TVM_REGISTER_OP("relax.arange") .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) .set_attr("FPurity", Bool(true)); +/* relax.hamming_window */ +Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, + DataType dtype) { + ObjectPtr attrs = make_object(); + attrs->dtype = dtype; + static const Op& op = Op::Get("relax.hamming_window"); + return Call(op, {std::move(window_size), std::move(periodic), std::move(alpha), std::move(beta)}, + Attrs(attrs), {}); +} + +TVM_FFI_REGISTER_GLOBAL("relax.op.hamming_window").set_body_typed(hamming_window); + +StructInfo InferStructInfoHammingWindow(const Call& call, const BlockBuilder& ctx) { + DataType dtype = call->attrs.as()->dtype; + if (dtype.is_int() || dtype.is_uint() || dtype.is_uint()) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Hamming Window expects the datatype to be float but got " << dtype); + } + auto get_prim_value = [&ctx](const Expr& expr, std::string key) { + if (!expr->IsInstance()) { + ctx->ReportFatal(Diagnostic::Error(expr) + << "Hamming_window expects the `" << key << "` to be a PrimValue, but got " + << expr->GetTypeKey()); + } + return expr.as()->value; + }; + PrimExpr window_size = get_prim_value(call->args[0], "window_size"); + + arith::Analyzer analyzer; + if (analyzer.CanProveLess(window_size, 1)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Hamming_window expects the window_size must be greater than zero but got " + << window_size); + } + window_size = analyzer.Simplify(window_size); + return TensorStructInfo(ShapeExpr({window_size}), dtype); +} + +TVM_REGISTER_OP("relax.hamming_window") + .set_attrs_type() + .set_num_inputs(4) + .add_argument("window_size", "PrimValue", "The size of the window") + .add_argument("periodic", "PrimValue", + "If True, returns a window to be used as periodic function. If False, return a " + "symmetric window") + .add_argument("alpha", "PrimValue", "The coefficient alpha") + .add_argument("beta", "PrimValue", "The coefficient beta") + .set_attr("FInferStructInfo", InferStructInfoHammingWindow) + .set_attr("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kFollow) + .set_attr("FPurity", Bool(true)); + /* relax.tril & relax.triu */ TVM_REGISTER_NODE_TYPE(TriluAttrs); diff --git a/src/relax/op/tensor/create.h b/src/relax/op/tensor/create.h index 0bf15bbd57e7..f252eebf824f 100644 --- a/src/relax/op/tensor/create.h +++ b/src/relax/op/tensor/create.h @@ -28,6 +28,7 @@ #include #include "../op_common.h" +#include "tvm/relax/expr.h" namespace tvm { namespace relax { @@ -118,6 +119,19 @@ Expr eye_like(Expr x, PrimValue k, Optional dtype); /*! \brief Construct a tensor with evenly spaced elements. */ Expr arange(PrimValue start, PrimValue stop, PrimValue step, DataType dtype); +/*! + * \brief Hamming window function. + * \param window_size The size of the returned window. + * \param periodic If True, returns a window to be used as periodic function. + * If False, return a symmetric window. + * \param alpha The co-efficient alpha. + * \param beta The co-efficient beta. + * \param dtype The data type of the created tensor. + * \return The result tensor. + */ +Expr hamming_window(PrimValue window_size, PrimValue periodic, PrimValue alpha, PrimValue beta, + DataType dtype); + /*! \brief Return the lower triangular part of a matrix or a batch of matrices. */ Expr tril(Expr x, Expr k); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index e6f75372d1b0..dd04833e07b8 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4349,6 +4349,33 @@ def main( verify_model(Arange(), example_args, {}, Expected) +def test_hamming_window(): + class HammingWindow(Module): + def forward(self, input): + return torch.hamming_window(20, True, dtype=torch.float32) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + input: R.Tensor((10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((20,), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((20,), dtype="float32") = R.hamming_window( + R.prim_value(20), + R.prim_value(1), + R.prim_value(T.float32(0.54000000000000004)), + R.prim_value(T.float32(0.46000000000000002)), + dtype="float32", + ) + gv: R.Tuple(R.Tensor((20,), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(10, 10, dtype=torch.float32),) + verify_model(HammingWindow(), example_args, {}, Expected) + + def test_contiguous(): class Contiguous(Module): def forward(self, input):