diff --git a/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu index 490b3e9404561..0599c7e4f06a0 100644 --- a/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu @@ -15,8 +15,16 @@ #include "paddle/phi/kernels/logsumexp_grad_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - logsumexp_grad, GPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {} +using float16 = phi::dtype::float16; + +PD_REGISTER_KERNEL(logsumexp_grad, + GPU, + ALL_LAYOUT, + phi::LogsumexpGradKernel, + float, + double, + float16) {} diff --git a/paddle/phi/kernels/gpu/logsumexp_kernel.cu b/paddle/phi/kernels/gpu/logsumexp_kernel.cu index 18249ff3bafb0..7963808476ded 100644 --- a/paddle/phi/kernels/gpu/logsumexp_kernel.cu +++ b/paddle/phi/kernels/gpu/logsumexp_kernel.cu @@ -15,8 +15,11 @@ #include "paddle/phi/kernels/logsumexp_kernel.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" +using float16 = phi::dtype::float16; + PD_REGISTER_KERNEL( - logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} + logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {} diff --git a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h index 23e4414858a78..7e5b5ca4f8d4e 100644 --- a/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h @@ -16,6 +16,7 @@ #include #include +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/reduce_grad_functions.h" @@ -23,6 +24,7 @@ namespace phi { +template struct LogsumexpGradFunctor { template device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); + using MT = typename phi::dtype::MPTypeTrait::Type; + auto x_mt = (*x).template cast(); + auto y_mt = (*y).template cast(); + auto dy_mt = (*dy).template cast(); + dx->device(place) = + (dy_mt.broadcast(dim) * (x_mt - y_mt.broadcast(dim)).exp()) + .template cast(); } }; @@ -62,11 +70,11 @@ void LogsumexpGradKernel(const Context& dev_ctx, auto dx = phi::EigenVector::Flatten(*in_grad); auto& place = *dev_ctx.eigen_device(); auto broadcast_dim = Eigen::array({{static_cast(in.numel())}}); - LogsumexpGradFunctor()( + LogsumexpGradFunctor()( place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]); } else { int rank = in.dims().size(); - LogsumexpGradFunctor functor; + LogsumexpGradFunctor functor; std::vector axis32; axis32.reserve(axis.size()); std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) { @@ -74,21 +82,26 @@ void LogsumexpGradKernel(const Context& dev_ctx, }); switch (rank) { case 1: - phi::funcs::ReduceGradFunctor( + phi::funcs::ReduceGradFunctor>( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 2: - phi::funcs::ReduceGradFunctor( + phi::funcs::ReduceGradFunctor>( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 3: - phi::funcs::ReduceGradFunctor( + phi::funcs::ReduceGradFunctor>( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; case 4: - phi::funcs::ReduceGradFunctor( + phi::funcs::ReduceGradFunctor>( dev_ctx, in, out, out_grad, in_grad, functor, axis32); break; + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported dimensions, please keep maximum dimensions of input " + "data less than 4.")); + break; } } } diff --git a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h index 7c2eadcb7df78..30a118a1317ba 100644 --- a/paddle/phi/kernels/impl/logsumexp_kernel_impl.h +++ b/paddle/phi/kernels/impl/logsumexp_kernel_impl.h @@ -16,6 +16,7 @@ #include #include +#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/reduce_function.h" @@ -23,15 +24,17 @@ namespace phi { -#define HANDLE_DIM(NDIM, RDIM) \ - if (ndim == NDIM && rdim == RDIM) { \ - funcs::ReduceFunctor( \ - dev_ctx, x, out, axis, keepdim); \ +#define HANDLE_DIM(NDIM, RDIM) \ + if (ndim == NDIM && rdim == RDIM) { \ + funcs::ReduceFunctor>( \ + dev_ctx, x, out, axis, keepdim); \ } +template struct LogsumexpFunctor { template void operator()(const Context& place, X* x, Y* y, const Dim& dim) { + using MT = typename phi::dtype::MPTypeTrait::Type; auto x_dim = x->dimensions(); auto t_dim = x_dim; for (int i = 0; i < static_cast(dim.size()); i++) { @@ -46,12 +49,14 @@ struct LogsumexpFunctor { r_dim[dim[i]] = x_dim[dim[i]]; } + auto x_mt = (*x).template cast(); auto y_dim = y->dimensions(); - auto x_max = x->maximum(dim); + auto x_max = x_mt.maximum(dim); y->device(place) = (x_max + - (*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) - .reshape(y_dim); + (x_mt - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) + .reshape(y_dim) + .template cast(); } }; @@ -74,10 +79,16 @@ void LogsumexpKernel(const Context& dev_ctx, auto output = phi::EigenScalar::From(*out); auto& place = *dev_ctx.eigen_device(); auto reduce_dim = Eigen::array({{0}}); - LogsumexpFunctor()(place, &input, &output, reduce_dim); + LogsumexpFunctor()(place, &input, &output, reduce_dim); } else { int ndim = input_dim_size; int rdim = axis.size(); + if (ndim > 4) { + PADDLE_THROW(phi::errors::Unimplemented( + "Unsupported dimensions, please keep maximum dimensions of input " + "data less than 4.")); + } + // comments for accelerating compiling temporarily. // HANDLE_DIM(6, 5); // HANDLE_DIM(6, 4); diff --git a/python/paddle/fluid/tests/unittests/test_logsumexp.py b/python/paddle/fluid/tests/unittests/test_logsumexp.py index 922349cab36db..ca732d6b0a766 100644 --- a/python/paddle/fluid/tests/unittests/test_logsumexp.py +++ b/python/paddle/fluid/tests/unittests/test_logsumexp.py @@ -15,6 +15,7 @@ import paddle import unittest import numpy as np +import paddle.fluid.core as core from op_test import OpTest @@ -35,6 +36,22 @@ def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False): return paddle.logsumexp(x, axis, keepdim) +def logsumexp_op_grad(x, axis=None, keepdim=False, reduce_all=False): + paddle.disable_static() + tensor_x = paddle.to_tensor(x) + tensor_x.stop_gradient = False + out = logsumexp_wrapper(tensor_x, axis, keepdim, reduce_all) + grad = paddle.grad(out, [tensor_x]) + x_grad = grad[0].numpy() + paddle.enable_static() + return x_grad + + +def logsumexp_ref_grad(x): + sum = np.exp(x).sum() + return np.exp(x) / sum + + class TestLogsumexp(OpTest): def setUp(self): @@ -125,6 +142,47 @@ def set_attrs_addition(self): self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)] +class TestLogsumexp_FP32(TestLogsumexp): + + def set_attrs(self): + self.dtype = 'float32' + + def test_check_grad(self): + self.__class__.dtype = self.dtype + x_grad = logsumexp_op_grad(self.inputs['X']) + ref_x_grad = logsumexp_ref_grad(self.inputs['X']) + np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-08, atol=1e-08) + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestLogsumexp_FP16(TestLogsumexp): + + def set_attrs(self): + self.dtype = 'float16' + + def test_check_output(self): + ref_x = self.inputs['X'].astype(np.float32) + out_ref = ref_logsumexp(ref_x) + paddle.disable_static() + x = self.inputs['X'].astype(np.float16) + tensor_x = paddle.to_tensor(x) + out_pad = logsumexp_wrapper(tensor_x) + paddle.enable_static() + np.testing.assert_allclose(out_pad.numpy(), + out_ref, + rtol=1e-03, + atol=1e-08) + + def test_check_grad(self): + self.__class__.dtype = self.dtype + ref_x = self.inputs['X'].astype(np.float32) + ref_x_grad = logsumexp_ref_grad(ref_x) + x = self.inputs['X'].astype(np.float16) + x_grad = logsumexp_op_grad(x) + np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05) + + class TestLogsumexpError(unittest.TestCase): def test_errors(self):