From bd3c10f49ee038f0f953e520a9f60098ca6b7230 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sat, 2 Apr 2022 19:09:31 +0300 Subject: [PATCH 01/13] diff_iou_rotated is working --- mmcv/ops/__init__.py | 1 + .../pytorch/cuda/diff_iou_rotated_cuda.cu | 160 ++++++++++++++++++ mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp | 62 +++++++ mmcv/ops/csrc/pytorch/pybind.cpp | 6 + mmcv/ops/diff_iou_rotated.py | 22 +++ 5 files changed, 251 insertions(+) create mode 100644 mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu create mode 100644 mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp create mode 100644 mmcv/ops/diff_iou_rotated.py diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index bdd39fcae7..b18fbc4e38 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -65,6 +65,7 @@ from .tin_shift import TINShift, tin_shift from .upfirdn2d import upfirdn2d from .voxelize import Voxelization, voxelization +from .diff_iou_rotated import diff_iou_rotated_sort_vertices __all__ = [ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', diff --git a/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu new file mode 100644 index 0000000000..cfb4b69a2e --- /dev/null +++ b/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu @@ -0,0 +1,160 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#define TOTAL_THREADS 512 +#define MAX_NUM_VERT_IDX 9 +#define INTERSECTION_OFFSET 8 +#define EPSILON 1e-8 + + +inline int opt_n_thread(int work_size){ + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + return max(min(1< 0 && y2 < 0) + return true; + if (y1 < 0 && y2 > 0) + return false; + + float n1 = x1*x1 + y1*y1 + EPSILON; + float n2 = x2*x2 + y2*y2 + EPSILON; + + if (y1 > 0 && y2 > 0){ + if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 > EPSILON) + return true; + else + return false; + } + if (y1 < 0 && y2 < 0) { + if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 < EPSILON) + return true; + else + return false; + } +} + +__global__ void sort_vertices_kernel(int b, int n, int m, + const float *__restrict__ vertices, + const bool *__restrict__ mask, + const int *__restrict__ num_valid, + int *__restrict__ idx){ + int batch_idx = blockIdx.x; + vertices += batch_idx * n * m *2; + mask += batch_idx * n * m; + num_valid += batch_idx * n; + idx += batch_idx * n * MAX_NUM_VERT_IDX; + + int index = threadIdx.x; // index of polygon + int stride = blockDim.x; + for (int i = index; i>>(b, n, m, vertices, mask, num_valid, idx); + AT_CUDA_CHECK(cudaGetLastError()); +} diff --git a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp new file mode 100644 index 0000000000..58446e9a6f --- /dev/null +++ b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp @@ -0,0 +1,62 @@ +#include +#include +#include + +#define CHECK_CUDA(x) \ + do { \ + TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); \ + } while (0) + +#define CHECK_CONTIGUOUS(x) \ + do { \ + TORCH_CHECK(x.is_contiguous(), #x " must ne a contiguous tensor"); \ + } while (0) + +#define CHECK_IS_INT(x) \ + do { \ + TORCH_CHECK(x.scalar_type()==at::ScalarType::Int, \ + #x " must be a int tensor"); \ + } while (0) + +#define CHECK_IS_FLOAT(x) \ + do { \ + TORCH_CHECK(x.scalar_type()==at::ScalarType::Float, \ + #x " must be a float tensor"); \ + } while (0) + +#define CHECK_IS_BOOL(x) \ + do { \ + TORCH_CHECK(x.scalar_type()==at::ScalarType::Bool, \ + #x " must be a bool tensor"); \ + } while (0) + +#define MAX_NUM_VERT_IDX 9 + +void sort_vertices_wrapper(int b, int n, int m, const float *vertices, const bool *mask, const int *num_valid, int* idx); + +at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid){ + CHECK_CONTIGUOUS(vertices); + CHECK_CONTIGUOUS(mask); + CHECK_CONTIGUOUS(num_valid); + CHECK_CUDA(vertices); + CHECK_CUDA(mask); + CHECK_CUDA(num_valid); + CHECK_IS_FLOAT(vertices); + CHECK_IS_BOOL(mask); + CHECK_IS_INT(num_valid); + + int b = vertices.size(0); + int n = vertices.size(1); + int m = vertices.size(2); + at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX}, + at::device(vertices.device()).dtype(at::ScalarType::Int)); + + // fix issue with multi-gpu (kernel only works for cuda:0) + const at::cuda::OptionalCUDAGuard device_guard(device_of(idx)); + + sort_vertices_wrapper(b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); + + return idx; +} + diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 01c84c948f..46fd35a194 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -400,6 +400,9 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious); void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output); +at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, + at::Tensor num_valid); + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), py::arg("kernel"), py::arg("up_x"), py::arg("up_y"), py::arg("down_x"), @@ -809,4 +812,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("polygons"), py::arg("ious")); m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"), py::arg("polygons"), py::arg("output")); + m.def("diff_iou_rotated_sort_vertices", &diff_iou_rotated_sort_vertices, + "diff_iou_rotated_sort_vertices", py::arg("vertices"), + py::arg("mask"), py::arg("num_valid")); } diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py new file mode 100644 index 0000000000..e727794d75 --- /dev/null +++ b/mmcv/ops/diff_iou_rotated.py @@ -0,0 +1,22 @@ +import torch +from torch import nn +from torch.autograd import Function +from ..utils import ext_loader + +ext_module = ext_loader.load_ext('_ext', [ + 'diff_iou_rotated_sort_vertices' +]) + +class SortVertices(Function): + @staticmethod + def forward(ctx, vertices, mask, num_valid): + idx = ext_module.diff_iou_rotated_sort_vertices(vertices, mask, num_valid) + ctx.mark_non_differentiable(idx) + return idx + + @staticmethod + def backward(ctx, gradout): + return () + +diff_iou_rotated_sort_vertices = SortVertices.apply + From ec7002da78ad1387128c31276b0413e4ac5496f8 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sat, 2 Apr 2022 22:46:17 +0300 Subject: [PATCH 02/13] add test; fix lint --- mmcv/ops/__init__.py | 4 +- .../pytorch/cuda/diff_iou_rotated_cuda.cu | 8 +- mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp | 35 +-- mmcv/ops/diff_iou_rotated.py | 285 +++++++++++++++++- tests/test_ops/test_diff_iou_rotated.py | 58 ++++ 5 files changed, 347 insertions(+), 43 deletions(-) create mode 100644 tests/test_ops/test_diff_iou_rotated.py diff --git a/mmcv/ops/__init__.py b/mmcv/ops/__init__.py index b18fbc4e38..fd112f049d 100644 --- a/mmcv/ops/__init__.py +++ b/mmcv/ops/__init__.py @@ -18,6 +18,7 @@ from .deprecated_wrappers import ConvTranspose2d_deprecated as ConvTranspose2d from .deprecated_wrappers import Linear_deprecated as Linear from .deprecated_wrappers import MaxPool2d_deprecated as MaxPool2d +from .diff_iou_rotated import diff_iou_rotated_2d, diff_iou_rotated_3d from .focal_loss import (SigmoidFocalLoss, SoftmaxFocalLoss, sigmoid_focal_loss, softmax_focal_loss) from .furthest_point_sample import (furthest_point_sample, @@ -65,7 +66,6 @@ from .tin_shift import TINShift, tin_shift from .upfirdn2d import upfirdn2d from .voxelize import Voxelization, voxelization -from .diff_iou_rotated import diff_iou_rotated_sort_vertices __all__ = [ 'bbox_overlaps', 'CARAFE', 'CARAFENaive', 'CARAFEPack', 'carafe', @@ -97,5 +97,5 @@ 'SparseMaxPool2d', 'SparseMaxPool3d', 'SparseConvTensor', 'scatter_nd', 'points_in_boxes_part', 'points_in_boxes_cpu', 'points_in_boxes_all', 'points_in_polygons', 'min_area_polygons', 'active_rotated_filter', - 'convex_iou', 'convex_giou' + 'convex_iou', 'convex_giou', 'diff_iou_rotated_2d', 'diff_iou_rotated_3d' ] diff --git a/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu b/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu index cfb4b69a2e..39e4f4d729 100644 --- a/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu +++ b/mmcv/ops/csrc/pytorch/cuda/diff_iou_rotated_cuda.cu @@ -1,3 +1,5 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa #include #include #include @@ -29,7 +31,7 @@ inline dim3 opt_block_config(int x, int y){ /* compare normalized vertices (vertices around (0,0)) -if vertex1 < vertex2 return ture. +if vertex1 < vertex2 return true. order: minimum at x-aixs, become larger in anti-clockwise direction */ __device__ bool compare_vertices(float x1, float y1, float x2, float y2){ @@ -73,7 +75,7 @@ __global__ void sort_vertices_kernel(int b, int n, int m, int index = threadIdx.x; // index of polygon int stride = blockDim.x; for (int i = index; i #include #include - -#define CHECK_CUDA(x) \ - do { \ - TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor"); \ - } while (0) - -#define CHECK_CONTIGUOUS(x) \ - do { \ - TORCH_CHECK(x.is_contiguous(), #x " must ne a contiguous tensor"); \ - } while (0) - -#define CHECK_IS_INT(x) \ - do { \ - TORCH_CHECK(x.scalar_type()==at::ScalarType::Int, \ - #x " must be a int tensor"); \ - } while (0) - -#define CHECK_IS_FLOAT(x) \ - do { \ - TORCH_CHECK(x.scalar_type()==at::ScalarType::Float, \ - #x " must be a float tensor"); \ - } while (0) - -#define CHECK_IS_BOOL(x) \ - do { \ - TORCH_CHECK(x.scalar_type()==at::ScalarType::Bool, \ - #x " must be a bool tensor"); \ - } while (0) +#include "pytorch_cpp_helper.hpp" #define MAX_NUM_VERT_IDX 9 @@ -41,9 +16,6 @@ at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, CHECK_CUDA(vertices); CHECK_CUDA(mask); CHECK_CUDA(num_valid); - CHECK_IS_FLOAT(vertices); - CHECK_IS_BOOL(mask); - CHECK_IS_INT(num_valid); int b = vertices.size(0); int n = vertices.size(1); @@ -59,4 +31,3 @@ at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, return idx; } - diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index e727794d75..d88b0aa251 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -1,16 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/box_intersection_2d.py # noqa +# Adapted from https://github.com/lilanxiao/Rotated_IoU/blob/master/oriented_iou_loss.py # noqa import torch -from torch import nn from torch.autograd import Function + from ..utils import ext_loader -ext_module = ext_loader.load_ext('_ext', [ - 'diff_iou_rotated_sort_vertices' -]) +EPSILON = 1e-8 +ext_module = ext_loader.load_ext('_ext', ['diff_iou_rotated_sort_vertices']) + class SortVertices(Function): + @staticmethod def forward(ctx, vertices, mask, num_valid): - idx = ext_module.diff_iou_rotated_sort_vertices(vertices, mask, num_valid) + idx = ext_module.diff_iou_rotated_sort_vertices( + vertices, mask, num_valid) ctx.mark_non_differentiable(idx) return idx @@ -18,5 +23,273 @@ def forward(ctx, vertices, mask, num_valid): def backward(ctx, gradout): return () -diff_iou_rotated_sort_vertices = SortVertices.apply +def box_intersection_th(corners1: torch.Tensor, corners2: torch.Tensor): + """Find intersection points of rectangles. + Convention: if two edges are collinear, there is no intersection point + + Args: + corners1 (torch.Tensor): B, N, 4, 2 + corners2 (torch.Tensor): B, N, 4, 2 + + Returns: + intersections (torch.Tensor): B, N, 4, 4, 2 + mask (torch.Tensor) : B, N, 4, 4; bool + """ + # build edges from corners + # B, N, 4, 4: Batch, Box, edge, point + line1 = torch.cat([corners1, corners1[:, :, [1, 2, 3, 0], :]], dim=3) + line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3) + # duplicate data to pair each edges from the boxes + # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point + line1_ext = line1.unsqueeze(3).repeat([1, 1, 1, 4, 1]) + line2_ext = line2.unsqueeze(2).repeat([1, 1, 4, 1, 1]) + x1 = line1_ext[..., 0] + y1 = line1_ext[..., 1] + x2 = line1_ext[..., 2] + y2 = line1_ext[..., 3] + x3 = line2_ext[..., 0] + y3 = line2_ext[..., 1] + x4 = line2_ext[..., 2] + y4 = line2_ext[..., 3] + # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection + num = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) + den_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) + t = den_t / num + t[num == .0] = -1. + mask_t = (t > 0) * (t < 1) # intersection on line segment 1 + den_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) + u = -den_u / num + u[num == .0] = -1. + mask_u = (u > 0) * (u < 1) # intersection on line segment 2 + mask = mask_t * mask_u + # overwrite with EPSILON. otherwise numerically unstable + t = den_t / (num + EPSILON) + intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)], + dim=-1) + intersections = intersections * mask.float().unsqueeze(-1) + return intersections, mask + + +def box1_in_box2(corners1: torch.Tensor, corners2: torch.Tensor): + """Check if corners of box1 lie in box2. + Convention: if a corner is exactly on the edge of the other box, + it's also a valid point. + + Args: + corners1 (torch.Tensor): (B, N, 4, 2) + corners2 (torch.Tensor): (B, N, 4, 2) + + Returns: + c1_in_2: (B, N, 4) Bool + """ + a = corners2[:, :, 0:1, :] # (B, N, 1, 2) + b = corners2[:, :, 1:2, :] # (B, N, 1, 2) + d = corners2[:, :, 3:4, :] # (B, N, 1, 2) + ab = b - a # (B, N, 1, 2) + am = corners1 - a # (B, N, 4, 2) + ad = d - a # (B, N, 1, 2) + p_ab = torch.sum(ab * am, dim=-1) # (B, N, 4) + norm_ab = torch.sum(ab * ab, dim=-1) # (B, N, 1) + p_ad = torch.sum(ad * am, dim=-1) # (B, N, 4) + norm_ad = torch.sum(ad * ad, dim=-1) # (B, N, 1) + # NOTE: the expression looks ugly but is stable if the two boxes + # are exactly the same also stable with different scale of bboxes + cond1 = (p_ab / norm_ab > -1e-6) * (p_ab / norm_ab < 1 + 1e-6) # (B, N, 4) + cond2 = (p_ad / norm_ad > -1e-6) * (p_ad / norm_ad < 1 + 1e-6) # (B, N, 4) + return cond1 * cond2 + + +def box_in_box_th(corners1: torch.Tensor, corners2: torch.Tensor): + """Check if corners of two boxes lie in each other. + + Args: + corners1 (torch.Tensor): (B, N, 4, 2) + corners2 (torch.Tensor): (B, N, 4, 2) + + Returns: + c1_in_2: (B, N, 4) Bool. i-th corner of box1 in box2 + c2_in_1: (B, N, 4) Bool. i-th corner of box2 in box1 + """ + c1_in_2 = box1_in_box2(corners1, corners2) + c2_in_1 = box1_in_box2(corners2, corners1) + return c1_in_2, c2_in_1 + + +def build_vertices(corners1: torch.Tensor, corners2: torch.Tensor, + c1_in_2: torch.Tensor, c2_in_1: torch.Tensor, + inters: torch.Tensor, mask_inter: torch.Tensor): + """Find vertices of intersection area. + + Args: + corners1 (torch.Tensor): (B, N, 4, 2) + corners2 (torch.Tensor): (B, N, 4, 2) + c1_in_2 (torch.Tensor): Bool, (B, N, 4) + c2_in_1 (torch.Tensor): Bool, (B, N, 4) + inters (torch.Tensor): (B, N, 4, 4, 2) + mask_inter (torch.Tensor): (B, N, 4, 4) + + Returns: + vertices (torch.Tensor): (B, N, 24, 2) vertices of intersection area; + only some elements are valid + mask (torch.Tensor): (B, N, 24) indicates valid elements in vertices + """ + # NOTE: inter has elements equals zero and has zeros gradient + # (masked by multiplying with 0); can be used as trick + B = corners1.size()[0] + N = corners1.size()[1] + # (B, N, 4 + 4 + 16, 2) + vertices = torch.cat( + [corners1, corners2, inters.view([B, N, -1, 2])], dim=2) + # Bool (B, N, 4 + 4 + 16) + mask = torch.cat([c1_in_2, c2_in_1, mask_inter.view([B, N, -1])], dim=2) + return vertices, mask + + +def sort_indices(vertices: torch.Tensor, mask: torch.Tensor): + """Sort indices. + + Args: + vertices (torch.Tensor): float (B, N, 24, 2) + mask (torch.Tensor): bool (B, N, 24) + + Returns: + sorted_index: bool (B, N, 9) + + Note: + why 9? the polygon has maximal 8 vertices. + +1 to duplicate the first element. + the index should have following structure: + (A, B, C, ... , A, X, X, X) + and X indicates the index of arbitrary elements in the last + 16 (intersections not corners) with + value 0 and mask False. (cause they have zero value and zero gradient) + """ + num_valid = torch.sum(mask.int(), dim=2).int() # (B, N) + mean = torch.sum( + vertices * mask.float().unsqueeze(-1), dim=2, + keepdim=True) / num_valid.unsqueeze(-1).unsqueeze(-1) + vertices_normalized = vertices - mean # normalization makes sorting easier + return SortVertices.apply(vertices_normalized, mask, num_valid).long() + + +def calculate_area(idx_sorted: torch.Tensor, vertices: torch.Tensor): + """Calculate area of intersection. + + Args: + idx_sorted (torch.Tensor): (B, N, 9) + vertices (torch.Tensor): (B, N, 24, 2) + + return: + area: (B, N), area of intersection + selected: (B, N, 9, 2), vertices of polygon with zero padding + """ + idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2]) + selected = torch.gather(vertices, 2, idx_ext) + total = selected[:, :, 0:-1, 0] * selected[:, :, 1:, 1] \ + - selected[:, :, 0:-1, 1] * selected[:, :, 1:, 0] + total = torch.sum(total, dim=2) + area = torch.abs(total) / 2 + return area, selected + + +def oriented_box_intersection_2d(corners1: torch.Tensor, + corners2: torch.Tensor): + """Calculate intersection area of 2d rectangles. + + Args: + corners1 (torch.Tensor): (B, N, 4, 2) + corners2 (torch.Tensor): (B, N, 4, 2) + + Returns: + area: (B, N), area of intersection + selected: (B, N, 9, 2), vertices of polygon with zero padding + """ + inters, mask_inter = box_intersection_th(corners1, corners2) + c12, c21 = box_in_box_th(corners1, corners2) + vertices, mask = build_vertices(corners1, corners2, c12, c21, inters, + mask_inter) + sorted_indices = sort_indices(vertices, mask) + return calculate_area(sorted_indices, vertices) + + +def box2corners_th(box: torch.Tensor): + """Convert box coordinate to corners. + + Args: + box (torch.Tensor): (B, N, 5) with x, y, w, h, alpha + + Returns: + torch.Tensor: (B, N, 4, 2) corners + """ + B = box.size()[0] + x = box[..., 0:1] + y = box[..., 1:2] + w = box[..., 2:3] + h = box[..., 3:4] + alpha = box[..., 4:5] # (B, N, 1) + x4 = torch.FloatTensor([0.5, -0.5, + -0.5, 0.5]).unsqueeze(0).unsqueeze(0).to( + box.device) # (1,1,4) + x4 = x4 * w # (B, N, 4) + y4 = torch.FloatTensor([0.5, 0.5, -0.5, + -0.5]).unsqueeze(0).unsqueeze(0).to(box.device) + y4 = y4 * h # (B, N, 4) + corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) + sin = torch.sin(alpha) + cos = torch.cos(alpha) + row1 = torch.cat([cos, sin], dim=-1) + row2 = torch.cat([-sin, cos], dim=-1) # (B, N, 2) + rot_T = torch.stack([row1, row2], dim=-2) # (B, N, 2, 2) + rotated = torch.bmm(corners.view([-1, 4, 2]), rot_T.view([-1, 2, 2])) + rotated = rotated.view([B, -1, 4, 2]) # (B * N, 4, 2) -> (B, N, 4, 2) + rotated[..., 0] += x + rotated[..., 1] += y + return rotated + + +def diff_iou_rotated_2d(box1: torch.Tensor, box2: torch.Tensor): + """Calculate differentiable iou of 2d boxes. + Args: + box1 (torch.Tensor): (B, N, 5) + box2 (torch.Tensor): (B, N, 5) + + Returns: + iou (torch.Tensor): (B, N) + """ + corners1 = box2corners_th(box1) + corners2 = box2corners_th(box2) + inter_area, _ = oriented_box_intersection_2d(corners1, corners2) # (B, N) + area1 = box1[:, :, 2] * box1[:, :, 3] + area2 = box2[:, :, 2] * box2[:, :, 3] + u = area1 + area2 - inter_area + iou = inter_area / u + return iou + + +def diff_iou_rotated_3d(box3d1: torch.Tensor, box3d2: torch.Tensor): + """Calculate differentiable iou of 3d boxes. + + Args: + box3d1 (torch.Tensor): (B, N, 3+3+1), (x,y,z,w,h,l,alpha) + box3d2 (torch.Tensor): (B, N, 3+3+1), (x,y,z,w,h,l,alpha) + + Returns: + iou (torch.Tensor): (B, N) + """ + box1 = box3d1[..., [0, 1, 3, 4, 6]] # 2d box + box2 = box3d2[..., [0, 1, 3, 4, 6]] + corners1 = box2corners_th(box1) + corners2 = box2corners_th(box2) + inter_area, _ = oriented_box_intersection_2d(corners1, corners2) + zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 + zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 + zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5 + zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5 + z_overlap = (torch.min(zmax1, zmax2) - + torch.max(zmin1, zmin2)).clamp_min(0.) + intersection_3d = inter_area * z_overlap + v1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] + v2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] + u3d = v1 + v2 - intersection_3d + return intersection_3d / u3d diff --git a/tests/test_ops/test_diff_iou_rotated.py b/tests/test_ops/test_diff_iou_rotated.py new file mode 100644 index 0000000000..a48ea10cf1 --- /dev/null +++ b/tests/test_ops/test_diff_iou_rotated.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest +import torch + +from mmcv.ops import diff_iou_rotated_2d, diff_iou_rotated_3d + + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_diff_iou_rotated_2d(): + np_boxes1 = np.asarray([[ + [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0] + ]], dtype=np.float32) + np_boxes2 = np.asarray([[ + [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., np.pi / 2], + [0.5, 0.5, 1., 1., np.pi / 4], + [1., 1., 1., 1., .0], + [1.5, 1.5, 1., 1., .0] + ]], dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]]) + ious = diff_iou_rotated_2d(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + +@pytest.mark.skipif( + not torch.cuda.is_available(), reason='requires CUDA support') +def test_diff_iou_rotated_3d(): + np_boxes1 = np.asarray([[ + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0] + ]], dtype=np.float32) + np_boxes2 = np.asarray([[ + [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 2., np.pi / 2], + [.5, .5, .5, 1., 1., 1., np.pi / 4], + [1., 1., 1., 1., 1., 1., .0], + [-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0] + ]], dtype=np.float32) + + boxes1 = torch.from_numpy(np_boxes1).cuda() + boxes2 = torch.from_numpy(np_boxes2).cuda() + + np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]]) + ious = diff_iou_rotated_3d(boxes1, boxes2) + assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + From c18a8ce644d50fe0932898c7a64f0cbfbc29b9ea Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sat, 2 Apr 2022 22:57:53 +0300 Subject: [PATCH 03/13] fix lint for test --- tests/test_ops/test_diff_iou_rotated.py | 51 ++++++++++--------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/test_ops/test_diff_iou_rotated.py b/tests/test_ops/test_diff_iou_rotated.py index a48ea10cf1..01e05551b0 100644 --- a/tests/test_ops/test_diff_iou_rotated.py +++ b/tests/test_ops/test_diff_iou_rotated.py @@ -9,45 +9,37 @@ @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') def test_diff_iou_rotated_2d(): - np_boxes1 = np.asarray([[ - [0.5, 0.5, 1., 1., .0], - [0.5, 0.5, 1., 1., .0], - [0.5, 0.5, 1., 1., .0], - [0.5, 0.5, 1., 1., .0], - [0.5, 0.5, 1., 1., .0] - ]], dtype=np.float32) - np_boxes2 = np.asarray([[ - [0.5, 0.5, 1., 1., .0], - [0.5, 0.5, 1., 1., np.pi / 2], - [0.5, 0.5, 1., 1., np.pi / 4], - [1., 1., 1., 1., .0], - [1.5, 1.5, 1., 1., .0] - ]], dtype=np.float32) + np_boxes1 = np.asarray([[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., .0], + [0.5, 0.5, 1., 1., .0]]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[[0.5, 0.5, 1., 1., .0], [0.5, 0.5, 1., 1., np.pi / 2], + [0.5, 0.5, 1., 1., np.pi / 4], [1., 1., 1., 1., .0], + [1.5, 1.5, 1., 1., .0]]], + dtype=np.float32) boxes1 = torch.from_numpy(np_boxes1).cuda() boxes2 = torch.from_numpy(np_boxes2).cuda() - + np_expect_ious = np.asarray([[1., 1., .7071, 1 / 7, .0]]) ious = diff_iou_rotated_2d(boxes1, boxes2) assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) + @pytest.mark.skipif( not torch.cuda.is_available(), reason='requires CUDA support') def test_diff_iou_rotated_3d(): - np_boxes1 = np.asarray([[ - [.5, .5, .5, 1., 1., 1., .0], - [.5, .5, .5, 1., 1., 1., .0], - [.5, .5, .5, 1., 1., 1., .0], - [.5, .5, .5, 1., 1., 1., .0], - [.5, .5, .5, 1., 1., 1., .0] - ]], dtype=np.float32) - np_boxes2 = np.asarray([[ - [.5, .5, .5, 1., 1., 1., .0], - [.5, .5, .5, 1., 1., 2., np.pi / 2], - [.5, .5, .5, 1., 1., 1., np.pi / 4], - [1., 1., 1., 1., 1., 1., .0], - [-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0] - ]], dtype=np.float32) + np_boxes1 = np.asarray( + [[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 1., .0], + [.5, .5, .5, 1., 1., 1., .0]]], + dtype=np.float32) + np_boxes2 = np.asarray( + [[[.5, .5, .5, 1., 1., 1., .0], [.5, .5, .5, 1., 1., 2., np.pi / 2], + [.5, .5, .5, 1., 1., 1., np.pi / 4], [1., 1., 1., 1., 1., 1., .0], + [-1.5, -1.5, -1.5, 2.5, 2.5, 2.5, .0]]], + dtype=np.float32) boxes1 = torch.from_numpy(np_boxes1).cuda() boxes2 = torch.from_numpy(np_boxes2).cuda() @@ -55,4 +47,3 @@ def test_diff_iou_rotated_3d(): np_expect_ious = np.asarray([[1., .5, .7071, 1 / 15, .0]]) ious = diff_iou_rotated_3d(boxes1, boxes2) assert np.allclose(ious.cpu().numpy(), np_expect_ious, atol=1e-4) - From 7a9db895b0c290c65c3ec453096da297076eb7b4 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Sun, 3 Apr 2022 22:59:34 +0300 Subject: [PATCH 04/13] disable cpu build --- mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp | 55 +++++++++++++--------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp index 4c1e60defc..8bda4f406c 100644 --- a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp @@ -1,33 +1,42 @@ // Copyright (c) OpenMMLab. All rights reserved // Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert.cpp # noqa -#include -#include -#include #include "pytorch_cpp_helper.hpp" +#ifdef MMCV_WITH_CUDA +#include + #define MAX_NUM_VERT_IDX 9 void sort_vertices_wrapper(int b, int n, int m, const float *vertices, const bool *mask, const int *num_valid, int* idx); +#endif at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid){ - CHECK_CONTIGUOUS(vertices); - CHECK_CONTIGUOUS(mask); - CHECK_CONTIGUOUS(num_valid); - CHECK_CUDA(vertices); - CHECK_CUDA(mask); - CHECK_CUDA(num_valid); - - int b = vertices.size(0); - int n = vertices.size(1); - int m = vertices.size(2); - at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX}, - at::device(vertices.device()).dtype(at::ScalarType::Int)); - - // fix issue with multi-gpu (kernel only works for cuda:0) - const at::cuda::OptionalCUDAGuard device_guard(device_of(idx)); - - sort_vertices_wrapper(b, n, m, vertices.data_ptr(), mask.data_ptr(), - num_valid.data_ptr(), idx.data_ptr()); - - return idx; + if (vertices.device().is_cuda()) { +#ifdef MMCV_WITH_CUDA + CHECK_CONTIGUOUS(vertices); + CHECK_CONTIGUOUS(mask); + CHECK_CONTIGUOUS(num_valid); + CHECK_CUDA(vertices); + CHECK_CUDA(mask); + CHECK_CUDA(num_valid); + + int b = vertices.size(0); + int n = vertices.size(1); + int m = vertices.size(2); + at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX}, + at::device(vertices.device()).dtype(at::ScalarType::Int)); + + // fix issue with multi-gpu (kernel only works for cuda:0) + const at::cuda::OptionalCUDAGuard device_guard(device_of(idx)); + + sort_vertices_wrapper(b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); + + return idx; +#else + AT_ERROR("group_points is not compiled with GPU support"); +#endif + } else { + AT_ERROR("group_points is not implemented on CPU"); + } } From 3b156af68933b7948d94a79f8f5013e2359de910 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Tue, 5 Apr 2022 09:39:13 +0300 Subject: [PATCH 05/13] refactor files structure --- .../cuda/diff_iou_rotated_cuda_kernel.cuh | 155 +++++++++++++++ mmcv/ops/csrc/pytorch/cuda/cudabind.cpp | 13 ++ .../pytorch/cuda/diff_iou_rotated_cuda.cu | 184 +++--------------- mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp | 43 +--- mmcv/ops/csrc/pytorch/pybind.cpp | 9 +- mmcv/ops/diff_iou_rotated.py | 5 +- 6 files changed, 211 insertions(+), 198 deletions(-) create mode 100644 mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh diff --git a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh new file mode 100644 index 0000000000..6808c7eaae --- /dev/null +++ b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh @@ -0,0 +1,155 @@ +// Copyright (c) OpenMMLab. All rights reserved +// Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa +#include +#include +#include +#include +#include +#include +#include +#include + +#define TOTAL_THREADS 512 +#define MAX_NUM_VERT_IDX 9 +#define INTERSECTION_OFFSET 8 +#define EPSILON 1e-8 + + +inline int opt_n_thread(int work_size){ + const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); + return max(min(1< 0 && y2 < 0) + return true; + if (y1 < 0 && y2 > 0) + return false; + + float n1 = x1*x1 + y1*y1 + EPSILON; + float n2 = x2*x2 + y2*y2 + EPSILON; + + if (y1 > 0 && y2 > 0){ + if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 > EPSILON) + return true; + else + return false; + } + if (y1 < 0 && y2 < 0) { + if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 < EPSILON) + return true; + else + return false; + } +} + +__global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel( + int b, int n, int m, const float *__restrict__ vertices, + const bool *__restrict__ mask, const int *__restrict__ num_valid, + int *__restrict__ idx){ + int batch_idx = blockIdx.x; + vertices += batch_idx * n * m *2; + mask += batch_idx * n * m; + num_valid += batch_idx * n; + idx += batch_idx * n * MAX_NUM_VERT_IDX; + + int index = threadIdx.x; // index of polygon + int stride = blockDim.x; + for (int i = index; i -#include -#include -#include -#include -#include -#include -#include +#include "diff_iou_rotated_cuda_kernel.cuh" +#include "pytorch_cuda_helper.hpp" +#include "pytorch_cpp_helper.hpp" -#define TOTAL_THREADS 512 -#define MAX_NUM_VERT_IDX 9 -#define INTERSECTION_OFFSET 8 -#define EPSILON 1e-8 +// #define MAX_NUM_VERT_IDX 9 - -inline int opt_n_thread(int work_size){ - const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); - return max(min(1< 0 && y2 < 0) - return true; - if (y1 < 0 && y2 > 0) - return false; - - float n1 = x1*x1 + y1*y1 + EPSILON; - float n2 = x2*x2 + y2*y2 + EPSILON; - - if (y1 > 0 && y2 > 0){ - if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 > EPSILON) - return true; - else - return false; - } - if (y1 < 0 && y2 < 0) { - if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 < EPSILON) - return true; - else - return false; - } -} - -__global__ void sort_vertices_kernel(int b, int n, int m, - const float *__restrict__ vertices, - const bool *__restrict__ mask, - const int *__restrict__ num_valid, - int *__restrict__ idx){ - int batch_idx = blockIdx.x; - vertices += batch_idx * n * m *2; - mask += batch_idx * n * m; - num_valid += batch_idx * n; - idx += batch_idx * n * MAX_NUM_VERT_IDX; - - int index = threadIdx.x; // index of polygon - int stride = blockDim.x; - for (int i = index; i>>(b, n, m, vertices, mask, num_valid, idx); + + CHECK_CONTIGUOUS(vertices); + CHECK_CONTIGUOUS(mask); + CHECK_CONTIGUOUS(num_valid); + CHECK_CUDA(vertices); + CHECK_CUDA(mask); + CHECK_CUDA(num_valid); + + int b = vertices.size(0); + int n = vertices.size(1); + int m = vertices.size(2); + at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX}, + at::device(vertices.device()).dtype(at::ScalarType::Int)); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + vertices.scalar_type(), "diff_iou_rotated_sort_vertices_forward_cuda_kernel", ([&] { + diff_iou_rotated_sort_vertices_forward_cuda_kernel + <<>>( + b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); + })); AT_CUDA_CHECK(cudaGetLastError()); + + return idx; } diff --git a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp index 8bda4f406c..5699c0e139 100644 --- a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp @@ -1,42 +1,11 @@ // Copyright (c) OpenMMLab. All rights reserved -// Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert.cpp # noqa #include "pytorch_cpp_helper.hpp" +#include "pytorch_device_registry.hpp" -#ifdef MMCV_WITH_CUDA -#include - -#define MAX_NUM_VERT_IDX 9 - -void sort_vertices_wrapper(int b, int n, int m, const float *vertices, const bool *mask, const int *num_valid, int* idx); -#endif - -at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, at::Tensor num_valid){ - if (vertices.device().is_cuda()) { -#ifdef MMCV_WITH_CUDA - CHECK_CONTIGUOUS(vertices); - CHECK_CONTIGUOUS(mask); - CHECK_CONTIGUOUS(num_valid); - CHECK_CUDA(vertices); - CHECK_CUDA(mask); - CHECK_CUDA(num_valid); - - int b = vertices.size(0); - int n = vertices.size(1); - int m = vertices.size(2); - at::Tensor idx = torch::zeros({b, n, MAX_NUM_VERT_IDX}, - at::device(vertices.device()).dtype(at::ScalarType::Int)); - - // fix issue with multi-gpu (kernel only works for cuda:0) - const at::cuda::OptionalCUDAGuard device_guard(device_of(idx)); - - sort_vertices_wrapper(b, n, m, vertices.data_ptr(), mask.data_ptr(), - num_valid.data_ptr(), idx.data_ptr()); +Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, Tensor num_valid){ + return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, vertices, mask, num_valid); +} - return idx; -#else - AT_ERROR("group_points is not compiled with GPU support"); -#endif - } else { - AT_ERROR("group_points is not implemented on CPU"); - } +Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask, Tensor num_valid){ + return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid); } diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 46fd35a194..18bed3ba84 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -400,8 +400,8 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious); void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output); -at::Tensor diff_iou_rotated_sort_vertices(at::Tensor vertices, at::Tensor mask, - at::Tensor num_valid); +at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices, at::Tensor mask, + at::Tensor num_valid); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)", py::arg("input"), @@ -812,7 +812,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("polygons"), py::arg("ious")); m.def("convex_giou", &convex_giou, "convex_giou", py::arg("pointsets"), py::arg("polygons"), py::arg("output")); - m.def("diff_iou_rotated_sort_vertices", &diff_iou_rotated_sort_vertices, - "diff_iou_rotated_sort_vertices", py::arg("vertices"), + m.def("diff_iou_rotated_sort_vertices_forward", + &diff_iou_rotated_sort_vertices_forward, + "diff_iou_rotated_sort_vertices_forward", py::arg("vertices"), py::arg("mask"), py::arg("num_valid")); } diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index d88b0aa251..33c37c6635 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -7,14 +7,15 @@ from ..utils import ext_loader EPSILON = 1e-8 -ext_module = ext_loader.load_ext('_ext', ['diff_iou_rotated_sort_vertices']) +ext_module = ext_loader.load_ext('_ext', + ['diff_iou_rotated_sort_vertices_forward']) class SortVertices(Function): @staticmethod def forward(ctx, vertices, mask, num_valid): - idx = ext_module.diff_iou_rotated_sort_vertices( + idx = ext_module.diff_iou_rotated_sort_vertices_forward( vertices, mask, num_valid) ctx.mark_non_differentiable(idx) return idx From cca9fd1aa3e31ea21098e739b2c2a47e8bbdeddb Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Thu, 7 Apr 2022 23:27:57 +0300 Subject: [PATCH 06/13] fix comments --- .../cuda/diff_iou_rotated_cuda_kernel.cuh | 41 +++++++------------ .../pytorch/cuda/diff_iou_rotated_cuda.cu | 13 ++---- mmcv/ops/diff_iou_rotated.py | 11 ++--- 3 files changed, 22 insertions(+), 43 deletions(-) diff --git a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh index 6808c7eaae..88a62d57c1 100644 --- a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh @@ -1,15 +1,11 @@ // Copyright (c) OpenMMLab. All rights reserved // Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa -#include -#include -#include -#include -#include -#include -#include -#include +#ifdef MMCV_USE_PARROTS +#include "parrots_cuda_helper.hpp" +#else +#include "pytorch_cuda_helper.hpp" +#endif -#define TOTAL_THREADS 512 #define MAX_NUM_VERT_IDX 9 #define INTERSECTION_OFFSET 8 #define EPSILON 1e-8 @@ -17,18 +13,9 @@ inline int opt_n_thread(int work_size){ const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); - return max(min(1< 0 && y2 > 0){ - if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 > EPSILON) + if (diff > EPSILON) return true; else return false; } if (y1 < 0 && y2 < 0) { - if (fabs(x1)*x1/n1 - fabs(x2)*x2/n2 < EPSILON) + if (diff < EPSILON) return true; else return false; @@ -95,6 +83,9 @@ __global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel( float x_min = 1; float y_min = -EPSILON; int i_take = 0; + int i2 = idx[i*MAX_NUM_VERT_IDX + j - 1]; + float x2 = vertices[i*m*2 + i2*2 + 0]; + float y2 = vertices[i*m*2 + i2*2 + 1]; for (int k=0; k>>( - b, n, m, vertices.data_ptr(), mask.data_ptr(), - num_valid.data_ptr(), idx.data_ptr()); - })); + diff_iou_rotated_sort_vertices_forward_cuda_kernel + <<>>( + b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); AT_CUDA_CHECK(cudaGetLastError()); return idx; diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index 33c37c6635..ae3f245922 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -58,11 +58,11 @@ def box_intersection_th(corners1: torch.Tensor, corners2: torch.Tensor): den_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) t = den_t / num t[num == .0] = -1. - mask_t = (t > 0) * (t < 1) # intersection on line segment 1 + mask_t = (t > 0) & (t < 1) # intersection on line segment 1 den_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) u = -den_u / num u[num == .0] = -1. - mask_u = (u > 0) * (u < 1) # intersection on line segment 2 + mask_u = (u > 0) & (u < 1) # intersection on line segment 2 mask = mask_t * mask_u # overwrite with EPSILON. otherwise numerically unstable t = den_t / (num + EPSILON) @@ -229,12 +229,9 @@ def box2corners_th(box: torch.Tensor): w = box[..., 2:3] h = box[..., 3:4] alpha = box[..., 4:5] # (B, N, 1) - x4 = torch.FloatTensor([0.5, -0.5, - -0.5, 0.5]).unsqueeze(0).unsqueeze(0).to( - box.device) # (1,1,4) + x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device) x4 = x4 * w # (B, N, 4) - y4 = torch.FloatTensor([0.5, 0.5, -0.5, - -0.5]).unsqueeze(0).unsqueeze(0).to(box.device) + y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device) y4 = y4 * h # (B, N, 4) corners = torch.stack([x4, y4], dim=-1) # (B, N, 4, 2) sin = torch.sin(alpha) From 950edd1c3c22657cf4ae662b777c002ea4cd7009 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Fri, 8 Apr 2022 09:44:47 +0300 Subject: [PATCH 07/13] remove extra .repeat() --- mmcv/ops/diff_iou_rotated.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index ae3f245922..f699f81530 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -43,8 +43,8 @@ def box_intersection_th(corners1: torch.Tensor, corners2: torch.Tensor): line2 = torch.cat([corners2, corners2[:, :, [1, 2, 3, 0], :]], dim=3) # duplicate data to pair each edges from the boxes # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point - line1_ext = line1.unsqueeze(3).repeat([1, 1, 1, 4, 1]) - line2_ext = line2.unsqueeze(2).repeat([1, 1, 4, 1, 1]) + line1_ext = line1.unsqueeze(3) + line2_ext = line2.unsqueeze(2) x1 = line1_ext[..., 0] y1 = line1_ext[..., 1] x2 = line1_ext[..., 2] From 6dd99118c9d4cdf9b1243e05a854aae7d2c3add5 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Fri, 8 Apr 2022 20:52:37 +0300 Subject: [PATCH 08/13] add comment --- mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh index 88a62d57c1..6c63656c55 100644 --- a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh @@ -78,6 +78,7 @@ __global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel( } else { // sort the valid vertices // note the number of valid vertices is known + // note: check that num_valid[i] < MAX_NUM_VERT_IDX for (int j=0; j Date: Mon, 11 Apr 2022 23:00:34 +0300 Subject: [PATCH 09/13] fix j-1 bug; update doc --- docs/en/understand_mmcv/ops.md | 1 + docs/zh_cn/understand_mmcv/ops.md | 1 + .../cuda/diff_iou_rotated_cuda_kernel.cuh | 23 ++++++++----------- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index 4ff81374ba..f100558fc4 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -13,6 +13,7 @@ We implement common CUDA ops used in detection, segmentation, etc. - CornerPool - Deformable Convolution v1/v2 - Deformable RoIPool +- DiffIoURotated - DynamicScatter - GatherPoints - FurthestPointSample diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 0d0a9c9103..776b05536f 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -13,6 +13,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子 - CornerPool - Deformable Convolution v1/v2 - Deformable RoIPool +- DiffIoURotated - DynamicScatter - GatherPoints - FurthestPointSample diff --git a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh index 6c63656c55..4fc6fd6f06 100644 --- a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh @@ -84,22 +84,19 @@ __global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel( float x_min = 1; float y_min = -EPSILON; int i_take = 0; - int i2 = idx[i*MAX_NUM_VERT_IDX + j - 1]; - float x2 = vertices[i*m*2 + i2*2 + 0]; - float y2 = vertices[i*m*2 + i2*2 + 1]; + int i2; + float x2, y2; + if (j != 0) { + i2 = idx[i*MAX_NUM_VERT_IDX + j - 1]; + x2 = vertices[i*m*2 + i2*2 + 0]; + y2 = vertices[i*m*2 + i2*2 + 1]; + } for (int k=0; k Date: Tue, 12 Apr 2022 13:10:11 +0300 Subject: [PATCH 10/13] fix clang lint --- .../cuda/diff_iou_rotated_cuda_kernel.cuh | 212 +++++++++--------- mmcv/ops/csrc/pytorch/cuda/cudabind.cpp | 21 +- .../pytorch/cuda/diff_iou_rotated_cuda.cu | 48 ++-- mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp | 11 +- mmcv/ops/csrc/pytorch/pybind.cpp | 3 +- 5 files changed, 151 insertions(+), 144 deletions(-) diff --git a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh index 4fc6fd6f06..3ee1814e12 100644 --- a/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh +++ b/mmcv/ops/csrc/common/cuda/diff_iou_rotated_cuda_kernel.cuh @@ -1,5 +1,6 @@ // Copyright (c) OpenMMLab. All rights reserved -// Adapted from https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa +// Adapted from +// https://github.com/lilanxiao/Rotated_IoU/cuda_op/sort_vert_kernel.cu # noqa #ifdef MMCV_USE_PARROTS #include "parrots_cuda_helper.hpp" #else @@ -10,10 +11,9 @@ #define INTERSECTION_OFFSET 8 #define EPSILON 1e-8 - -inline int opt_n_thread(int work_size){ - const int pow_2 = std::log(static_cast(work_size)) / std::log(2.0); - return max(min(1<(work_size)) / std::log(2.0); + return max(min(1 << pow_2, THREADS_PER_BLOCK), 1); } /* @@ -21,120 +21,116 @@ compare normalized vertices (vertices around (0,0)) if vertex1 < vertex2 return true. order: minimum at x-aixs, become larger in anti-clockwise direction */ -__device__ bool compare_vertices(float x1, float y1, float x2, float y2){ - - if (fabs(x1-x2) 0 && y2 < 0) - return true; - if (y1 < 0 && y2 > 0) - return false; + if (y1 > 0 && y2 < 0) return true; + if (y1 < 0 && y2 > 0) return false; - float n1 = x1*x1 + y1*y1 + EPSILON; - float n2 = x2*x2 + y2*y2 + EPSILON; - float diff = fabs(x1)*x1/n1 - fabs(x2)*x2/n2; + float n1 = x1 * x1 + y1 * y1 + EPSILON; + float n2 = x2 * x2 + y2 * y2 + EPSILON; + float diff = fabs(x1) * x1 / n1 - fabs(x2) * x2 / n2; - if (y1 > 0 && y2 > 0){ - if (diff > EPSILON) - return true; - else - return false; - } - if (y1 < 0 && y2 < 0) { - if (diff < EPSILON) - return true; - else - return false; - } + if (y1 > 0 && y2 > 0) { + if (diff > EPSILON) + return true; + else + return false; + } + if (y1 < 0 && y2 < 0) { + if (diff < EPSILON) + return true; + else + return false; + } } __global__ void diff_iou_rotated_sort_vertices_forward_cuda_kernel( - int b, int n, int m, const float *__restrict__ vertices, - const bool *__restrict__ mask, const int *__restrict__ num_valid, - int *__restrict__ idx){ - int batch_idx = blockIdx.x; - vertices += batch_idx * n * m *2; - mask += batch_idx * n * m; - num_valid += batch_idx * n; - idx += batch_idx * n * MAX_NUM_VERT_IDX; + int b, int n, int m, const float *__restrict__ vertices, + const bool *__restrict__ mask, const int *__restrict__ num_valid, + int *__restrict__ idx) { + int batch_idx = blockIdx.x; + vertices += batch_idx * n * m * 2; + mask += batch_idx * n * m; + num_valid += batch_idx * n; + idx += batch_idx * n * MAX_NUM_VERT_IDX; - int index = threadIdx.x; // index of polygon - int stride = blockDim.x; - for (int i = index; i>>( - b, n, m, vertices.data_ptr(), mask.data_ptr(), - num_valid.data_ptr(), idx.data_ptr()); - AT_CUDA_CHECK(cudaGetLastError()); + diff_iou_rotated_sort_vertices_forward_cuda_kernel<<>>( + b, n, m, vertices.data_ptr(), mask.data_ptr(), + num_valid.data_ptr(), idx.data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); - return idx; + return idx; } diff --git a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp index 5699c0e139..2361b7fbe5 100644 --- a/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/diff_iou_rotated.cpp @@ -2,10 +2,13 @@ #include "pytorch_cpp_helper.hpp" #include "pytorch_device_registry.hpp" -Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, Tensor num_valid){ - return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, vertices, mask, num_valid); +Tensor diff_iou_rotated_sort_vertices_forward_impl(Tensor vertices, Tensor mask, + Tensor num_valid) { + return DISPATCH_DEVICE_IMPL(diff_iou_rotated_sort_vertices_forward_impl, + vertices, mask, num_valid); } -Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask, Tensor num_valid){ - return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid); +Tensor diff_iou_rotated_sort_vertices_forward(Tensor vertices, Tensor mask, + Tensor num_valid) { + return diff_iou_rotated_sort_vertices_forward_impl(vertices, mask, num_valid); } diff --git a/mmcv/ops/csrc/pytorch/pybind.cpp b/mmcv/ops/csrc/pytorch/pybind.cpp index 18bed3ba84..b53ef3fb10 100644 --- a/mmcv/ops/csrc/pytorch/pybind.cpp +++ b/mmcv/ops/csrc/pytorch/pybind.cpp @@ -400,7 +400,8 @@ void convex_iou(const Tensor pointsets, const Tensor polygons, Tensor ious); void convex_giou(const Tensor pointsets, const Tensor polygons, Tensor output); -at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices, at::Tensor mask, +at::Tensor diff_iou_rotated_sort_vertices_forward(at::Tensor vertices, + at::Tensor mask, at::Tensor num_valid); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { From 24b735c7abe5c648f4847f2563cb47f2a0a27079 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Wed, 13 Apr 2022 13:53:55 +0300 Subject: [PATCH 11/13] update docstrings --- mmcv/ops/diff_iou_rotated.py | 145 ++++++++++++++++++----------------- 1 file changed, 74 insertions(+), 71 deletions(-) diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index f699f81530..baf1e82c6b 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -25,17 +25,18 @@ def backward(ctx, gradout): return () -def box_intersection_th(corners1: torch.Tensor, corners2: torch.Tensor): +def box_intersection(corners1, corners2): """Find intersection points of rectangles. - Convention: if two edges are collinear, there is no intersection point + Convention: if two edges are collinear, there is no intersection point. Args: - corners1 (torch.Tensor): B, N, 4, 2 - corners2 (torch.Tensor): B, N, 4, 2 + corners1 (Tensor): (B, N, 4, 2) Corners of the first box. + corners2 (Tensor): (B, N, 4, 2) Corners of the second box. Returns: - intersections (torch.Tensor): B, N, 4, 4, 2 - mask (torch.Tensor) : B, N, 4, 4; bool + Tuple: + - Tensor: (B, N, 4, 4, 2) Intersections. + - Tensor: (B, N, 4, 4) Valid intersections mask. """ # build edges from corners # B, N, 4, 4: Batch, Box, edge, point @@ -72,17 +73,17 @@ def box_intersection_th(corners1: torch.Tensor, corners2: torch.Tensor): return intersections, mask -def box1_in_box2(corners1: torch.Tensor, corners2: torch.Tensor): +def box1_in_box2(corners1, corners2): """Check if corners of box1 lie in box2. Convention: if a corner is exactly on the edge of the other box, it's also a valid point. Args: - corners1 (torch.Tensor): (B, N, 4, 2) - corners2 (torch.Tensor): (B, N, 4, 2) + corners1 (Tensor): (B, N, 4, 2) Corners of the first box. + corners2 (Tensor): (B, N, 4, 2) Corners of the second box. Returns: - c1_in_2: (B, N, 4) Bool + Tensor: (B, N, 4) Intersection. """ a = corners2[:, :, 0:1, :] # (B, N, 1, 2) b = corners2[:, :, 1:2, :] # (B, N, 1, 2) @@ -101,39 +102,39 @@ def box1_in_box2(corners1: torch.Tensor, corners2: torch.Tensor): return cond1 * cond2 -def box_in_box_th(corners1: torch.Tensor, corners2: torch.Tensor): +def box_in_box(corners1, corners2): """Check if corners of two boxes lie in each other. Args: - corners1 (torch.Tensor): (B, N, 4, 2) - corners2 (torch.Tensor): (B, N, 4, 2) + corners1 (Tensor): (B, N, 4, 2) Corners of the first box. + corners2 (Tensor): (B, N, 4, 2) Corners of the second box. Returns: - c1_in_2: (B, N, 4) Bool. i-th corner of box1 in box2 - c2_in_1: (B, N, 4) Bool. i-th corner of box2 in box1 + Tuple: + - Tensor: (B, N, 4) True if i-th corner of box1 is in box2. + - Tensor: (B, N, 4) True if i-th corner of box2 is in box1. """ c1_in_2 = box1_in_box2(corners1, corners2) c2_in_1 = box1_in_box2(corners2, corners1) return c1_in_2, c2_in_1 -def build_vertices(corners1: torch.Tensor, corners2: torch.Tensor, - c1_in_2: torch.Tensor, c2_in_1: torch.Tensor, - inters: torch.Tensor, mask_inter: torch.Tensor): +def build_vertices(corners1, corners2, c1_in_2, c2_in_1, inters, mask_inter): """Find vertices of intersection area. Args: - corners1 (torch.Tensor): (B, N, 4, 2) - corners2 (torch.Tensor): (B, N, 4, 2) - c1_in_2 (torch.Tensor): Bool, (B, N, 4) - c2_in_1 (torch.Tensor): Bool, (B, N, 4) - inters (torch.Tensor): (B, N, 4, 4, 2) - mask_inter (torch.Tensor): (B, N, 4, 4) + corners1 (Tensor): (B, N, 4, 2) Corners of the first box. + corners2 (Tensor): (B, N, 4, 2) Corners of the second box. + c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2. + c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1. + inters (Tensor): (B, N, 4, 4, 2) Intersections. + mask_inter (Tensor): (B, N, 4, 4) Valid intersections mask. Returns: - vertices (torch.Tensor): (B, N, 24, 2) vertices of intersection area; - only some elements are valid - mask (torch.Tensor): (B, N, 24) indicates valid elements in vertices + Tuple: + - Tensor: (B, N, 24, 2) Vertices of intersection area; + only some elements are valid. + - Tensor: (B, N, 24) Mask of valid elements in vertices. """ # NOTE: inter has elements equals zero and has zeros gradient # (masked by multiplying with 0); can be used as trick @@ -147,24 +148,24 @@ def build_vertices(corners1: torch.Tensor, corners2: torch.Tensor, return vertices, mask -def sort_indices(vertices: torch.Tensor, mask: torch.Tensor): +def sort_indices(vertices, mask): """Sort indices. - - Args: - vertices (torch.Tensor): float (B, N, 24, 2) - mask (torch.Tensor): bool (B, N, 24) - - Returns: - sorted_index: bool (B, N, 9) - Note: why 9? the polygon has maximal 8 vertices. +1 to duplicate the first element. the index should have following structure: (A, B, C, ... , A, X, X, X) and X indicates the index of arbitrary elements in the last - 16 (intersections not corners) with - value 0 and mask False. (cause they have zero value and zero gradient) + 16 (intersections not corners) with value 0 and mask False. + (cause they have zero value and zero gradient) + + Args: + vertices (Tensor): (B, N, 24, 2) Box vertices. + mask (Tensor): (B, N, 24) Mask. + + Returns: + Tensor: (B, N, 9) Sorted indices. + """ num_valid = torch.sum(mask.int(), dim=2).int() # (B, N) mean = torch.sum( @@ -174,16 +175,17 @@ def sort_indices(vertices: torch.Tensor, mask: torch.Tensor): return SortVertices.apply(vertices_normalized, mask, num_valid).long() -def calculate_area(idx_sorted: torch.Tensor, vertices: torch.Tensor): +def calculate_area(idx_sorted, vertices): """Calculate area of intersection. Args: - idx_sorted (torch.Tensor): (B, N, 9) - vertices (torch.Tensor): (B, N, 24, 2) + idx_sorted (Tensor): (B, N, 9) Sorted vertex ids. + vertices (Tensor): (B, N, 24, 2) Vertices. - return: - area: (B, N), area of intersection - selected: (B, N, 9, 2), vertices of polygon with zero padding + Returns: + Tuple: + - Tensor (B, N): Area of intersection. + - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding. """ idx_ext = idx_sorted.unsqueeze(-1).repeat([1, 1, 1, 2]) selected = torch.gather(vertices, 2, idx_ext) @@ -194,34 +196,34 @@ def calculate_area(idx_sorted: torch.Tensor, vertices: torch.Tensor): return area, selected -def oriented_box_intersection_2d(corners1: torch.Tensor, - corners2: torch.Tensor): - """Calculate intersection area of 2d rectangles. +def oriented_box_intersection_2d(corners1, corners2): + """Calculate intersection area of 2d rotated boxes. Args: - corners1 (torch.Tensor): (B, N, 4, 2) - corners2 (torch.Tensor): (B, N, 4, 2) + corners1 (Tensor): (B, N, 4, 2) Corners of the first box. + corners2 (Tensor): (B, N, 4, 2) Corners of the second box. Returns: - area: (B, N), area of intersection - selected: (B, N, 9, 2), vertices of polygon with zero padding + Tuple: + - Tensor (B, N): Area of intersection. + - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding. """ - inters, mask_inter = box_intersection_th(corners1, corners2) - c12, c21 = box_in_box_th(corners1, corners2) + inters, mask_inter = box_intersection(corners1, corners2) + c12, c21 = box_in_box(corners1, corners2) vertices, mask = build_vertices(corners1, corners2, c12, c21, inters, mask_inter) sorted_indices = sort_indices(vertices, mask) return calculate_area(sorted_indices, vertices) -def box2corners_th(box: torch.Tensor): - """Convert box coordinate to corners. +def box2corners(box): + """Convert rotated 2d box coordinate to corners. Args: - box (torch.Tensor): (B, N, 5) with x, y, w, h, alpha + box (Tensor): (B, N, 5) with x, y, w, h, alpha. Returns: - torch.Tensor: (B, N, 4, 2) corners + Tensor: (B, N, 4, 2) Corners. """ B = box.size()[0] x = box[..., 0:1] @@ -246,17 +248,18 @@ def box2corners_th(box: torch.Tensor): return rotated -def diff_iou_rotated_2d(box1: torch.Tensor, box2: torch.Tensor): - """Calculate differentiable iou of 2d boxes. +def diff_iou_rotated_2d(box1, box2): + """Calculate differentiable iou of rotated 2d boxes. + Args: - box1 (torch.Tensor): (B, N, 5) - box2 (torch.Tensor): (B, N, 5) + box1 (Tensor): (B, N, 5) First box. + box2 (Tensor): (B, N, 5) Second box. Returns: - iou (torch.Tensor): (B, N) + Tensor: (B, N) IoU. """ - corners1 = box2corners_th(box1) - corners2 = box2corners_th(box2) + corners1 = box2corners(box1) + corners2 = box2corners(box2) inter_area, _ = oriented_box_intersection_2d(corners1, corners2) # (B, N) area1 = box1[:, :, 2] * box1[:, :, 3] area2 = box2[:, :, 2] * box2[:, :, 3] @@ -265,20 +268,20 @@ def diff_iou_rotated_2d(box1: torch.Tensor, box2: torch.Tensor): return iou -def diff_iou_rotated_3d(box3d1: torch.Tensor, box3d2: torch.Tensor): - """Calculate differentiable iou of 3d boxes. +def diff_iou_rotated_3d(box3d1, box3d2): + """Calculate differentiable iou of rotated 3d boxes. Args: - box3d1 (torch.Tensor): (B, N, 3+3+1), (x,y,z,w,h,l,alpha) - box3d2 (torch.Tensor): (B, N, 3+3+1), (x,y,z,w,h,l,alpha) + box3d1 (Tensor): (B, N, 3+3+1) First box (x,y,z,w,h,l,alpha). + box3d2 (Tensor): (B, N, 3+3+1) Second box (x,y,z,w,h,l,alpha). Returns: - iou (torch.Tensor): (B, N) + Tensor: (B, N) IoU. """ box1 = box3d1[..., [0, 1, 3, 4, 6]] # 2d box box2 = box3d2[..., [0, 1, 3, 4, 6]] - corners1 = box2corners_th(box1) - corners2 = box2corners_th(box2) + corners1 = box2corners(box1) + corners2 = box2corners(box2) inter_area, _ = oriented_box_intersection_2d(corners1, corners2) zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 From a946822fa34494372fb29ca4ed662512f7cb2796 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Wed, 13 Apr 2022 20:05:47 +0300 Subject: [PATCH 12/13] fix comments --- mmcv/ops/diff_iou_rotated.py | 79 ++++++++++++++++++------------------ 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index baf1e82c6b..7026f6d049 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -30,8 +30,8 @@ def box_intersection(corners1, corners2): Convention: if two edges are collinear, there is no intersection point. Args: - corners1 (Tensor): (B, N, 4, 2) Corners of the first box. - corners2 (Tensor): (B, N, 4, 2) Corners of the second box. + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tuple: @@ -46,27 +46,21 @@ def box_intersection(corners1, corners2): # (B, N, 4, 4) -> (B, N, 4, 4, 4) : Batch, Box, edge1, edge2, point line1_ext = line1.unsqueeze(3) line2_ext = line2.unsqueeze(2) - x1 = line1_ext[..., 0] - y1 = line1_ext[..., 1] - x2 = line1_ext[..., 2] - y2 = line1_ext[..., 3] - x3 = line2_ext[..., 0] - y3 = line2_ext[..., 1] - x4 = line2_ext[..., 2] - y4 = line2_ext[..., 3] + x1, y1, x2, y2 = line1_ext.split([1, 1, 1, 1], dim=-1) + x3, y3, x4, y4 = line2_ext.split([1, 1, 1, 1], dim=-1) # math: https://en.wikipedia.org/wiki/Line%E2%80%93line_intersection - num = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) - den_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) - t = den_t / num - t[num == .0] = -1. + numerator = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4) + denumerator_t = (x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4) + t = denumerator_t / numerator + t[numerator == .0] = -1. mask_t = (t > 0) & (t < 1) # intersection on line segment 1 - den_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) - u = -den_u / num - u[num == .0] = -1. + denumerator_u = (x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3) + u = -denumerator_u / numerator + u[numerator == .0] = -1. mask_u = (u > 0) & (u < 1) # intersection on line segment 2 mask = mask_t * mask_u # overwrite with EPSILON. otherwise numerically unstable - t = den_t / (num + EPSILON) + t = denumerator_t / (numerator + EPSILON) intersections = torch.stack([x1 + t * (x2 - x1), y1 + t * (y2 - y1)], dim=-1) intersections = intersections * mask.float().unsqueeze(-1) @@ -79,26 +73,30 @@ def box1_in_box2(corners1, corners2): it's also a valid point. Args: - corners1 (Tensor): (B, N, 4, 2) Corners of the first box. - corners2 (Tensor): (B, N, 4, 2) Corners of the second box. + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tensor: (B, N, 4) Intersection. """ + # a, b, c, d - 4 vertices of box2 a = corners2[:, :, 0:1, :] # (B, N, 1, 2) b = corners2[:, :, 1:2, :] # (B, N, 1, 2) d = corners2[:, :, 3:4, :] # (B, N, 1, 2) + # ab, am, ad - vectors between corresponding vertices ab = b - a # (B, N, 1, 2) am = corners1 - a # (B, N, 4, 2) ad = d - a # (B, N, 1, 2) - p_ab = torch.sum(ab * am, dim=-1) # (B, N, 4) + prod_ab = torch.sum(ab * am, dim=-1) # (B, N, 4) norm_ab = torch.sum(ab * ab, dim=-1) # (B, N, 1) - p_ad = torch.sum(ad * am, dim=-1) # (B, N, 4) + prod_ad = torch.sum(ad * am, dim=-1) # (B, N, 4) norm_ad = torch.sum(ad * ad, dim=-1) # (B, N, 1) # NOTE: the expression looks ugly but is stable if the two boxes # are exactly the same also stable with different scale of bboxes - cond1 = (p_ab / norm_ab > -1e-6) * (p_ab / norm_ab < 1 + 1e-6) # (B, N, 4) - cond2 = (p_ad / norm_ad > -1e-6) * (p_ad / norm_ad < 1 + 1e-6) # (B, N, 4) + cond1 = (prod_ab / norm_ab > -1e-6) * (prod_ab / norm_ab < 1 + 1e-6 + ) # (B, N, 4) + cond2 = (prod_ad / norm_ad > -1e-6) * (prod_ad / norm_ad < 1 + 1e-6 + ) # (B, N, 4) return cond1 * cond2 @@ -106,8 +104,8 @@ def box_in_box(corners1, corners2): """Check if corners of two boxes lie in each other. Args: - corners1 (Tensor): (B, N, 4, 2) Corners of the first box. - corners2 (Tensor): (B, N, 4, 2) Corners of the second box. + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tuple: @@ -123,8 +121,8 @@ def build_vertices(corners1, corners2, c1_in_2, c2_in_1, inters, mask_inter): """Find vertices of intersection area. Args: - corners1 (Tensor): (B, N, 4, 2) Corners of the first box. - corners2 (Tensor): (B, N, 4, 2) Corners of the second box. + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2. c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1. inters (Tensor): (B, N, 4, 4, 2) Intersections. @@ -200,8 +198,8 @@ def oriented_box_intersection_2d(corners1, corners2): """Calculate intersection area of 2d rotated boxes. Args: - corners1 (Tensor): (B, N, 4, 2) Corners of the first box. - corners2 (Tensor): (B, N, 4, 2) Corners of the second box. + corners1 (Tensor): (B, N, 4, 2) First batch of boxes. + corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. Returns: Tuple: @@ -260,11 +258,12 @@ def diff_iou_rotated_2d(box1, box2): """ corners1 = box2corners(box1) corners2 = box2corners(box2) - inter_area, _ = oriented_box_intersection_2d(corners1, corners2) # (B, N) + intersection, _ = oriented_box_intersection_2d(corners1, + corners2) # (B, N) area1 = box1[:, :, 2] * box1[:, :, 3] area2 = box2[:, :, 2] * box2[:, :, 3] - u = area1 + area2 - inter_area - iou = inter_area / u + union = area1 + area2 - intersection + iou = intersection / union return iou @@ -282,15 +281,15 @@ def diff_iou_rotated_3d(box3d1, box3d2): box2 = box3d2[..., [0, 1, 3, 4, 6]] corners1 = box2corners(box1) corners2 = box2corners(box2) - inter_area, _ = oriented_box_intersection_2d(corners1, corners2) + intersection, _ = oriented_box_intersection_2d(corners1, corners2) zmax1 = box3d1[..., 2] + box3d1[..., 5] * 0.5 zmin1 = box3d1[..., 2] - box3d1[..., 5] * 0.5 zmax2 = box3d2[..., 2] + box3d2[..., 5] * 0.5 zmin2 = box3d2[..., 2] - box3d2[..., 5] * 0.5 z_overlap = (torch.min(zmax1, zmax2) - - torch.max(zmin1, zmin2)).clamp_min(0.) - intersection_3d = inter_area * z_overlap - v1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] - v2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] - u3d = v1 + v2 - intersection_3d - return intersection_3d / u3d + torch.max(zmin1, zmin2)).clamp_(min=0.) + intersection_3d = intersection * z_overlap + volume1 = box3d1[..., 3] * box3d1[..., 4] * box3d1[..., 5] + volume2 = box3d2[..., 3] * box3d2[..., 4] * box3d2[..., 5] + union_3d = volume1 + volume2 - intersection_3d + return intersection_3d / union_3d From 21db87dcec3fa76bd57a3ab24e08d95cafcf81f5 Mon Sep 17 00:00:00 2001 From: Danila Rukhovich Date: Thu, 14 Apr 2022 17:13:25 +0300 Subject: [PATCH 13/13] fix comments --- mmcv/ops/diff_iou_rotated.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/mmcv/ops/diff_iou_rotated.py b/mmcv/ops/diff_iou_rotated.py index 7026f6d049..26bdbecf6e 100644 --- a/mmcv/ops/diff_iou_rotated.py +++ b/mmcv/ops/diff_iou_rotated.py @@ -117,7 +117,8 @@ def box_in_box(corners1, corners2): return c1_in_2, c2_in_1 -def build_vertices(corners1, corners2, c1_in_2, c2_in_1, inters, mask_inter): +def build_vertices(corners1, corners2, c1_in_2, c2_in_1, intersections, + valid_mask): """Find vertices of intersection area. Args: @@ -125,8 +126,8 @@ def build_vertices(corners1, corners2, c1_in_2, c2_in_1, inters, mask_inter): corners2 (Tensor): (B, N, 4, 2) Second batch of boxes. c1_in_2 (Tensor): (B, N, 4) True if i-th corner of box1 is in box2. c2_in_1 (Tensor): (B, N, 4) True if i-th corner of box2 is in box1. - inters (Tensor): (B, N, 4, 4, 2) Intersections. - mask_inter (Tensor): (B, N, 4, 4) Valid intersections mask. + intersections (Tensor): (B, N, 4, 4, 2) Intersections. + valid_mask (Tensor): (B, N, 4, 4) Valid intersections mask. Returns: Tuple: @@ -140,9 +141,10 @@ def build_vertices(corners1, corners2, c1_in_2, c2_in_1, inters, mask_inter): N = corners1.size()[1] # (B, N, 4 + 4 + 16, 2) vertices = torch.cat( - [corners1, corners2, inters.view([B, N, -1, 2])], dim=2) + [corners1, corners2, + intersections.view([B, N, -1, 2])], dim=2) # Bool (B, N, 4 + 4 + 16) - mask = torch.cat([c1_in_2, c2_in_1, mask_inter.view([B, N, -1])], dim=2) + mask = torch.cat([c1_in_2, c2_in_1, valid_mask.view([B, N, -1])], dim=2) return vertices, mask @@ -204,12 +206,12 @@ def oriented_box_intersection_2d(corners1, corners2): Returns: Tuple: - Tensor (B, N): Area of intersection. - - Tensor: (B, N, 9, 2) Vertices of polygon with zero padding. + - Tensor (B, N, 9, 2): Vertices of polygon with zero padding. """ - inters, mask_inter = box_intersection(corners1, corners2) + intersections, valid_mask = box_intersection(corners1, corners2) c12, c21 = box_in_box(corners1, corners2) - vertices, mask = build_vertices(corners1, corners2, c12, c21, inters, - mask_inter) + vertices, mask = build_vertices(corners1, corners2, c12, c21, + intersections, valid_mask) sorted_indices = sort_indices(vertices, mask) return calculate_area(sorted_indices, vertices) @@ -224,11 +226,7 @@ def box2corners(box): Tensor: (B, N, 4, 2) Corners. """ B = box.size()[0] - x = box[..., 0:1] - y = box[..., 1:2] - w = box[..., 2:3] - h = box[..., 3:4] - alpha = box[..., 4:5] # (B, N, 1) + x, y, w, h, alpha = box.split([1, 1, 1, 1, 1], dim=-1) x4 = torch.FloatTensor([0.5, -0.5, -0.5, 0.5]).to(box.device) x4 = x4 * w # (B, N, 4) y4 = torch.FloatTensor([0.5, 0.5, -0.5, -0.5]).to(box.device)