Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add knn op from mmdet3d #1354

Merged
merged 14 commits into from
Oct 14, 2021
1 change: 1 addition & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- Deformable Convolution v1/v2
- Deformable RoIPool
- GeneralizedAttention
- KNN
- MaskedConv
- NMS
- PSAMask
Expand Down
1 change: 1 addition & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- Deformable Convolution v1/v2
- Deformable RoIPool
- GeneralizedAttention
- KNN
- MaskedConv
- NMS
- PSAMask
Expand Down
2 changes: 2 additions & 0 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
Expand Down Expand Up @@ -52,6 +53,7 @@
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'upfirdn2d', 'FusedBiasLeakyReLU', 'knn', 'fused_bias_leakyrelu',
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
Expand Down
93 changes: 93 additions & 0 deletions mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#ifndef KNN_CUDA_KERNEL_CUH
#define KNN_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
DCNSW marked this conversation as resolved.
Show resolved Hide resolved

__device__ void swap_float(float *x, float *y) {
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
float tmp = *x;
*x = *y;
*y = tmp;
}

__device__ void swap_int(int *x, int *y) {
int tmp = *x;
*x = *y;
*y = tmp;
}

__device__ void reheap(float *dist, int *idx, int k) {
int root = 0;
int child = root * 2 + 1;
while (child < k) {
if (child + 1 < k && dist[child + 1] > dist[child]) child++;
if (dist[root] > dist[child]) return;
swap_float(&dist[root], &dist[child]);
swap_int(&idx[root], &idx[child]);
root = child;
child = root * 2 + 1;
}
}

__device__ void heap_sort(float *dist, int *idx, int k) {
int i;
for (i = k - 1; i > 0; i--) {
swap_float(&dist[0], &dist[i]);
swap_int(&idx[0], &idx[i]);
reheap(dist, idx, i);
}
}

// input: xyz (b, n, 3) new_xyz (b, m, 3)
// output: idx (b, m, nsample) dist2 (b, m, nsample)
template <typename T>
__global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
const T *xyz, const T *new_xyz,
int *__restrict__ idx, T *dist2) {
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;

new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
dist2 += bs_idx * m * nsample + pt_idx * nsample;

T new_x = new_xyz[0];
T new_y = new_xyz[1];
T new_z = new_xyz[2];

float best_dist[100];
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
int best_idx[100];
for (int i = 0; i < nsample; i++) {
best_dist[i] = 1e10;
best_idx[i] = 0;
}
for (int i = 0; i < n; i++) {
T x = xyz[i * 3 + 0];
T y = xyz[i * 3 + 1];
T z = xyz[i * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < best_dist[0]) {
best_dist[0] = d2;
best_idx[0] = i;
reheap(best_dist, best_idx, nsample);
}
}
heap_sort(best_dist, best_idx, nsample);
for (int i = 0; i < nsample; i++) {
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
idx[i] = best_idx[i];
dist2[i] = best_dist[i];
}
}

#endif // KNN_CUDA_KERNEL_CUH
34 changes: 34 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap

#include <cmath>
#include <cstdio>

#include "knn_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
const Tensor xyz, const Tensor new_xyz,
Tensor idx, Tensor dist2) {
// param new_xyz: (B, m, 3)
// param xyz: (B, n, 3)
// param idx: (B, m, nsample)

at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

dim3 blocks(DIVUP(m, THREADS_PER_BLOCK),
b); // blockIdx.x(col), blockIdx.y(row)
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "knn_forward_cuda_kernel", [&] {
knn_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, n, m, nsample, xyz.data_ptr<scalar_t>(),
new_xyz.data_ptr<scalar_t>(), idx.data_ptr<int>(),
dist2.data_ptr<scalar_t>());
});

AT_CUDA_CHECK(cudaGetLastError());
}
33 changes: 33 additions & 0 deletions mmcv/ops/csrc/pytorch/knn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
const Tensor xyz, const Tensor new_xyz,
Tensor idx, Tensor dist2);

void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2);
}
#endif

void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor,
Tensor dist2_tensor) {
if (new_xyz_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(new_xyz_tensor);
CHECK_CUDA_INPUT(xyz_tensor);

knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor,
dist2_tensor);
#else
AT_ERROR("knn is not compiled with GPU support");
#endif
} else {
AT_ERROR("knn is not implemented on CPU");
}
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);

void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor);

void masked_im2col_forward(const Tensor im, const Tensor mask_h_idx,
const Tensor mask_w_idx, Tensor col,
const int kernel_h, const int kernel_w,
Expand Down Expand Up @@ -248,6 +251,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
py::arg("dist2_tensor"));
m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward",
py::arg("features"), py::arg("masks"), py::arg("output"),
py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor"));
Expand Down
73 changes: 73 additions & 0 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['knn_forward'])


class KNN(Function):
r"""KNN (CUDA) based on heap data structure.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/pointops/src/knnquery_heap>`.
DCNSW marked this conversation as resolved.
Show resolved Hide resolved

Find k-nearest points.
"""

@staticmethod
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
k (int): number of nearest neighbors.
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) if transposed == False,
else (B, 3, npoint). centers of the knn query.
transposed (bool): whether the input tensors are transposed.
defaults to False. Should not expicitly use this keyword
when calling knn (=KNN.apply), just add the fourth param.

Returns:
Tensor: (B, k, npoint) tensor with the indicies of
the features that form k-nearest neighbours.
"""
assert k > 0

if center_xyz is None:
center_xyz = xyz

if transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()

assert xyz.is_contiguous() # [B, N, 3]
assert center_xyz.is_contiguous() # [B, npoint, 3]

center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)

B, npoint, _ = center_xyz.shape
N = xyz.shape[1]

idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()

ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx)
return idx

@staticmethod
def backward(ctx, a=None):
return None, None, None


knn = KNN.apply
54 changes: 54 additions & 0 deletions tests/test_ops/test_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

from mmcv.ops import knn


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_knn():
new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625]],
[[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]]).cuda()

xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666,
-0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418,
-1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334],
[-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432],
[0.4917, 1.1529, -1.3496]],
[[-2.0289, 2.4952,
-0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda()

idx = knn(5, xyz, new_xyz)
new_xyz_ = new_xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz_ = xyz.unsqueeze(1).repeat(1, new_xyz.shape[1], 1, 1)
dist = ((new_xyz_ - xyz_) * (new_xyz_ - xyz_)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)

idx = knn(5,
xyz.transpose(1, 2).contiguous(),
new_xyz.transpose(1, 2).contiguous(), True)
assert torch.all(idx == expected_idx)

idx = knn(5, xyz, xyz)
xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1)
dist = ((xyz_ - xyz__) * (xyz_ - xyz__)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)