Skip to content

Commit

Permalink
[Fix] Fix Correlation op (open-mmlab#2274)
Browse files Browse the repository at this point in the history
* fix correlation

* fix lint
  • Loading branch information
q.yao authored and zhouzaida committed Nov 22, 2022
1 parent ab720ce commit 28d64d5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
6 changes: 5 additions & 1 deletion mmcv/ops/csrc/common/cuda/correlation_cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ template <typename scalar_t>
__global__ void correlation_forward_cuda_kernel(
const TensorAcc4R rInput1, const TensorAcc4R rInput2, TensorAcc5R output,
int kH, int kW, int patchH, int patchW, int padH, int padW, int dilationH,
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW) {
int dilationW, int dilation_patchH, int dilation_patchW, int dH, int dW,
int oH, int oW) {
const int iH = rInput1.size(1);
const int iW = rInput1.size(2);
const int C = rInput1.size(3);

const int n = blockIdx.x;
const int h = blockIdx.y * blockDim.y + threadIdx.y;
const int w = blockIdx.z * blockDim.z + threadIdx.z;

if (h >= oH || w >= oW) return;

const int thread = threadIdx.x;

const int start_i = -padH + h * dH;
Expand Down
2 changes: 1 addition & 1 deletion mmcv/ops/csrc/pytorch/cuda/correlation_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void CorrelationForwardCUDAKernelLauncher(Tensor input1, Tensor input2,
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
trInput1_acc, trInput2_acc, output_acc, kH, kW, patchH, patchW,
padH, padW, dilationH, dilationW, dilation_patchH,
dilation_patchW, dH, dW);
dilation_patchW, dH, dW, oH, oW);
}));
}

Expand Down

0 comments on commit 28d64d5

Please sign in to comment.