Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support knn gpu op #360

Merged
merged 2 commits into from
Mar 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmdet3d/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .group_points import (GroupAll, QueryAndGroup, group_points,
grouping_operation)
from .interpolate import three_interpolate, three_nn
from .knn import knn
from .norm import NaiveSyncBatchNorm1d, NaiveSyncBatchNorm2d
from .pointnet_modules import (PointFPModule, PointSAModule, PointSAModuleMSG,
build_sa_module)
Expand All @@ -25,7 +26,7 @@
'dynamic_scatter', 'DynamicScatter', 'sigmoid_focal_loss',
'SigmoidFocalLoss', 'SparseBasicBlock', 'SparseBottleneck',
'RoIAwarePool3d', 'points_in_boxes_gpu', 'points_in_boxes_cpu',
'make_sparse_convmodule', 'ball_query', 'furthest_point_sample',
'make_sparse_convmodule', 'ball_query', 'knn', 'furthest_point_sample',
'furthest_point_sample_with_dist', 'three_interpolate', 'three_nn',
'gather_points', 'grouping_operation', 'group_points', 'GroupAll',
'QueryAndGroup', 'PointSAModule', 'PointSAModuleMSG', 'PointFPModule',
Expand Down
14 changes: 7 additions & 7 deletions mmdet3d/ops/gather_points/gather_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,28 @@ class GatherPoints(Function):

@staticmethod
def forward(ctx, features: torch.Tensor,
indicies: torch.Tensor) -> torch.Tensor:
indices: torch.Tensor) -> torch.Tensor:
"""forward.

Args:
features (Tensor): (B, C, N) features to gather.
indicies (Tensor): (B, M) where M is the number of points.
indices (Tensor): (B, M) where M is the number of points.

Returns:
Tensor: (B, C, M) where M is the number of points.
"""
assert features.is_contiguous()
assert indicies.is_contiguous()
assert indices.is_contiguous()

B, npoint = indicies.size()
B, npoint = indices.size()
_, C, N = features.size()
output = torch.cuda.FloatTensor(B, C, npoint)

gather_points_ext.gather_points_wrapper(B, C, N, npoint, features,
indicies, output)
indices, output)

ctx.for_backwards = (indicies, C, N)
ctx.mark_non_differentiable(indicies)
ctx.for_backwards = (indices, C, N)
ctx.mark_non_differentiable(indices)
return output

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions mmdet3d/ops/knn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .knn import knn

__all__ = ['knn']
68 changes: 68 additions & 0 deletions mmdet3d/ops/knn/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
from torch.autograd import Function

from . import knn_ext


class KNN(Function):
"""KNN (CUDA).

Copy link
Collaborator

@ZwwWayne ZwwWayne Mar 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to ref the original implementation if necessary.

Find k-nearest points.
"""

@staticmethod
def forward(ctx,
k: int,
xyz: torch.Tensor,
center_xyz: torch.Tensor,
transposed: bool = False) -> torch.Tensor:
"""forward.

Args:
k (int): number of nearest neighbors.
xyz (Tensor): (B, N, 3) if transposed == False, else (B, 3, N).
xyz coordinates of the features.
center_xyz (Tensor): (B, npoint, 3) if transposed == False,
else (B, 3, npoint). centers of the knn query.
transposed (bool): whether the input tensors are transposed.
defaults to False.

Returns:
Tensor: (B, k, npoint) tensor with the indicies of
the features that form k-nearest neighbours.
"""
assert k > 0

B, npoint = center_xyz.shape[:2]
N = xyz.shape[1]

if not transposed:
xyz = xyz.transpose(2, 1).contiguous()
center_xyz = center_xyz.transpose(2, 1).contiguous()

assert center_xyz.is_contiguous()
assert xyz.is_contiguous()

center_xyz_device = center_xyz.get_device()
assert center_xyz_device == xyz.get_device(), \
'center_xyz and xyz should be put on the same device'
if torch.cuda.current_device() != center_xyz_device:
torch.cuda.set_device(center_xyz_device)

idx = center_xyz.new_zeros((B, k, npoint)).long()

for bi in range(B):
knn_ext.knn_wrapper(xyz[bi], N, center_xyz[bi], npoint, idx[bi], k)

ctx.mark_non_differentiable(idx)

idx -= 1

return idx

@staticmethod
def backward(ctx, a=None):
return None, None


knn = KNN.apply
62 changes: 62 additions & 0 deletions mmdet3d/ops/knn/src/knn.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Modified from https://github.com/unlimblue/KNN_CUDA

#include <vector>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_TYPE(x, t) AT_ASSERTM(x.dtype() == t, #x " must be " #t)
#define CHECK_CUDA(x) AT_ASSERTM(x.device().type() == at::Device::Type::CUDA, #x " must be on CUDA")
#define CHECK_INPUT(x, t) CHECK_CONTIGUOUS(x); CHECK_TYPE(x, t); CHECK_CUDA(x)


void knn_kernels_launcher(
const float* ref_dev,
int ref_nb,
const float* query_dev,
int query_nb,
int dim,
int k,
float* dist_dev,
long* ind_dev,
cudaStream_t stream
);

// std::vector<at::Tensor> knn_wrapper(
void knn_wrapper(
at::Tensor & ref,
int ref_nb,
at::Tensor & query,
int query_nb,
at::Tensor & ind,
const int k
) {

CHECK_INPUT(ref, at::kFloat);
CHECK_INPUT(query, at::kFloat);
const float * ref_dev = ref.data_ptr<float>();
const float * query_dev = query.data_ptr<float>();
int dim = query.size(0);
auto dist = at::empty({ref_nb, query_nb}, query.options().dtype(at::kFloat));
float * dist_dev = dist.data_ptr<float>();
long * ind_dev = ind.data_ptr<long>();

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

knn_kernels_launcher(
ref_dev,
ref_nb,
query_dev,
query_nb,
dim,
k,
dist_dev,
ind_dev,
stream
);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("knn_wrapper", &knn_wrapper, "knn_wrapper");
}
Loading