Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Differentiable rotated IoU #1854

Merged
merged 13 commits into from
Apr 15, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d
from .deprecated_wrappers import Linear_deprecated as Linear
from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d
from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d
from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss,
sigmoid_focal_loss, softmax_focal_loss)
from .furthest_point_sample import (furthest_point_sample,
Expand Down Expand Up @@ -96,5 +97,5 @@
'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd',
'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all',
'points_in_polygons', 'min_area_polygons', 'active_rotated_filter',
'convex_iou', 'convex_giou'
'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d'
]
155 changes: 155 additions & 0 deletions mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#include <ATen/ATen.h>
filaPro marked this conversation as resolved.
Show resolved Hide resolved
#include <ATen/cuda/CUDAContext.h>
#include <cmath>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
#include <stdio.h>
#include <stdlib.h>

#define TOTAL_THREADS 512
filaPro marked this conversation as resolved.
Show resolved Hide resolved
#define MAX_NUM_VERT_IDX 9
#define INTERSECTION_OFFSET 8
#define EPSILON 1e-8
grimoire marked this conversation as resolved.
Show resolved Hide resolved


inline int opt_n_thread(int work_size){
const int pow_2 = std::log(static_cast<double>(work_size)) / std::log(2.0);
return max(min(1<<pow_2, TOTAL_THREADS), 1);
}

inline dim3 opt_block_config(int x, int y){
filaPro marked this conversation as resolved.
Show resolved Hide resolved
const int x_thread = opt_n_thread(x);
const int y_thread = max(min(opt_n_thread(y), TOTAL_THREADS/x_thread), 1);
dim3 block_config(x_thread, y_thread, 1);

return block_config;
}


/*
compare normalized vertices (vertices around (0,0))
if vertex1 < vertex2 return true.
order: minimum at x-aixs, become larger in anti-clockwise direction
*/
__device__ bool compare_vertices(float x1, float y1, float x2, float y2){

if (fabs(x1-x2)<EPSILON && fabs(y2-y1)<EPSILON)
return false; // if equal, return false

if (y1 > 0 && y2 < 0)
return true;
if (y1 < 0 && y2 > 0)
return false;

float n1 = x1*x1 + y1*y1 + EPSILON;
float n2 = x2*x2 + y2*y2 + EPSILON;

if (y1 > 0 && y2 > 0){
if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 > EPSILON)
filaPro marked this conversation as resolved.
Show resolved Hide resolved
return true;
else
return false;
}
if (y1 < 0 && y2 < 0) {
if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 < EPSILON)
return true;
else
return false;
}
}

__global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel(
int b, int n, int m, const float *__restrict__ vertices,
const bool *__restrict__ mask, const int *__restrict__ num_valid,
int *__restrict__ idx){
int batch_idx = blockIdx.x;
vertices += batch_idx * n * m *2;
mask += batch_idx * n * m;
num_valid += batch_idx * n;
idx += batch_idx * n * MAX_NUM_VERT_IDX;

int index = threadIdx.x; // index of polygon
int stride = blockDim.x;
for (int i = index; i<n; i+=stride){
int pad; // index of arbitrary invalid intersection point (not box corner!)
for (int j=INTERSECTION_OFFSET; j<m; ++j){
if (!mask[i*m + j]){
pad = j;
break;
}
}
if (num_valid[i] < 3){
// not enough vertices, take an invalid intersection point
// (zero padding)
for (int j=0; j<MAX_NUM_VERT_IDX; ++j){
idx[i*MAX_NUM_VERT_IDX + j] = pad;
}
} else {
// sort the valid vertices
// note the number of valid vertices is known
for (int j=0; j<num_valid[i]; ++j){
filaPro marked this conversation as resolved.
Show resolved Hide resolved
// initialize with a "big" value
float x_min = 1;
float y_min = -EPSILON;
int i_take = 0;
for (int k=0; k<m; ++k){
float x = vertices[i*m*2 + k*2 + 0];
float y = vertices[i*m*2 + k*2 + 1];
if (j==0){
if (mask[i*m+k] && compare_vertices(x, y, x_min, y_min)){
x_min = x;
y_min = y;
i_take = k;
}
} else {
int i2 = idx[i*MAX_NUM_VERT_IDX + j - 1];
filaPro marked this conversation as resolved.
Show resolved Hide resolved
float x2 = vertices[i*m*2 + i2*2 + 0];
float y2 = vertices[i*m*2 + i2*2 + 1];
if (mask[i*m+k] &&
compare_vertices(x, y, x_min, y_min) &&
compare_vertices(x2, y2, x, y)
){
x_min = x;
y_min = y;
i_take = k;
}
}
idx[i*MAX_NUM_VERT_IDX + j] = i_take;
filaPro marked this conversation as resolved.
Show resolved Hide resolved
}
}
// duplicate the first idx
idx[i*MAX_NUM_VERT_IDX + num_valid[i]] = idx[i*MAX_NUM_VERT_IDX + 0];

// pad zeros
for (int j=num_valid[i]+1; j<MAX_NUM_VERT_IDX; ++j){
idx[i*MAX_NUM_VERT_IDX + j] = pad;
}

// for corner case: the two boxes are exactly the same.
// in this case, idx would have duplicate elements, which makes the shoelace formula broken
// because of the definition, the duplicate elements only appear in the first 8 positions
// (they are "corners in box", not "intersection of edges")
if (num_valid[i] == 8){
int counter = 0;
for (int j=0; j<4; ++j){
int check = idx[i*MAX_NUM_VERT_IDX + j];
for (int k=4; k<INTERSECTION_OFFSET; ++k){
if (idx[i*MAX_NUM_VERT_IDX + k] == check)
counter++;
}
}
if (counter == 4){
idx[i*MAX_NUM_VERT_IDX + 4] = idx[i*MAX_NUM_VERT_IDX + 0];
for (int j = 5; j<MAX_NUM_VERT_IDX; ++j){
idx[i*MAX_NUM_VERT_IDX + j] = pad;
}
}
}

// TODO: still might need to cover some other corner cases :(
}
}
}
13 changes: 13 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1699,3 +1699,16 @@ void convex_giou_impl(const Tensor pointsets, const Tensor polygons,

REGISTER_DEVICE_IMPL(convex_iou_impl, CUDA, convex_iou_cuda);
REGISTER_DEVICE_IMPL(convex_giou_impl, CUDA, convex_giou_cuda);

Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(Tensor vertices, Tensor mask, Tensor num_valid);

Tensor diff_iou_rotated_sort_vertices_forward_cuda(
Tensor vertices, Tensor mask, Tensor num_valid) {
return DiffIoURotatedSortVerticesCUDAKernelLauncher(vertices, mask, num_valid);
}

Tensor diff_iou_rotated_sort_vertices_forward_impl(
Tensor vertices, Tensor mask, Tensor num_valid);

REGISTER_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl,
CUDA, diff_iou_rotated_sort_vertices_forward_cuda);
36 changes: 36 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) OpenMMLab. All rights reserved
// Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa
#include "diff_iou_rotated_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"
#include "pytorch_cpp_helper.hpp"

// #define MAX_NUM_VERT_IDX 9
filaPro marked this conversation as resolved.
Show resolved Hide resolved

at::Tensor DiffIoURotatedSortVerticesCUDAKernelLauncher(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid) {
at::cuda::CUDAGuard device_guard(vertices.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

CHECK_CONTIGUOUS(vertices);
CHECK_CONTIGUOUS(mask);
CHECK_CONTIGUOUS(num_valid);
CHECK_CUDA(vertices);
CHECK_CUDA(mask);
CHECK_CUDA(num_valid);

int b = vertices.size(0);
int n = vertices.size(1);
int m = vertices.size(2);
at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX},
at::device(vertices.device()).dtype(at::ScalarType::Int));

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
filaPro marked this conversation as resolved.
Show resolved Hide resolved
vertices.scalar_type(), "diff_iou_rotated_sort_vertices_forward_cuda_kernel", ([&] {
diff_iou_rotated_sort_vertices_forward_cuda_kernel
<<<b, opt_n_thread(n)>>>(
filaPro marked this conversation as resolved.
Show resolved Hide resolved
b, n, m, vertices.data_ptr<float>(), mask.data_ptr<bool>(),
num_valid.data_ptr<int>(), idx.data_ptr<int>());
}));
AT_CUDA_CHECK(cudaGetLastError());

return idx;
}
11 changes: 11 additions & 0 deletions mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) OpenMMLab. All rights reserved
#include "pytorch_cpp_helper.hpp"
#include "pytorch_device_registry.hpp"

Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, Tensor num_valid){
return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, vertices, mask, num_valid);
}

Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask, Tensor num_valid){
return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid);
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,9 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious);

void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output);

at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices, at::Tensor mask,
at::Tensor num_valid);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"),
py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"),
Expand Down Expand Up @@ -809,4 +812,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("polygons"), py::arg("ious"));
m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"),
py::arg("polygons"), py::arg("output"));
m.def("diff_iou_rotated_sort_vertices_forward",
&diff_iou_rotated_sort_vertices_forward,
"diff_iou_rotated_sort_vertices_forward", py::arg("vertices"),
py::arg("mask"), py::arg("num_valid"));
}
Loading