-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【PaddlePaddle Hackathon 3 No.47】为 Paddle 新增 logsumexp support fp16 #45817
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,8 +15,11 @@ | |
#include "paddle/phi/kernels/logsumexp_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/common/float16.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/logsumexp_kernel_impl.h" | ||
|
||
using float16 = phi::dtype::float16; | ||
|
||
PD_REGISTER_KERNEL( | ||
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double) {} | ||
logsumexp, GPU, ALL_LAYOUT, phi::LogsumexpKernel, float, double, float16) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LogsumexpFunctor实现中存在exp、log,需要使用float作为计算类型。修改方式可以参考下#45952 。 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,13 +16,15 @@ | |
#include <type_traits> | ||
#include <vector> | ||
|
||
#include "paddle/phi/common/amp_type_traits.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#include "paddle/phi/kernels/funcs/reduce_grad_functions.h" | ||
#include "paddle/phi/kernels/logsumexp_grad_kernel.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T> | ||
struct LogsumexpGradFunctor { | ||
template <typename Context, | ||
typename X, | ||
|
@@ -37,7 +39,13 @@ struct LogsumexpGradFunctor { | |
DY* dy, | ||
const Dim& dim, | ||
int size) { | ||
dx->device(place) = dy->broadcast(dim) * (*x - y->broadcast(dim)).exp(); | ||
using MT = typename phi::dtype::MPTypeTrait<T>::Type; | ||
auto x_mt = (*x).template cast<MT>(); | ||
auto y_mt = (*y).template cast<MT>(); | ||
auto dy_mt = (*dy).template cast<MT>(); | ||
dx->device(place) = | ||
(dy_mt.broadcast(dim) * (x_mt - y_mt.broadcast(dim)).exp()) | ||
.template cast<T>(); | ||
} | ||
}; | ||
|
||
|
@@ -62,33 +70,38 @@ void LogsumexpGradKernel(const Context& dev_ctx, | |
auto dx = phi::EigenVector<T>::Flatten(*in_grad); | ||
auto& place = *dev_ctx.eigen_device(); | ||
auto broadcast_dim = Eigen::array<int, 1>({{static_cast<int>(in.numel())}}); | ||
LogsumexpGradFunctor()( | ||
LogsumexpGradFunctor<T>()( | ||
place, &x, &y, &dx, &dy, broadcast_dim, broadcast_dim[0]); | ||
} else { | ||
int rank = in.dims().size(); | ||
LogsumexpGradFunctor functor; | ||
LogsumexpGradFunctor<T> functor; | ||
std::vector<int32_t> axis32; | ||
axis32.reserve(axis.size()); | ||
std::for_each(axis.begin(), axis.end(), [&axis32](const int64_t& t) { | ||
axis32.push_back(t); | ||
}); | ||
switch (rank) { | ||
case 1: | ||
phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor>( | ||
phi::funcs::ReduceGradFunctor<Context, T, 1, LogsumexpGradFunctor<T>>( | ||
dev_ctx, in, out, out_grad, in_grad, functor, axis32); | ||
break; | ||
case 2: | ||
phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor>( | ||
phi::funcs::ReduceGradFunctor<Context, T, 2, LogsumexpGradFunctor<T>>( | ||
dev_ctx, in, out, out_grad, in_grad, functor, axis32); | ||
break; | ||
case 3: | ||
phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor>( | ||
phi::funcs::ReduceGradFunctor<Context, T, 3, LogsumexpGradFunctor<T>>( | ||
dev_ctx, in, out, out_grad, in_grad, functor, axis32); | ||
break; | ||
case 4: | ||
phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor>( | ||
phi::funcs::ReduceGradFunctor<Context, T, 4, LogsumexpGradFunctor<T>>( | ||
dev_ctx, in, out, out_grad, in_grad, functor, axis32); | ||
break; | ||
default: | ||
PADDLE_THROW(phi::errors::Unimplemented( | ||
"Unsupported dimensions, please keep maximum dimensions of input " | ||
"data less than 4.")); | ||
break; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 建议在这里加一下default的处理,对于大于4维的输入使用 |
||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,22 +16,25 @@ | |
#include <type_traits> | ||
#include <vector> | ||
|
||
#include "paddle/phi/common/amp_type_traits.h" | ||
#include "paddle/phi/kernels/funcs/eigen/common.h" | ||
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" | ||
#include "paddle/phi/kernels/funcs/reduce_function.h" | ||
#include "paddle/phi/kernels/logsumexp_kernel.h" | ||
|
||
namespace phi { | ||
|
||
#define HANDLE_DIM(NDIM, RDIM) \ | ||
if (ndim == NDIM && rdim == RDIM) { \ | ||
funcs::ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor>( \ | ||
dev_ctx, x, out, axis, keepdim); \ | ||
#define HANDLE_DIM(NDIM, RDIM) \ | ||
if (ndim == NDIM && rdim == RDIM) { \ | ||
funcs::ReduceFunctor<Context, T, NDIM, RDIM, LogsumexpFunctor<T>>( \ | ||
dev_ctx, x, out, axis, keepdim); \ | ||
} | ||
|
||
template <typename T> | ||
struct LogsumexpFunctor { | ||
template <typename Context, typename X, typename Y, typename Dim> | ||
void operator()(const Context& place, X* x, Y* y, const Dim& dim) { | ||
using MT = typename phi::dtype::MPTypeTrait<T>::Type; | ||
auto x_dim = x->dimensions(); | ||
auto t_dim = x_dim; | ||
for (int i = 0; i < static_cast<int>(dim.size()); i++) { | ||
|
@@ -46,12 +49,14 @@ struct LogsumexpFunctor { | |
r_dim[dim[i]] = x_dim[dim[i]]; | ||
} | ||
|
||
auto x_mt = (*x).template cast<MT>(); | ||
auto y_dim = y->dimensions(); | ||
auto x_max = x->maximum(dim); | ||
auto x_max = x_mt.maximum(dim); | ||
y->device(place) = | ||
(x_max + | ||
(*x - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) | ||
.reshape(y_dim); | ||
(x_mt - x_max.reshape(t_dim).broadcast(r_dim)).exp().sum(dim).log()) | ||
.reshape(y_dim) | ||
.template cast<T>(); | ||
} | ||
}; | ||
|
||
|
@@ -74,10 +79,16 @@ void LogsumexpKernel(const Context& dev_ctx, | |
auto output = phi::EigenScalar<T>::From(*out); | ||
auto& place = *dev_ctx.eigen_device(); | ||
auto reduce_dim = Eigen::array<int, 1>({{0}}); | ||
LogsumexpFunctor()(place, &input, &output, reduce_dim); | ||
LogsumexpFunctor<T>()(place, &input, &output, reduce_dim); | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 对于不支持的维度,麻烦帮忙加一个报错吧。 |
||
int ndim = input_dim_size; | ||
int rdim = axis.size(); | ||
if (ndim > 4) { | ||
PADDLE_THROW(phi::errors::Unimplemented( | ||
"Unsupported dimensions, please keep maximum dimensions of input " | ||
"data less than 4.")); | ||
} | ||
|
||
// comments for accelerating compiling temporarily. | ||
// HANDLE_DIM(6, 5); | ||
// HANDLE_DIM(6, 4); | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -15,6 +15,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import paddle | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import unittest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
import paddle.fluid.core as core | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
from op_test import OpTest | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -35,6 +36,22 @@ def logsumexp_wrapper(x, axis=None, keepdim=False, allreduce=False): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return paddle.logsumexp(x, axis, keepdim) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def logsumexp_op_grad(x, axis=None, keepdim=False, reduce_all=False): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
paddle.disable_static() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tensor_x = paddle.to_tensor(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tensor_x.stop_gradient = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out = logsumexp_wrapper(tensor_x, axis, keepdim, reduce_all) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
grad = paddle.grad(out, [tensor_x]) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_grad = grad[0].numpy() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
paddle.enable_static() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return x_grad | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 目前已经不推荐使用fluid的接口,这里建议参考如下单测:
Paddle/python/paddle/fluid/tests/unittests/test_layer_norm_op.py Lines 346 to 388 in 97ec57f
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def logsumexp_ref_grad(x): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
sum = np.exp(x).sum() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return np.exp(x) / sum | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果输入是fp16,这里计算过程都是fp16也会损失精度,这个参考值应该不够精确。建议这里计算过程也用fp32 |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class TestLogsumexp(OpTest): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def setUp(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -125,6 +142,47 @@ def set_attrs_addition(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.user_defined_grad_outputs = [np.ones(1, dtype=self.dtype)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class TestLogsumexp_FP32(TestLogsumexp): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def set_attrs(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.dtype = 'float32' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def test_check_grad(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.__class__.dtype = self.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_grad = logsumexp_op_grad(self.inputs['X']) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_x_grad = logsumexp_ref_grad(self.inputs['X']) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-08, atol=1e-08) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
@unittest.skipIf(not core.is_compiled_with_cuda(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
"core is not compiled with CUDA") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class TestLogsumexp_FP16(TestLogsumexp): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fp16单测可添加如下装饰器,跳过CPU上的执行: @unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA") 另外,若单测精度无法通过,可尝试修改单测中的atol、rtol、max_relative_error。对于float16来说,有效位数只有3位,设置成1e-3也是合理的。 |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def set_attrs(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.dtype = 'float16' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
该单测中,可以重写下test_check_output和test_check_grad函数,并且指定大一些的atol、max_relative_error精度阈值。 |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def test_check_output(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_x = self.inputs['X'].astype(np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out_ref = ref_logsumexp(ref_x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
paddle.disable_static() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = self.inputs['X'].astype(np.float16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
tensor_x = paddle.to_tensor(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out_pad = logsumexp_wrapper(tensor_x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
paddle.enable_static() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
np.testing.assert_allclose(out_pad.numpy(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
out_ref, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
rtol=1e-03, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
atol=1e-08) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def test_check_grad(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.__class__.dtype = self.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_x = self.inputs['X'].astype(np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
ref_x_grad = logsumexp_ref_grad(ref_x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x = self.inputs['X'].astype(np.float16) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_grad = logsumexp_op_grad(x) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
np.testing.assert_allclose(x_grad, ref_x_grad, rtol=1e-03, atol=1e-05) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
class TestLogsumexpError(unittest.TestCase): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
def test_errors(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LogsumexpGradFunctor实现中存在exp,也需要使用float作为计算类型。