Skip to content

Commit

Permalink
use CUDA_1D_KERNEL_LOOP
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangshilong committed Apr 23, 2021
1 parent 3372178 commit a06add4
Showing 1 changed file with 11 additions and 14 deletions.
25 changes: 11 additions & 14 deletions mmcv/ops/csrc/ms_deform_attn_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,12 @@
*https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
**************************************************************************************************
*/
#ifndef OPENMMLAB_PRS_ADD_MSD_ATTEN_MMCV_OPS_CSRC_MS_DEFORM_ATTN_CUDA_KERNEL_CUH_
#define OPENMMLAB_PRS_ADD_MSD_ATTEN_MMCV_OPS_CSRC_MS_DEFORM_ATTN_CUDA_KERNEL_CUH_
#ifndef DEFORM_ATTN_CUDA_KERNEL
#define DEFORM_ATTN_CUDA_KERNEL

#include "common_cuda_helper.hpp"
#include "pytorch_cuda_helper.hpp"

#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)

const int CUDA_NUM_THREADS = 1024;
inline int GET_BLOCKS(const int N, const int num_threads) {
return (N + num_threads - 1) / num_threads;
Expand Down Expand Up @@ -213,7 +210,7 @@ __global__ void ms_deformable_im2col_gpu_kernel(
const int spatial_size, const int num_heads, const int channels,
const int num_levels, const int num_query, const int num_point,
scalar_t *data_col) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
Expand Down Expand Up @@ -271,7 +268,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
Expand Down Expand Up @@ -362,7 +359,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
__shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
__shared__ scalar_t cache_grad_attn_weight[blockSize];
unsigned int tid = threadIdx.x;
Expand Down Expand Up @@ -455,7 +452,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
Expand Down Expand Up @@ -547,7 +544,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
Expand Down Expand Up @@ -650,7 +647,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
extern __shared__ int _s[];
scalar_t *cache_grad_sampling_loc = reinterpret_cast<scalar_t *>(_s);
scalar_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
Expand Down Expand Up @@ -753,7 +750,7 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
const int channels, const int num_levels, const int num_query,
const int num_point, scalar_t *grad_value, scalar_t *grad_sampling_loc,
scalar_t *grad_attn_weight) {
CUDA_KERNEL_LOOP(index, n) {
CUDA_1D_KERNEL_LOOP(index, n) {
int _temp = index;
const int c_col = _temp % channels;
_temp /= channels;
Expand Down Expand Up @@ -807,4 +804,4 @@ __global__ void ms_deformable_col2im_gpu_kernel_gm(
}
}
}
#endif // OPENMMLAB_PRS_ADD_MSD_ATTEN_MMCV_OPS_CSRC_MS_DEFORM_ATTN_CUDA_KERNEL_CUH_
#endif // DEFORM_ATTN_CUDA_KERNEL

0 comments on commit a06add4

Please sign in to comment.