Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon 4 No.35】为 Paddle 优化 prelu op 在 GPU 上的计算性能 #51131

Merged
merged 8 commits into from
Mar 15, 2023
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/index_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ __global__ void VectorizedIndexKernel(T *out,
out + data_offset, &result[0], BLOCK_NUM_X * VecSize);
}
size_t num = numel - data_offset;
if (num > 0) {
if (static_cast<int>(num) > 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里感觉不需要做 static_cast 的转换操作.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前跑benchmark的时候,定位了很久才发现这里一直报错,所以才修改了. @JamesLim-sy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原来是这样,那这里保持即可,不用做修改.

kps::InitWithDataIndex<size_t, VecSize, 1>(&args[0], data_offset);
kps::ElementwiseUnary<size_t, T, VecSize, 1, Functor>(
&result[0], &args[0], func);
Expand Down
198 changes: 62 additions & 136 deletions paddle/phi/kernels/gpu/prelu_funcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,157 +28,83 @@ inline static int PADDLE_GET_BLOCKS(const int N) {
}

template <typename T>
__global__ void PReluChannelFirstWiseKernel(const T *input,
const T *alpha,
T *output,
size_t channel_num,
size_t plane_size,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t temp = index / plane_size;
size_t channel_index = temp % channel_num;
T scale = alpha[channel_index];
T x = input[index];
struct PReluChannelFirstWiseCUDAFunctor {
const T* x_;
const T* alpha_;
size_t channel_num_;
size_t plane_size_;
int numel_;

HOSTDEVICE inline PReluChannelFirstWiseCUDAFunctor(const T* x,
const T* alpha,
int numel,
size_t channel_num,
size_t plane_size)
: x_(x),
alpha_(alpha),
numel_(numel),
channel_num_(channel_num),
plane_size_(plane_size) {}

HOSTDEVICE inline T operator()(const unsigned int n) const {
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
size_t temp = n / plane_size_;
size_t channel_index = temp % channel_num_;
T scale = alpha_[channel_index];
T x = x_[n];
return (x > zero) ? x : scale * x;
}
}
};

template <typename T>
__global__ void PReluChannelLastWiseKernel(const T *input,
const T *alpha,
T *output,
size_t channel_num,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t channel_index = index % channel_num;
T scale = alpha[channel_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
struct PReluChannelLastWiseCUDAFunctor {
const T* x_;
const T* alpha_;
size_t channel_num_;

template <typename T>
__global__ void PReluElementWiseKernel(const T *input,
const T *alpha,
T *output,
size_t spatial_size,
size_t numel) {
CUDA_KERNEL_LOOP(index, numel) {
size_t element_index = index % spatial_size;
T scale = alpha[element_index];
T x = input[index];
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
}
}
HOSTDEVICE inline PReluChannelLastWiseCUDAFunctor(const T* x,
const T* alpha,
size_t channel_num)
: x_(x), alpha_(alpha), channel_num_(channel_num) {}

template <typename T>
__global__ void PReluScalarKernel(const T *input,
const T *alpha,
T *output,
size_t numel) {
T scale = alpha[0];
CUDA_KERNEL_LOOP(index, numel) {
T x = input[index];
HOSTDEVICE inline T operator()(const unsigned int n) const {
T zero = static_cast<T>(0);
output[index] = (x > zero) ? x : scale * x;
size_t channel_index = n % channel_num_;
T scale = alpha_[channel_index];
T x = x_[n];
return (x > zero) ? x : scale * x;
}
}

template <typename T>
class PreluChannelWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t channel,
bool channel_last,
size_t numel);
};

template <typename T>
class PreluElementWiseDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t numel);
};
struct PreluElementWiseDirectCUDAFunctor {
const T* x_;
const T* alpha_;
size_t spatial_size_;

template <typename T>
class PreluScalarDirectCUDAFunctor {
public:
void operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t numel);
};
HOSTDEVICE inline PreluElementWiseDirectCUDAFunctor(const T* x,
const T* alpha,
size_t spatial_size)
: x_(x), alpha_(alpha), spatial_size_(spatial_size) {}

template <typename T>
void PreluChannelWiseDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t channel,
bool channel_last,
size_t numel) {
if (channel_last) {
PReluChannelLastWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
stream>>>(
input, alpha, output, channel, numel);
} else {
PReluChannelFirstWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
stream>>>(
input, alpha, output, channel, numel / batch_size / channel, numel);
HOSTDEVICE inline T operator()(const unsigned int n) const {
T zero = static_cast<T>(0);
size_t element_index = n % spatial_size_;
T scale = alpha_[element_index];
T x = x_[n];
return (x > zero) ? x : scale * x;
}
}

template <typename T>
void PreluElementWiseDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t batch_size,
size_t numel) {
PReluElementWiseKernel<<<PADDLE_GET_BLOCKS(numel),
CUDA_NUM_THREADS,
0,
stream>>>(
input, alpha, output, numel / batch_size, numel);
}
};

template <typename T>
void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
const T *input,
const T *alpha,
T *output,
size_t numel) {
PReluScalarKernel<<<PADDLE_GET_BLOCKS(numel), CUDA_NUM_THREADS, 0, stream>>>(
input, alpha, output, numel);
}

template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<phi::dtype::float16>;
template class PreluChannelWiseDirectCUDAFunctor<double>;

template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<phi::dtype::float16>;
template class PreluElementWiseDirectCUDAFunctor<double>;

template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<phi::dtype::float16>;
template class PreluScalarDirectCUDAFunctor<double>;
struct PreluScalarDirectCUDAFunctor {
const T* scalar_;
HOSTDEVICE inline PreluScalarDirectCUDAFunctor(const T* scalar)
: scalar_(scalar) {}
HOSTDEVICE inline T operator()(const T x) const {
T zero = static_cast<T>(0);
return (x > zero) ? x : scalar_[0] * x;
}
};

} // namespace phi
43 changes: 26 additions & 17 deletions paddle/phi/kernels/gpu/prelu_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
#include "paddle/phi/kernels/gpu/prelu_funcs.h"

namespace phi {
Expand All @@ -27,36 +29,43 @@ void PReluKernel(const Context& dev_ctx,
const std::string& data_format,
const std::string& mode,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
const T* x_ptr = x.data<T>();
T* o_ptr = dev_ctx.template Alloc<T>(out);

const T* alpha_ptr = alpha.data<T>();

int numel = x.numel();
auto dim = x.dims();
auto x_rank = dim.size();

VLOG(4) << "dim[0]:" << dim[0] << ", dim[1]:" << dim[1] << ", dim["
<< x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel;
<< x_rank - 1 << "]:" << dim[x_rank - 1] << ", numel:" << numel
<< ", mode:" << mode << ", format:" << data_format;

if (mode == "channel") {
bool channel_last = data_format == "NHWC";
size_t channel = channel_last ? dim[x_rank - 1] : dim[1];
PreluChannelWiseDirectCUDAFunctor<T> prelu_channel_wise;
prelu_channel_wise(dev_ctx.stream(),
x_ptr,
alpha_ptr,
o_ptr,
dim[0],
channel,
channel_last,
numel);
if (channel_last) {
auto func = PReluChannelLastWiseCUDAFunctor<T>(x_ptr, alpha_ptr, channel);
phi::IndexKernel<T, PReluChannelLastWiseCUDAFunctor<T>>(
dev_ctx, out, func);
} else {
size_t plane_size = numel / dim[0] / channel;
auto func = PReluChannelFirstWiseCUDAFunctor<T>(
x_ptr, alpha_ptr, numel, channel, plane_size);
phi::IndexKernel<T, PReluChannelFirstWiseCUDAFunctor<T>>(
dev_ctx, out, func);
}
} else if (mode == "element") {
PreluElementWiseDirectCUDAFunctor<T> prelu_element_wise;
prelu_element_wise(
dev_ctx.stream(), x_ptr, alpha_ptr, o_ptr, dim[0], numel);
size_t spatial_size = numel / dim[0];
auto func =
PreluElementWiseDirectCUDAFunctor<T>(x_ptr, alpha_ptr, spatial_size);
phi::IndexKernel<T, PreluElementWiseDirectCUDAFunctor<T>>(
dev_ctx, out, func);
} else {
PreluScalarDirectCUDAFunctor<T> prelu_scalar;
prelu_scalar(dev_ctx.stream(), x_ptr, alpha_ptr, o_ptr, numel);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
auto func = PreluScalarDirectCUDAFunctor<T>(alpha_ptr);
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, func);
}
}

Expand Down