Skip to content

Commit

Permalink
logsumexp support fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohemaikoo committed Oct 13, 2022
1 parent 5cd6a70 commit 5ff6557
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 18 deletions.
12 changes: 10 additions & 2 deletions paddle/phi/kernels/gpu/logsumexp_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h"

PD_REGISTER_KERNEL(
logsumexp_grad, GPU, ALL_LAYOUT, phi::LogsumexpGradKernel, float, double) {}
using float16 = phi::dtype::float16;

PD_REGISTER_KERNEL(logsumexp_grad,
GPU,
ALL_LAYOUT,
phi::LogsumexpGradKernel,
float,
double,
float16) {}
5 changes: 4 additions & 1 deletion paddle/phi/kernels/gpu/logsumexp_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
#include "paddle/phi/kernels/logsumexp_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h"

using float16 = phi::dtype::float16;

PD_REGISTER_KERNEL(
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {}
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {}
27 changes: 20 additions & 7 deletions paddle/phi/kernels/impl/logsumexp_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
#include <type_traits>
#include <vector>

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h"
#include "paddle/phi/kernels/logsumexp_grad_kernel.h"

namespace phi {

template <typename T>
struct LogsumexpGradFunctor {
template <typename Context,
typename X,
Expand All @@ -37,7 +39,13 @@ struct LogsumexpGradFunctor {
DY* dy,
const Dim& dim,
int size) {
dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp();
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
auto x_mt = (*x).template cast<MT>();
auto y_mt = (*y).template cast<MT>();
auto dy_mt = (*dy).template cast<MT>();
dx->device(place) =
(dy_mt.broadcast(dim) * (x_mt - y_mt.broadcast(dim)).exp())
.template cast<T>();
}
};

Expand All @@ -62,33 +70,38 @@ void LogsumexpGradKernel(const Context& dev_ctx,
auto dx = phi::EigenVector<T>::Flatten(*in_grad);
auto& place = *dev_ctx.eigen_device();
auto broadcast_dim = Eigen::array<int, 1>({{static_cast<int>(in.numel())}});
LogsumexpGradFunctor()(
LogsumexpGradFunctor<T>()(
place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]);
} else {
int rank = in.dims().size();
LogsumexpGradFunctor functor;
LogsumexpGradFunctor<T> functor;
std::vector<int32_t> axis32;
axis32.reserve(axis.size());
std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) {
axis32.push_back(t);
});
switch (rank) {
case 1:
phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor>(
phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break;
case 2:
phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor>(
phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break;
case 3:
phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor>(
phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break;
case 4:
phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor>(
phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor<T>>(
dev_ctx, in, out, out_grad, in_grad, functor, axis32);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported dimensions, please keep maximum dimensions of input "
"data less than 4."));
break;
}
}
}
Expand Down
27 changes: 19 additions & 8 deletions paddle/phi/kernels/impl/logsumexp_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,25 @@
#include <type_traits>
#include <vector>

#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/logsumexp_kernel.h"

namespace phi {

#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
funcs::ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor>( \
dev_ctx, x, out, axis, keepdim); \
#define HANDLE_DIM(NDIM, RDIM) \
if (ndim == NDIM && rdim == RDIM) { \
funcs::ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor<T>>( \
dev_ctx, x, out, axis, keepdim); \
}

template <typename T>
struct LogsumexpFunctor {
template <typename Context, typename X, typename Y, typename Dim>
void operator()(const Context& place, X* x, Y* y, const Dim& dim) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
auto x_dim = x->dimensions();
auto t_dim = x_dim;
for (int i = 0; i < static_cast<int>(dim.size()); i++) {
Expand All @@ -46,12 +49,14 @@ struct LogsumexpFunctor {
r_dim[dim[i]] = x_dim[dim[i]];
}

auto x_mt = (*x).template cast<MT>();
auto y_dim = y->dimensions();
auto x_max = x->maximum(dim);
auto x_max = x_mt.maximum(dim);
y->device(place) =
(x_max +
(*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())
.reshape(y_dim);
(x_mt - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log())
.reshape(y_dim)
.template cast<T>();
}
};

Expand All @@ -74,10 +79,16 @@ void LogsumexpKernel(const Context& dev_ctx,
auto output = phi::EigenScalar<T>::From(*out);
auto& place = *dev_ctx.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
LogsumexpFunctor()(place, &input, &output, reduce_dim);
LogsumexpFunctor<T>()(place, &input, &output, reduce_dim);
} else {
int ndim = input_dim_size;
int rdim = axis.size();
if (ndim > 4) {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported dimensions, please keep maximum dimensions of input "
"data less than 4."));
}

// comments for accelerating compiling temporarily.
// HANDLE_DIM(6, 5);
// HANDLE_DIM(6, 4);
Expand Down
58 changes: 58 additions & 0 deletions python/paddle/fluid/tests/unittests/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest


Expand All @@ -35,6 +36,22 @@ def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False):
return paddle.logsumexp(x, axis, keepdim)


def logsumexp_op_grad(x, axis=None, keepdim=False, reduce_all=False):
paddle.disable_static()
tensor_x = paddle.to_tensor(x)
tensor_x.stop_gradient = False
out = logsumexp_wrapper(tensor_x, axis, keepdim, reduce_all)
grad = paddle.grad(out, [tensor_x])
x_grad = grad[0].numpy()
paddle.enable_static()
return x_grad


def logsumexp_ref_grad(x):
sum = np.exp(x).sum()
return np.exp(x) / sum


class TestLogsumexp(OpTest):

def setUp(self):
Expand Down Expand Up @@ -125,6 +142,47 @@ def set_attrs_addition(self):
self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)]


class TestLogsumexp_FP32(TestLogsumexp):

def set_attrs(self):
self.dtype = 'float32'

def test_check_grad(self):
self.__class__.dtype = self.dtype
x_grad = logsumexp_op_grad(self.inputs['X'])
ref_x_grad = logsumexp_ref_grad(self.inputs['X'])
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-08, atol=1e-08)


@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestLogsumexp_FP16(TestLogsumexp):

def set_attrs(self):
self.dtype = 'float16'

def test_check_output(self):
ref_x = self.inputs['X'].astype(np.float32)
out_ref = ref_logsumexp(ref_x)
paddle.disable_static()
x = self.inputs['X'].astype(np.float16)
tensor_x = paddle.to_tensor(x)
out_pad = logsumexp_wrapper(tensor_x)
paddle.enable_static()
np.testing.assert_allclose(out_pad.numpy(),
out_ref,
rtol=1e-03,
atol=1e-08)

def test_check_grad(self):
self.__class__.dtype = self.dtype
ref_x = self.inputs['X'].astype(np.float32)
ref_x_grad = logsumexp_ref_grad(ref_x)
x = self.inputs['X'].astype(np.float16)
x_grad = logsumexp_op_grad(x)
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05)


class TestLogsumexpError(unittest.TestCase):

def test_errors(self):
Expand Down

0 comments on commit 5ff6557

Please sign in to comment.