diff --git a/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc b/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc index a8889d09aa757e..d69558bdc670da 100644 --- a/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/deformable_conv_grad_kernel.cc @@ -20,52 +20,52 @@ namespace phi { -template +template inline void ModulatedDeformableCol2imCPUKernel( - const int num_kernels, + const IndexT num_kernels, const T* data_col, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int deformable_group, - const int height_col, - const int width_col, + const IndexT channels, + const IndexT height, + const IndexT width, + const IndexT kernel_h, + const IndexT kernel_w, + const IndexT pad_h, + const IndexT pad_w, + const IndexT stride_h, + const IndexT stride_w, + const IndexT dilation_h, + const IndexT dilation_w, + const IndexT channel_per_deformable_group, + const IndexT batch_size, + const IndexT deformable_group, + const IndexT height_col, + const IndexT width_col, T* grad_im) { - for (int thread = 0; thread < num_kernels; thread++) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = + for (IndexT thread = 0; thread < num_kernels; thread++) { + const IndexT j = (thread / width_col / height_col / batch_size) % kernel_w; + const IndexT i = (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = + const IndexT c = thread / width_col / height_col / batch_size / kernel_w / kernel_h; - const int deformable_group_index = c / channel_per_deformable_group; + const IndexT deformable_group_index = c / channel_per_deformable_group; - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; + IndexT w_out = thread % width_col; + IndexT h_out = (thread / width_col) % height_col; + IndexT b = (thread / width_col / height_col) % batch_size; + IndexT w_in = w_out * stride_w - pad_w; + IndexT h_in = h_out * stride_h - pad_h; const T* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = + const IndexT data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = + const IndexT data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = + const IndexT data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const T offset_h = data_offset_ptr[data_offset_h_ptr]; const T offset_w = data_offset_ptr[data_offset_w_ptr]; @@ -80,14 +80,14 @@ inline void ModulatedDeformableCol2imCPUKernel( const T mask = data_mask_ptr[data_mask_hw_ptr]; cur_top_grad *= mask; } - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { + const IndexT cur_h = static_cast(cur_inv_h_data); + const IndexT cur_w = static_cast(cur_inv_w_data); + for (IndexT dy = -2; dy <= 2; dy++) { + for (IndexT dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = + IndexT cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; T weight = DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, @@ -104,7 +104,7 @@ inline void ModulatedDeformableCol2imCPUKernel( } } -template +template void ModulatedDeformableCol2im(const Context& dev_ctx, const T* data_col, const T* data_offset, @@ -117,70 +117,69 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, const std::vector& dilation, const int deformable_group, T* grad_im) { - int channel_per_deformable_group = - static_cast(im_shape[0] / deformable_group); - int num_kernels = static_cast(col_shape[0] * col_shape[1] * - col_shape[2] * col_shape[3]); + int64_t channel_per_deformable_group = im_shape[0] / deformable_group; + int64_t num_kernels = + col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - ModulatedDeformableCol2imCPUKernel(num_kernels, - data_col, - data_offset, - data_mask, - im_shape[0], - im_shape[1], - im_shape[2], - kernel_shape[2], - kernel_shape[3], - pad[0], - pad[1], - stride[0], - stride[1], - dilation[0], - dilation[1], - channel_per_deformable_group, - col_shape[1], - deformable_group, - col_shape[2], - col_shape[3], - grad_im); + ModulatedDeformableCol2imCPUKernel(num_kernels, + data_col, + data_offset, + data_mask, + im_shape[0], + im_shape[1], + im_shape[2], + kernel_shape[2], + kernel_shape[3], + pad[0], + pad[1], + stride[0], + stride[1], + dilation[0], + dilation[1], + channel_per_deformable_group, + col_shape[1], + deformable_group, + col_shape[2], + col_shape[3], + grad_im); } -template +template void ModulatedDeformableCol2imCoordCPUKernel( - const int num_kernels, + const IndexT num_kernels, const T* data_col, const T* data_im, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int offset_channels, - const int deformable_group, - const int height_col, - const int width_col, + const IndexT channels, + const IndexT height, + const IndexT width, + const IndexT kernel_h, + const IndexT kernel_w, + const IndexT pad_h, + const IndexT pad_w, + const IndexT stride_h, + const IndexT stride_w, + const IndexT dilation_h, + const IndexT dilation_w, + const IndexT channel_per_deformable_group, + const IndexT batch_size, + const IndexT offset_channels, + const IndexT deformable_group, + const IndexT height_col, + const IndexT width_col, T* grad_offset, T* grad_mask) { - for (int i = 0; i < num_kernels; i++) { + for (IndexT i = 0; i < num_kernels; i++) { T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; + const IndexT w = i % width_col; + const IndexT h = (i / width_col) % height_col; + const IndexT c = (i / width_col / height_col) % offset_channels; + const IndexT b = (i / width_col / height_col) / offset_channels; - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; + const IndexT deformable_group_index = c / (2 * kernel_h * kernel_w); + const IndexT col_step = kernel_h * kernel_w; + IndexT cnt = 0; const T* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; @@ -197,24 +196,25 @@ void ModulatedDeformableCol2imCoordCPUKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + const IndexT offset_c = + c - deformable_group_index * 2 * kernel_h * kernel_w; - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + for (IndexT col_c = offset_c / 2; col_c < channel_per_deformable_group; col_c += col_step) { - const int col_pos = + const IndexT col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; + const IndexT bp_dir = offset_c % 2; - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = + IndexT j = (col_pos / width_col / height_col / batch_size) % kernel_w; + IndexT i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = + IndexT w_out = col_pos % width_col; + IndexT h_out = (col_pos / width_col) % height_col; + IndexT w_in = w_out * stride_w - pad_w; + IndexT h_in = h_out * stride_h - pad_h; + const IndexT data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = + const IndexT data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const T offset_h = data_offset_ptr[data_offset_h_ptr]; @@ -241,7 +241,7 @@ void ModulatedDeformableCol2imCoordCPUKernel( width, bp_dir); if (data_mask_ptr) { - const int data_mask_hw_ptr = + const IndexT data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const T mask = data_mask_ptr[data_mask_hw_ptr]; val += weight * data_col_ptr[col_pos] * mask; @@ -262,7 +262,7 @@ void ModulatedDeformableCol2imCoordCPUKernel( } } -template +template void ModulatedDeformableCol2imCoord(const Context& dev_ctx, const T* data_col, const T* data_im, @@ -277,13 +277,11 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, const int deformable_groups, T* grad_offset, T* grad_mask) { - int num_kernels = - static_cast(2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups); - int channel_per_deformable_group = - static_cast(col_shape[0] / deformable_groups); + int64_t num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int64_t channel_per_deformable_group = col_shape[0] / deformable_groups; - ModulatedDeformableCol2imCoordCPUKernel( + ModulatedDeformableCol2imCoordCPUKernel( num_kernels, data_col, data_im, @@ -310,15 +308,15 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, grad_mask); } -template +template void FilterGradAddup(const Context& dev_ctx, - const int nthreads, - const int n, - const int height, - const int width, + const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad) { - for (int i = 0; i < nthreads; i++) { + for (IndexT i = 0; i < nthreads; i++) { filter_grad[i] = filter_grad[i] + dweight_3d[i]; } } diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.cc b/paddle/phi/kernels/funcs/deformable_conv_functor.cc index e028b51e3ce7b9..879c3b3a1ddc9d 100644 --- a/paddle/phi/kernels/funcs/deformable_conv_functor.cc +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.cc @@ -18,40 +18,40 @@ namespace phi::funcs { -template +template inline void ModulatedDeformableIm2colCPUKernel( - const int num_kernels, + const IndexT num_kernels, const T* data_im, const T* data_offset, const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const IndexT height, + const IndexT width, + const IndexT kernel_h, + const IndexT kernel_w, + const IndexT pad_h, + const IndexT pad_w, + const IndexT stride_h, + const IndexT stride_w, + const IndexT dilation_h, + const IndexT dilation_w, + const IndexT channel_per_deformable_group, + const IndexT batch_size, + const IndexT num_channels, + const IndexT deformable_group, + const IndexT height_col, + const IndexT width_col, T* data_col) { - for (int i = 0; i < num_kernels; i++) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + for (IndexT i = 0; i < num_kernels; i++) { + const IndexT w_col = i % width_col; + const IndexT h_col = (i / width_col) % height_col; + const IndexT b_col = (i / width_col) / height_col % batch_size; + const IndexT c_im = (i / width_col / height_col) / batch_size; + const IndexT c_col = c_im * kernel_h * kernel_w; - const int deformable_group_index = c_im / channel_per_deformable_group; + const IndexT deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const IndexT h_in = h_col * stride_h - pad_h; + const IndexT w_in = w_col * stride_w - pad_w; T* data_col_ptr = data_col + @@ -67,11 +67,11 @@ inline void ModulatedDeformableIm2colCPUKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = + for (IndexT i = 0; i < kernel_h; ++i) { + for (IndexT j = 0; j < kernel_w; ++j) { + const IndexT data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = + const IndexT data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; @@ -97,7 +97,7 @@ inline void ModulatedDeformableIm2colCPUKernel( } } -template +template void ModulatedDeformableIm2col(const Context& dev_ctx UNUSED, const T* data_im, const T* data_offset, @@ -110,36 +110,35 @@ void ModulatedDeformableIm2col(const Context& dev_ctx UNUSED, const std::vector& dilations, const int deformable_groups, T* data_col) { - int channel_per_deformable_group = - static_cast(im_shape[0] / deformable_groups); - int num_kernels = static_cast(im_shape[0] * col_shape[1] * col_shape[2] * - col_shape[3]); + int64_t channel_per_deformable_group = im_shape[0] / deformable_groups; + int64_t num_kernels = + im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; // get outputs of im2col with offset by bilinear interpolation - ModulatedDeformableIm2colCPUKernel(num_kernels, - data_im, - data_offset, - data_mask, - im_shape[1], - im_shape[2], - filter_shape[2], - filter_shape[3], - paddings[0], - paddings[1], - strides[0], - strides[1], - dilations[0], - dilations[1], - channel_per_deformable_group, - col_shape[1], - im_shape[0], - deformable_groups, - col_shape[2], - col_shape[3], - data_col); + ModulatedDeformableIm2colCPUKernel(num_kernels, + data_im, + data_offset, + data_mask, + im_shape[1], + im_shape[2], + filter_shape[2], + filter_shape[3], + paddings[0], + paddings[1], + strides[0], + strides[1], + dilations[0], + dilations[1], + channel_per_deformable_group, + col_shape[1], + im_shape[0], + deformable_groups, + col_shape[2], + col_shape[3], + data_col); } -template void ModulatedDeformableIm2col( +template void ModulatedDeformableIm2col( const phi::CPUContext& dev_ctx, const float* data_im, const float* data_offset, @@ -153,7 +152,35 @@ template void ModulatedDeformableIm2col( const int deformable_groups, float* data_col); -template void ModulatedDeformableIm2col( +template void ModulatedDeformableIm2col( + const phi::CPUContext& dev_ctx, + const float* data_im, + const float* data_offset, + const float* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + float* data_col); + +template void ModulatedDeformableIm2col( + const phi::CPUContext& dev_ctx, + const double* data_im, + const double* data_offset, + const double* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + double* data_col); + +template void ModulatedDeformableIm2col( const phi::CPUContext& dev_ctx, const double* data_im, const double* data_offset, diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.cu b/paddle/phi/kernels/funcs/deformable_conv_functor.cu index 48105d1f517e9b..47ffc5fb31d055 100644 --- a/paddle/phi/kernels/funcs/deformable_conv_functor.cu +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.cu @@ -21,47 +21,47 @@ namespace funcs { static constexpr int kNumCUDAThreads = 512; static constexpr int kNumMaximumNumBlocks = 4096; -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaximumNumBlocks); +static inline int64_t NumBlocks(const int64_t N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); } -template +template __global__ void ModulatedDeformableIm2colGpuKernel( - const int nthreads, + const IndexT nthreads, const T* data_im, const T* data_offset, const T* data_mask, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int num_channels, - const int deformable_group, - const int height_col, - const int width_col, + const IndexT height, + const IndexT width, + const IndexT kernel_h, + const IndexT kernel_w, + const IndexT pad_h, + const IndexT pad_w, + const IndexT stride_h, + const IndexT stride_w, + const IndexT dilation_h, + const IndexT dilation_w, + const IndexT channel_per_deformable_group, + const IndexT batch_size, + const IndexT num_channels, + const IndexT deformable_group, + const IndexT height_col, + const IndexT width_col, T* data_col) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { - const int w_col = i % width_col; - const int h_col = (i / width_col) % height_col; - const int b_col = (i / width_col) / height_col % batch_size; - const int c_im = (i / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; + IndexT index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IndexT offset = blockDim.x * static_cast(gridDim.x); + for (IndexT i = index; i < nthreads; i += offset) { + const IndexT w_col = i % width_col; + const IndexT h_col = (i / width_col) % height_col; + const IndexT b_col = (i / width_col) / height_col % batch_size; + const IndexT c_im = (i / width_col / height_col) / batch_size; + const IndexT c_col = c_im * kernel_h * kernel_w; - const int deformable_group_index = c_im / channel_per_deformable_group; + const IndexT deformable_group_index = c_im / channel_per_deformable_group; - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; + const IndexT h_in = h_col * stride_h - pad_h; + const IndexT w_in = w_col * stride_w - pad_w; T* data_col_ptr = data_col + @@ -77,11 +77,11 @@ __global__ void ModulatedDeformableIm2colGpuKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - for (int i = 0; i < kernel_h; ++i) { - for (int j = 0; j < kernel_w; ++j) { - const int data_offset_h_ptr = + for (IndexT i = 0; i < kernel_h; ++i) { + for (IndexT j = 0; j < kernel_w; ++j) { + const IndexT data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = + const IndexT data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; @@ -96,7 +96,7 @@ __global__ void ModulatedDeformableIm2colGpuKernel( } *data_col_ptr = val; if (data_mask_ptr) { - const int data_mask_hw_ptr = + const IndexT data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; const T mask = data_mask_ptr[data_mask_hw_ptr]; *data_col_ptr *= mask; @@ -107,7 +107,7 @@ __global__ void ModulatedDeformableIm2colGpuKernel( } } -template +template void ModulatedDeformableIm2col(const Context& dev_ctx, const T* data_im, const T* data_offset, @@ -120,13 +120,13 @@ void ModulatedDeformableIm2col(const Context& dev_ctx, const std::vector& dilations, const int deformable_groups, T* data_col) { - int channel_per_deformable_group = im_shape[0] / deformable_groups; - int num_kernels = im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int64_t channel_per_deformable_group = im_shape[0] / deformable_groups; + int64_t num_kernels = + im_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; - - ModulatedDeformableIm2colGpuKernel + int64_t blocks = NumBlocks(num_kernels); + int64_t threads = kNumCUDAThreads; + ModulatedDeformableIm2colGpuKernel <<>>(num_kernels, data_im, data_offset, @@ -150,7 +150,21 @@ void ModulatedDeformableIm2col(const Context& dev_ctx, data_col); } -template void ModulatedDeformableIm2col( +template void ModulatedDeformableIm2col( + const phi::GPUContext& dev_ctx, + const float* data_im, + const float* data_offset, + const float* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + float* data_col); + +template void ModulatedDeformableIm2col( const phi::GPUContext& dev_ctx, const float* data_im, const float* data_offset, @@ -164,7 +178,21 @@ template void ModulatedDeformableIm2col( const int deformable_groups, float* data_col); -template void ModulatedDeformableIm2col( +template void ModulatedDeformableIm2col( + const phi::GPUContext& dev_ctx, + const double* data_im, + const double* data_offset, + const double* data_mask, + const std::vector& im_shape, + const std::vector& col_shape, + const std::vector& filter_shape, + const std::vector& paddings, + const std::vector& strides, + const std::vector& dilations, + const int deformable_groups, + double* data_col); + +template void ModulatedDeformableIm2col( const phi::GPUContext& dev_ctx, const double* data_im, const double* data_offset, diff --git a/paddle/phi/kernels/funcs/deformable_conv_functor.h b/paddle/phi/kernels/funcs/deformable_conv_functor.h index eecda72927510d..5f1296d12b425d 100644 --- a/paddle/phi/kernels/funcs/deformable_conv_functor.h +++ b/paddle/phi/kernels/funcs/deformable_conv_functor.h @@ -56,7 +56,7 @@ HOSTDEVICE T DmcnIm2colBilinear(const T* bottom_data, return w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; } -template +template void ModulatedDeformableIm2col(const Context& dev_ctx, const T* data_im, const T* data_offset, diff --git a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu index 3f45fefade40ea..6596da3a4ed142 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_bias_dropout_residual_layer_norm_kernel.cu @@ -62,7 +62,6 @@ void FusedBiasDropoutResidualLnKernel( dropout_mask_out, dropout_mask_out->numel() * sizeof(uint8_t)); auto* y_data = dev_ctx.template Alloc(y, y->numel() * sizeof(T)); if (y->numel() == 0) return; - const auto input_x_dims = x.dims(); int bsz_seq = 1; for (int i = 0; i < input_x_dims.size() - 1; i++) { diff --git a/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu b/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu index 55c8a9f96fd818..f38747bfbac784 100644 --- a/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/deformable_conv_grad_kernel.cu @@ -29,54 +29,54 @@ static inline int NumBlocks(const int N) { kNumMaximumNumBlocks); } -template +template __global__ void ModulatedDeformableCol2imGpuKernel( - const int nthreads, + const IndexT nthreads, const T* data_col, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int deformable_group, - const int height_col, - const int width_col, + const IndexT channels, + const IndexT height, + const IndexT width, + const IndexT kernel_h, + const IndexT kernel_w, + const IndexT pad_h, + const IndexT pad_w, + const IndexT stride_h, + const IndexT stride_w, + const IndexT dilation_h, + const IndexT dilation_w, + const IndexT channel_per_deformable_group, + const IndexT batch_size, + const IndexT deformable_group, + const IndexT height_col, + const IndexT width_col, T* grad_im) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t thread = index; thread < nthreads; thread += offset) { - const int j = (thread / width_col / height_col / batch_size) % kernel_w; - const int i = + IndexT index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IndexT offset = blockDim.x * static_cast(gridDim.x); + for (IndexT thread = index; thread < nthreads; thread += offset) { + const IndexT j = (thread / width_col / height_col / batch_size) % kernel_w; + const IndexT i = (thread / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = + const IndexT c = thread / width_col / height_col / batch_size / kernel_w / kernel_h; - const int deformable_group_index = c / channel_per_deformable_group; + const IndexT deformable_group_index = c / channel_per_deformable_group; - int w_out = thread % width_col; - int h_out = (thread / width_col) % height_col; - int b = (thread / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; + IndexT w_out = thread % width_col; + IndexT h_out = (thread / width_col) % height_col; + IndexT b = (thread / width_col / height_col) % batch_size; + IndexT w_in = w_out * stride_w - pad_w; + IndexT h_in = h_out * stride_h - pad_h; const T* data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = + const IndexT data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = + const IndexT data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = + const IndexT data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; const T offset_h = data_offset_ptr[data_offset_h_ptr]; const T offset_w = data_offset_ptr[data_offset_w_ptr]; @@ -91,14 +91,14 @@ __global__ void ModulatedDeformableCol2imGpuKernel( const T mask = data_mask_ptr[data_mask_hw_ptr]; cur_top_grad *= mask; } - const int cur_h = static_cast(cur_inv_h_data); - const int cur_w = static_cast(cur_inv_w_data); - for (int dy = -2; dy <= 2; dy++) { - for (int dx = -2; dx <= 2; dx++) { + const IndexT cur_h = static_cast(cur_inv_h_data); + const IndexT cur_w = static_cast(cur_inv_w_data); + for (IndexT dy = -2; dy <= 2; dy++) { + for (IndexT dx = -2; dx <= 2; dx++) { if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && abs(cur_inv_w_data - (cur_w + dx)) < 1) { - int cur_bottom_grad_pos = + IndexT cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; T weight = DmcnGetGradientWeight(cur_inv_h_data, cur_inv_w_data, @@ -115,7 +115,7 @@ __global__ void ModulatedDeformableCol2imGpuKernel( } } -template +template void ModulatedDeformableCol2im(const Context& dev_ctx, const T* data_col, const T* data_offset, @@ -128,12 +128,12 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, const std::vector& dilation, const int deformable_group, T* grad_im) { - int channel_per_deformable_group = im_shape[0] / deformable_group; - int num_kernels = col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; - - ModulatedDeformableCol2imGpuKernel + int64_t channel_per_deformable_group = im_shape[0] / deformable_group; + int64_t num_kernels = + col_shape[0] * col_shape[1] * col_shape[2] * col_shape[3]; + int64_t blocks = NumBlocks(num_kernels); + int64_t threads = kNumCUDAThreads; + ModulatedDeformableCol2imGpuKernel <<>>(num_kernels, data_col, data_offset, @@ -157,44 +157,44 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, grad_im); } -template +template __global__ void ModulatedDeformableCol2imCoordGpuKernel( - const int nthreads, + const IndexT nthreads, const T* data_col, const T* data_im, const T* data_offset, const T* data_mask, - const int channels, - const int height, - const int width, - const int kernel_h, - const int kernel_w, - const int pad_h, - const int pad_w, - const int stride_h, - const int stride_w, - const int dilation_h, - const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, - const int offset_channels, - const int deformable_group, - const int height_col, - const int width_col, + const IndexT channels, + const IndexT height, + const IndexT width, + const IndexT kernel_h, + const IndexT kernel_w, + const IndexT pad_h, + const IndexT pad_w, + const IndexT stride_h, + const IndexT stride_w, + const IndexT dilation_h, + const IndexT dilation_w, + const IndexT channel_per_deformable_group, + const IndexT batch_size, + const IndexT offset_channels, + const IndexT deformable_group, + const IndexT height_col, + const IndexT width_col, T* grad_offset, T* grad_mask) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { + IndexT index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IndexT offset = blockDim.x * static_cast(gridDim.x); + for (IndexT i = index; i < nthreads; i += offset) { T val = 0, mval = 0; - const int w = i % width_col; - const int h = (i / width_col) % height_col; - const int c = (i / width_col / height_col) % offset_channels; - const int b = (i / width_col / height_col) / offset_channels; + const IndexT w = i % width_col; + const IndexT h = (i / width_col) % height_col; + const IndexT c = (i / width_col / height_col) % offset_channels; + const IndexT b = (i / width_col / height_col) / offset_channels; - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; + const IndexT deformable_group_index = c / (2 * kernel_h * kernel_w); + const IndexT col_step = kernel_h * kernel_w; + IndexT cnt = 0; const T* data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; @@ -211,24 +211,25 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( kernel_h * kernel_w * height_col * width_col : nullptr; - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + const IndexT offset_c = + c - deformable_group_index * 2 * kernel_h * kernel_w; - for (int col_c = offset_c / 2; col_c < channel_per_deformable_group; + for (IndexT col_c = offset_c / 2; col_c < channel_per_deformable_group; col_c += col_step) { - const int col_pos = + const IndexT col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; + const IndexT bp_dir = offset_c % 2; - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = + IndexT j = (col_pos / width_col / height_col / batch_size) % kernel_w; + IndexT i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = + IndexT w_out = col_pos % width_col; + IndexT h_out = (col_pos / width_col) % height_col; + IndexT w_in = w_out * stride_w - pad_w; + IndexT h_in = h_out * stride_h - pad_h; + const IndexT data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = + const IndexT data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); const T offset_h = data_offset_ptr[data_offset_h_ptr]; @@ -255,7 +256,7 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( width, bp_dir); if (data_mask_ptr) { - const int data_mask_hw_ptr = + const IndexT data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); const T mask = data_mask_ptr[data_mask_hw_ptr]; val += weight * data_col_ptr[col_pos] * mask; @@ -276,7 +277,7 @@ __global__ void ModulatedDeformableCol2imCoordGpuKernel( } } -template +template void ModulatedDeformableCol2imCoord(const Context& dev_ctx, const T* data_col, const T* data_im, @@ -291,13 +292,13 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, const int deformable_groups, T* grad_offset, T* grad_mask) { - int num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * - col_shape[2] * col_shape[3] * deformable_groups; - int channel_per_deformable_group = col_shape[0] / deformable_groups; - int blocks = NumBlocks(num_kernels); - int threads = kNumCUDAThreads; + int64_t num_kernels = 2 * kernel_shape[2] * kernel_shape[3] * col_shape[1] * + col_shape[2] * col_shape[3] * deformable_groups; + int64_t channel_per_deformable_group = col_shape[0] / deformable_groups; + int64_t blocks = NumBlocks(num_kernels); + int64_t threads = kNumCUDAThreads; - ModulatedDeformableCol2imCoordGpuKernel + ModulatedDeformableCol2imCoordGpuKernel <<>>( num_kernels, data_col, @@ -325,30 +326,34 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, grad_mask); } -template -__global__ void FilterGradAddupGpuKernel(const int nthreads, - const int n, - const int height, - const int width, +template +__global__ void FilterGradAddupGpuKernel(const IndexT nthreads, + const IndexT n, + const IndexT height, + const IndexT width, const T* dweight_3d, T* filter_grad) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int offset = blockDim.x * gridDim.x; - for (size_t i = index; i < nthreads; i += offset) { + IndexT index = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + IndexT offset = blockDim.x * static_cast(gridDim.x); + for (IndexT i = index; i < nthreads; i += offset) { filter_grad[i] = filter_grad[i] + dweight_3d[i]; } } -template +template void FilterGradAddup(const Context& dev_ctx, - const int nthreads, - const int n, - const int height, - const int width, + const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad) { - FilterGradAddupGpuKernel - <<>>( + const int64_t max_grid_x = dev_ctx.GetCUDAMaxGridDimSize()[0]; + const int64_t grid_size = std::min( + (nthreads + kNumCUDAThreads - 1) / kNumCUDAThreads, max_grid_x); + + FilterGradAddupGpuKernel + <<>>( nthreads, n, height, width, dweight_3d, filter_grad); } diff --git a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h index fe2107e52af7f6..4459a931dafd78 100644 --- a/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_grad_kernel_impl.h @@ -119,7 +119,7 @@ HOSTDEVICE T DmcnGetCoordinateWeight(T argmax_h, return weight; } -template +template void ModulatedDeformableCol2imCoord(const Context& dev_ctx, const T* data_col, const T* data_im, @@ -135,7 +135,7 @@ void ModulatedDeformableCol2imCoord(const Context& dev_ctx, T* grad_offset, T* grad_mask); -template +template void ModulatedDeformableCol2im(const Context& dev_ctx, const T* data_col, const T* data_offset, @@ -149,12 +149,12 @@ void ModulatedDeformableCol2im(const Context& dev_ctx, const int deformable_group, T* grad_im); -template +template void FilterGradAddup(const Context& dev_ctx, - const int nthreads, - const int n, - const int height, - const int width, + const int64_t nthreads, + const int64_t n, + const int64_t height, + const int64_t width, const T* dweight_3d, T* filter_grad); @@ -241,9 +241,9 @@ void DeformableConvGradKernel(const Context& dev_ctx, phi::funcs::SetConstant set_zero; auto blas = phi::funcs::GetBlas(dev_ctx); - int input_dim = x.numel() / x.dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; + int64_t input_dim = x.numel() / x.dims()[0]; + int64_t input_offset_dim = offset.numel() / offset.dims()[0]; + int64_t input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; if (filter_grad) { Full(dev_ctx, @@ -267,6 +267,13 @@ void DeformableConvGradKernel(const Context& dev_ctx, } } + bool using_int32_index = + (x.numel() <= std::numeric_limits::max()) && + (offset.numel() <= std::numeric_limits::max()) && + (filter.numel() <= std::numeric_limits::max()) && + (mask ? mask->numel() <= std::numeric_limits::max() : true) && + (out_grad.numel() <= std::numeric_limits::max()); + for (int i = 0; i < batch_size / im2col_step; ++i) { DenseTensor out_grad_3d = out_grad_4d.Slice(i, i + 1).Resize( common::slice_ddim(out_grad_4d.dims(), 1, out_grad_4d.dims().size())); @@ -299,9 +306,78 @@ void DeformableConvGradKernel(const Context& dev_ctx, mask_grad ? mask_grad->data() + i * im2col_step * input_mask_dim : nullptr; // get grad of offset and mask - ModulatedDeformableCol2imCoord( + if (using_int32_index) { + ModulatedDeformableCol2imCoord( + dev_ctx, + col_buffer_ptr, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim, + mask_grad_data_ptr); + } else { + ModulatedDeformableCol2imCoord( + dev_ctx, + col_buffer_ptr, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + offset_grad_ptr + i * im2col_step * input_offset_dim, + mask_grad_data_ptr); + } + } + if (dx) { + T* dx_ptr = dx->data(); + // get grad of input + if (using_int32_index) { + ModulatedDeformableCol2im( + dev_ctx, + col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + dx_ptr + i * im2col_step * input_dim); + } else { + ModulatedDeformableCol2im( + dev_ctx, + col_buffer_ptr, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + dx_ptr + i * im2col_step * input_dim); + } + + dx->Resize(x.dims()); + } + if (using_int32_index) { + funcs::ModulatedDeformableIm2col( dev_ctx, - col_buffer_ptr, input_ptr + i * im2col_step * input_dim, offset_ptr + i * im2col_step * input_offset_dim, mask_data_ptr, @@ -312,41 +388,23 @@ void DeformableConvGradKernel(const Context& dev_ctx, strides, dilations, deformable_groups, - offset_grad_ptr + i * im2col_step * input_offset_dim, - mask_grad_data_ptr); - } - if (dx) { - T* dx_ptr = dx->data(); - // get grad of input - ModulatedDeformableCol2im(dev_ctx, - col_buffer_ptr, - offset_ptr + i * im2col_step * input_offset_dim, - mask_data_ptr, - input_shape_vec, - col_buffer_shape_vec, - filter_shape_vec, - paddings, - strides, - dilations, - deformable_groups, - dx_ptr + i * im2col_step * input_dim); - dx->Resize(x.dims()); + col_buffer_ptr); + } else { + funcs::ModulatedDeformableIm2col( + dev_ctx, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + mask_data_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + col_buffer_ptr); } - funcs::ModulatedDeformableIm2col( - dev_ctx, - input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - mask_data_ptr, - input_shape_vec, - col_buffer_shape_vec, - filter_shape_vec, - paddings, - strides, - dilations, - deformable_groups, - col_buffer_ptr); - col_buffer_3d.Resize(col_buffer_3d_shape); if (filter_grad) { @@ -372,13 +430,23 @@ void DeformableConvGradKernel(const Context& dev_ctx, } // update grad of weights - FilterGradAddup(dev_ctx, - dweight_3d.numel(), - groups, - K, - M, - dweight_3d.data(), - filter_grad->data()); + if (using_int32_index) { + FilterGradAddup(dev_ctx, + dweight_3d.numel(), + groups, + K, + M, + dweight_3d.data(), + filter_grad->data()); + } else { + FilterGradAddup(dev_ctx, + dweight_3d.numel(), + groups, + K, + M, + dweight_3d.data(), + filter_grad->data()); + } } } if (filter_grad) { diff --git a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h index 8fcf4bf0f38700..ad9e9197ddd8a8 100644 --- a/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h +++ b/paddle/phi/kernels/impl/deformable_conv_kernel_impl.h @@ -43,9 +43,10 @@ void DeformableConvKernel(const Context& dev_ctx, dev_ctx, phi::IntArray(common::vectorize(out->dims())), 0, out); return; } - const int batch_size = static_cast(x.dims()[0]); - int temp_step = std::min(64, batch_size); + const int64_t batch_size = static_cast(x.dims()[0]); + + int64_t temp_step = std::min(64, batch_size); if (batch_size % temp_step == 0) { im2col_step = temp_step; } @@ -86,9 +87,9 @@ void DeformableConvKernel(const Context& dev_ctx, DDim input_shape = common::slice_ddim(x.dims(), 1, x.dims().size()); std::vector input_shape_vec = common::vectorize(input_shape); - int input_dim = x.numel() / x.dims()[0]; - int input_offset_dim = offset.numel() / offset.dims()[0]; - int input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; + int64_t input_dim = x.numel() / x.dims()[0]; + int64_t input_offset_dim = offset.numel() / offset.dims()[0]; + int64_t input_mask_dim = mask ? mask->numel() / mask->dims()[0] : 0; const T* input_ptr = x.data(); const T* offset_ptr = offset.data(); @@ -97,22 +98,46 @@ void DeformableConvKernel(const Context& dev_ctx, auto blas = phi::funcs::GetBlas(dev_ctx); - for (int i = 0; i < batch_size / im2col_step; ++i) { + bool using_int32_index = + (x.numel() <= std::numeric_limits::max()) && + (offset.numel() <= std::numeric_limits::max()) && + (filter.numel() <= std::numeric_limits::max()) && + (mask ? mask->numel() <= std::numeric_limits::max() : true) && + (out->numel() <= std::numeric_limits::max()); + + for (int64_t i = 0; i < batch_size / im2col_step; ++i) { const T* temp_mask_ptr = mask_ptr ? mask_ptr + i * im2col_step * input_mask_dim : nullptr; - funcs::ModulatedDeformableIm2col( - dev_ctx, - input_ptr + i * im2col_step * input_dim, - offset_ptr + i * im2col_step * input_offset_dim, - temp_mask_ptr, - input_shape_vec, - col_buffer_shape_vec, - filter_shape_vec, - paddings, - strides, - dilations, - deformable_groups, - col_buffer_ptr); + if (using_int32_index) { + funcs::ModulatedDeformableIm2col( + dev_ctx, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + temp_mask_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + col_buffer_ptr); + } else { + funcs::ModulatedDeformableIm2col( + dev_ctx, + input_ptr + i * im2col_step * input_dim, + offset_ptr + i * im2col_step * input_offset_dim, + temp_mask_ptr, + input_shape_vec, + col_buffer_shape_vec, + filter_shape_vec, + paddings, + strides, + dilations, + deformable_groups, + col_buffer_ptr); + } + DenseTensor output_3d = output_4d.Slice(i, i + 1).Resize(common::slice_ddim( output_4d.dims(), 1,