diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 8f63012e095a..0adcf29772cd 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -455,6 +455,19 @@ struct LeakyReluAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in softplus operators */ +struct SoftplusAttrs : public tvm::AttrsNode { + double beta; + double threshold; + + TVM_DECLARE_ATTRS(SoftplusAttrs, "relax.attrs.SoftplusAttrs") { + TVM_ATTR_FIELD(beta).describe( + "Scaling factor controlling the sharpness of the Softplus transition."); + TVM_ATTR_FIELD(threshold).describe( + "Value determining when to use linear approximation for numerical stability."); + } +}; + /*! \brief Attributes used in batch_norm operator */ struct BatchNormAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index 23045f7c4ebf..e81ff7c5ad2c 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1046,6 +1046,32 @@ def softmax(x: Tensor, axis: int = -1, name: str = "softmax") -> Tensor: return wrap_nested(_op.nn.softmax(x._expr, axis), name) +def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str = "softplus"): + r"""Softplus activation function. + + .. math:: + \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) + + Parameters + ---------- + data : relax.Expr + The input data. + + beta : float, optional + Controls the smoothness of the transition. Default is 1.0. + + threshold : float, optional + The value beyond which the function is approximated as linear + to avoid numerical instability. Default is 20.0. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name) + + def tanh(x: Tensor, name: str = "tanh") -> Tensor: r"""Applies the hyperbolic tangent function. diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 890f925079e0..17d965493751 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -307,6 +307,12 @@ def _softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.softmax(x, dim)) + def _softplus(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + beta = node.args[1] if len(node.args) > 1 else node.kwargs.get("beta", 1.0) + threshold = node.args[2] if len(node.args) > 2 else node.kwargs.get("threshold", 20.0) + return self.block_builder.emit(relax.op.nn.softplus(x, beta, threshold)) + def _softshrink(self, node: fx.Node) -> relax.Var: """ Applies the Softshrink activation function in Relax. diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 2e7c682aa34b..a896cf5e7148 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -287,6 +287,7 @@ def create_convert_map( "sin.default": self._unary_op(relax.op.sin), "sinh.default": self._unary_op(relax.op.sinh), "softmax.int": self._softmax, + "softplus.default": self._softplus, "softshrink.default": self._softshrink, "sqrt.default": self._unary_op(relax.op.sqrt), "square.default": self._unary_op(relax.op.square), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 3ddf919c2ed1..0f860a965833 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -72,6 +72,13 @@ def _leakyrelu_module(self, node: fx.Node) -> relax.Var: alpha = module.negative_slope return self.block_builder.emit(relax.op.nn.leakyrelu(x, alpha)) + def _softplus_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + beta = module.beta + threshold = module.threshold + return self.block_builder.emit(relax.op.nn.softplus(x, beta, threshold)) + def _log2(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] return self.block_builder.emit( @@ -653,6 +660,7 @@ def create_convert_map( nn.SELU: self._unary_op(relax.op.nn.selu), nn.SiLU: self._unary_op(relax.op.nn.silu), nn.Softmax: self._softmax_module, + nn.Softplus: self._softplus_module, nn.Tanh: self._unary_op(relax.op.tanh), # neural network nn.AdaptiveAvgPool2d: self._adaptive_avg_pool2d_module, @@ -717,6 +725,7 @@ def create_convert_map( "sin": self._unary_op(relax.op.sin), "sinh": self._unary_op(relax.op.sinh), "softmax": self._softmax, + "softplus": self._softplus, "sqrt": self._unary_op(relax.op.sqrt), "square": self._unary_op(relax.op.square), "tan": self._unary_op(relax.op.tan), diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index e45982a0fed2..9d56058e4649 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -48,4 +48,5 @@ selu, silu, softmax, + softplus, ) diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 5232eea047cf..17197b010ef6 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1378,6 +1378,31 @@ def softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.softmax(data, axis) # type: ignore +def softplus(data: Expr, beta: float = 1.0, threshold: float = 20.0) -> Expr: + r"""Softplus activation function. + + .. math:: \text{Softplus}(x) = \frac{1}{\beta} \log(1 + e^{\beta x}) + + Parameters + ---------- + data : relax.Expr + The input data. + + beta : float, optional + Controls the smoothness of the transition. Default is 1.0. + + threshold : float, optional + The value beyond which the function is approximated as linear + to avoid numerical instability. Default is 20.0. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.softplus(data, beta, threshold) + + def log_softmax(data: Expr, axis: int = -1) -> Expr: r"""Computes log softmax. diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index fd3db841e646..98fa3ef1ea5e 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -531,6 +531,16 @@ def te_silu(x: te.Tensor): return bb.call_te(te_silu, call.args[0], primfunc_name_hint="silu") +@register_legalize("relax.nn.softplus") +def _nn_softplus(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te( + topi.nn.softplus, + call.args[0], + call.attrs.beta, + call.attrs.threshold, + ) + + @register_legalize("relax.nn.softmax") def _nn_softmax(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.nn.softmax, call.args[0], call.attrs.axis) diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index a80047d900f3..2b174f8f1ed5 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -65,6 +65,39 @@ def _compute(*indices): return te.compute(x.shape, _compute) +@tvm.te.tag_scope(tag=tag.ELEMWISE) +def softplus(x, beta=1.0, threshold=20.0): + """Compute Softplus activation for input x with numerical stability. + + Parameters + ---------- + x : tvm.te.Tensor + Input tensor. + + beta : float, optional + The scaling factor β in the Softplus formula (default is 1.0). + + threshold : float, optional + The threshold value for numerical stability (default is 20.0). + + Returns + ------- + y : tvm.te.Tensor + The result. + """ + + def _compute(*indices): + value = x(*indices) + b = tvm.tir.const(beta, value.dtype) + t = tvm.tir.const(threshold, value.dtype) + + return tvm.tir.Select( + b * value > t, value, (1 / b) * tvm.tir.log(1 + tvm.tir.exp(b * value)) + ) + + return te.compute(x.shape, _compute) + + @tvm.te.tag_scope(tag=tag.BROADCAST) def prelu(x, slope, axis=1): """PReLU. diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 4a5a9a701612..7f545af1301d 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -60,6 +60,27 @@ TVM_REGISTER_OP("relax.nn.leakyrelu") InferStructInfoUnaryArith) .set_attr("FPurity", Bool(true)); +/* relax.nn.softplus */ +TVM_REGISTER_NODE_TYPE(SoftplusAttrs); + +Expr softplus(Expr data, double beta, double threshold) { + auto attrs = make_object(); + attrs->beta = beta; + attrs->threshold = threshold; + static const Op& op = Op::Get("relax.nn.softplus"); + return Call(op, {data}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.softplus").set_body_typed(softplus); + +TVM_REGISTER_OP("relax.nn.softplus") + .set_num_inputs(1) + .add_argument("data", "Tensor", "The input tensor.") + .set_attrs_type() + .set_attr("FInferStructInfo", + InferStructInfoUnaryArith) + .set_attr("FPurity", Bool(true)); + /* relax.nn.softmax */ TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index d6db36aba50a..3f5571af8207 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -66,6 +66,9 @@ Expr silu(Expr data); /*! \brief Softmax function. */ Expr softmax(Expr data, int axis); +/*! \brief Softplus function. */ +Expr softplus(Expr data, double beta, double threshold); + /*! \brief LogSoftmax function. */ Expr log_softmax(Expr data, int axis); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 2175f9aa391c..cd1201a30e00 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -409,6 +409,9 @@ def main( # leakyrelu test_leakyrelu() + # softplus + test_softplus() + # log2 class Log2(Module): def forward(self, x): @@ -655,6 +658,43 @@ def main( verify_model(Hardtanh2(), example_args, {}, expected1) +def test_softplus(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class Softplus0(torch.nn.Module): + def __init__(self): + super().__init__() + self.softplus = torch.nn.Softplus(1.0, 20.0) + + def forward(self, x): + return self.softplus(x) + + class Softplus1(Module): + def forward(self, input): + return torch.nn.functional.softplus(input, 1.0, 20.0) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.softplus( + x, beta=1.0, threshold=20.0 + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Softplus0(), example_args, {}, expected) + verify_model(Softplus1(), example_args, {}, expected) + + def test_leakyrelu(): import torch from torch.nn import Module diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index d913baf13a0d..0c11672e3e71 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -749,6 +749,43 @@ def main( verify_model(Einsum2(), [([5], "float32"), ([4], "float32")], {}, Expected2) +@tvm.testing.requires_gpu +def test_softplus(): + import torch + from torch.nn import Module + + torch.set_grad_enabled(False) + + class Softplus0(torch.nn.Module): + def __init__(self): + super().__init__() + self.softplus = torch.nn.Softplus(1.0, 20.0) + + def forward(self, x): + return self.softplus(x) + + class Softplus1(Module): + def forward(self, input): + return torch.nn.functional.softplus(input, 1.0, 20.0) + + @tvm.script.ir_module + class expected: + @R.function + def main(inp_0: R.Tensor((10, 10), dtype="float32")) -> R.Tensor((10, 10), dtype="float32"): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 10), dtype="float32") = R.nn.softplus( + inp_0, beta=1.0, threshold=20.0 + ) + gv: R.Tensor((10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([10, 10], "float32")] + verify_model(Softplus0(), input_info, {}, expected) + verify_model(Softplus1(), input_info, {}, expected) + + @tvm.testing.requires_gpu def test_leakyrelu(): import torch @@ -2226,6 +2263,9 @@ def main( # leaky_relu test_leakyrelu() + # softplus + test_softplus() + # log2 class Log2(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index 6e63b0e4c069..ed81aa49ed34 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -391,6 +391,7 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor): tanh_out = op.tanh(x) exp_out = op.exp(x) negative_out = op.negative(x) + softplus_out = op.softplus(x, beta=1.0, threshold=20.0) softmax_out = op.softmax(x, axis=2) rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) @@ -413,6 +414,9 @@ def test( tanh: R.Tensor((2, 3, 4, 5), dtype="float32") = R.tanh(x) exp: R.Tensor((2, 3, 4, 5), dtype="float32") = R.exp(x) negative: R.Tensor((2, 3, 4, 5), dtype="float32") = R.negative(x) + softplus: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softplus( + x, beta=1.0, threshold=20.0 + ) softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2) rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index ec4551872fc8..2401153c61de 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -27,6 +27,7 @@ def test_op_correctness(): x = relax.Var("x", R.Tensor((2, 3), "float32")) assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu") assert relax.op.nn.leakyrelu(x).op == Op.get("relax.nn.leakyrelu") + assert relax.op.nn.softplus(x).op == Op.get("relax.nn.softplus") assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu") assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu") assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax") @@ -75,6 +76,8 @@ def test_linear_unit_infer_struct_info(): _check_inference(bb, relax.op.nn.gelu(x4), relax.TensorStructInfo(dtype="")) _check_inference(bb, relax.op.nn.leakyrelu(x0), relax.TensorStructInfo((2, 3), "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x5), relax.TensorStructInfo((3, 4), dtype="")) + _check_inference(bb, relax.op.nn.softplus(x0), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.nn.softplus(x5), relax.TensorStructInfo((3, 4), dtype="")) def test_linear_unit_infer_struct_info_shape_symbolic(): @@ -87,6 +90,7 @@ def test_linear_unit_infer_struct_info_shape_symbolic(): _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), "float32")) _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo((4, n), "float32")) + _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo((4, n), "float32")) def test_linear_unit_infer_struct_info_shape_var(): @@ -99,6 +103,7 @@ def test_linear_unit_infer_struct_info_shape_var(): _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, "float32")) _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, "float32")) _check_inference(bb, relax.op.nn.leakyrelu(x1), relax.TensorStructInfo(s1, "float32")) + _check_inference(bb, relax.op.nn.softplus(x1), relax.TensorStructInfo(s1, "float32")) def test_linear_unit_infer_struct_info_more_input_dtype():