From 400742419ab69d979b6776ff122d74e2737bfb08 Mon Sep 17 00:00:00 2001 From: Starrysea996 <2462405885@qq.com> Date: Wed, 20 Aug 2025 21:16:44 +0800 Subject: [PATCH 1/6] support inplace and input parameter for silu api --- .../same_operands_result.cc | 1 + .../same_operands_result.h | 1 + paddle/phi/ops/yaml/ops.yaml | 3 +- python/paddle/nn/functional/activation.py | 19 +- python/paddle/nn/layer/activation.py | 19 +- test/legacy_test/test_silu_op.py | 276 ++++++++++++++++++ 6 files changed, 311 insertions(+), 8 deletions(-) create mode 100644 test/legacy_test/test_silu_op.py diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc index 7b9095897cd084..eea48f2e7e2106 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc @@ -195,6 +195,7 @@ OP_SAME_OPERANDS_AND_RESULT(Polygamma_) OP_SAME_OPERANDS_AND_RESULT(EnableCheckModelNanInf) OP_SAME_OPERANDS_AND_RESULT(ViewShape) OP_SAME_OPERANDS_AND_RESULT(Silu) +OP_SAME_OPERANDS_AND_RESULT(Silu_) OP_SAME_OPERANDS_AND_RESULT(ViewDtype) OP_SAME_OPERANDS_AND_RESULT(FusedSoftmaxMaskUpperTriangle) OP_SAME_OPERANDS_AND_RESULT(Gammaln) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h index 51a6625f7473a5..6a140ecaca65ac 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h @@ -151,6 +151,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Silu) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Silu_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sinh) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 44c5fdf0b53c58..9cefe2ee3a9814 100644 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -4954,12 +4954,13 @@ - op : silu args : (Tensor x) - output : Tensor + output : Tensor(out) infer_meta : func : UnchangedInferMeta spmd_rule : ElementwiseUnaryInferSpmd kernel : func : silu + inplace : (x -> out) backward : silu_grad interfaces : paddle::dialect::LayoutTransformationInterface, paddle::dialect::InferSymbolicShapeInterface traits: pir::UnaryElementWiseTrait diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index c3ddf5f8dd7973..ff61d8f206abd5 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -19,7 +19,7 @@ import paddle from paddle import _C_ops, in_dynamic_mode from paddle.framework import core, in_dynamic_or_pir_mode -from paddle.utils.decorator_utils import ParamAliasDecorator +from paddle.utils.decorator_utils import ParamAliasDecorator, param_one_alias from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ...base.data_feeder import check_dtype, check_variable_and_dtype @@ -1076,7 +1076,8 @@ def selu( return out -def silu(x: Tensor, name: str | None = None) -> Tensor: +@param_one_alias(["x", "input"]) +def silu(x: Tensor, inplace: bool = False, name: str | None = None) -> Tensor: r""" silu activation @@ -1088,6 +1089,7 @@ def silu(x: Tensor, name: str | None = None) -> Tensor: Parameters: x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, complex64, complex128. + inplace (bool, optional): Whether to use inplace operation. Default: False. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -1104,10 +1106,21 @@ def silu(x: Tensor, name: str | None = None) -> Tensor: >>> print(out) Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.73105860, 1.76159406, 2.85772228, 3.92805505]) + + >>> out = F.silu(x, True) + >>> print(out) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.73105860, 1.76159406, 2.85772228, 3.92805505]) + >>> print(x) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.73105860, 1.76159406, 2.85772228, 3.92805505]) """ if in_dynamic_or_pir_mode(): - return _C_ops.silu(x) + if inplace: + return _C_ops.silu_(x) + else: + return _C_ops.silu(x) else: check_variable_and_dtype( x, diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index bcd7369092766d..2da3e65791d68f 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -1263,6 +1263,7 @@ class Silu(Layer): Where :math:`x` is the input Tensor. Parameters: + inplace (bool, optional): Whether to use inplace operation. Default: False. name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Shape: @@ -1280,18 +1281,28 @@ class Silu(Layer): >>> print(out) Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [0.73105860, 1.76159406, 2.85772228, 3.92805505]) + + >>> m = paddle.nn.Silu(True) + >>> out = m(x) + >>> print(out) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.73105860, 1.76159406, 2.85772228, 3.92805505]) + >>> print(x) + Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, + [0.73105860, 1.76159406, 2.85772228, 3.92805505]) """ - def __init__(self, name: str | None = None) -> str: + def __init__(self, inplace: bool = False, name: str | None = None) -> str: super().__init__() self._name = name + self._inplace = inplace def forward(self, x: Tensor) -> Tensor: - return F.silu(x, self._name) + return F.silu(x, self._inplace, self._name) def extra_repr(self) -> str: - name_str = f'name={self._name}' if self._name else '' - return name_str + name_str = f', name={self._name}' if self._name else '' + return f'inplace={self._inplace}{name_str}' class LogSigmoid(Layer): diff --git a/test/legacy_test/test_silu_op.py b/test/legacy_test/test_silu_op.py new file mode 100644 index 00000000000000..3635a37577e4e3 --- /dev/null +++ b/test/legacy_test/test_silu_op.py @@ -0,0 +1,276 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.base.dygraph as dg +import paddle.nn.functional as F +from paddle import base, nn + + +def silu(x): + y_ref = x * (1 / (1 + np.exp(-x))) + return y_ref.astype(x.dtype) + + +class TestSiluOpClass(unittest.TestCase): + def _test_case1_cpu(self): + x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float32) + y_ref = silu(x) + + place = base.CPUPlace() + with dg.guard(place) as g: + x_var = paddle.to_tensor(x) + y_var1 = F.silu(x_var) + y_test1 = y_var1.numpy() + + func = nn.Silu() + y_var2 = func(x_var) + y_test2 = y_var2.numpy() + np.testing.assert_allclose(y_ref, y_test1, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(y_ref, y_test2, rtol=1e-05, atol=1e-08) + + def _test_case1_gpu(self): + x = np.random.uniform(-1, 1, size=(11, 17)).astype(np.float32) + y_ref = silu(x) + + place = base.CUDAPlace(0) + with dg.guard(place) as g: + x_var = paddle.to_tensor(x) + y_var1 = F.silu(x_var) + y_test1 = y_var1.numpy() + + func = nn.Silu() + y_var2 = func(x_var) + y_test2 = y_var2.numpy() + np.testing.assert_allclose(y_ref, y_test1, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(y_ref, y_test2, rtol=1e-05, atol=1e-08) + + def test_cases(self): + self._test_case1_cpu() + if base.is_compiled_with_cuda(): + self._test_case1_gpu() + + def test_fast_math(self): + if not paddle.is_compiled_with_cuda(): + return + + def use_fast_math(enabled): + paddle.set_flags({'FLAGS_use_fast_math': enabled}) + + shape = [11, 17, 8] + x_np = np.random.uniform(-1, 1, size=shape).astype(np.float16) + y_g_np = np.random.uniform(-1, 1, size=shape).astype(np.float16) + + def run_silu_op(): + with dg.guard(): + x = paddle.to_tensor(x_np) + x.stop_gradient = False + y = F.silu(x) + x_grad = paddle.grad([y], [x], [paddle.to_tensor(y_g_np)])[0] + return y.numpy(), x_grad.numpy() + + def run_silu_class(): + with dg.guard(): + x = paddle.to_tensor(x_np) + x.stop_gradient = False + func = nn.Silu() + y = func(x) + x_grad = paddle.grad([y], [x], [paddle.to_tensor(y_g_np)])[0] + return y.numpy(), x_grad.numpy() + + use_fast_math(True) + y_fast_math1, x_g_fast_math1 = run_silu_op() + y_fast_math2, x_g_fast_math2 = run_silu_class() + use_fast_math(False) + + y_ref1, x_g_ref1 = run_silu_op() + y_ref2, x_g_ref2 = run_silu_class() + np.testing.assert_allclose( + y_ref1, y_fast_math1, rtol=1e-05, atol=0.0005 + ) + + np.testing.assert_allclose( + x_g_ref1, x_g_fast_math1, rtol=1e-05, atol=0.0005 + ) + + np.testing.assert_allclose( + y_ref2, y_fast_math2, rtol=1e-05, atol=0.0005 + ) + + np.testing.assert_allclose( + x_g_ref2, x_g_fast_math2, rtol=1e-05, atol=0.0005 + ) + + +class TestSiluOpClass_ZeroSize(unittest.TestCase): + def _test_case1_cpu(self): + x = np.random.uniform(-1, 1, size=(0, 17)).astype(np.float32) + y_ref = silu(x) + + place = base.CPUPlace() + with dg.guard(place) as g: + x_var1 = paddle.to_tensor(x) + x_var2 = paddle.to_tensor(x) + + x_var1.stop_gradient = False + x_var2.stop_gradient = False + + y_var1 = F.silu(x_var1) + y_test1 = y_var1.numpy() + + func = nn.Silu() + y_var2 = func(x_var2) + y_test2 = y_var2.numpy() + + loss1 = paddle.sum(y_var1) + loss1.backward() + + loss2 = paddle.sum(y_var2) + loss2.backward() + np.testing.assert_allclose(y_ref, y_test1, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(x_var1.grad.shape, x_var1.shape) + + np.testing.assert_allclose(y_ref, y_test2, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(x_var2.grad.shape, x_var2.shape) + + def _test_case1_gpu(self): + x = np.random.uniform(-1, 1, size=(0, 17)).astype(np.float32) + y_ref = silu(x) + + place = base.CUDAPlace(0) + with dg.guard(place) as g: + x_var1 = paddle.to_tensor(x) + x_var2 = paddle.to_tensor(x) + + x_var1.stop_gradient = False + x_var2.stop_gradient = False + + y_var1 = F.silu(x_var1) + y_test1 = y_var1.numpy() + + func = nn.Silu() + y_var2 = func(x_var2) + y_test2 = y_var2.numpy() + + loss1 = paddle.sum(y_var1) + loss1.backward() + + loss2 = paddle.sum(y_var2) + loss2.backward() + np.testing.assert_allclose(y_ref, y_test1, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(x_var1.grad.shape, x_var1.shape) + + np.testing.assert_allclose(y_ref, y_test2, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(x_var2.grad.shape, x_var2.shape) + + def test_cases(self): + self._test_case1_cpu() + if base.is_compiled_with_cuda(): + self._test_case1_gpu() + + +class TestSiluOpClass_Inplace(unittest.TestCase): + def _test_case1_cpu(self): + x = np.random.uniform(-1, 1, size=(15, 17)).astype(np.float32) + y_ref = silu(x) + + place = base.CPUPlace() + with dg.guard(place) as g: + x_var1 = paddle.to_tensor(x) + x_var2 = paddle.to_tensor(x) + + y_var1 = F.silu(x_var1, True) + y_test1 = y_var1.numpy() + + func = nn.Silu(True) + y_var2 = func(x_var2) + y_test2 = y_var2.numpy() + + np.testing.assert_allclose(y_ref, y_test1, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(y_ref, y_test2, rtol=1e-05, atol=1e-08) + + np.testing.assert_allclose( + y_ref, x_var1.numpy(), rtol=1e-05, atol=1e-08 + ) + np.testing.assert_allclose( + y_ref, x_var2.numpy(), rtol=1e-05, atol=1e-08 + ) + + def _test_case1_gpu(self): + x = np.random.uniform(-1, 1, size=(15, 17)).astype(np.float32) + y_ref = silu(x) + + place = base.CUDAPlace(0) + with dg.guard(place) as g: + x_var1 = paddle.to_tensor(x) + x_var2 = paddle.to_tensor(x) + + y_var1 = F.silu(x_var1, True) + y_test1 = y_var1.numpy() + + func = nn.Silu(True) + y_var2 = func(x_var2) + y_test2 = y_var2.numpy() + + np.testing.assert_allclose(y_ref, y_test1, rtol=1e-05, atol=1e-08) + np.testing.assert_allclose(y_ref, y_test2, rtol=1e-05, atol=1e-08) + + np.testing.assert_allclose( + y_ref, x_var1.numpy(), rtol=1e-05, atol=1e-08 + ) + np.testing.assert_allclose( + y_ref, x_var2.numpy(), rtol=1e-05, atol=1e-08 + ) + + def test_cases(self): + self._test_case1_cpu() + if base.is_compiled_with_cuda(): + self._test_case1_gpu() + + +class TestSiluParamDecorator(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.x_np = np.random.random((10, 3, 4)).astype("float64") + self.test_types = ["decorator"] + + def do_test(self, test_type): + x = paddle.to_tensor(self.x_np, stop_gradient=False) + if test_type == 'raw': + result = F.silu(x, False) + result.mean().backward() + return result, x.grad + elif test_type == 'decorator': + result = F.silu(input=x, inplace=False) + result.mean().backward() + return result, x.grad + else: + raise ValueError(f"Unknown test type: {test_type}") + + def test_all(self): + out_std, grad_x_std = self.do_test('raw') + for test_type in self.test_types: + out, grad_x = self.do_test(test_type) + np.testing.assert_allclose(out.numpy(), out_std.numpy(), rtol=1e-7) + np.testing.assert_allclose( + grad_x.numpy(), grad_x_std.numpy(), rtol=1e-7 + ) + + +if __name__ == '__main__': + unittest.main() From 4720726ce7263055572682ff1a9da7b23e199789 Mon Sep 17 00:00:00 2001 From: Starrysea996 <2462405885@qq.com> Date: Thu, 21 Aug 2025 13:34:14 +0800 Subject: [PATCH 2/6] add test --- python/paddle/nn/layer/activation.py | 6 ++-- test/legacy_test/test_silu_op.py | 50 ++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index 2da3e65791d68f..52e8d1f7f42227 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -1301,8 +1301,10 @@ def forward(self, x: Tensor) -> Tensor: return F.silu(x, self._inplace, self._name) def extra_repr(self) -> str: - name_str = f', name={self._name}' if self._name else '' - return f'inplace={self._inplace}{name_str}' + name_str = f'inplace={self._inplace}' + ( + f', name={self._name}' if self._name else '' + ) + return name_str class LogSigmoid(Layer): diff --git a/test/legacy_test/test_silu_op.py b/test/legacy_test/test_silu_op.py index 3635a37577e4e3..730e5905527754 100644 --- a/test/legacy_test/test_silu_op.py +++ b/test/legacy_test/test_silu_op.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import unittest import numpy as np +from op_test import OpTest import paddle import paddle.base.dygraph as dg @@ -272,5 +274,53 @@ def test_all(self): ) +class SiluOpDefaultTest(OpTest): + """the base class of other op testcases""" + + def setUp(self): + self.initTestCase() + self.python_api = F.silu + + self.op_type = "silu" + self.inputs = {'X': self.X} + + self.target = copy.deepcopy(self.X) + self.target = silu(self.target) + self.outputs = {'Out': (self.target)} + + def test_check_output(self): + self.check_output(check_pir=True, check_symbol_infer=False) + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', check_pir=True) + + def init_dtype(self): + self.dtype = np.float64 + + def initTestCase(self): + self.init_dtype() + self.X = np.arange(1, 101, dtype=self.dtype).reshape([10, -1]) + if self.dtype == np.complex64 or self.dtype == np.complex128: + self.X = ( + np.random.uniform(-1, 1, [10, 10]) + + 1j * np.random.uniform(-1, 1, [10, 10]) + ).astype(self.dtype) + + +class SiluOpDefaultTestFP16(SiluOpDefaultTest): + def init_dtype(self): + self.dtype = np.float16 + + +class SiluOpDefaultTestComplex_64(SiluOpDefaultTest): + def init_dtype(self): + self.dtype = np.complex64 + + +class SiluOpDefaultTestComplex_128(SiluOpDefaultTest): + def init_dtype(self): + self.dtype = np.complex128 + + if __name__ == '__main__': unittest.main() From 5fa6a5cbdd925ee2a7ddbbd0f4dba9e76a913c2c Mon Sep 17 00:00:00 2001 From: Starrysea996 <2462405885@qq.com> Date: Thu, 21 Aug 2025 17:14:29 +0800 Subject: [PATCH 3/6] change position --- python/paddle/nn/functional/activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 2b7a31214c7649..b2d12a53a89268 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -19,7 +19,7 @@ import paddle from paddle import _C_ops, in_dynamic_mode from paddle.framework import core, in_dynamic_or_pir_mode -from paddle.utils.decorator_utils import ParamAliasDecorator, param_one_alias +from paddle.utils.decorator_utils import param_one_alias from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only from ...base.data_feeder import check_dtype, check_variable_and_dtype From 4917eaf3e066772182ad964c364fec3e445e5453 Mon Sep 17 00:00:00 2001 From: Starrysea996 <2462405885@qq.com> Date: Fri, 22 Aug 2025 15:16:49 +0800 Subject: [PATCH 4/6] fix codestyle --- python/paddle/nn/functional/activation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index ba36a24eb51d99..b6cc4d9bf50363 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1098,7 +1098,6 @@ def silu(x: Tensor, inplace: bool = False, name: str | None = None) -> Tensor: x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, complex64, complex128. alias: ``input``. inplace (bool, optional): Whether to use inplace operation. Default: False. - name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: From c92483406f87828e02b9ff8e8312db7e37e852ec Mon Sep 17 00:00:00 2001 From: Starrysea996 <2462405885@qq.com> Date: Tue, 26 Aug 2025 17:45:43 +0800 Subject: [PATCH 5/6] add print test for silu --- test/legacy_test/test_silu_op.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/legacy_test/test_silu_op.py b/test/legacy_test/test_silu_op.py index 730e5905527754..0688223e08b8aa 100644 --- a/test/legacy_test/test_silu_op.py +++ b/test/legacy_test/test_silu_op.py @@ -274,6 +274,15 @@ def test_all(self): ) +class TestSiluPrint(unittest.TestCase): + def test_print(self): + print(nn.Silu()) + print(nn.Silu(True)) + print(nn.Silu(False)) + print(nn.Silu(inplace=True)) + print(nn.Silu(inplace=False)) + + class SiluOpDefaultTest(OpTest): """the base class of other op testcases""" From 3fffaab613d88ba05ffe969f8ebb48234c41a4a2 Mon Sep 17 00:00:00 2001 From: Starrysea996 <2462405885@qq.com> Date: Wed, 27 Aug 2025 20:56:38 +0800 Subject: [PATCH 6/6] add test --- test/legacy_test/test_silu_op.py | 52 +++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_silu_op.py b/test/legacy_test/test_silu_op.py index 0688223e08b8aa..a543da01d22bc5 100644 --- a/test/legacy_test/test_silu_op.py +++ b/test/legacy_test/test_silu_op.py @@ -16,7 +16,7 @@ import unittest import numpy as np -from op_test import OpTest +from op_test import OpTest, get_places import paddle import paddle.base.dygraph as dg @@ -331,5 +331,55 @@ def init_dtype(self): self.dtype = np.complex128 +class TestSiluAPI(unittest.TestCase): + def setUp(self): + np.random.seed(0) + self.shape = [10, 10] + self.x_np = np.random.random(self.shape).astype(np.float32) + self.place = get_places() + self.x_feed = copy.deepcopy(self.x_np) + + def test_api_static(self): + paddle.enable_static() + + def run(place, inplace): + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('X', self.shape) + out = F.silu(x, inplace) + exe = paddle.static.Executor(self.place[0]) + res = exe.run( + feed={ + 'X': self.x_feed, + }, + fetch_list=[out], + ) + target = copy.deepcopy(self.x_np) + out_ref = silu(target) + + for out in res: + np.testing.assert_allclose(out, out_ref, rtol=0.001) + + for place in self.place: + run(place, True) + run(place, False) + + def test_api_dygraph(self): + def run(place, inplace): + paddle.disable_static(place) + x_tensor = paddle.to_tensor(self.x_np) + out = F.silu(x_tensor, inplace) + + target = copy.deepcopy(self.x_np) + out_ref = silu(target) + + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) + + paddle.enable_static() + + for place in self.place: + run(place, True) + run(place, False) + + if __name__ == '__main__': unittest.main()