Skip to content

Commit

Permalink
[Fix] Fix arf op's write conflict when num_orientations is not 1 (ope…
Browse files Browse the repository at this point in the history
  • Loading branch information
dflhw authored and Danielmic committed Jun 30, 2023
1 parent 6e23b6a commit d0df68d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 12 deletions.
15 changes: 9 additions & 6 deletions mmcv/ops/csrc/common/cuda/active_rotated_filter_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,19 @@ __global__ void active_rotated_filter_forward_cuda_kernel(
const int nthreads, const scalar_t* weight_data, const int* indices_data,
const int num_input_planes, const int num_output_planes,
const int num_orientations, const int num_rotations, const int nEntry,
scalar_t* output_data) {
const int kH, const int kW, scalar_t* output_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry;
int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes;
int k;
int fmIndex = (l / (kH * kW)) * kH * kW;
scalar_t val = *(weight_data + index);
for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t* target = output_data +
i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx;
scalar_t* target =
output_data + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx + fmIndex;
*target = val;
}
}
Expand All @@ -37,20 +38,22 @@ __global__ void active_rotated_filter_backward_cuda_kernel(
const int nthreads, const scalar_t* gradWeight_data,
const int* indices_data, const int num_input_planes,
const int num_output_planes, const int num_orientations,
const int num_rotations, const int nEntry, scalar_t* weight_data) {
const int num_rotations, const int nEntry, const int kH, const int kW,
scalar_t* weight_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int l = index % nEntry;
int j = (index / nEntry) % num_input_planes;
int i = index / nEntry / num_input_planes;
int k;
int fmIndex = (l / (kH * kW)) * kH * kW;
scalar_t* val = weight_data + index;
*val = 0;
scalar_t tmp = 0;
for (k = 0; k < num_rotations; k++) {
int idx = (int)(*(indices_data + l * num_rotations + k)) - 1;
scalar_t target =
*(gradWeight_data + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + idx);
k * (num_input_planes * nEntry) + j * (nEntry) + idx + fmIndex);
tmp = tmp + target;
}
*val = tmp;
Expand Down
10 changes: 6 additions & 4 deletions mmcv/ops/csrc/pytorch/cpu/active_rotated_filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ void active_rotated_filter_forward_cpu_kernel(
for (l = 0; l < nEntry; l++) {
int weightIndex = i * num_input_planes * nEntry + j * nEntry + l;
T val = *(weightData + weightIndex);
int fmIndex = (l / (kH * kW)) * kH * kW;
for (k = 0; k < num_rotations; k++) {
int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
T* target = outputData +
i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index;
T* target =
outputData + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index + fmIndex;
*target = val;
}
}
Expand All @@ -48,11 +49,12 @@ void active_rotated_filter_backward_cpu_kernel(
int gradInputIndex = i * num_input_planes * nEntry + j * nEntry + l;
T* val = gradInputData + gradInputIndex;
*val = 0;
int fmIndex = (l / (kH * kW)) * kH * kW;
for (k = 0; k < num_rotations; k++) {
int index = (int)(*(indicesData + l * num_rotations + k)) - 1;
const T* target =
gradOutputData + i * (num_rotations * num_input_planes * nEntry) +
k * (num_input_planes * nEntry) + j * (nEntry) + index;
k * (num_input_planes * nEntry) + j * (nEntry) + index + fmIndex;
*val = *val + *target;
}
}
Expand Down
4 changes: 2 additions & 2 deletions mmcv/ops/csrc/pytorch/cuda/active_rotated_filter_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ void ActiveRotatedFilterForwardCUDAKernelLauncher(const Tensor input,
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, input.data_ptr<scalar_t>(),
indices.data_ptr<int>(), num_input_planes, num_output_planes,
num_orientations, num_rotations, nEntry,
num_orientations, num_rotations, nEntry, kH, kW,
output.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
Expand All @@ -51,7 +51,7 @@ void ActiveRotatedFilterBackwardCUDAKernelLauncher(const Tensor grad_out,
<<<GET_BLOCKS(output_size), THREADS_PER_BLOCK, 0, stream>>>(
output_size, grad_out.data_ptr<scalar_t>(),
indices.data_ptr<int>(), num_input_planes, num_output_planes,
num_orientations, num_rotations, nEntry,
num_orientations, num_rotations, nEntry, kH, kW,
grad_in.data_ptr<scalar_t>());
});
AT_CUDA_CHECK(cudaGetLastError());
Expand Down

0 comments on commit d0df68d

Please sign in to comment.