diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index bae308f435ce..fa095f9db4fa 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -1099,8 +1099,7 @@ class PRelu(OnnxOpConverter): def _impl_v1(cls, bb, inputs, attr, params): x = inputs[0] slope = inputs[1] - # TODO(tvm-team): Should add a new op for this. - return x * slope + relax.op.nn.relu(x) * (relax.const(1.0) - slope) + return relax.op.nn.prelu(x, slope) class ThresholdedRelu(OnnxOpConverter): diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc index b79690d3a9bd..bc443e60d2ce 100644 --- a/src/relax/op/nn/nn.cc +++ b/src/relax/op/nn/nn.cc @@ -93,13 +93,54 @@ Expr prelu(Expr data, Expr alpha, int axis = 1) { TVM_FFI_REGISTER_GLOBAL("relax.op.nn.prelu").set_body_typed(prelu); +StructInfo InferStructInfoPRelu(const Call& call, const BlockBuilder& ctx) { + TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx); + if (data_sinfo->IsUnknownNdim()) { + return data_sinfo; + } + if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) { + ctx->ReportFatal(Diagnostic::Error(call) << "Prelu requires the input tensor to have float " + "dtype. However, the given input dtype is " + << data_sinfo->dtype); + } + const auto* attrs = call->attrs.as(); + NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis); + + return data_sinfo; +} + +InferLayoutOutput InferLayoutPRelu(const Call& call, + const Map>& desired_layouts, + const VarLayoutMap& var_layout_map) { + ICHECK(NoDesiredLayout(call, desired_layouts)); + const auto* attrs = call->attrs.as(); + ICHECK(attrs) << "Invalid Call"; + + LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]); + + // TODO(Siva): We could handle if the axis is not the sub indexed one. + if (layout->layout.ndim() != layout->layout.ndim_primal()) { + const auto* tensor_sinfo = GetStructInfoAs(call->args[0]); + ICHECK(tensor_sinfo != nullptr) << "Invalid Call"; + ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now"; + int ndim = tensor_sinfo->ndim; + layout = LayoutDecision(InitialLayout(ndim)); + } + + ObjectPtr new_attrs = make_object(*attrs); + new_attrs->axis = FindAxis(layout->layout, attrs->axis); + + LayoutDecision alpha_layout = GetLayoutDecision(var_layout_map, call->args[1]); + return InferLayoutOutput({layout, alpha_layout}, {layout}, Attrs(new_attrs)); +} + TVM_REGISTER_OP("relax.nn.prelu") .set_num_inputs(2) .add_argument("data", "Tensor", "The input tensor.") .add_argument("alpha", "Tensor", "The channel-wise learnable slope.") .set_attrs_type() - .set_attr("FInferStructInfo", - InferStructInfoUnaryArith) + .set_attr("FInferStructInfo", InferStructInfoPRelu) + .set_attr("FRelaxInferLayout", InferLayoutPRelu) .set_attr("FPurity", Bool(true)); /* relax.nn.softmax */ diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 6c3334f64d12..757fddbacd27 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -948,7 +948,7 @@ def test_mish(): def test_prelu(): - verify_binary("PRelu", [3, 32, 32], [3, 32, 32], [3, 32, 32]) + verify_binary("PRelu", [3, 32, 32], [1], [3, 32, 32]) def test_thresholded_relu(): diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py b/tests/python/relax/test_transform_legalize_ops_nn.py index 52986feef377..0c3f471ec9a7 100644 --- a/tests/python/relax/test_transform_legalize_ops_nn.py +++ b/tests/python/relax/test_transform_legalize_ops_nn.py @@ -1159,6 +1159,89 @@ def leaky_relu(var_rxplaceholder: T.handle, var_compute: T.handle): tvm.ir.assert_structural_equal(mod, Expected) +def test_prelu(): + # fmt: off + @tvm.script.ir_module + class PRelu: + @R.function + def main(x: R.Tensor((2, 3), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor((2, 3), "float32"): + gv: R.Tensor((2, 3), "float32") = R.nn.prelu(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor((2, 3), dtype="float32"): + gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((2, 3), dtype="float32")) + return gv + + @T.prim_func(private=True) + def prelu(x: T.Buffer((T.int64(2), T.int64(3)), "float32"), y: T.Buffer((T.int64(1),), "float32"), compute: T.Buffer((T.int64(2), T.int64(3)), "float32")): + T.func_attr({"tir.noalias": True}) + # with T.block("root"): + slope_broadcasted = T.alloc_buffer((T.int64(3),)) + for c in range(T.int64(3)): + with T.block("slope_broadcasted"): + v_c = T.axis.spatial(T.int64(3), c) + T.reads(y[T.int64(0)]) + T.writes(slope_broadcasted[v_c]) + slope_broadcasted[v_c] = y[T.int64(0)] + for i0, i1 in T.grid(T.int64(2), T.int64(3)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1]) + # fmt: on + + mod = LegalizeOps()(PRelu) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_prelu_symbolic(): + # fmt: off + @tvm.script.ir_module + class PRelu: + @R.function + def main(x: R.Tensor(("m", 7), "float32"), y: R.Tensor((1,), "float32")) -> R.Tensor(("m", 7), "float32"): + m = T.int64() + gv: R.Tensor((m, 7), "float32") = R.nn.prelu(x, y) + return gv + + @tvm.script.ir_module + class Expected: + @R.function + def main(x: R.Tensor(("m", 7), dtype="float32"), y: R.Tensor((1,), dtype="float32")) -> R.Tensor(("m", 7), dtype="float32"): + m = T.int64() + gv = R.call_tir(Expected.prelu, (x, y), out_sinfo=R.Tensor((m, 7), dtype="float32")) + return gv + + @T.prim_func(private=True) + def prelu(var_x: T.handle, y: T.Buffer((T.int64(1),), "float32"), var_compute: T.handle): + T.func_attr({"tir.noalias": True}) + m = T.int64() + x = T.match_buffer(var_x, (m, T.int64(7))) + compute = T.match_buffer(var_compute, (m, T.int64(7))) + # with T.block("root"): + slope_broadcasted = T.alloc_buffer((T.int64(7),)) + for c in range(T.int64(7)): + with T.block("slope_broadcasted"): + v_c = T.axis.spatial(T.int64(7), c) + T.reads(y[T.int64(0)]) + T.writes(slope_broadcasted[v_c]) + slope_broadcasted[v_c] = y[T.int64(0)] + for i0, i1 in T.grid(m, T.int64(7)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + T.reads(x[v_i0, v_i1], slope_broadcasted[v_i1]) + T.writes(compute[v_i0, v_i1]) + compute[v_i0, v_i1] = T.Select(T.float32(0.0) < x[v_i0, v_i1], x[v_i0, v_i1], x[v_i0, v_i1] * slope_broadcasted[v_i1]) + # fmt: on + + mod = LegalizeOps()(PRelu) + tvm.ir.assert_structural_equal(mod, Expected) + + def test_gelu(): # fmt: off @tvm.script.ir_module diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py b/tests/python/relax/test_tvmscript_parser_op_nn.py index bba08d4d842e..4c458a7ead2a 100644 --- a/tests/python/relax/test_tvmscript_parser_op_nn.py +++ b/tests/python/relax/test_tvmscript_parser_op_nn.py @@ -364,5 +364,24 @@ def foo( _check(foo, bb.get()["foo"]) +def test_prelu(): + @R.function + def foo( + x: R.Tensor((2, 4, 4, 5), "float32"), + alpha: R.Tensor((1,), "float32"), + ) -> R.Tensor((2, 4, 4, 5), "float32"): + gv: R.Tensor((2, 4, 4, 5), "float32") = R.nn.prelu(x, alpha) + return gv + + x = relax.Var("x", R.Tensor((2, 4, 4, 5), "float32")) + alpha = relax.Var("alpha", R.Tensor((1,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, alpha]): + gv = bb.emit(relax.op.nn.prelu(x, alpha)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + if __name__ == "__main__": tvm.testing.main()