From 2f2242831345ef6692e84363914089ce3dea58c1 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Thu, 23 Nov 2023 21:54:47 +0800 Subject: [PATCH 01/12] =?UTF-8?q?=E3=80=90Hackathon=205th=20No.25=E3=80=91?= =?UTF-8?q?add=20gammaln=20api?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/api/yaml/backward.yaml | 10 ++ paddle/phi/api/yaml/ops.yaml | 10 ++ paddle/phi/kernels/cpu/gammaln_grad_kernel.cc | 22 +++ paddle/phi/kernels/cpu/gammaln_kernel.cc | 22 +++ paddle/phi/kernels/gammaln_grad_kernel.h | 27 ++++ paddle/phi/kernels/gammaln_kernel.h | 26 +++ paddle/phi/kernels/gpu/gammaln_grad_kernel.cu | 30 ++++ paddle/phi/kernels/gpu/gammaln_kernel.cu | 29 ++++ .../kernels/impl/gammaln_grad_kernel_impl.h | 56 +++++++ paddle/phi/kernels/impl/gammaln_kernel_impl.h | 50 ++++++ python/paddle/__init__.py | 4 + python/paddle/tensor/__init__.py | 4 + python/paddle/tensor/math.py | 45 ++++++ test/legacy_test/test_gammaln_op.py | 152 ++++++++++++++++++ 14 files changed, 487 insertions(+) create mode 100644 paddle/phi/kernels/cpu/gammaln_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/gammaln_kernel.cc create mode 100644 paddle/phi/kernels/gammaln_grad_kernel.h create mode 100644 paddle/phi/kernels/gammaln_kernel.h create mode 100644 paddle/phi/kernels/gpu/gammaln_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/gammaln_kernel.cu create mode 100644 paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h create mode 100644 paddle/phi/kernels/impl/gammaln_kernel_impl.h create mode 100644 test/legacy_test/test_gammaln_op.py diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 98b376b55f864..bfd08787a55a7 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -920,6 +920,16 @@ kernel : func : frame_grad +- backward_op : gammaln_grad + forward : gammaln(Tensor x) -> Tensor(out) + args : (Tensor x, Tensor out_grad) + output : Tensor(x_grad) + infer_meta : + func : UnchangedInferMeta + param: [x] + kernel : + func : gammaln_grad + - backward_op : gather_grad forward : gather(Tensor x, Tensor index, Scalar axis=0) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad, Scalar axis=0) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index c55e8ffc132e6..14ac499854614 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -1030,6 +1030,16 @@ data_type : dtype backend : place +- op : gammaln + args : (Tensor x) + output : Tensor(out) + infer_meta : + func : UnchangedInferMeta + kernel : + func : gammaln + inplace: (x -> out) + backward : gammaln_grad + - op : gather args : (Tensor x, Tensor index, Scalar axis=0) output : Tensor(out) diff --git a/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc new file mode 100644 index 0000000000000..c52ee8b3848e9 --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaln_grad_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_grad_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaln_grad, CPU, ALL_LAYOUT, phi::GammalnGradKernel, float, double) {} diff --git a/paddle/phi/kernels/cpu/gammaln_kernel.cc b/paddle/phi/kernels/cpu/gammaln_kernel.cc new file mode 100644 index 0000000000000..ff62f86d2522f --- /dev/null +++ b/paddle/phi/kernels/cpu/gammaln_kernel.cc @@ -0,0 +1,22 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" + +PD_REGISTER_KERNEL( + gammaln, CPU, ALL_LAYOUT, phi::GammalnKernel, float, double) {} diff --git a/paddle/phi/kernels/gammaln_grad_kernel.h b/paddle/phi/kernels/gammaln_grad_kernel.h new file mode 100644 index 0000000000000..440dca72a9d46 --- /dev/null +++ b/paddle/phi/kernels/gammaln_grad_kernel.h @@ -0,0 +1,27 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammalnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& d_out, + DenseTensor* d_x); +} // namespace phi diff --git a/paddle/phi/kernels/gammaln_kernel.h b/paddle/phi/kernels/gammaln_kernel.h new file mode 100644 index 0000000000000..db3015c4a747d --- /dev/null +++ b/paddle/phi/kernels/gammaln_kernel.h @@ -0,0 +1,26 @@ + +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void GammalnKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); +} // namespace phi diff --git a/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu new file mode 100644 index 0000000000000..b2513d9e3f25c --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaln_grad_kernel.cu @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h" + +PD_REGISTER_KERNEL(gammaln_grad, + GPU, + ALL_LAYOUT, + phi::GammalnGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gammaln_kernel.cu b/paddle/phi/kernels/gpu/gammaln_kernel.cu new file mode 100644 index 0000000000000..3d57be7b27733 --- /dev/null +++ b/paddle/phi/kernels/gpu/gammaln_kernel.cu @@ -0,0 +1,29 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/gammaln_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/impl/gammaln_kernel_impl.h" + +PD_REGISTER_KERNEL(gammaln, + GPU, + ALL_LAYOUT, + phi::GammalnKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h new file mode 100644 index 0000000000000..f588a28afbd10 --- /dev/null +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -0,0 +1,56 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h" +namespace phi { +template +struct GammalnGradFunctor { + GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel) + : dout_(dout), x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_dout = static_cast(dout_[idx]); + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast( + mp_dout * Eigen::numext::polygamma(static_cast(0), mp_x)); + } + + private: + const T* dout_; + const T* x_; + T* output_; + int64_t numel_; +}; +template +void GammalnGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& d_out, + DenseTensor* d_x) { + auto numel = d_out.numel(); + auto* dout_data = d_out.data(); + auto* x_data = x.data(); + auto* dx_data = + dev_ctx.template Alloc(d_x, static_cast(numel * sizeof(T))); + phi::funcs::ForRange for_range(dev_ctx, numel); + GammalnGradFunctor functor(dout_data, x_data, dx_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h new file mode 100644 index 0000000000000..8ccd078b8ed72 --- /dev/null +++ b/paddle/phi/kernels/impl/gammaln_kernel_impl.h @@ -0,0 +1,50 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/funcs/for_range.h" +#include "paddle/phi/kernels/gammaln_kernel.h" + +namespace phi { +template +struct GammalnFunctor { + GammalnFunctor(const T* x, T* output, int64_t numel) + : x_(x), output_(output), numel_(numel) {} + + HOSTDEVICE void operator()(int64_t idx) const { + using MT = typename phi::dtype::MPTypeTrait::Type; + const MT mp_x = static_cast(x_[idx]); + output_[idx] = static_cast(std::lgamma(mp_x)); + } + + private: + const T* x_; + T* output_; + int64_t numel_; +}; + +template +void GammalnKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + auto numel = x.numel(); + auto* x_data = x.data(); + auto* out_data = dev_ctx.template Alloc(out); + phi::funcs::ForRange for_range(dev_ctx, numel); + GammalnFunctor functor(x_data, out_data, numel); + for_range(functor); +} +} // namespace phi diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 4f7e1ce38a3ff..e2752d4c61e68 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -373,6 +373,8 @@ conj, trunc, trunc_, + gammaln, + gammaln_, digamma, digamma_, neg, @@ -764,6 +766,8 @@ 'square_', 'divide', 'divide_', + 'gammaln', + 'gammaln_', 'ceil', 'atan', 'atan_', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index b96045d35faf6..1ca93644d0755 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -336,6 +336,8 @@ from .math import i0e # noqa: F401 from .math import i1 # noqa: F401 from .math import i1e # noqa: F401 +from .math import gammaln # noqa: F401 +from .math import gammaln_ # noqa: F401 from .math import polygamma # noqa: F401 from .math import polygamma_ # noqa: F401 from .math import renorm # noqa: F401 @@ -629,6 +631,8 @@ 'real', 'imag', 'is_floating_point', + 'gammaln', + 'gammaln_', 'digamma', 'digamma_', 'diagonal', diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 3aad0a6a91a9a..9219b34e17693 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -4986,6 +4986,51 @@ def conj(x, name=None): return out +def gammaln(x, name=None): + r""" + Calculates the logarithm of the absolute value of the gamma function elementwisely. + + Args: + x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, uint16. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. + + Returns: + Tensor, The values of the logarithm of the absolute value of the gamma at the given tensor x. + + Examples: + .. code-block:: python + + >>> import paddle + + >>> x = paddle.arange(1.5, 4.5, 0.5) + >>> out = paddle.gammaln(x) + >>> print(out) + Tensor(shape=[6], dtype=float32, place=Place(cpu), stop_gradient=True, + [-0.12078224, 0. , 0.28468287, 0.69314718, 1.20097363, + 1.79175949]) + """ + if in_dynamic_or_pir_mode(): + return _C_ops.gammaln(x) + else: + check_variable_and_dtype( + x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'gammaln' + ) + helper = LayerHelper('gammaln', **locals()) + out = helper.create_variable_for_type_inference(x.dtype) + helper.append_op(type='gammaln', inputs={'x': x}, outputs={'out': out}) + return out + + +@inplace_apis_in_dygraph_only +def gammaln_(x, name=None): + r""" + Inplace version of ``gammaln`` API, the output Tensor will be inplaced with input ``x``. + Please refer to :ref:`api_paddle_gammaln`. + """ + if in_dynamic_mode(): + return _C_ops.gammaln_(x) + + def digamma(x, name=None): r""" Calculates the digamma of the given input tensor, element-wise. diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py new file mode 100644 index 0000000000000..ac7c8a6aa1606 --- /dev/null +++ b/test/legacy_test/test_gammaln_op.py @@ -0,0 +1,152 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import OpTest, convert_float_to_uint16 +from scipy import special + +import paddle +from paddle.base import core + + +def ref_gammaln(x): + return special.gammaln(x) + + +def ref_gammaln_grad(x, dout): + return dout * special.polygamma(0, x) + + +class TestGammalnOp(OpTest): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.init_dtype_type() + self.shape = (3, 40) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.inputs = {'x': self.x} + out = ref_gammaln(self.x) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_output(self): + self.check_output(check_pir=True) + + def test_check_grad(self): + self.check_grad(['x'], 'out', check_pir=True) + + +class TestGammalnOpFp32(TestGammalnOp): + def init_dtype_type(self): + self.dtype = np.float32 + + +class TestGammalnFP16Op(TestGammalnOp): + def init_dtype_type(self): + self.dtype = np.float16 + + +class TestGammalnBigNumberOp(TestGammalnOp): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.init_dtype_type() + self.shape = (100, 1) + self.x = np.random.random(self.shape).astype(self.dtype) + 1 + self.x[:5, 0] = np.array([1e5, 1e10, 1e20, 1e40, 1e80]) + self.inputs = {'x': self.x} + out = ref_gammaln(self.x) + self.outputs = {'out': out} + + def init_dtype_type(self): + self.dtype = np.float64 + + def test_check_grad(self): + d_out = self.outputs['out'] + d_x = ref_gammaln_grad(self.x, d_out) + self.check_grad( + ['x'], + 'out', + user_defined_grads=[ + d_x, + ], + user_defined_grad_outputs=[ + d_out, + ], + check_pir=True, + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda() + or not core.is_bfloat16_supported(core.CUDAPlace(0)), + "core is not compiled with CUDA or not support bfloat16", +) +class TestGammalnBF16Op(OpTest): + def setUp(self): + self.op_type = 'gammaln' + self.python_api = paddle.gammaln + self.dtype = np.uint16 + self.shape = (5, 30) + x = np.random.random(self.shape).astype("float32") + 1 + self.inputs = {'x': convert_float_to_uint16(x)} + out = ref_gammaln(x) + self.outputs = {'out': convert_float_to_uint16(out)} + + def test_check_output(self): + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + + def test_check_grad(self): + self.check_grad_with_place( + core.CUDAPlace(0), ['x'], 'out', check_pir=True + ) + + +class TestGammalnOpApi(unittest.TestCase): + def setUp(self): + self.shape = [2, 3, 4, 5] + self.dtype = "float64" + self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 + self.place = ( + paddle.CUDAPlace(0) + if core.is_compiled_with_cuda() + else paddle.CPUPlace() + ) + + def test_static_api(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.static.data('x', self.x_np.shape, self.x_np.dtype) + out = paddle.gammaln(x) + exe = paddle.static.Executor(self.place) + (res,) = exe.run(feed={'x': self.x_np}, fetch_list=[out]) + out_ref = ref_gammaln(self.x_np) + np.testing.assert_allclose(out_ref, res) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + out = paddle.gammaln(x) + out_ref = ref_gammaln(self.x_np) + np.testing.assert_allclose(out_ref, out.numpy()) + paddle.enable_static() + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main() From 7cfc8af203aea95b4dc3df2a9e89400d25035e55 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 28 Nov 2023 12:20:24 +0800 Subject: [PATCH 02/12] Merge branch 'develop' into add_gammaln_op --- python/paddle/tensor/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 7e0601d36319e..c698f05346356 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -265,6 +265,8 @@ frexp, gcd, gcd_, + gammaln, + gammaln_, heaviside, hypot, hypot_, From 19f61ebfb8fd376852f9a370ac13a6d9fed61376 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 18 Dec 2023 18:11:17 +0800 Subject: [PATCH 03/12] update ut --- test/legacy_test/test_gammaln_op.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py index ac7c8a6aa1606..8840c74b31bca 100644 --- a/test/legacy_test/test_gammaln_op.py +++ b/test/legacy_test/test_gammaln_op.py @@ -120,7 +120,7 @@ def test_check_grad(self): class TestGammalnOpApi(unittest.TestCase): def setUp(self): self.shape = [2, 3, 4, 5] - self.dtype = "float64" + self.init_dtype_type() self.x_np = np.random.random(self.shape).astype(self.dtype) + 1 self.place = ( paddle.CUDAPlace(0) @@ -128,6 +128,9 @@ def setUp(self): else paddle.CPUPlace() ) + def init_dtype_type(self): + self.dtype = "float64" + def test_static_api(self): paddle.enable_static() with paddle.static.program_guard(paddle.static.Program()): @@ -147,6 +150,11 @@ def test_dygraph_api(self): paddle.enable_static() +class TestGammalnOpApiFp32(TestGammalnOpApi): + def init_dtype_type(self): + self.dtype = "float32" + + if __name__ == "__main__": paddle.enable_static() unittest.main() From 35510a4f3f64d58fd997f14b50485c34350cf577 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 18 Dec 2023 18:25:51 +0800 Subject: [PATCH 04/12] fix bug --- python/paddle/__init__.py | 4 ++-- python/paddle/tensor/__init__.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 43126e07db302..4c00f893eb032 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -397,10 +397,10 @@ frac, frac_, frexp, - gcd, - gcd_, gammaln, gammaln_, + gcd, + gcd_, heaviside, hypot, hypot_, diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 1a744ea9ae3c6..32c243fee09bf 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -277,10 +277,10 @@ frac, frac_, frexp, - gcd, - gcd_, gammaln, gammaln_, + gcd, + gcd_, heaviside, hypot, hypot_, From 73fc65bf3ce44fae5bef41a017bfac6b1ca80deb Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 18 Dec 2023 19:34:59 +0800 Subject: [PATCH 05/12] fix bug --- test/legacy_test/test_gammaln_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py index 8840c74b31bca..bc61ece0e0116 100644 --- a/test/legacy_test/test_gammaln_op.py +++ b/test/legacy_test/test_gammaln_op.py @@ -139,14 +139,14 @@ def test_static_api(self): exe = paddle.static.Executor(self.place) (res,) = exe.run(feed={'x': self.x_np}, fetch_list=[out]) out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, res) + np.testing.assert_allclose(out_ref, res, rtol=1e-5, atol=1e-5) def test_dygraph_api(self): paddle.disable_static(self.place) x = paddle.to_tensor(self.x_np) out = paddle.gammaln(x) out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, out.numpy()) + np.testing.assert_allclose(out_ref, out.numpy(), rolt=1e-5, atol=1e-5) paddle.enable_static() From 788b2584a2c2005474fbd747d1685c3da3b2499f Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 18 Dec 2023 21:31:40 +0800 Subject: [PATCH 06/12] fix bug --- test/legacy_test/test_gammaln_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/legacy_test/test_gammaln_op.py b/test/legacy_test/test_gammaln_op.py index bc61ece0e0116..50331af5c7a34 100644 --- a/test/legacy_test/test_gammaln_op.py +++ b/test/legacy_test/test_gammaln_op.py @@ -146,7 +146,7 @@ def test_dygraph_api(self): x = paddle.to_tensor(self.x_np) out = paddle.gammaln(x) out_ref = ref_gammaln(self.x_np) - np.testing.assert_allclose(out_ref, out.numpy(), rolt=1e-5, atol=1e-5) + np.testing.assert_allclose(out_ref, out.numpy(), rtol=1e-5, atol=1e-5) paddle.enable_static() From ee0acc3ad5210dffabe13436d142a1db7f6b687e Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 19 Dec 2023 08:28:51 +0800 Subject: [PATCH 07/12] add test inplace --- test/legacy_test/test_inplace.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/legacy_test/test_inplace.py b/test/legacy_test/test_inplace.py index f06edfd83206c..6041469a2648b 100644 --- a/test/legacy_test/test_inplace.py +++ b/test/legacy_test/test_inplace.py @@ -869,6 +869,14 @@ def test_leaf_inplace_var_error(self): pass +class TestDygraphInplaceGammaln(TestDygraphInplaceWithContinuous): + def inplace_api_processing(self, var): + return paddle.gammaln_(var) + + def non_inplace_api_processing(self, var): + return paddle.gammaln(var) + + class TestDygraphInplaceNeg(TestDygraphInplaceWithContinuous): def inplace_api_processing(self, var): return paddle.neg_(var) From 5575e88c904aa329fb5ed70a35333b741227a44e Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 19 Dec 2023 14:36:24 +0800 Subject: [PATCH 08/12] fix bug --- paddle/phi/kernels/impl/gammaln_kernel_impl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h index 8ccd078b8ed72..e52771f9e49f2 100644 --- a/paddle/phi/kernels/impl/gammaln_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_kernel_impl.h @@ -27,7 +27,7 @@ struct GammalnFunctor { HOSTDEVICE void operator()(int64_t idx) const { using MT = typename phi::dtype::MPTypeTrait::Type; const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast(std::lgamma(mp_x)); + output_[idx] = static_cast(Eigen::numext::lgamma(mp_x)); } private: From cbcfa8772cae208ca522b0902b9f399c1b0a0789 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Tue, 19 Dec 2023 15:56:24 +0800 Subject: [PATCH 09/12] fix bug --- paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h | 3 ++- paddle/phi/kernels/impl/gammaln_kernel_impl.h | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h index f588a28afbd10..706b98dbcdf8f 100644 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -15,10 +15,11 @@ #pragma once #include +#include "unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h" #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/for_range.h" -#include "unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h" + namespace phi { template struct GammalnGradFunctor { diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h index e52771f9e49f2..20b783b17f026 100644 --- a/paddle/phi/kernels/impl/gammaln_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_kernel_impl.h @@ -14,9 +14,11 @@ #pragma once +#include +#include "unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h" + #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/for_range.h" -#include "paddle/phi/kernels/gammaln_kernel.h" namespace phi { template From 4f45937f0c8a53e4920ab35dc9d466cc62dde93d Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Wed, 20 Dec 2023 16:25:04 +0800 Subject: [PATCH 10/12] fix --- python/paddle/tensor/math.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index bc79e384daaea..acb3fced70aea 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -5004,7 +5004,7 @@ def gammaln(x, name=None): Calculates the logarithm of the absolute value of the gamma function elementwisely. Args: - x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, uint16. + x (Tensor): Input Tensor. Must be one of the following types: float16, float32, float64, bfloat16. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -5026,7 +5026,7 @@ def gammaln(x, name=None): return _C_ops.gammaln(x) else: check_variable_and_dtype( - x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'gammaln' + x, 'x', ['float16', 'float32', 'float64', 'bfloat16'], 'gammaln' ) helper = LayerHelper('gammaln', **locals()) out = helper.create_variable_for_type_inference(x.dtype) From e372c305d721af6cfc1567e581db17e8f7fa5a09 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 25 Dec 2023 00:17:40 +0800 Subject: [PATCH 11/12] update code --- .../kernels/impl/gammaln_grad_kernel_impl.h | 46 +++++++++++++++++-- paddle/phi/kernels/impl/gammaln_kernel_impl.h | 5 +- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h index 706b98dbcdf8f..b181cde6b2fa7 100644 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -14,13 +14,49 @@ #pragma once -#include -#include "unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h" - #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/for_range.h" namespace phi { +template +HOSTDEVICE T digamma(T x) { + static T c = T{8.5}; + static T euler_mascheroni = T{0.57721566490153286060}; + T r; + T value; + T x2; + + if (x <= T{0.0}) { + value = T{0.0}; + return value; + } + + if (x <= T{0.000001}) { + value = -euler_mascheroni - T{1.0} / x + T{1.6449340668482264365} * x; + return value; + } + + value = T{0.0}; + x2 = x; + while (x2 < c) { + value = value - T{1.0} / x2; + x2 = x2 + T{1.0}; + } + + r = T{1.0} / x2; + value = value + std::log(x2) - T{0.5} * r; + + r = r * r; + + value = value - + r * (T{1.0} / T{12.0} - + r * (T{1.0} / T{120.0} - + r * (T{1.0} / T{252.0} - + r * (T{1.0} / T{240.0} - r * (T{1.0} / T{132.0}))))); + + return value; +} + template struct GammalnGradFunctor { GammalnGradFunctor(const T* dout, const T* x, T* output, int64_t numel) @@ -30,8 +66,8 @@ struct GammalnGradFunctor { using MT = typename phi::dtype::MPTypeTrait::Type; const MT mp_dout = static_cast(dout_[idx]); const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast( - mp_dout * Eigen::numext::polygamma(static_cast(0), mp_x)); + const auto one = MT{1}; + output_[idx] = static_cast(mp_dout * digamma(mp_x)); } private: diff --git a/paddle/phi/kernels/impl/gammaln_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_kernel_impl.h index 20b783b17f026..38385610de0de 100644 --- a/paddle/phi/kernels/impl/gammaln_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_kernel_impl.h @@ -14,9 +14,6 @@ #pragma once -#include -#include "unsupported/Eigen/src/SpecialFunctions/SpecialFunctionsImpl.h" - #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/for_range.h" @@ -29,7 +26,7 @@ struct GammalnFunctor { HOSTDEVICE void operator()(int64_t idx) const { using MT = typename phi::dtype::MPTypeTrait::Type; const MT mp_x = static_cast(x_[idx]); - output_[idx] = static_cast(Eigen::numext::lgamma(mp_x)); + output_[idx] = static_cast(std::lgamma(mp_x)); } private: From 97f049aa90d01e55838e196e5f610346f77dc984 Mon Sep 17 00:00:00 2001 From: Wang Xin Date: Mon, 25 Dec 2023 00:46:22 +0800 Subject: [PATCH 12/12] fix bug --- paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h index b181cde6b2fa7..50c73cff27ce4 100644 --- a/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/gammaln_grad_kernel_impl.h @@ -66,7 +66,6 @@ struct GammalnGradFunctor { using MT = typename phi::dtype::MPTypeTrait::Type; const MT mp_dout = static_cast(dout_[idx]); const MT mp_x = static_cast(x_[idx]); - const auto one = MT{1}; output_[idx] = static_cast(mp_dout * digamma(mp_x)); }