diff --git a/paddle/phi/kernels/gpu/depthwise_conv.h b/paddle/phi/kernels/gpu/depthwise_conv.h index eae7b77519911..3b5aa7e61e786 100644 --- a/paddle/phi/kernels/gpu/depthwise_conv.h +++ b/paddle/phi/kernels/gpu/depthwise_conv.h @@ -87,43 +87,36 @@ class DepthwiseConvFilterGradFunctor { const DataLayout data_layout = DataLayout::kNCHW); }; +#define FINAL_MASK 0xffffffff +#define HALF_WARP 16 +#define WARP_SIZE 32 + template -static __forceinline__ __device__ T WarpReduceSum(T val, int warp_size) { - typedef cub::WarpReduce WarpReduce; - typename WarpReduce::TempStorage temp_storage; - val = WarpReduce(temp_storage).Sum(val, warp_size); +__forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) { + for (int mask = HALF_WARP; mask > 0; mask >>= 1) + val += platform::CudaShuffleDownSync(lane_mask, val, mask); return val; } template -__forceinline__ __device__ T BlockReduceSum(T val) { - static __shared__ T shared[32]; - int thread_id = threadIdx.x + threadIdx.y * blockDim.x + - threadIdx.z * blockDim.x * blockDim.y; - int warp_size = min(blockDim.x * blockDim.y * blockDim.z, warpSize); - int lane = thread_id % warp_size; - int wid = thread_id / warp_size; - - val = WarpReduceSum(val, warp_size); // Each warp performs partial reduction - - if (lane == 0) shared[wid] = val; // Write reduced value to shared memory - __syncthreads(); // Wait for all partial reductions - - // read from shared memory only if that warp existed - int block_size = blockDim.x * blockDim.y * blockDim.z; - if (thread_id < (block_size - 1) / warp_size + 1) { - val = shared[lane]; - } else { - val = static_cast(0); - } +__forceinline__ __device__ T BlockReduceSum(T val, unsigned mask = FINAL_MASK) { + static __shared__ T shared[WARP_SIZE]; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int lane = tid & 0x1f; + int wid = tid >> 5; + + val = WarpReduceSum(val, mask); - if (wid == 0) { - val = WarpReduceSum(val, warp_size); // Final reduce within first warp - } __syncthreads(); - if (thread_id != 0) { - val = static_cast(0); - } + if (lane == 0) shared[wid] = val; + + __syncthreads(); + + // align block_span to WARP_SIZE + int block_span = (blockDim.x * blockDim.y + WARP_SIZE - 1) >> 5; + val = (lane < block_span) ? shared[lane] : static_cast(0.0f); + val = WarpReduceSum(val, mask); + return val; } @@ -139,55 +132,53 @@ __forceinline__ __device__ T BlockReduceSum(T val) { // A Cuda kernel to compute the depthwise convolution forward pass // in NCHW format. -template +template __device__ __inline__ void KernelDepthwiseConvNCHW( ARG_DEFINE_KernelDepthwiseConv) { + const int fw_size = c_filter != -1 ? c_filter : filter_width; + const int fh_size = c_filter != -1 ? c_filter : filter_height; int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= (output_channels * batch_size * output_height * output_width)) return; - const int w_out = idx % output_width; - const int h_out = (idx / output_width) % output_height; - const int c_out = (idx / output_width / output_height) % output_channels; - const int batch = idx / output_width / output_height / output_channels; + int tmp_1 = idx / output_width; + const int w_out = idx - tmp_1 * output_width; + int tmp_2 = tmp_1 / output_height; + const int h_out = tmp_1 - tmp_2 * output_height; + tmp_1 = tmp_2; + tmp_2 = tmp_1 / output_channels; + const int c_out = tmp_1 - tmp_2 * output_channels; + const int batch = tmp_2; const int c_in = c_out / filter_multiplier; - const T* weight = filter_data + c_out * filter_height * filter_width; T value(0); - const int h_in_start = -padding_height + h_out * stride_height; - const int w_in_start = -padding_width + w_out * stride_width; - const int h_in_end = h_in_start + filter_height * dilate_height; - const int w_in_end = w_in_start + filter_width * dilate_width; int in_offset = ((batch * input_channels + c_in) * input_height) * input_width; - - const int h_end = h_in_end < input_height ? h_in_end : input_height; - const int w_end = w_in_end < input_width ? w_in_end : input_width; - const int h_start = h_in_start > 0 ? h_in_start : 0; - const int w_start = w_in_start > 0 ? w_in_start : 0; - int weight_offset = 0; + int weight_offset = c_out * filter_height * filter_width; + int h_in_start = -padding_height + h_out * stride_height; + int w_in_start = -padding_width + w_out * stride_width; #pragma unroll - for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) { + for (int fh = 0, h_in = h_in_start; fh < fh_size; + fh++, h_in += dilate_height) { #pragma unroll - for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) { - if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) { + for (int fw = 0, w_in = w_in_start; fw < fw_size; + fw++, w_in += dilate_width) { + if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) { int offset = in_offset + h_in * input_width + w_in; T in_data = input_data[offset]; if (fuse_relu_before_conv) { - value += weight[weight_offset] * T(max(0.0f, double(in_data))); + value += filter_data[weight_offset] * + static_cast(max(0.0f, static_cast(in_data))); } else { - value += weight[weight_offset] * in_data; + value += filter_data[weight_offset] * in_data; } } weight_offset++; } } - int index = batch * output_channels * output_height * output_width + - c_out * output_height * output_width + h_out * output_width + - w_out; - output_data[index] = value; + output_data[idx] = value; } // A Cuda kernel to compute the depthwise convolution forward pass @@ -228,7 +219,8 @@ __device__ __inline__ void KernelDepthwiseConvNHWC( T in_data = input_data[offset]; const T* weight = filter_data + weight_offset * output_channels + c_out; if (fuse_relu_before_conv) { - value += weight[0] * T(max(0.0f, double(in_data))); + value += weight[0] * + static_cast(max(0.0f, static_cast(in_data))); } else { value += weight[0] * in_data; } @@ -281,7 +273,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW( int offset = in_offset + h_in * input_width + w_in; if (fuse_relu_before_conv) { value += r_weight[h_f * c_filter + w_f] * - T(max(0.0f, double(input_data[offset]))); + static_cast( + max(0.0f, static_cast(input_data[offset]))); } else { value += r_weight[h_f * c_filter + w_f] * input_data[offset]; } @@ -337,7 +330,8 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNHWC( in_offset + (h_in * input_width + w_in) * input_channels + c_in; if (fuse_relu_before_conv) { value += r_weight[h_f * c_filter + w_f] * - T(max(0.0, double(input_data[offset]))); + static_cast( + max(0.0, static_cast(input_data[offset]))); } else { value += r_weight[h_f * c_filter + w_f] * input_data[offset]; } @@ -367,25 +361,26 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { } if (c_filter == -1) { if (data_layout != DataLayout::kNHWC) { - KernelDepthwiseConvNCHW(input_data, - filter_data, - batch_size, - output_channels, - output_height, - output_width, - input_channels, - input_height, - input_width, - final_filter_multiplier, - filter_height, - filter_width, - h_stride, - w_stride, - padding_height, - padding_width, - dilate_height, - dilate_width, - output_data); + KernelDepthwiseConvNCHW( + input_data, + filter_data, + batch_size, + output_channels, + output_height, + output_width, + input_channels, + input_height, + input_width, + final_filter_multiplier, + filter_height, + filter_width, + h_stride, + w_stride, + padding_height, + padding_width, + dilate_height, + dilate_width, + output_data); } else { KernelDepthwiseConvNHWC(input_data, filter_data, @@ -467,60 +462,62 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) { const int dilate_height, const int dilate_width, \ T *const input_grad_data -template +template __device__ __inline__ void KernelDepthwiseConvInputGradNCHW( ARG_DEFINE_KernelDepthwiseConvInputGrad) { - const int batch = blockIdx.y; - const int c_in = blockIdx.x; - for (int w_in = threadIdx.x; w_in < input_width; w_in += blockDim.x) { - for (int h_in = threadIdx.y; h_in < input_height; h_in += blockDim.y) { - const int c_out_start = c_in * filter_multiplier; - int h_out_start = - h_in - (filter_height - 1) * dilate_height + padding_height; - int h_out_end = h_in + padding_height; - int w_out_start = - w_in - (filter_width - 1) * dilate_width + padding_width; - int w_out_end = w_in + padding_width; + const int fw_size = c_filter != -1 ? c_filter : filter_width; + const int fh_size = c_filter != -1 ? c_filter : filter_height; + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= batch_size * input_channels * input_height * input_width) { + return; + } + if (fuse_relu_before_conv) { + if (input_data[idx] <= static_cast(0.0f)) { + input_grad_data[idx] = 0; + return; + } + } - T value(0); - int index = - ((batch * gridDim.x + c_in) * input_height + h_in) * input_width + - w_in; + int tmp_1 = idx / input_width; + const int w_in = idx - tmp_1 * input_width; + int tmp_2 = tmp_1 / input_height; + const int h_in = tmp_1 - tmp_2 * input_height; + tmp_1 = tmp_2; + tmp_2 = tmp_1 / input_channels; + const int c_in = tmp_1 - tmp_2 * input_channels; + const int batch = tmp_2; - if (fuse_relu_before_conv) { - if (input_data[index] <= T(0)) { - input_grad_data[index] = 0; - continue; - } - } + T value(0); + for (int c_mul = 0; c_mul < filter_multiplier; ++c_mul) { + int c_out = c_in * filter_multiplier + c_mul; + int filter_offset = c_out * filter_height * filter_width; - for (int c_out = c_out_start; c_out < c_out_start + filter_multiplier; - c_out++) { - int filter_offset = (c_out + 1) * filter_height * filter_width; - for (int h_out = h_out_start; h_out <= h_out_end; - h_out += dilate_height) { - for (int w_out = w_out_start; w_out <= w_out_end; - w_out += dilate_width) { - filter_offset--; - int s_h_out = h_out / stride_height; - int s_w_out = w_out / stride_width; - if (h_out % stride_height == 0 && w_out % stride_width == 0 && - s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && - s_w_out < output_width) { - int output_grad_offset = - ((batch * output_channels + c_out) * output_height + - s_h_out) * - output_width + - s_w_out; - value += output_grad_data[output_grad_offset] * - filter_data[filter_offset]; - } +#pragma unroll + for (int fh = 0; fh < fh_size; ++fh) { +#pragma unroll + for (int fw = 0; fw < fw_size; ++fw) { + int h_out = h_in + padding_height - fh * dilate_height; + int w_out = w_in + padding_width - fw * dilate_width; + if ((h_out - h_out / stride_height * stride_height == 0) && + (w_out - w_out / stride_width * stride_width == 0)) { + h_out /= stride_height; + w_out /= stride_width; + + if (h_out >= 0 && h_out < output_height && w_out >= 0 && + w_out < output_width) { + int output_grad_offset = + ((batch * output_channels + c_out) * output_height + h_out) * + output_width + + w_out; + value += output_grad_data[output_grad_offset] * + filter_data[filter_offset]; } } + filter_offset++; } - input_grad_data[index] = value; } } + input_grad_data[idx] = value; } template @@ -733,7 +730,7 @@ __global__ void KernelDepthwiseConvInputGradSp( if (c_filter_multiplier == 0 || c_filter == -1) { if (data_layout != DataLayout::kNHWC) { - KernelDepthwiseConvInputGradNCHW( + KernelDepthwiseConvInputGradNCHW( input_data, output_grad_data, filter_data, @@ -854,44 +851,81 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNCHW( const int dilate_height, const int dilate_width, T* filter_grad_data) { - T s(0); - int gbid = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; - - for (int image_w = threadIdx.x; image_w < output_width; - image_w += blockDim.x) { - for (int bid = 0; bid < num; bid++) { - for (int image_h = threadIdx.y; image_h < output_height; - image_h += blockDim.y) { - int kernel_id = blockIdx.z; - int kernel_h = blockIdx.y * dilate_height - padding_height; - int kernel_w = blockIdx.x * dilate_width - padding_width; - - int image_hk = image_h * stride_height + kernel_h; - int image_wk = image_w * stride_width + kernel_w; - if (image_hk < 0 || image_hk >= input_height) continue; - if (image_wk < 0 || image_wk >= input_width) continue; -#define gaid(N, C, H, W) \ - ((((N)*gridDim.z + (C)) * output_height + (H)) * output_width + (W)) - int input_id = ((bid * (gridDim.z / filter_multiplier) + - kernel_id / filter_multiplier) * - input_height + - image_hk) * - input_width + - image_wk; + T f_grad(0); + const bool loop_batch = output_height * output_width >= WARP_SIZE; + + int kw_id = blockIdx.x; + int kh_id = blockIdx.y; + int oc_id = blockIdx.z; + int ic_id = oc_id / filter_multiplier; + int idx = ((blockIdx.z * gridDim.y) + blockIdx.y) * gridDim.x + blockIdx.x; + + const int ohw = output_height * output_width; + const int onhw = num * ohw; + const int h_offset = kh_id * dilate_height - padding_height; + const int w_offset = kw_id * dilate_width - padding_width; + + if (loop_batch) { + for (int og_w = threadIdx.x; og_w < output_width; og_w += blockDim.x) { + for (int bid = 0; bid < num; ++bid) { + for (int og_h = threadIdx.y; og_h < output_height; og_h += blockDim.y) { + int i_h = og_h * stride_height + h_offset; + int i_w = og_w * stride_width + w_offset; + + if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) { + int input_offset = + ((bid * input_channels + ic_id) * input_height + i_h) * + input_width + + i_w; + int output_grad_offset = + ((bid * output_channels + oc_id) * output_height + og_h) * + output_width + + og_w; + if (fuse_relu_before_conv) { + f_grad += + output_grad_data[output_grad_offset] * + static_cast( + max(0.0f, static_cast(input_data[input_offset]))); + } else { + f_grad += output_grad_data[output_grad_offset] * + input_data[input_offset]; + } + } + } + } + } + } else { + for (int id = threadIdx.x; id < onhw; id += blockDim.x) { + int bid = id / ohw; + int og_hw = id - bid * ohw; + int og_h = og_hw / output_width; + int og_w = og_hw - og_h * output_width; + + int i_h = og_h * stride_height + h_offset; + int i_w = og_w * stride_width + w_offset; + + if (i_w >= 0 && i_w < input_width && i_h >= 0 && i_h < input_height) { + int input_offset = + ((bid * input_channels + ic_id) * input_height + i_h) * + input_width + + i_w; + int output_grad_offset = (bid * output_channels + oc_id) * ohw + og_hw; if (fuse_relu_before_conv) { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - T(max(0.0f, double(input_data[input_id]))); + f_grad += output_grad_data[output_grad_offset] * + static_cast(max( + 0.0f, static_cast(input_data[input_offset]))); } else { - s += output_grad_data[gaid(bid, kernel_id, image_h, image_w)] * - input_data[input_id]; + f_grad += + output_grad_data[output_grad_offset] * input_data[input_offset]; } -#undef gaid } } } - T val = BlockReduceSum(s); - platform::CudaAtomicAdd(&filter_grad_data[gbid], val); + T val = BlockReduceSum(f_grad); + if (threadIdx.x == 0 && threadIdx.y == 0) { + filter_grad_data[idx] = val; + } } template @@ -941,7 +975,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradNHWC( kernel_id / filter_multiplier; if (fuse_relu_before_conv) { s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * - T(max(0.0f, double(input_data[input_id]))); + static_cast( + max(0.0f, static_cast(input_data[input_id]))); } else { s += output_grad_data[gaid(bid, image_h, image_w, kernel_id)] * input_data[input_id]; @@ -1013,7 +1048,8 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC( T s(0); if (fuse_relu_before_conv) { s = output_grad_data[output_id] * - T(max(0.0f, double(input_data[input_id]))); + static_cast( + max(0.0f, static_cast(input_data[input_id]))); } else { s = output_grad_data[output_id] * input_data[input_id]; } @@ -1242,8 +1278,7 @@ class DepthwiseConvFunctor { batch_size); } int filter_multiplier = output_channels / input_channels; - int nums_output = - batch_size * output_channels * output_height * output_width; + int nums_output = output->numel(); #ifdef __HIPCC__ int block_size = 256; #else @@ -1416,6 +1451,13 @@ class DepthwiseConvInputGradFunctor { batch_size); } int filter_multiplier = output_channels / input_channels; + int nums_input = input_grad->numel(); +#ifdef __HIPCC__ + int block_size = 256; +#else + int block_size = 512; +#endif + int grid_size = (nums_input + block_size - 1) / block_size; #define check_case(c_filter_multiplier, c_stride, c_filter) \ if (c_filter_multiplier == 0 || \ @@ -1424,6 +1466,11 @@ class DepthwiseConvInputGradFunctor { (ksize_height == ksize_width && ksize_height == c_filter || \ c_filter == -1)) { \ if (data_layout != DataLayout::kNHWC) { \ + if (c_filter == -1) { \ + threads.x = block_size; \ + grid.x = grid_size; \ + threads.y = threads.z = grid.y = grid.z = 1; \ + } \ KernelDepthwiseConvInputGradSp