Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize performance of depthwise_conv_bwd #46362

Merged
merged 4 commits into from
Sep 30, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 60 additions & 47 deletions paddle/phi/kernels/gpu/depthwise_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,60 +469,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
const int dilate_height, const int dilate_width, \
T *const input_grad_data

template <typename T, bool fuse_relu_before_conv>
template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvInputGradNCHW(
ARG_DEFINE_KernelDepthwiseConvInputGrad) {
const int batch = blockIdx.y;
const int c_in = blockIdx.x;
for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) {
for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) {
const int c_out_start = c_in * filter_multiplier;
int h_out_start =
h_in - (filter_height - 1) * dilate_height + padding_height;
int h_out_end = h_in + padding_height;
int w_out_start =
w_in - (filter_width - 1) * dilate_width + padding_width;
int w_out_end = w_in + padding_width;
const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int fh_size = c_filter != -1 ? c_filter : filter_height;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= batch_size * input_channels * input_height * input_width) {
return;
}
if (fuse_relu_before_conv) {
if (input_data[idx] <= static_cast<T>(0.0f)) {
input_grad_data[idx] = 0;
return;
}
}

T value(0);
int index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
int tmp_1 = idx / input_width;
const int w_in = idx - tmp_1 * input_width;
int tmp_2 = tmp_1 / input_height;
const int h_in = tmp_1 - tmp_2 * input_height;
tmp_1 = tmp_2;
tmp_2 = tmp_1 / input_channels;
const int c_in = tmp_1 - tmp_2 * input_channels;
const int batch = tmp_2;

if (fuse_relu_before_conv) {
if (input_data[index] <= T(0)) {
input_grad_data[index] = 0;
continue;
}
}
T value(0);
for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) {
int c_out = c_in * filter_multiplier + c_mul;
int filter_offset = c_out * filter_height * filter_width;

for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier;
c_out++) {
int filter_offset = (c_out + 1) * filter_height * filter_width;
for (int h_out = h_out_start; h_out <= h_out_end;
h_out += dilate_height) {
for (int w_out = w_out_start; w_out <= w_out_end;
w_out += dilate_width) {
filter_offset--;
int s_h_out = h_out / stride_height;
int s_w_out = w_out / stride_width;
if (h_out % stride_height == 0 && w_out % stride_width == 0 &&
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) {
int output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
output_width +
s_w_out;
value += output_grad_data[output_grad_offset] *
filter_data[filter_offset];
}
#pragma unroll
for (int fh = 0; fh < fh_size; ++fh) {
#pragma unroll
for (int fw = 0; fw < fw_size; ++fw) {
int h_out = h_in + padding_height - fh * dilate_height;
int w_out = w_in + padding_width - fw * dilate_width;
if ((h_out - h_out / stride_height * stride_height == 0) &&
(w_out - w_out / stride_width * stride_width == 0)) {
h_out /= stride_height;
w_out /= stride_width;

if (h_out >= 0 && h_out < output_height && w_out >= 0 &&
w_out < output_width) {
int output_grad_offset =
((batch * output_channels + c_out) * output_height + h_out) *
output_width +
w_out;
value += output_grad_data[output_grad_offset] *
filter_data[filter_offset];
}
}
filter_offset++;
}
input_grad_data[index] = value;
}
}
input_grad_data[idx] = value;
}

template <typename T, bool fuse_relu_before_conv>
Expand Down Expand Up @@ -735,7 +737,7 @@ __global__ void KernelDepthwiseConvInputGradSp(

if (c_filter_multiplier == 0 || c_filter == -1) {
if (data_layout != DataLayout::kNHWC) {
KernelDepthwiseConvInputGradNCHW<T, fuse_relu_before_conv>(
KernelDepthwiseConvInputGradNCHW<T, c_filter, fuse_relu_before_conv>(
input_data,
output_grad_data,
filter_data,
Expand Down Expand Up @@ -1247,8 +1249,7 @@ class DepthwiseConvFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
batch_size);
}
int filter_multiplier = output_channels / input_channels;
int nums_output =
batch_size * output_channels * output_height * output_width;
int nums_output = output->numel();
#ifdef __HIPCC__
int block_size = 256;
#else
Expand Down Expand Up @@ -1421,6 +1422,13 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
batch_size);
}
int filter_multiplier = output_channels / input_channels;
int nums_input = input_grad->numel();
#ifdef __HIPCC__
int block_size = 256;
#else
int block_size = 512;
#endif
int grid_size = (nums_input + block_size - 1) / block_size;

#define check_case(c_filter_multiplier, c_stride, c_filter) \
if (c_filter_multiplier == 0 || \
Expand All @@ -1429,6 +1437,11 @@ class DepthwiseConvInputGradFunctor<phi::GPUContext, T, fuse_relu_before_conv> {
(ksize_height == ksize_width && ksize_height == c_filter || \
c_filter == -1)) { \
if (data_layout != DataLayout::kNHWC) { \
if (c_filter == -1) { \
threads.x = block_size; \
grid.x = grid_size; \
threads.y = threads.z = grid.y = grid.z = 1; \
} \
KernelDepthwiseConvInputGradSp<T, \
c_filter_multiplier, \
c_stride, \
Expand Down