Skip to content

Commit

Permalink
seal more operation into template specialization
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesLim-sy committed Sep 6, 2021
1 parent 2402277 commit 2294254
Showing 1 changed file with 76 additions and 51 deletions.
127 changes: 76 additions & 51 deletions paddle/fluid/operators/math/pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,43 @@ struct FastDivModOfPool {

template <typename T, typename PoolProcess, typename Enable = void>
struct PoolingFunctor {
inline HOSTDEVICE void operator()(const T input,
const T* __restrict__ output_data,
const T* __restrict__ output_grad,
const T* __restrict__ input_data;
const T* __restrict__ output_data;
T input;

explicit PoolingFunctor(const T* __restrict__ _input_data,
const T* __restrict__ _output_data)
: input_data(_input_data), output_data(_output_data) {}

inline DEVICE void ParameterUpdate(int tid, int output_stride) {
input = input_data[tid];
output_data += output_stride;
}

inline HOSTDEVICE void operator()(const T* __restrict__ output_grad,
T* __restrict__ gradient, int pool_size,
int index) const {
*gradient += output_grad[index] * static_cast<T>(input, output_data[index]);
}
};

/*
Different from MaxPoolGrad, parameters like input_data and
output_data is unnecessary in AvgPoolGrad, individual template
specialization of AvgPoolGrad can gain more kernel performance.
*/
template <typename T, typename PoolProcess>
struct PoolingFunctor<T, PoolProcess,
typename std::enable_if<std::is_same<
PoolProcess, math::AvgPoolGrad<T>>::value>::type> {
inline HOSTDEVICE void operator()(T input, const T* __restrict__ output_data,
const T* __restrict__ output_grad,
explicit PoolingFunctor(const T* __restrict__ _input_data,
const T* __restrict__ _output_data) {}
inline DEVICE void ParameterUpdate(int tid, int output_stride) {}

inline HOSTDEVICE void operator()(const T* __restrict__ output_grad,
T* __restrict__ gradient, int pool_size,
int index) const {
printf("AVG_Pooling2D\n");
*gradient += static_cast<T>(1.0f / pool_size) * output_grad[index];
*gradient += output_grad[index] / static_cast<T>(pool_size);
}
};

Expand Down Expand Up @@ -151,21 +169,19 @@ __global__ void KernelPool2D(const int nthreads, const T* input_data,
}
}

template <typename PoolProcess, typename T, typename Function>
template <typename PoolProcess, typename T>
__global__ void KernelPool2DGrad(
const int nthreads, const T* __restrict__ input_data,
const T* __restrict__ output_data, const T* __restrict__ output_grad,
const int nthreads, const T* __restrict__ output_grad,
const int output_height, const int output_width, const int input_width,
const int input_height, const int ksize_width, const int ksize_height,
const int stride_width, const int stride_height, FastDivModOfPool divmods,
const int padding_height, const int padding_width, Function func,
bool exclusive, bool adaptive, T* __restrict__ input_grad,
bool channel_last = false) {
const int padding_height, const int padding_width,
PoolingFunctor<T, PoolProcess> functor, bool exclusive, bool adaptive,
T* __restrict__ input_grad, bool channel_last = false) {
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
T gradient = static_cast<T>(0);
T input = input_data[index];
int w_offset, h_offset, offsetC, batch_idx;
int w_offset, h_offset, offsetC;
int phstart, phend, pwstart, pwend;
int output_stride;

Expand All @@ -177,9 +193,9 @@ __global__ void KernelPool2DGrad(
w_offset = input_width_divmod.val[1] + padding_width;
h_offset = input_height_divmod.val[1] + padding_height;
offsetC = channel_divmod.val[1];
batch_idx = channel_divmod.val[0];
output_stride = (batch_idx * divmods.channel.divisor + offsetC) *
output_height * output_width;
output_stride =
(channel_divmod.val[0] * divmods.channel.divisor + offsetC) *
output_height * output_width;
} else { /* NHWC */
auto c_divmod = divmods.channel.Divmod(index);
auto input_width_divmod = divmods.input_w.Divmod(c_divmod.val[0]);
Expand All @@ -188,11 +204,10 @@ __global__ void KernelPool2DGrad(
offsetC = c_divmod.val[1];
w_offset = input_width_divmod.val[1] + padding_width;
h_offset = input_height_divmod.val[1] + padding_height;
batch_idx = input_height_divmod.val[0];
output_stride =
batch_idx * output_height * output_width * divmods.channel.divisor;
output_stride = input_height_divmod.val[0] * output_height *
output_width * divmods.channel.divisor;
}
output_data += output_stride;
functor.ParameterUpdate(index, output_stride);
output_grad += output_stride;

if (adaptive) {
Expand All @@ -217,9 +232,7 @@ __global__ void KernelPool2DGrad(
int output_sub_idx = channel_last
? tmp_idx * divmods.channel.divisor + offsetC
: tmp_idx;

func(input, output_data, output_grad, &gradient, pool_size,
output_sub_idx);
functor(output_grad, &gradient, pool_size, output_sub_idx);
}
}
} else {
Expand All @@ -230,23 +243,34 @@ __global__ void KernelPool2DGrad(
phend = min(divmods.stride_h.Div(h_offset) + 1, output_height);
pwend = min(divmods.stride_w.Div(w_offset) + 1, output_width);

for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx = channel_last
? tmp_idx * divmod.channel.divisor + offsetC
: tmp_idx;

func(input, output_data, output_grad, &gradient, pool_size,
output_sub_idx);
if (exclusive) {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int hstart = ph * stride_height - padding_height;
int wstart = pw * stride_width - padding_width;
int hend = min(hstart + ksize_height, input_height);
int wend = min(wstart + ksize_width, input_width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int pool_size = exclusive ? (hend - hstart) * (wend - wstart)
: ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + offsetC
: tmp_idx;
functor(output_grad, &gradient, pool_size, output_sub_idx);
}
}
} else {
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
int pool_size = ksize_height * ksize_width;
int tmp_idx = ph * output_width + pw;
int output_sub_idx =
channel_last ? tmp_idx * divmods.channel.divisor + offsetC
: tmp_idx;
functor(output_grad, &gradient, pool_size, output_sub_idx);
}
}
}
}
Expand Down Expand Up @@ -483,12 +507,13 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
auto pool_divmod =
FastDivModOfPool(input_channels, input_width, input_height, ksize_width,
ksize_height, stride_width, stride_height);
auto pool_functor = PoolingFunctor<T, PoolProcess>(input_data, output_data);

KernelPool2DGrad<PoolProcess, T><<<grids, blocks, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, output_height,
output_width, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height, pool_divmod, padding_height, padding_width,
PoolingFunctor<T, PoolProcess>(), exclusive, adaptive, input_grad_data);
nthreads, output_grad_data, output_height, output_width, input_width,
input_height, ksize_width, ksize_height, stride_width, stride_height,
pool_divmod, padding_height, padding_width, pool_functor, exclusive,
adaptive, input_grad_data);
}
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
Expand Down Expand Up @@ -538,13 +563,13 @@ class Pool2dGradFunctor<platform::CUDADeviceContext, PoolProcess, T> {
auto pool_divmod =
FastDivModOfPool(input_channels, input_width, input_height, ksize_width,
ksize_height, stride_width, stride_height);
auto pool_functor = PoolingFunctor<T, PoolProcess>(input_data, output_data);

KernelPool2DGrad<PoolProcess, T><<<grids, blocks, 0, context.stream()>>>(
nthreads, input_data, output_data, output_grad_data, output_height,
output_width, input_width, input_height, ksize_width, ksize_height,
stride_width, stride_height, pool_divmod, padding_height, padding_width,
PoolingFunctor<T, PoolProcess>(), exclusive, adaptive, input_grad_data,
channel_last);
nthreads, output_grad_data, output_height, output_width, input_width,
input_height, ksize_width, ksize_height, stride_width, stride_height,
pool_divmod, padding_height, padding_width, pool_functor, exclusive,
adaptive, input_grad_data, channel_last);
}
};

Expand Down

0 comments on commit 2294254

Please sign in to comment.