diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h index 0adcf29772cd..e2ce2be6a882 100644 --- a/include/tvm/relax/attrs/nn.h +++ b/include/tvm/relax/attrs/nn.h @@ -468,6 +468,15 @@ struct SoftplusAttrs : public tvm::AttrsNode { } }; +/*! \brief Attributes used in PReLU operator */ +struct PReluAttrs : public tvm::AttrsNode { + int axis; + + TVM_DECLARE_ATTRS(PReluAttrs, "relax.attrs.PReluAttrs") { + TVM_ATTR_FIELD(axis).describe("The axis along which the alpha values are applied."); + } +}; + /*! \brief Attributes used in batch_norm operator */ struct BatchNormAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relax/frontend/nn/op.py b/python/tvm/relax/frontend/nn/op.py index e81ff7c5ad2c..86be98cba786 100644 --- a/python/tvm/relax/frontend/nn/op.py +++ b/python/tvm/relax/frontend/nn/op.py @@ -1072,6 +1072,34 @@ def softplus(x: Tensor, beta: float = 1.0, threshold: float = 20.0, name: str = return wrap_nested(_op.nn.softplus(x._expr, beta=beta, threshold=threshold), name) +def prelu(x: Tensor, alpha: Tensor, name: str = "prelu"): + r"""Parametric ReLU activation function. + + .. math:: + \text{PReLU}(x) = \begin{cases} + x & \text{if } x \geq 0 \\ + \alpha \cdot x & \text{if } x < 0 + \end{cases} + + Parameters + ---------- + x : Tensor + The input data. + + alpha : Tensor + Slope coefficient for the negative part of the input. + + name : str, optional + Optional name for the operation. Default is "prelu". + + Returns + ------- + result : Tensor + The computed result. + """ + return wrap_nested(_op.nn.prelu(x._expr, alpha._expr), name) + + def tanh(x: Tensor, name: str = "tanh") -> Tensor: r"""Applies the hyperbolic tangent function. diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 4c9480b58748..21cbd14d7ea9 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -307,6 +307,12 @@ def _log_softmax(self, node: fx.Node) -> relax.Var: dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1) return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _prelu(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + alpha = self.env[node.args[1]] + axis = 0 if len(x.struct_info.shape) == 1 else 1 + return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) + def _round(self, node: fx.Node) -> relax.Expr: if node.kwargs.get("decimals", 0) != 0: raise ValueError("specifying decimals for round is not supported yet") diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c82a5e2b1100..2c9e255f2946 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -299,6 +299,7 @@ def create_convert_map( "log1p.default": self._log1p, "log_softmax.int": self._log_softmax, "neg.default": self._unary_op(relax.op.negative), + "prelu.default": self._prelu, "reciprocal.default": self._reciprocal, "relu.default": self._unary_op(relax.op.nn.relu), "relu_.default": self._unary_op(relax.op.nn.relu), diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 297529e8bf29..a26185ce3caa 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -103,6 +103,14 @@ def _log_softmax_module(self, node: fx.Node) -> relax.Var: assert dim is not None return self.block_builder.emit(relax.op.nn.log_softmax(x, dim)) + def _prelu_module(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + module = self.named_modules[node.target] + alpha_tensor = module.weight.numpy() + alpha = relax.const(alpha_tensor, dtype="float32") + axis = 0 if len(x.struct_info.shape) == 1 else 1 # Extract Channel size + return self.block_builder.emit(relax.op.nn.prelu(x, alpha, axis)) + def _softmax_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -595,6 +603,7 @@ def create_convert_map( nn.Identity: lambda node: self.env[node.args[0]], nn.LeakyReLU: self._leakyrelu_module, nn.LogSoftmax: self._log_softmax_module, + nn.PReLU: self._prelu_module, nn.ReLU: self._unary_op(relax.op.nn.relu), nn.ReLU6: lambda node: self.block_builder.emit( relax.op.clip(self.env[node.args[0]], 0, 6) @@ -657,6 +666,7 @@ def create_convert_map( "logical_not": self._unary_op(relax.op.logical_not), "log_softmax": self._log_softmax, "neg": self._unary_op(relax.op.negative), + "prelu": self._prelu, "reciprocal": self._reciprocal, "relu": self._unary_op(relax.op.nn.relu), "round": self._round, diff --git a/python/tvm/relax/op/nn/__init__.py b/python/tvm/relax/op/nn/__init__.py index 9d56058e4649..14b5dcfc0681 100644 --- a/python/tvm/relax/op/nn/__init__.py +++ b/python/tvm/relax/op/nn/__init__.py @@ -43,6 +43,7 @@ max_pool3d, nll_loss, pad, + prelu, relu, rms_norm, selu, diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py index 17197b010ef6..9d9eb3ef4820 100644 --- a/python/tvm/relax/op/nn/nn.py +++ b/python/tvm/relax/op/nn/nn.py @@ -1431,6 +1431,32 @@ def log_softmax(data: Expr, axis: int = -1) -> Expr: return _ffi_api.log_softmax(data, axis) # type: ignore +def prelu(data: Expr, alpha: Expr, axis: int = 1) -> Expr: + r"""Parametric Rectified Linear Unit (PReLU). + + .. math:: + PReLU(x) = x \text{ if } x > 0 \text{ else } \alpha * x + + Parameters + ---------- + data : relax.Expr + The input tensor. + + alpha : relax.Expr + The learnable slope tensor, applied channel-wise. + + axis : int + The axis along which the `alpha` values are applied + Default is 1 (assuming NCHW format). + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.prelu(data, alpha, axis) + + def batch_norm( data: Expr, gamma: Expr, diff --git a/python/tvm/relax/transform/legalize_ops/nn.py b/python/tvm/relax/transform/legalize_ops/nn.py index 98fa3ef1ea5e..5d942e5f645d 100644 --- a/python/tvm/relax/transform/legalize_ops/nn.py +++ b/python/tvm/relax/transform/legalize_ops/nn.py @@ -469,6 +469,11 @@ def _nn_leakyrelu(bb: BlockBuilder, call: Call) -> Expr: return bb.call_te(topi.nn.leaky_relu, call.args[0], call.attrs.alpha) +@register_legalize("relax.nn.prelu") +def _nn_prelu(bb: BlockBuilder, call: Call) -> Expr: + return bb.call_te(topi.nn.prelu, call.args[0], call.args[1], call.attrs.axis) + + @register_legalize("relax.nn.gelu") def _nn_gelu(bb: BlockBuilder, call: Call) -> Expr: def te_gelu(x: te.Tensor): diff --git a/python/tvm/topi/nn/elemwise.py b/python/tvm/topi/nn/elemwise.py index 2b174f8f1ed5..59cc3598e9f2 100644 --- a/python/tvm/topi/nn/elemwise.py +++ b/python/tvm/topi/nn/elemwise.py @@ -129,6 +129,9 @@ def prelu(x, slope, axis=1): assert len(slope.shape) == 1 assert axis < len(x.shape) + slope = te.compute( + (get_const_int(x.shape[axis]),), lambda c: slope[0], name="slope_broadcasted" + ) assert get_const_int(slope.shape[0]) == get_const_int(x.shape[axis]) def _compute_channelwise(*indices): diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index 7f545af1301d..8c0b86fe5f8e 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -81,6 +81,27 @@ TVM_REGISTER_OP("relax.nn.softplus") InferStructInfoUnaryArith) .set_attr("FPurity", Bool(true)); +/* relax.nn.prelu */ +TVM_REGISTER_NODE_TYPE(PReluAttrs); + +Expr prelu(Expr data, Expr alpha, int axis = 1) { + auto attrs = make_object(); + attrs->axis = axis; + static const Op& op = Op::Get("relax.nn.prelu"); + return Call(op, {data, alpha}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); + +TVM_REGISTER_OP("relax.nn.prelu") + .set_num_inputs(2) + .add_argument("data", "Tensor", "The input tensor.") + .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") + .set_attrs_type() + .set_attr("FInferStructInfo", + InferStructInfoUnaryArith) + .set_attr("FPurity", Bool(true)); + /* relax.nn.softmax */ TVM_REGISTER_NODE_TYPE(SoftmaxAttrs); diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h index 3f5571af8207..a9c3dd0a5767 100644 --- a/src/relax/op/nn/nn.h +++ b/src/relax/op/nn/nn.h @@ -57,6 +57,9 @@ Expr gelu(Expr data); /*! \brief Gaussian Error Linear Units function approximated by tanh. */ Expr gelu_tanh(Expr data); +/*! \brief Parametric Rectified Linear Unit function.*/ +Expr prelu(Expr data, Expr alpha, int axis); + /*! \brief Scaled Exponential Linear Unit function. */ Expr selu(Expr data); diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 26d3d3f7bde2..e4694efa5617 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -642,6 +642,42 @@ def main( verify_model(LogSoftmax2(), example_args, {}, expected1) +def test_prelu(): + class Prelu1(Module): + def __init__(self, num_parameters=1, alpha=0.25): + super().__init__() + self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) + + def forward(self, x): + return self.prelu(x) + + class Prelu2(torch.nn.Module): + def __init__(self): + super(Prelu2, self).__init__() + self.alpha = torch.nn.Parameter(torch.tensor([0.25])) + + def forward(self, x): + return torch.nn.functional.prelu(x, self.alpha) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu( + x, R.const([0.25], dtype="float32"), axis=1 + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(Prelu1(), example_args, {}, expected) + verify_model(Prelu2(), example_args, {}, expected) + + def test_softmax(): class Softmax(Module): def __init__(self): diff --git a/tests/python/relax/test_frontend_from_fx.py b/tests/python/relax/test_frontend_from_fx.py index a962de8a3237..caecce4979b5 100644 --- a/tests/python/relax/test_frontend_from_fx.py +++ b/tests/python/relax/test_frontend_from_fx.py @@ -823,6 +823,42 @@ def main( verify_model(LeakyReLU1(), input_info, {}, expected) +def test_prelu(): + class Prelu1(Module): + def __init__(self, num_parameters=1, alpha=0.25): + super().__init__() + self.prelu = torch.nn.PReLU(num_parameters=num_parameters, init=alpha) + + def forward(self, x): + return self.prelu(x) + + class Prelu2(torch.nn.Module): + def __init__(self): + super(Prelu2, self).__init__() + self.alpha = torch.nn.Parameter(torch.tensor([0.25])) + + def forward(self, x): + return torch.nn.functional.prelu(x, self.alpha) + + @tvm.script.ir_module + class expected: + @R.function + def main( + x: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tensor((1, 3, 10, 10), dtype="float32"): + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.prelu( + x, R.const([0.25], dtype="float32"), axis=1 + ) + gv: R.Tensor((1, 3, 10, 10), dtype="float32") = lv + R.output(gv) + return gv + + input_info = [([1, 3, 10, 10], "float32")] + verify_model(Prelu1(), input_info, {}, expected) + verify_model(Prelu2(), input_info, {}, expected) + + def test_maxpool2d(): input_info = [([1, 3, 10, 10], "float32")] @@ -2266,6 +2302,9 @@ def main( # softplus test_softplus() + # prelu + test_prelu() + # log2 class Log2(Module): def forward(self, x): diff --git a/tests/python/relax/test_frontend_nn_op.py b/tests/python/relax/test_frontend_nn_op.py index ed81aa49ed34..cc09998443b2 100644 --- a/tests/python/relax/test_frontend_nn_op.py +++ b/tests/python/relax/test_frontend_nn_op.py @@ -393,6 +393,7 @@ def test(self, x: Tensor, weight: Tensor, bias: Tensor): negative_out = op.negative(x) softplus_out = op.softplus(x, beta=1.0, threshold=20.0) softmax_out = op.softmax(x, axis=2) + prelu_out = op.prelu(x, alpha=bias) rms_norm_out = op.rms_norm(x, weight, axes=[-2, -1]) rms_norm_with_bias_out = op.rms_norm(x, weight, axes=[-2, -1]) group_norm_out = op.group_norm(x, num_groups=1, weight=bias, bias=bias) @@ -418,6 +419,7 @@ def test( x, beta=1.0, threshold=20.0 ) softmax: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.softmax(x, axis=2) + prelu: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.prelu(x, bias) rms_norm: R.Tensor((2, 3, 4, 5), dtype="float32") = R.nn.rms_norm( x, weight, axes=[-2, -1], epsilon=1.0000000000000001e-05 ) diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py index 2401153c61de..1c03d8fe4649 100644 --- a/tests/python/relax/test_op_nn.py +++ b/tests/python/relax/test_op_nn.py @@ -35,6 +35,10 @@ def test_op_correctness(): assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout") assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad") + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) + alpha = relax.Var("alpha", R.Tensor((3,), "float32")) + assert relax.op.nn.prelu(x, alpha, axis=1).op == Op.get("relax.nn.prelu") + x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32")) gamma = relax.Var("gamma", R.Tensor((3,), "float32")) beta = relax.Var("beta", R.Tensor((3,), "float32"))