Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancment] Optimize correlation #1814

Merged
merged 3 commits into from
Apr 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 83 additions & 89 deletions mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ using namespace torch;
#define TensorAcc5R PackedTensorAccessor32<scalar_t, 5, RestrictPtrTraits>
#define WITHIN_BOUNDS(x, y, H, W) (x >= 0 && x < H && y >= 0 && y < W)

#define THREADS_FORWARD 32
#define THREADS_BACKWARD 16
#define WARP_SIZE 32
#define FULL_MASK 0xffffffff

template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(
Expand All @@ -42,8 +42,8 @@ __global__ void correlation_forward_cuda_kernel(
const int C = rInput1.size(3);

const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z;
const int thread = threadIdx.x;

const int start_i = -padH + h * dH;
Expand All @@ -52,13 +52,11 @@ __global__ void correlation_forward_cuda_kernel(
const int patchRadH = dilation_patchH * (patchH - 1) / 2;
const int patchRadW = dilation_patchW * (patchW - 1) / 2;

__shared__ scalar_t prod_sum[THREADS_FORWARD];

for (int ph = 0; ph < patchH; ++ph) {
int ph_dilated = ph * dilation_patchH - patchRadH;
for (int pw = 0; pw < patchW; ++pw) {
int pw_dilated = pw * dilation_patchW - patchRadW;
prod_sum[thread] = 0;
scalar_t prod_sum = 0.0f;
for (int i = 0; i < kH; ++i) {
int i1 = start_i + i * dilationH;
int i2 = i1 + ph_dilated;
Expand All @@ -69,23 +67,20 @@ __global__ void correlation_forward_cuda_kernel(
int j2 = j1 + pw_dilated;
if
WITHIN_BOUNDS(j1, j2, iW, iW) {
for (int c = thread; c < C; c += THREADS_FORWARD) {
for (int c = thread; c < C; c += WARP_SIZE) {
scalar_t v1 = rInput1[n][i1][j1][c];
scalar_t v2 = rInput2[n][i2][j2][c];
prod_sum[thread] += v1 * v2;
prod_sum += v1 * v2;
}
}
}
}
}
// accumulate
__syncthreads();
for (int offset = 16; offset > 0; offset /= 2)
prod_sum += __shfl_down_sync(FULL_MASK, float(prod_sum), offset);
if (thread == 0) {
scalar_t reduce_sum = 0;
for (int index = 0; index < THREADS_FORWARD; ++index) {
reduce_sum += prod_sum[index];
}
output[n][ph][pw][h][w] = reduce_sum;
output[n][ph][pw][h][w] = prod_sum;
}
}
}
Expand All @@ -97,64 +92,64 @@ __global__ void correlation_backward_cuda_kernel_input1(
TensorAcc4R grad_input1, const int kH, const int kW, const int patchH,
const int patchW, const int padH, const int padW, const int dilationH,
const int dilationW, const int dilation_patchH, const int dilation_patchW,
const int dH, const int dW, const int batch) {
const int iH = input2.size(2);
const int iW = input2.size(3);
const int dH, const int dW) {
const int iH = input2.size(1);
const int iW = input2.size(2);
const int C = input2.size(3);

const int H = grad_output.size(3);
const int W = grad_output.size(4);

const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;

const int n = batch;
const int c = blockIdx.x;
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;

const int h_2 = h + padH;
const int w_2 = w + padW;
const int min_h = h_2 - kH * dilationH;
const int min_w = w_2 - kW * dilationW;

__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;

for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t val = input2[n][c][i1][j1];
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
int j1 = w + dilation_patchW * (pw - patchRadW);

if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t grad_val = 0.0f;
for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if (WITHIN_BOUNDS(i2, j2, H, W)) {
grad_val += grad_output[n][ph][pw][i2][j2];
}
}
}
grad_cache[i] = grad_val;
}
}

__syncthreads();

if (ph_off == 0 && pw_off == 0) {
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
reduce_sum += prod_sum[ph][pw];
for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < patchH; ++ph) {
int i1 = h + dilation_patchH * (ph - patchRadH);
for (int pw = 0; pw < patchW; ++pw) {
int j1 = w + dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input2[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
}
}
grad_input1[n][c][h][w] = reduce_sum;
grad_input1[n][c][h][w] = grad_input_val;
}
}

Expand All @@ -163,9 +158,10 @@ __global__ void correlation_backward_cuda_kernel_input2(
const TensorAcc5R grad_output, const TensorAcc4R input1,
TensorAcc4R grad_input2, int kH, int kW, int patchH, int patchW, int padH,
int padW, int dilationH, int dilationW, int dilation_patchH,
int dilation_patchW, int dH, int dW, int batch) {
const int iH = input1.size(2);
const int iW = input1.size(3);
int dilation_patchW, int dH, int dW) {
const int iH = input1.size(1);
const int iW = input1.size(2);
const int C = input1.size(3);

const int patchRadH = (patchH - 1) / 2;
const int patchRadW = (patchW - 1) / 2;
Expand All @@ -176,56 +172,54 @@ __global__ void correlation_backward_cuda_kernel_input2(
const int dilatedKH = kH * dilationH;
const int dilatedKW = kW * dilationW;

const int n = batch;
const int c = blockIdx.x;
const int n = blockIdx.x;
const int h = blockIdx.y;
const int w = blockIdx.z;
const int ph_off = threadIdx.x;
const int pw_off = threadIdx.y;

__shared__ scalar_t prod_sum[THREADS_BACKWARD][THREADS_BACKWARD];
prod_sum[ph_off][pw_off] = 0;

for (int ph = ph_off; ph < patchH; ph += THREADS_BACKWARD) {
extern __shared__ __align__(sizeof(4)) unsigned char grad_cache_char[];
scalar_t *grad_cache = reinterpret_cast<scalar_t *>(grad_cache_char);
for (int i = threadIdx.x; i < patchH * patchW; i += blockDim.x) {
const int ph = i / patchW;
const int pw = i % patchW;
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = pw_off; pw < patchW; pw += THREADS_BACKWARD) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if
WITHIN_BOUNDS(i1, j1, iH, iW) {
scalar_t val = input1[n][c][i1][j1];

const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;

for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if
WITHIN_BOUNDS(i2, j2, H, W) {
prod_sum[ph_off][pw_off] +=
grad_output[n][ph][pw][i2][j2] * val;
}
}
int j1 = w - dilation_patchW * (pw - patchRadW);

if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
scalar_t grad_val = 0.0f;

const int h_2 = i1 + padH;
const int w_2 = j1 + padW;
const int min_h = h_2 - dilatedKH;
const int min_w = w_2 - dilatedKW;

for (int h_3 = h_2; h_3 > min_h; h_3 -= dilationH) {
int i2 = (h_3) / dH;
if (i2 * dH != h_3) continue;
for (int w_3 = w_2; w_3 > min_w; w_3 -= dilationW) {
int j2 = (w_3) / dW;
if (j2 * dW != w_3) continue;
if (WITHIN_BOUNDS(i2, j2, H, W)) {
grad_val += grad_output[n][ph][pw][i2][j2];
}
}
}
grad_cache[i] = grad_val;
}
}

__syncthreads();

if (ph_off == 0 && pw_off == 0) {
scalar_t reduce_sum = 0;
for (int ph = 0; ph < THREADS_BACKWARD; ++ph) {
for (int pw = 0; pw < THREADS_BACKWARD; ++pw) {
reduce_sum += prod_sum[ph][pw];
for (int c = threadIdx.x; c < C; c += blockDim.x) {
scalar_t grad_input_val = 0.0f;
for (int ph = 0; ph < patchH; ++ph) {
int i1 = h - dilation_patchH * (ph - patchRadH);
for (int pw = 0; pw < patchW; ++pw) {
int j1 = w - dilation_patchW * (pw - patchRadW);
if (WITHIN_BOUNDS(i1, j1, iH, iW)) {
grad_input_val += input1[n][i1][j1][c] * grad_cache[ph * patchW + pw];
}
}
}
grad_input2[n][c][h][w] = reduce_sum;
grad_input2[n][c][h][w] = grad_input_val;
}
}
#endif
41 changes: 21 additions & 20 deletions mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();

const int threads = THREADS_FORWARD;
const dim3 blocks(batch_size, oH, oW);
const dim3 threads(WARP_SIZE, 4, 4);
const dim3 blocks(batch_size, (oH + 3) >> 2, (oW + 3) >> 2);

at::cuda::CUDAGuard device_guard(input1.device());

Expand Down Expand Up @@ -56,38 +56,39 @@ void CorrelationBackwardCUDAKernelLauncher(
const int iW = input1.size(3);
const int C = input1.size(1);

const dim3 blocks(C, iH, iW);
const dim3 threads(THREADS_BACKWARD, THREADS_BACKWARD);
auto trInput1 = input1.permute({0, 2, 3, 1}).contiguous();
auto trInput2 = input2.permute({0, 2, 3, 1}).contiguous();
const dim3 blocks(batch_size, iH, iW);
const dim3 threads(THREADS_PER_BLOCK);

at::cuda::CUDAGuard device_guard(input1.device());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input1.scalar_type(), "correlation_backward_cuda", ([&] {
const int grad_cache_size = patchH * patchW * sizeof(scalar_t);
TensorAcc4R input1_acc =
input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
trInput1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R input2_acc =
input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
trInput2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input1_acc =
grad_input1.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc4R grad_input2_acc =
grad_input2.packed_accessor32<scalar_t, 4, RestrictPtrTraits>();
TensorAcc5R grad_output_acc =
grad_output.packed_accessor32<scalar_t, 5, RestrictPtrTraits>();

for (int n = 0; n < batch_size; ++n) {
correlation_backward_cuda_kernel_input1<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, n);
}
correlation_backward_cuda_kernel_input1<scalar_t>
<<<blocks, threads, grad_cache_size,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input2_acc, grad_input1_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);

for (int n = 0; n < batch_size; ++n) {
correlation_backward_cuda_kernel_input2<scalar_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW, n);
}
correlation_backward_cuda_kernel_input2<scalar_t>
<<<blocks, threads, grad_cache_size,
at::cuda::getCurrentCUDAStream()>>>(
grad_output_acc, input1_acc, grad_input2_acc, kH, kW, patchH,
patchW, padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
}));
}