From 4d8bee05b9e2fdcc2df041eb5c686d92452e20d6 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 3 Apr 2025 15:24:58 +0000 Subject: [PATCH 01/17] add softplus op into exported program and fx graph frontend --- include/tvm/relax/attrs/nn.h | 11 ++++++++ .../torch/base_fx_graph_translator.py | 6 +++++ .../torch/exported_program_translator.py | 1 + .../tvm/relax/frontend/torch/fx_translator.py | 9 +++++++ python/tvm/relax/op/nn/__init__.py | 1 + python/tvm/relax/op/nn/nn.py | 26 +++++++++++++++++++ python/tvm/relax/transform/legalize_ops/nn.py | 5 ++++ python/tvm/topi/nn/elemwise.py | 26 +++++++++++++++++++ src/relax/op/nn/nn.cc | 21 +++++++++++++++ src/relax/op/nn/nn.h | 3 +++ 10 files changed, 109 insertions(+) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 8f63012e095a..ce87a32e1988 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -455,6 +455,17 @@ struct LeakyReluAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in softplus operators */ +struct SoftplusAttrs : public tvm::AttrsNode { + double beta; + double threshold; + + TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") { + TVM_ATTR_FIELD(beta).describe("It controls the curvature; higher values make the transition sharper."); + TVM_ATTR_FIELD(threshold).describe("It defines when to approximate the function linearly for numerical stability."); + } +}; + /*! \brief Attributes used in batch_norm operator */ struct BatchNormAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 890f925079e0..17d965493751 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -307,6 +307,12 @@ def _softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _softplus(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + beta = node.args[1] if len(node.args) > 1 else node.kwargs.get("beta", 1.0) + threshold = node.args[2] if len(node.args) > 2 else node.kwargs.get("threshold", 20.0) + return self.block_builder.emit(relax.op.nn.softplus(x, beta, threshold)) + def _softshrink(self, node: fx.Node) -> relax.Var: """ Applies the Softshrink activation function in Relax. diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2e7c682aa34b..a896cf5e7148 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -287,6 +287,7 @@ def create_convert_map( "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, + "softplus.default": self._softplus, "softshrink.default": self._softshrink, "sqrt.default": self._unary_op(relax.op.sqrt), "square.default": self._unary_op(relax.op.square), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3ddf919c2ed1..5f68a871cc45 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -71,6 +71,13 @@ def _leakyrelu_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] alpha = module.negative_slope return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + + def _softplus_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + beta = module.beta + threshold = module.threshold + return self.block_builder.emit(relax.op.nn.softplus(x, beta, threshold)) def _log2(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] @@ -653,6 +660,7 @@ def create_convert_map( nn.SELU: self._unary_op(relax.op.nn.selu), nn.SiLU: self._unary_op(relax.op.nn.silu), nn.Softmax: self._softmax_module, + nn.Softplus: self._softplus_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, @@ -717,6 +725,7 @@ def create_convert_map( "sin": self._unary_op(relax.op.sin), "sinh": self._unary_op(relax.op.sinh), "softmax": self._softmax, + "softplus":self._softplus, "sqrt": self._unary_op(relax.op.sqrt), "square": self._unary_op(relax.op.square), "tan": self._unary_op(relax.op.tan), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index e45982a0fed2..20676fd4cb98 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -48,4 +48,5 @@ selu, silu, softmax, + softplus ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 5232eea047cf..57c1d0f0ff6d 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1378,6 +1378,32 @@ def softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.softmax(data, axis) # type: ignore +def softplus(data: Expr, beta: float = 1.0, threshold: float = 20.0) -> Expr: + """Softplus activation function. + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) + + Parameters + ---------- + data : relax.Expr + The input data. + + beta : float, optional + Controls the smoothness of the transition. Default is 1.0. + + threshold : float, optional + The value beyond which the function is approximated as linear + to avoid numerical instability. Default is 20.0. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.softplus(data, beta, threshold) + + def log_softmax(data: Expr, axis: int = -1) -> Expr: r"""Computes log softmax. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index fd3db841e646..7d7dd47b36f7 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -531,6 +531,11 @@ def te_silu(x: te.Tensor): return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu") +@register_legalize("relax.nn.softplus") +def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.softplus, call.args[0], call.attrs.beta) + + @register_legalize("relax.nn.softmax") def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index a80047d900f3..c6043c5ed174 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -65,6 +65,32 @@ def _compute(*indices): return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def softplus(x, beta=1.0): + """Compute Softplus activation for input x. + + Parameters + ---------- + x : tvm.te.Tensor + Input tensor. + + beta : float, optional + The scaling factor β in the Softplus formula (default is 1.0). + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + + def _compute(*indices): + value = x(*indices) + b = tvm.tir.const(beta, value.dtype) + return (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) + + return te.compute(x.shape, _compute) + + @tvm.te.tag_scope(tag=tag.BROADCAST) def prelu(x, slope, axis=1): """PReLU. diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 4a5a9a701612..7f545af1301d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -60,6 +60,27 @@ TVM_REGISTER_OP("relax.nn.leakyrelu") InferStructInfoUnaryArith) .set_attr("FPurity", Bool(true)); +/* relax.nn.softplus */ +TVM_REGISTER_NODE_TYPE(SoftplusAttrs); + +Expr softplus(Expr data, double beta, double threshold) { + auto attrs = make_object(); + attrs->beta = beta; + attrs->threshold = threshold; + static const Op& op = Op::Get("relax.nn.softplus"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); + +TVM_REGISTER_OP("relax.nn.softplus") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", + InferStructInfoUnaryArith) + .set_attr("FPurity", Bool(true)); + /* relax.nn.softmax */ TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index d6db36aba50a..3f5571af8207 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -66,6 +66,9 @@ Expr silu(Expr data); /*! \brief Softmax function. */ Expr softmax(Expr data, int axis); +/*! \brief Softplus function. */ +Expr softplus(Expr data, double beta, double threshold); + /*! \brief LogSoftmax function. */ Expr log_softmax(Expr data, int axis); From a4415a676e3d8ba8a09615e81cba06d18536b009 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 3 Apr 2025 15:33:35 +0000 Subject: [PATCH 02/17] fixing trailing whitespace issue --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- python/tvm/topi/nn/elemwise.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 5f68a871cc45..8b5bb1963fdf 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -71,7 +71,7 @@ def _leakyrelu_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] alpha = module.negative_slope return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) - + def _softplus_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index c6043c5ed174..2cc6a61f75a1 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -88,7 +88,7 @@ def _compute(*indices): b = tvm.tir.const(beta, value.dtype) return (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) - return te.compute(x.shape, _compute) + return te.compute(x.shape, _compute) @tvm.te.tag_scope(tag=tag.BROADCAST) From c7b294dec07279744480d777336519f766475dea Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Thu, 3 Apr 2025 15:55:25 +0000 Subject: [PATCH 03/17] fixing lint issues --- python/tvm/relax/frontend/torch/fx_translator.py | 2 +- python/tvm/relax/op/nn/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 8b5bb1963fdf..0f860a965833 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -725,7 +725,7 @@ def create_convert_map( "sin": self._unary_op(relax.op.sin), "sinh": self._unary_op(relax.op.sinh), "softmax": self._softmax, - "softplus":self._softplus, + "softplus": self._softplus, "sqrt": self._unary_op(relax.op.sqrt), "square": self._unary_op(relax.op.square), "tan": self._unary_op(relax.op.tan), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 20676fd4cb98..9d56058e4649 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -48,5 +48,5 @@ selu, silu, softmax, - softplus + softplus, ) From c2dbfe84c0d602fde378740d0352da934b8e2730 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 03:25:46 +0000 Subject: [PATCH 04/17] fix lint issue on docs --- python/tvm/relax/op/nn/nn.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 57c1d0f0ff6d..17197b010ef6 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1379,10 +1379,9 @@ def softmax(data: Expr, axis: int = -1) -> Expr: def softplus(data: Expr, beta: float = 1.0, threshold: float = 20.0) -> Expr: - """Softplus activation function. + r"""Softplus activation function. - .. math:: - \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) + .. math:: \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) Parameters ---------- From 9b6300b111980386a173c2f1ab94f04e74010ec0 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 03:44:38 +0000 Subject: [PATCH 05/17] modify description to avoid cpplints issue --- include/tvm/relax/attrs/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index ce87a32e1988..0daec53bcbe8 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -461,8 +461,8 @@ struct SoftplusAttrs : public tvm::AttrsNode { double threshold; TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") { - TVM_ATTR_FIELD(beta).describe("It controls the curvature; higher values make the transition sharper."); - TVM_ATTR_FIELD(threshold).describe("It defines when to approximate the function linearly for numerical stability."); + TVM_ATTR_FIELD(beta).describe("It controls the curvature"); + TVM_ATTR_FIELD(threshold).describe("It specifies when to use a linear approximation"); } }; From 172ee81a321ee27789f64f01f99d9ad16e92e593 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 07:58:35 +0000 Subject: [PATCH 06/17] update softplus function with threshold attr --- include/tvm/relax/attrs/nn.h | 6 ++++-- python/tvm/relax/transform/legalize_ops/nn.py | 7 ++++++- python/tvm/topi/nn/elemwise.py | 15 ++++++++++++--- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 0daec53bcbe8..3088429534a2 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -461,8 +461,10 @@ struct SoftplusAttrs : public tvm::AttrsNode { double threshold; TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") { - TVM_ATTR_FIELD(beta).describe("It controls the curvature"); - TVM_ATTR_FIELD(threshold).describe("It specifies when to use a linear approximation"); + TVM_ATTR_FIELD(beta).describe( + "Scaling factor controlling the sharpness of the Softplus transition."); + TVM_ATTR_FIELD(threshold).describe( + "Value determining when to use linear approximation for numerical stability."); } }; diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 7d7dd47b36f7..64ee3f8a11d6 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -533,7 +533,12 @@ def te_silu(x: te.Tensor): @register_legalize("relax.nn.softplus") def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr: - return bb.call_te(topi.nn.softplus, call.args[0], call.attrs.beta) + return bb.call_te( + topi.nn.softplus, + call.args[0], + call.attrs.beta, + call.attrs.threshold, + ) @register_legalize("relax.nn.softmax") diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 2cc6a61f75a1..d9b23bd0e5ff 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -66,8 +66,8 @@ def _compute(*indices): @tvm.te.tag_scope(tag=tag.ELEMWISE) -def softplus(x, beta=1.0): - """Compute Softplus activation for input x. +def softplus(x, beta=1.0, threshold=20.0): + """Compute Softplus activation for input x with numerical stability. Parameters ---------- @@ -77,6 +77,9 @@ def softplus(x, beta=1.0): beta : float, optional The scaling factor β in the Softplus formula (default is 1.0). + threshold : float, optional + The threshold value for numerical stability (default is 20.0). + Returns ------- y : tvm.te.Tensor @@ -86,7 +89,13 @@ def softplus(x, beta=1.0): def _compute(*indices): value = x(*indices) b = tvm.tir.const(beta, value.dtype) - return (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) + t = tvm.tir.const(threshold, value.dtype) + + return tvm.tir.Select( + b * value > t, + value, + (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) + ) return te.compute(x.shape, _compute) From acaa231b48443da09b4fc4e4dc9dfb5e13debf78 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 07:59:37 +0000 Subject: [PATCH 07/17] remove trailing spaces in softplus func --- python/tvm/relax/transform/legalize_ops/nn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 64ee3f8a11d6..98fa3ef1ea5e 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -534,9 +534,9 @@ def te_silu(x: te.Tensor): @register_legalize("relax.nn.softplus") def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te( - topi.nn.softplus, - call.args[0], - call.attrs.beta, + topi.nn.softplus, + call.args[0], + call.attrs.beta, call.attrs.threshold, ) From bdc2fa62c68361aa6c6774e2e2ca434ccfbeeac3 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 08:13:02 +0000 Subject: [PATCH 08/17] fix lint issues in legalize func --- python/tvm/topi/nn/elemwise.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index d9b23bd0e5ff..2b174f8f1ed5 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -92,9 +92,7 @@ def _compute(*indices): t = tvm.tir.const(threshold, value.dtype) return tvm.tir.Select( - b * value > t, - value, - (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) + b * value > t, value, (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) ) return te.compute(x.shape, _compute) From f9f1d84ea16fd25c38376b76b3aa9e2237a4edde Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 08:15:59 +0000 Subject: [PATCH 09/17] fixing cpp lints issue --- include/tvm/relax/attrs/nn.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 3088429534a2..0adcf29772cd 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -462,9 +462,9 @@ struct SoftplusAttrs : public tvm::AttrsNode { TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") { TVM_ATTR_FIELD(beta).describe( - "Scaling factor controlling the sharpness of the Softplus transition."); + "Scaling factor controlling the sharpness of the Softplus transition."); TVM_ATTR_FIELD(threshold).describe( - "Value determining when to use linear approximation for numerical stability."); + "Value determining when to use linear approximation for numerical stability."); } }; From c2dae61f03843f1b31b4bd55a05223741690e465 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 09:21:20 +0000 Subject: [PATCH 10/17] test script for both exported and fx graph --- .../test_frontend_from_exported_program.py | 38 ++++++++++++++++++ tests/python/relax/test_frontend_from_fx.py | 40 +++++++++++++++++++ 2 files changed, 78 insertions(+) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2175f9aa391c..474067d18ed0 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -409,6 +409,9 @@ def main( # leakyrelu test_leakyrelu() + # softplus + test_softplus() + # log2 class Log2(Module): def forward(self, x): @@ -655,6 +658,41 @@ def main( verify_model(Hardtanh2(), example_args, {}, expected1) +def test_softplus(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class Softplus0(torch.nn.Module): + def __init__(self): + super().__init__() + self.softplus = torch.nn.Softplus(1.0, 20.0) + + def forward(self, x): + return self.softplus(x) + + class Softplus1(Module): + def forward(self, input): + return torch.nn.functional.softplus(input, 1.0, 20.0) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus(x, beta=1.0, threshold=20.0) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softplus0(), example_args, {}, expected) + verify_model(Softplus1(), example_args, {}, expected) + + def test_leakyrelu(): import torch from torch.nn import Module diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index d913baf13a0d..e1bcccf6d1a6 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -749,6 +749,43 @@ def main( verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) +@tvm.testing.requires_gpu +def test_softplus(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class Softplus0(torch.nn.Module): + def __init__(self): + super().__init__() + self.softplus = torch.nn.Softplus(1.0, 20.0) + + def forward(self, x): + return self.softplus(x) + + class Softplus1(Module): + def forward(self, input): + return torch.nn.functional.softplus(input, 1.0, 20.0) + + @tvm.script.ir_module + class expected: + @R.function + def main( + inp_0: R.Tensor((10, 10), dtype="float32") + ) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus(inp_0, beta=1.0, threshold=20.0) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(Softplus0(), input_info, {}, expected) + verify_model(Softplus1(), input_info, {}, expected) + + @tvm.testing.requires_gpu def test_leakyrelu(): import torch @@ -2226,6 +2263,9 @@ def main( # leaky_relu test_leakyrelu() + # softplus + test_softplus() + # log2 class Log2(Module): def forward(self, x): From eb634b7c97f71971fbfc4d83f2301345242c2860 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 09:22:18 +0000 Subject: [PATCH 11/17] trim trailing spaces iin test script --- tests/python/relax/test_frontend_from_exported_program.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 474067d18ed0..159075782134 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -687,7 +687,7 @@ def main( gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv - + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) verify_model(Softplus0(), example_args, {}, expected) verify_model(Softplus1(), example_args, {}, expected) From 0b242e21b52fa038b2af9c89e8957c9df0f43e75 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Fri, 4 Apr 2025 10:33:42 +0000 Subject: [PATCH 12/17] fix lint issues in test script --- tests/python/relax/test_frontend_from_exported_program.py | 6 ++++-- tests/python/relax/test_frontend_from_fx.py | 8 ++++---- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 159075782134..cd1201a30e00 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -681,9 +681,11 @@ class expected: @R.function def main( x: R.Tensor((1, 3, 10, 10), dtype="float32") - ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): with R.dataflow(): - lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus(x, beta=1.0, threshold=20.0) + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus( + x, beta=1.0, threshold=20.0 + ) gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) R.output(gv) return gv diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index e1bcccf6d1a6..0c11672e3e71 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -771,12 +771,12 @@ def forward(self, input): @tvm.script.ir_module class expected: @R.function - def main( - inp_0: R.Tensor((10, 10), dtype="float32") - ) -> R.Tensor((10, 10), dtype="float32"): + def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): # block 0 with R.dataflow(): - lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus(inp_0, beta=1.0, threshold=20.0) + lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus( + inp_0, beta=1.0, threshold=20.0 + ) gv: R.Tensor((10, 10), dtype="float32") = lv R.output(gv) return gv From caf6310ad377099fe6e32e8ce0caab4ff1c6a95b Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 7 Apr 2025 06:36:41 +0000 Subject: [PATCH 13/17] unit test script is added in test frontend op files --- tests/python/relax/test_frontend_nn_op.py | 1 + tests/python/relax/test_op_nn.py | 6 ++++++ 2 files changed, 7 insertions(+) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 6e63b0e4c069..7dad152b686c 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -391,6 +391,7 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor): tanh_out = op.tanh(x) exp_out = op.exp(x) negative_out = op.negative(x) + softplus_out = op.softplus(x, beta=1.0, threshold=20.0) softmax_out = op.softmax(x, axis=2) rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index ec4551872fc8..bf70f9bf58d9 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -27,6 +27,7 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3), "float32")) assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu") assert relax.op.nn.leakyrelu(x).op == Op.get("relax.nn.leakyrelu") + assert relax.op.nn.softplus(x).op == Op.get("relax.nn.softplus") assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") @@ -75,6 +76,9 @@ def test_linear_unit_infer_struct_info(): _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 3), "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 4), dtype="")) + _check_inference(bb, relax.op.nn.softplus(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3, 4), dtype="")) + def test_linear_unit_infer_struct_info_shape_symbolic(): @@ -87,6 +91,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic(): _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, n), "float32")) def test_linear_unit_infer_struct_info_shape_var(): @@ -99,6 +104,7 @@ def test_linear_unit_infer_struct_info_shape_var(): _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, "float32")) def test_linear_unit_infer_struct_info_more_input_dtype(): From 019d34580335c85581ff0b6335edaff8aa55ced6 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 7 Apr 2025 07:04:43 +0000 Subject: [PATCH 14/17] fixing lint issues in test_op_nn file --- tests/python/relax/test_op_nn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index bf70f9bf58d9..2401153c61de 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -80,7 +80,6 @@ def test_linear_unit_infer_struct_info(): _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3, 4), dtype="")) - def test_linear_unit_infer_struct_info_shape_symbolic(): bb = relax.BlockBuilder() m = tir.Var("m", "int64") From 6f18839faacf2ce8c526171b87d98199c2638b95 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 7 Apr 2025 08:58:55 +0000 Subject: [PATCH 15/17] fixing attribute error in test script --- tests/python/relax/test_frontend_nn_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 7dad152b686c..51f92a44e508 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -414,6 +414,7 @@ def test( tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x) exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x) negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x) + softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus(x, beta=1.0, threshold=20.0) softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2) rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 From f384a2df1957224a6e997e06d083b7048ca5a537 Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 7 Apr 2025 09:42:54 +0000 Subject: [PATCH 16/17] fixing lint issues in test script functions --- tests/python/relax/test_frontend_nn_op.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 51f92a44e508..ed81aa49ed34 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -414,7 +414,9 @@ def test( tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x) exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x) negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x) - softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus(x, beta=1.0, threshold=20.0) + softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus( + x, beta=1.0, threshold=20.0 + ) softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2) rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 From 3612a18592a7d53f405577f8cd20d9ffe5e2011e Mon Sep 17 00:00:00 2001 From: deivanayakisankaralingam Date: Mon, 7 Apr 2025 11:36:03 +0000 Subject: [PATCH 17/17] adding softplus wrapper function in op file --- python/tvm/relax/frontend/nn/op.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 23045f7c4ebf..e81ff7c5ad2c 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1046,6 +1046,32 @@ def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor: return wrap_nested(_op.nn.softmax(x._expr, axis), name) +def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str = "softplus"): + r"""Softplus activation function. + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) + + Parameters + ---------- + data : relax.Expr + The input data. + + beta : float, optional + Controls the smoothness of the transition. Default is 1.0. + + threshold : float, optional + The value beyond which the function is approximated as linear + to avoid numerical instability. Default is 20.0. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name) + + def tanh(x: Tensor, name: str = "tanh") -> Tensor: r"""Applies the hyperbolic tangent function.