From be9efb41d9c294cbb344bcb8bf7f77c84b60371d Mon Sep 17 00:00:00 2001 From: PuQing Date: Fri, 12 Apr 2024 13:33:35 +0000 Subject: [PATCH 1/3] Add support for value parameter in thresholded_relu op --- paddle/phi/api/yaml/backward.yaml | 4 +- paddle/phi/api/yaml/ops.yaml | 2 +- paddle/phi/kernels/activation_grad_kernel.h | 2 +- paddle/phi/kernels/activation_kernel.h | 2 +- .../phi/kernels/cpu/activation_grad_kernel.cc | 7 ++-- paddle/phi/kernels/cpu/activation_kernel.cc | 7 ++-- paddle/phi/kernels/funcs/activation_functor.h | 19 +++++---- .../phi/kernels/gpu/activation_grad_kernel.cu | 8 ++-- paddle/phi/kernels/gpu/activation_kernel.cu | 7 ++-- python/paddle/nn/functional/activation.py | 13 +++--- python/paddle/nn/layer/activation.py | 10 +++-- test/legacy_test/test_activation_op.py | 41 +++++++++++-------- test/legacy_test/test_imperative_layers.py | 4 +- 13 files changed, 74 insertions(+), 52 deletions(-) diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 3937464fbce49..002457a8584f4 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -2548,8 +2548,8 @@ func : tensor_unfold_grad - backward_op : thresholded_relu_grad - forward : thresholded_relu (Tensor x, float threshold) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float threshold) + forward : thresholded_relu (Tensor x, float threshold, float value) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float threshold, float value) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 668d98a8cdef3..43a9bb2f6a36d 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -2854,7 +2854,7 @@ no_need_buffer : input - op : thresholded_relu - args : (Tensor x, float threshold = 1.0) + args : (Tensor x, float threshold = 1.0, float value = 0.0) output : Tensor(out) infer_meta : func : UnchangedInferMeta diff --git a/paddle/phi/kernels/activation_grad_kernel.h b/paddle/phi/kernels/activation_grad_kernel.h index b2fae7b0406e0..8aed27bb59ea9 100644 --- a/paddle/phi/kernels/activation_grad_kernel.h +++ b/paddle/phi/kernels/activation_grad_kernel.h @@ -307,7 +307,6 @@ DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Floor); DECLARE_ACTIVATION_GRAD_KERNEL_NODEP(Ceil); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, alpha); -DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, lambda); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(HardShrink, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Logit, eps); @@ -318,5 +317,6 @@ DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(HardTanh, t_min, t_max); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(STanh, scale_a, scale_b); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, beta, threshold); DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, slope, offset); +DECLARE_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, threshold, value); } // namespace phi diff --git a/paddle/phi/kernels/activation_kernel.h b/paddle/phi/kernels/activation_kernel.h index 70c0187e68865..bf3cb325160d3 100644 --- a/paddle/phi/kernels/activation_kernel.h +++ b/paddle/phi/kernels/activation_kernel.h @@ -74,7 +74,6 @@ DECLARE_ACTIVATION_KERNEL(Ceil) DECLARE_ACTIVATION_KERNEL(Negative) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, alpha) -DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(SoftShrink, lambda) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, threshold) DECLARE_ACTIVATION_KERNEL_WITH_ONE_ATTRS(HardShrink, threshold) @@ -87,6 +86,7 @@ DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardTanh, t_min, t_max) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(STanh, scale_a, scale_b) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, beta, threshold) DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) +DECLARE_ACTIVATION_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, threshold, value) template void HardSwishKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/cpu/activation_grad_kernel.cc b/paddle/phi/kernels/cpu/activation_grad_kernel.cc index 3f26f8c388e66..b8ced8d4defe2 100644 --- a/paddle/phi/kernels/cpu/activation_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_grad_kernel.cc @@ -155,9 +155,6 @@ DEFINE_CPU_ACTIVATION_GRAD_KERNEL_NODEP(Ceil, ZeroGradFunctor); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, LeakyReluGradFunctor, alpha); -DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, - ThresholdedReluGradFunctor, - threshold); DEFINE_CPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, SoftShrinkGradFunctor, lambda); @@ -188,6 +185,10 @@ DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, HardSigmoidGradFunctor, slope, offset); +DEFINE_CPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, + ThresholdedReluGradFunctor, + threshold, + value); template void SiluGradKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/cpu/activation_kernel.cc b/paddle/phi/kernels/cpu/activation_kernel.cc index 92acf104fedcf..fda8493c9f452 100644 --- a/paddle/phi/kernels/cpu/activation_kernel.cc +++ b/paddle/phi/kernels/cpu/activation_kernel.cc @@ -106,9 +106,6 @@ DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Exp, ExpFunctor) DEFINE_CPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, Expm1Functor) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, LeakyReluFunctor, alpha) -DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, - ThresholdedReluFunctor, - threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(Mish, MishFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, HardShrinkFunctor, threshold) DEFINE_CPU_ACT_KERNEL_WITH_ONE_ATTRS(SoftShrink, SoftShrinkFunctor, lambda) @@ -122,6 +119,10 @@ DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, HardSigmoidFunctor, slope, offset) +DEFINE_CPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, + ThresholdedReluFunctor, + threshold, + value) template void HardSwishKernel(const Context& dev_ctx, diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index ba1d9873ec2a4..12168e0e5e3cc 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -1825,22 +1825,25 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor { template struct ThresholdedReluFunctor : public BaseActivationFunctor { float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } template void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); // NOLINT - out.device(d) = (x > th).template cast() * x; + out.device(d) = (x > th).template cast() * x + + (x <= th).template cast() * static_cast(value); } }; template struct ThresholdedReluGradFunctor : public BaseActivationFunctor { float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } - // thresholded_relu(x) = x > threshold ? x : 0 + // thresholded_relu(x, threshold, value) = x > threshold ? x : value __device__ __forceinline__ T operator()(const T x) const { - return x > static_cast(threshold) ? x : zero; + return x > static_cast(threshold) ? x : static_cast(value); } }; @@ -4247,9 +4251,10 @@ template struct CudaThresholdedReluGradFunctor : public BaseActivationFunctor { T zero = static_cast(0.0f); float threshold; + float value; typename BaseActivationFunctor::AttrPair GetAttrs() { - return {{"threshold", &threshold}}; + return {{"threshold", &threshold}, {"value", &value}}; } // dx = x > threshold ? dout : 0 diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 594eefe5b8de1..ecfd46852c134 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -209,9 +209,6 @@ DEFINE_GPU_ACTIVATION_GRAD_KERNEL_DEPX(Swish, CudaSwishGradFunctor); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, CudaLeakyReluGradFunctor, alpha); -DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(ThresholdedRelu, - CudaThresholdedReluGradFunctor, - threshold); DEFINE_GPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(SoftShrink, CudaSoftShrinkGradFunctor, lambda); @@ -247,7 +244,10 @@ DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT(HardSigmoid, CudaHardSigmoidGradFunctor, slope, offset); - +DEFINE_GPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(ThresholdedRelu, + CudaThresholdedReluGradFunctor, + threshold, + value); template void SiluGradKernel(const Context& dev_ctx, const DenseTensor& x, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 1bf3d92d80620..aa874c5e0dd81 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -123,9 +123,6 @@ DEFINE_GPU_ACTIVATION_KERNEL_WITH_INT_IN_FLOAT_OUT(Expm1, CudaExpm1Functor) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LeakyRelu, CudaLeakyReluFunctor, alpha) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(LogitCUDA, CudaLogitFunctor, eps) -DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(ThresholdedRelu, - CudaThresholdedReluFunctor, - threshold) DEFINE_GPU_ACT_KERNEL_WITH_ONE_ATTRS(HardShrink, CudaHardShrinkFunctor, threshold) @@ -148,6 +145,10 @@ DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(HardSigmoid, slope, offset) DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(Selu, CudaSeluFunctor, scale, alpha) +DEFINE_GPU_ACT_KERNEL_WITH_TWO_ATTRS(ThresholdedRelu, + CudaThresholdedReluFunctor, + threshold, + value) template void HardSwishKernel(const Context& dev_ctx, diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 3dd30afeec986..ddfb04d8530a1 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -1543,7 +1543,7 @@ def tanhshrink(x, name=None): return out -def thresholded_relu(x, threshold=1.0, name=None): +def thresholded_relu(x, threshold=1.0, value=0.0, name=None): r""" thresholded relu activation. @@ -1553,7 +1553,7 @@ def thresholded_relu(x, threshold=1.0, name=None): \left\{ \begin{array}{rl} x,& \text{if } \ x > threshold \\ - 0,& \text{otherwise} + value,& \text{otherwise} \end{array} \right. @@ -1561,6 +1561,7 @@ def thresholded_relu(x, threshold=1.0, name=None): Parameters: x (Tensor): The input Tensor with data type float32, float64. threshold (float, optional): The value of threshold for thresholded_relu. Default is 1.0 + value (float, optional): The value to replace with when x is less than threshold. Default is 0.0 name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. Returns: @@ -1580,7 +1581,7 @@ def thresholded_relu(x, threshold=1.0, name=None): """ if in_dynamic_or_pir_mode(): - return _C_ops.thresholded_relu(x, threshold) + return _C_ops.thresholded_relu(x, threshold, value) else: check_variable_and_dtype( x, @@ -1594,19 +1595,19 @@ def thresholded_relu(x, threshold=1.0, name=None): type='thresholded_relu', inputs={'X': x}, outputs={'Out': out}, - attrs={'threshold': threshold}, + attrs={'threshold': threshold, 'value': value}, ) return out @inplace_apis_in_dygraph_only -def thresholded_relu_(x, threshold=1.0, name=None): +def thresholded_relu_(x, threshold=1.0, value=0.0, name=None): r""" Inplace version of ``thresholded_relu`` API, the output Tensor will be inplaced with input ``x``. Please refer to :ref:`api_paddle_nn_functional_thresholded_relu`. """ if in_dynamic_mode(): - return _C_ops.thresholded_relu_(x, threshold) + return _C_ops.thresholded_relu_(x, threshold, value) def log_softmax(x, axis=-1, dtype=None, name=None): diff --git a/python/paddle/nn/layer/activation.py b/python/paddle/nn/layer/activation.py index c1234c28bc47d..b08f5f9ca8bbb 100644 --- a/python/paddle/nn/layer/activation.py +++ b/python/paddle/nn/layer/activation.py @@ -1164,13 +1164,14 @@ class ThresholdedReLU(Layer): \left\{ \begin{array}{rl} x,& \text{if } \ x > threshold \\ - 0,& \text{otherwise} + value,& \text{otherwise} \end{array} \right. Parameters: threshold (float, optional): The value of threshold for ThresholdedReLU. Default is 1.0 + value (float, optinal): The value to replace with when x is less than threshold. Default is 0.0 name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. @@ -1191,17 +1192,18 @@ class ThresholdedReLU(Layer): [2., 0., 0.]) """ - def __init__(self, threshold=1.0, name=None): + def __init__(self, threshold=1.0, value=0.0, name=None): super().__init__() self._threshold = threshold + self._value = value self._name = name def forward(self, x): - return F.thresholded_relu(x, self._threshold, self._name) + return F.thresholded_relu(x, self._threshold, self._value, self._name) def extra_repr(self): name_str = f', name={self._name}' if self._name else '' - return f'threshold={self._threshold}{name_str}' + return f'threshold={self._threshold}, value={self._value}{name_str}' class Silu(Layer): diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index f3d7e5a1fa517..136e75b6c29f5 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -3287,26 +3287,34 @@ def test_check_grad(self): self.check_grad( ['X'], 'Out', - check_prim=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), only_check_prim=self.if_only_check_prim(), check_pir=True, - check_prim_pir=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim_pir=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), check_pir_onednn=self.check_pir_onednn, ) def test_check_output(self): self.check_output( - check_prim=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), check_pir=True, - check_prim_pir=True - if self.dtype not in [np.complex64, np.complex128] - else False, + check_prim_pir=( + True + if self.dtype not in [np.complex64, np.complex128] + else False + ), check_pir_onednn=self.check_pir_onednn, ) @@ -4864,8 +4872,8 @@ def test_errors(self): F.softsign(x_fp16) -def ref_thresholded_relu(x, threshold=1.0): - out = (x > threshold) * x +def ref_thresholded_relu(x, threshold=1.0, value=0.0): + out = (x > threshold) * x + (x <= threshold) * value return out @@ -4877,15 +4885,16 @@ def setUp(self): self.python_api = paddle.nn.functional.thresholded_relu threshold = 15 + value = 5 np.random.seed(1024) x = np.random.uniform(-20, 20, self.shape).astype(self.dtype) x[np.abs(x) < 0.005] = 0.02 - out = ref_thresholded_relu(x, threshold) + out = ref_thresholded_relu(x, threshold, value) self.inputs = {'X': OpTest.np_dtype_to_base_dtype(x)} self.outputs = {'Out': out} - self.attrs = {"threshold": threshold} + self.attrs = {"threshold": threshold, "value": value} self.convert_input_output() def init_shape(self): diff --git a/test/legacy_test/test_imperative_layers.py b/test/legacy_test/test_imperative_layers.py index 9906d3ba0ede0..947ab037ee89b 100644 --- a/test/legacy_test/test_imperative_layers.py +++ b/test/legacy_test/test_imperative_layers.py @@ -85,7 +85,9 @@ def test_layer_str(self): self.assertEqual(str(module), 'Tanhshrink()') module = nn.ThresholdedReLU() - self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)') + self.assertEqual( + str(module), 'ThresholdedReLU(threshold=1.0, value=0.0)' + ) module = nn.LogSigmoid() self.assertEqual(str(module), 'LogSigmoid()') From d7cac777ce0f997def8b9d80fb3e81dd933ae8fc Mon Sep 17 00:00:00 2001 From: PuQing Date: Tue, 4 Jun 2024 04:36:11 +0000 Subject: [PATCH 2/3] set thresholded_relu op_version --- paddle/phi/kernels/funcs/activation_functor.h | 1 - paddle/phi/ops/yaml/op_version.yaml | 8 ++++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 12168e0e5e3cc..27223dad0c1de 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -4233,7 +4233,6 @@ struct CudaHardTanhGradFunctor : public BaseActivationFunctor { template struct CudaThresholdedReluFunctor : public BaseActivationFunctor { - T zero = static_cast(0.0f); float threshold; float value; diff --git a/paddle/phi/ops/yaml/op_version.yaml b/paddle/phi/ops/yaml/op_version.yaml index 7ef9a6f83e84d..a41a67e9ded17 100644 --- a/paddle/phi/ops/yaml/op_version.yaml +++ b/paddle/phi/ops/yaml/op_version.yaml @@ -486,6 +486,14 @@ comment : A flag to indicate whether to do softmax default : "true" +- op : thresholded_relu + version : + - checkpoint : Upgrade thresholded_relu, add a new attribute [value] + action : + - add_attr : value + comment : The threshold value of thresholded_relu. + default : 0.0 + - op : trace version : - checkpoint : Upgrade trace add a new attribute [axis2] From fcf314d5832f832a6be93cc9128b2b5b28da0254 Mon Sep 17 00:00:00 2001 From: PuQing Date: Tue, 4 Jun 2024 10:57:12 +0000 Subject: [PATCH 3/3] additional value parameter in thresholded_relu --- test/legacy_test/test_activation_op.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 214c16687c737..4de793c943265 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -4938,6 +4938,7 @@ class TestThresholdedReluAPI(unittest.TestCase): # test paddle.nn.ThresholdedReLU, paddle.nn.functional.thresholded_relu def setUp(self): self.threshold = 15 + self.value = 5 np.random.seed(1024) self.x_np = np.random.uniform(-20, 20, [10, 12]).astype(np.float64) self.x_np[np.abs(self.x_np) < 0.005] = 0.02 @@ -4952,22 +4953,30 @@ def test_static_api(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): x = paddle.static.data('X', self.x_np.shape, self.x_np.dtype) - out1 = F.thresholded_relu(x, self.threshold) - thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out1 = F.thresholded_relu(x, self.threshold, self.value) + thresholded_relu = paddle.nn.ThresholdedReLU( + self.threshold, self.value + ) out2 = thresholded_relu(x) exe = paddle.static.Executor(self.place) res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2]) - out_ref = ref_thresholded_relu(self.x_np, self.threshold) + out_ref = ref_thresholded_relu( + self.x_np, self.threshold, self.value + ) for r in res: np.testing.assert_allclose(out_ref, r, rtol=1e-05) def test_dygraph_api(self): with dynamic_guard(): x = paddle.to_tensor(self.x_np) - out1 = F.thresholded_relu(x, self.threshold) - thresholded_relu = paddle.nn.ThresholdedReLU(self.threshold) + out1 = F.thresholded_relu(x, self.threshold, self.value) + thresholded_relu = paddle.nn.ThresholdedReLU( + self.threshold, self.value + ) out2 = thresholded_relu(x) - out_ref = ref_thresholded_relu(self.x_np, self.threshold) + out_ref = ref_thresholded_relu( + self.x_np, self.threshold, self.value + ) for r in [out1, out2]: np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)