Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add knn op from mmdet3d #1354

Merged
merged 14 commits into from
Oct 14, 2021
1 change: 1 addition & 0 deletions docs/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ We implement common CUDA ops used in detection, segmentation, etc.
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- KNN
- MaskedConv
- NMS
- PSAMask
Expand Down
1 change: 1 addition & 0 deletions docs_zh_CN/understand_mmcv/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ MMCV 提供了检测、分割等任务中常用的 CUDA 算子
- FurthestPointSample
- FurthestPointSampleWithDist
- GeneralizedAttention
- KNN
- MaskedConv
- NMS
- PSAMask
Expand Down
11 changes: 6 additions & 5 deletions mmcv/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .fused_bias_leakyrelu import FusedBiasLeakyReLU, fused_bias_leakyrelu
from .info import (get_compiler_version, get_compiling_cuda_version,
get_onnxruntime_op_path)
from .knn import knn
from .masked_conv import MaskedConv2d, masked_conv2d
from .modulated_deform_conv import (ModulatedDeformConv2d,
ModulatedDeformConv2dPack,
Expand Down Expand Up @@ -55,9 +56,9 @@
'ConvTranspose2d', 'Linear', 'MaxPool2d', 'CrissCrossAttention', 'PSAMask',
'point_sample', 'rel_roi_point_to_rel_img_point', 'SimpleRoIAlign',
'SAConv2d', 'TINShift', 'tin_shift', 'box_iou_rotated', 'nms_rotated',
'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU', 'fused_bias_leakyrelu',
'RoIAlignRotated', 'roi_align_rotated', 'pixel_group', 'contour_expand',
'MultiScaleDeformableAttention', 'BorderAlign', 'border_align',
'furthest_point_sample', 'furthest_point_sample_with_dist',
'PointsSampler', 'Correlation'
'knn', 'ball_query', 'upfirdn2d', 'FusedBiasLeakyReLU',
'fused_bias_leakyrelu', 'RoIAlignRotated', 'roi_align_rotated',
'pixel_group', 'contour_expand', 'MultiScaleDeformableAttention',
'BorderAlign', 'border_align', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'PointsSampler', 'Correlation'
]
91 changes: 91 additions & 0 deletions mmcv/ops/csrc/common/cuda/knn_cuda_kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap
#ifndef KNN_CUDA_KERNEL_CUH
#define KNN_CUDA_KERNEL_CUH

#ifdef MMCV_USE_PARROTS
#include "parrots_cuda_helper.hpp"
#else
#include "pytorch_cuda_helper.hpp"
#endif

inline __device__ void swap_float(float *x, float *y) {
float tmp = *x;
*x = *y;
*y = tmp;
}

inline __device__ void swap_int(int *x, int *y) {
int tmp = *x;
*x = *y;
*y = tmp;
}

__device__ void reheap(float *dist, int *idx, int k) {
int root = 0;
int child = root * 2 + 1;
while (child < k) {
if (child + 1 < k && dist[child + 1] > dist[child]) child++;
if (dist[root] > dist[child]) return;
swap_float(&dist[root], &dist[child]);
swap_int(&idx[root], &idx[child]);
root = child;
child = root * 2 + 1;
}
}

__device__ void heap_sort(float *dist, int *idx, int k) {
int i;
for (i = k - 1; i > 0; i--) {
swap_float(&dist[0], &dist[i]);
swap_int(&idx[0], &idx[i]);
reheap(dist, idx, i);
}
}

// input: xyz (b, n, 3) new_xyz (b, m, 3)
// output: idx (b, m, nsample) dist2 (b, m, nsample)
template <typename T>
__global__ void knn_forward_cuda_kernel(int b, int n, int m, int nsample,
const T *xyz, const T *new_xyz,
int *__restrict__ idx, T *dist2) {
int bs_idx = blockIdx.y;
int pt_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (bs_idx >= b || pt_idx >= m) return;

new_xyz += bs_idx * m * 3 + pt_idx * 3;
xyz += bs_idx * n * 3;
idx += bs_idx * m * nsample + pt_idx * nsample;
dist2 += bs_idx * m * nsample + pt_idx * nsample;

T new_x = new_xyz[0];
T new_y = new_xyz[1];
T new_z = new_xyz[2];

float best_dist[100];
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
int best_idx[100];
for (int i = 0; i < nsample; i++) {
best_dist[i] = 1e10;
best_idx[i] = 0;
}
for (int i = 0; i < n; i++) {
T x = xyz[i * 3 + 0];
T y = xyz[i * 3 + 1];
T z = xyz[i * 3 + 2];
T d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
(new_z - z) * (new_z - z);
if (d2 < best_dist[0]) {
best_dist[0] = d2;
best_idx[0] = i;
reheap(best_dist, best_idx, nsample);
}
}
heap_sort(best_dist, best_idx, nsample);
for (int i = 0; i < nsample; i++) {
DCNSW marked this conversation as resolved.
Show resolved Hide resolved
idx[i] = best_idx[i];
dist2[i] = best_dist[i];
}
}

#endif // KNN_CUDA_KERNEL_CUH
8 changes: 4 additions & 4 deletions mmcv/ops/csrc/parrots/modulated_deform_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,10 @@ void modulated_deform_conv_forward(
const int kernel_w_ = weight.size(3);

if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);

const int height_out =
Expand Down Expand Up @@ -220,10 +220,10 @@ void modulated_deform_conv_backward(
const int kernel_h_ = weight.size(2);
const int kernel_w_ = weight.size(3);
if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
AT_ERROR("Input shape and kernel shape won't match: (%d x %d vs %d x %d).",
kernel_h_, kernel_w, kernel_h_, kernel_w_);
if (channels != channels_kernel * group)
AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
AT_ERROR("Input shape and kernel channels won't match: (%d vs %d).",
channels, channels_kernel * group);

const int height_out =
Expand Down
34 changes: 34 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/knn_cuda.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) OpenMMLab. All rights reserved
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap

#include <cmath>
#include <cstdio>

#include "knn_cuda_kernel.cuh"
#include "pytorch_cuda_helper.hpp"

void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
const Tensor xyz, const Tensor new_xyz,
Tensor idx, Tensor dist2) {
// param new_xyz: (B, m, 3)
// param xyz: (B, n, 3)
// param idx: (B, m, nsample)

at::cuda::CUDAGuard device_guard(new_xyz.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// blockIdx.x(col), blockIdx.y(row)
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), b);
dim3 threads(THREADS_PER_BLOCK);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
new_xyz.scalar_type(), "knn_forward_cuda_kernel", [&] {
knn_forward_cuda_kernel<scalar_t><<<blocks, threads, 0, stream>>>(
b, n, m, nsample, xyz.data_ptr<scalar_t>(),
new_xyz.data_ptr<scalar_t>(), idx.data_ptr<int>(),
dist2.data_ptr<scalar_t>());
});

AT_CUDA_CHECK(cudaGetLastError());
}
33 changes: 33 additions & 0 deletions mmcv/ops/csrc/pytorch/knn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Modified from
// https://github.com/CVMI-Lab/PAConv/tree/main/scene_seg/lib/pointops/src/knnquery_heap

#include "pytorch_cpp_helper.hpp"

#ifdef MMCV_WITH_CUDA
void KNNForwardCUDAKernelLauncher(int b, int n, int m, int nsample,
const Tensor xyz, const Tensor new_xyz,
Tensor idx, Tensor dist2);

void knn_forward_cuda(int b, int n, int m, int nsample, const Tensor xyz,
const Tensor new_xyz, Tensor idx, Tensor dist2) {
KNNForwardCUDAKernelLauncher(b, n, m, nsample, xyz, new_xyz, idx, dist2);
}
#endif

void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor,
Tensor dist2_tensor) {
if (new_xyz_tensor.device().is_cuda()) {
#ifdef MMCV_WITH_CUDA
CHECK_CUDA_INPUT(new_xyz_tensor);
CHECK_CUDA_INPUT(xyz_tensor);

knn_forward_cuda(b, n, m, nsample, xyz_tensor, new_xyz_tensor, idx_tensor,
dist2_tensor);
#else
AT_ERROR("knn is not compiled with GPU support");
#endif
} else {
AT_ERROR("knn is not implemented on CPU");
}
}
7 changes: 7 additions & 0 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,9 @@ void softmax_focal_loss_backward(Tensor input, Tensor target, Tensor weight,
void bbox_overlaps(const Tensor bboxes1, const Tensor bboxes2, Tensor ious,
const int mode, const bool aligned, const int offset);

void knn_forward(int b, int n, int m, int nsample, Tensor xyz_tensor,
Tensor new_xyz_tensor, Tensor idx_tensor, Tensor dist2_tensor);

void furthest_point_sampling_forward(int b, int n, int m, Tensor points_tensor,
Tensor temp_tensor, Tensor idx_tensor);

Expand Down Expand Up @@ -256,6 +259,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version");
m.def("get_compiling_cuda_version", &get_compiling_cuda_version,
"get_compiling_cuda_version");
m.def("knn_forward", &knn_forward, "knn_forward", py::arg("b"), py::arg("n"),
py::arg("m"), py::arg("nsample"), py::arg("xyz_tensor"),
py::arg("new_xyz_tensor"), py::arg("idx_tensor"),
py::arg("dist2_tensor"));
m.def("carafe_naive_forward", &carafe_naive_forward, "carafe_naive_forward",
py::arg("features"), py::arg("masks"), py::arg("output"),
py::arg("kernel_size"), py::arg("group_size"), py::arg("scale_factor"));
Expand Down
75 changes: 75 additions & 0 deletions mmcv/ops/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
from torch.autograd import Function

from ..utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['knn_forward'])


class KNN(Function):
r"""KNN (CUDA) based on heap data structure.
Modified from `PAConv <https://github.com/CVMI-Lab/PAConv/tree/main/
scene_seg/lib/pointops/src/knnquery_heap>`_.

Find k-nearest points.
"""

@staticmethod
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor = None,
transposed: bool = False) -> torch.Tensor:
"""
Args:
k (int): number of nearest neighbors.
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
xyz coordinates of the features.
center_xyz (Tensor, optional): (B, npoint, 3) if transposed ==
False, else (B, 3, npoint). centers of the knn query.
Default: None.
transposed (bool, optional): whether the input tensors are
transposed. Should not explicitly use this keyword when
calling knn (=KNN.apply), just add the fourth param.
Default: False.

Returns:
Tensor: (B, k, npoint) tensor with the indices of
the features that form k-nearest neighbours.
"""
assert (k > 0) & (k < 100), 'k should be in range(0, 100)'

if center_xyz is None:
center_xyz = xyz

if transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()

assert xyz.is_contiguous() # [B, N, 3]
assert center_xyz.is_contiguous() # [B, npoint, 3]

center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)

B, npoint, _ = center_xyz.shape
N = xyz.shape[1]

idx = center_xyz.new_zeros((B, npoint, k)).int()
dist2 = center_xyz.new_zeros((B, npoint, k)).float()

ext_module.knn_forward(B, N, npoint, k, xyz, center_xyz, idx, dist2)
# idx shape to [B, k, npoint]
idx = idx.transpose(2, 1).contiguous()
ctx.mark_non_differentiable(idx)
return idx

@staticmethod
def backward(ctx, a=None):
return None, None, None


knn = KNN.apply
54 changes: 54 additions & 0 deletions tests/test_ops/test_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import torch

from mmcv.ops import knn


@pytest.mark.skipif(
not torch.cuda.is_available(), reason='requires CUDA support')
def test_knn():
new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625],
[-2.2769, 2.7817, -0.2334],
[-0.4003, 2.4666, -0.5116],
[-0.0740, 1.3147, -1.3625],
[-0.0740, 1.3147, -1.3625]],
[[-2.0289, 2.4952, -0.1708],
[-2.0668, 6.0278, -0.4875],
[0.4066, 1.4211, -0.2947],
[-2.0289, 2.4952, -0.1708],
[-2.0289, 2.4952, -0.1708]]]).cuda()

xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634],
[-0.4003, 2.4666,
-0.5116], [-0.5251, 2.4379, -0.8466],
[-0.9691, 1.1418,
-1.3733], [-0.2232, 0.9561, -1.3626],
[-2.2769, 2.7817, -0.2334],
[-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432],
[0.4917, 1.1529, -1.3496]],
[[-2.0289, 2.4952,
-0.1708], [-0.7188, 0.9956, -0.5096],
[-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610],
[0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791],
[-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947],
[0.3220, 1.4447, 0.3548], [-0.9744, 2.3856,
-1.2000]]]).cuda()

idx = knn(5, xyz, new_xyz)
new_xyz_ = new_xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz_ = xyz.unsqueeze(1).repeat(1, new_xyz.shape[1], 1, 1)
dist = ((new_xyz_ - xyz_) * (new_xyz_ - xyz_)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)

idx = knn(5,
xyz.transpose(1, 2).contiguous(),
new_xyz.transpose(1, 2).contiguous(), True)
assert torch.all(idx == expected_idx)

idx = knn(5, xyz, xyz)
xyz_ = xyz.unsqueeze(2).repeat(1, 1, xyz.shape[1], 1)
xyz__ = xyz.unsqueeze(1).repeat(1, xyz.shape[1], 1, 1)
dist = ((xyz_ - xyz__) * (xyz_ - xyz__)).sum(-1)
expected_idx = dist.topk(k=5, dim=2, largest=False)[1].transpose(2, 1)
assert torch.all(idx == expected_idx)