Skip to content

Commit

Permalink
[Feature] Add nondeterministic voxelization op from mmdet3d (#1783)
Browse files Browse the repository at this point in the history
* add nondeterministic voxelization op

* fix lint

* fix lint

* resolve comments

* fix lint
  • Loading branch information
wHao-Wu authored Mar 15, 2022
1 parent 33e14de commit b5d550f
Show file tree
Hide file tree
Showing 7 changed files with 311 additions and 10 deletions.
47 changes: 47 additions & 0 deletions mmcv/ops/csrc/common/cuda/voxelization_cuda_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,51 @@ __global__ void determin_voxel_num(
}
}

__global__ void nondeterministic_get_assign_pos(
const int nthreads, const int32_t* coors_map, int32_t* pts_id,
int32_t* coors_count, int32_t* reduce_count, int32_t* coors_order) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}

template <typename T>
__global__ void nondeterministic_assign_point_voxel(
const int nthreads, const T* points, const int32_t* coors_map,
const int32_t* pts_id, const int32_t* coors_in, const int32_t* reduce_count,
const int32_t* coors_order, T* voxels, int32_t* coors, int32_t* pts_count,
const int max_voxels, const int max_points, const int num_features,
const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1 && coors_pts_pos < max_points) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}

#endif // VOXELIZATION_CUDA_KERNEL_CUH
24 changes: 24 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/cudabind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1396,6 +1396,12 @@ int HardVoxelizeForwardCUDAKernelLauncher(
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);

int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3);

void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
Expand All @@ -1413,6 +1419,16 @@ int hard_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& voxels,
max_points, max_voxels, NDim);
};

int nondeterministic_hard_voxelize_forward_cuda(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim) {
return NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
points, voxels, coors, num_points_per_voxel, voxel_size, coors_range,
max_points, max_voxels, NDim);
};

void dynamic_voxelize_forward_cuda(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
Expand All @@ -1429,13 +1445,21 @@ int hard_voxelize_forward_impl(const at::Tensor& points, at::Tensor& voxels,
const int max_points, const int max_voxels,
const int NDim);

int nondeterministic_hard_voxelize_forward_impl(
const at::Tensor& points, at::Tensor& voxels, at::Tensor& coors,
at::Tensor& num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim);

void dynamic_voxelize_forward_impl(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int NDim);

REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, CUDA,
hard_voxelize_forward_cuda);
REGISTER_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl, CUDA,
nondeterministic_hard_voxelize_forward_cuda);
REGISTER_DEVICE_IMPL(dynamic_voxelize_forward_impl, CUDA,
dynamic_voxelize_forward_cuda);

Expand Down
98 changes: 98 additions & 0 deletions mmcv/ops/csrc/pytorch/cuda/voxelization_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,104 @@ int HardVoxelizeForwardCUDAKernelLauncher(
return voxel_num_int;
}

int NondeterministicHardVoxelizeForwardCUDAKernelLauncher(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
at::cuda::CUDAGuard device_guard(points.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

const int num_points = points.size(0);
const int num_features = points.size(1);

if (num_points == 0) return 0;

dim3 blocks(
std::min(at::cuda::ATenCeilDiv(num_points, THREADS_PER_BLOCK), 4096));
dim3 threads(THREADS_PER_BLOCK);

const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];

const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);

// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(at::kInt));

// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int><<<blocks, threads, 0, stream>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y, voxel_z,
coors_x_min, coors_y_min, coors_z_min, coors_x_max, coors_y_max,
coors_z_max, grid_x, grid_y, grid_z, num_points, num_features,
NDim);
}));

at::Tensor coors_map;
at::Tensor reduce_count;

auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);

std::tie(temp_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, false);

if (temp_coors[0][0].lt(0).item<bool>()) {
// the first element of temp_coors is (-1,-1,-1) and should be removed
temp_coors = temp_coors.slice(0, 1);
coors_map = coors_map - 1;
}

int num_coors = temp_coors.size(0);
temp_coors = temp_coors.to(at::kInt);
coors_map = coors_map.to(at::kInt);

at::Tensor coors_count = at::zeros({1}, coors_map.options());
at::Tensor coors_order = at::empty({num_coors}, coors_map.options());
at::Tensor pts_id = at::zeros({num_points}, coors_map.options());
reduce_count = at::zeros({num_coors}, coors_map.options());

AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "get_assign_pos", ([&] {
nondeterministic_get_assign_pos<<<blocks, threads, 0, stream>>>(
num_points, coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
coors_count.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>());
}));

AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
nondeterministic_assign_point_voxel<scalar_t>
<<<blocks, threads, 0, stream>>>(
num_points, points.contiguous().data_ptr<scalar_t>(),
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
temp_coors.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>(),
voxels.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int32_t>(),
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
max_voxels, max_points, num_features, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
return max_voxels < num_coors ? max_voxels : num_coors;
}

void DynamicVoxelizeForwardCUDAKernelLauncher(
const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size, const std::vector<float> coors_range,
Expand Down
6 changes: 4 additions & 2 deletions mmcv/ops/csrc/pytorch/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,8 @@ void hard_voxelize_forward(const at::Tensor &points,
const at::Tensor &coors_range, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
at::Tensor &voxel_num, const int max_points,
const int max_voxels, const int NDim);
const int max_voxels, const int NDim,
const bool deterministic);

void dynamic_voxelize_forward(const at::Tensor &points,
const at::Tensor &voxel_size,
Expand Down Expand Up @@ -756,7 +757,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"hard_voxelize_forward", py::arg("points"), py::arg("voxel_size"),
py::arg("coors_range"), py::arg("voxels"), py::arg("coors"),
py::arg("num_points_per_voxel"), py::arg("voxel_num"),
py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"));
py::arg("max_points"), py::arg("max_voxels"), py::arg("NDim"),
py::arg("deterministic"));
m.def("dynamic_voxelize_forward", &dynamic_voxelize_forward,
"dynamic_voxelize_forward", py::arg("points"), py::arg("voxel_size"),
py::arg("coors_range"), py::arg("coors"), py::arg("NDim"));
Expand Down
26 changes: 22 additions & 4 deletions mmcv/ops/csrc/pytorch/voxelization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels,
max_points, max_voxels, NDim);
}

int nondeterministic_hard_voxelize_forward_impl(
const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors,
at::Tensor &num_points_per_voxel, const std::vector<float> voxel_size,
const std::vector<float> coors_range, const int max_points,
const int max_voxels, const int NDim = 3) {
return DISPATCH_DEVICE_IMPL(nondeterministic_hard_voxelize_forward_impl,
points, voxels, coors, num_points_per_voxel,
voxel_size, coors_range, max_points, max_voxels,
NDim);
}

void dynamic_voxelize_forward_impl(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
Expand All @@ -27,7 +38,8 @@ void hard_voxelize_forward(const at::Tensor &points,
const at::Tensor &coors_range, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
at::Tensor &voxel_num, const int max_points,
const int max_voxels, const int NDim = 3) {
const int max_voxels, const int NDim = 3,
const bool deterministic = true) {
int64_t *voxel_num_data = voxel_num.data_ptr<int64_t>();
std::vector<float> voxel_size_v(
voxel_size.data_ptr<float>(),
Expand All @@ -36,9 +48,15 @@ void hard_voxelize_forward(const at::Tensor &points,
coors_range.data_ptr<float>(),
coors_range.data_ptr<float>() + coors_range.numel());

*voxel_num_data = hard_voxelize_forward_impl(
points, voxels, coors, num_points_per_voxel, voxel_size_v, coors_range_v,
max_points, max_voxels, NDim);
if (deterministic) {
*voxel_num_data = hard_voxelize_forward_impl(
points, voxels, coors, num_points_per_voxel, voxel_size_v,
coors_range_v, max_points, max_voxels, NDim);
} else {
*voxel_num_data = nondeterministic_hard_voxelize_forward_impl(
points, voxels, coors, num_points_per_voxel, voxel_size_v,
coors_range_v, max_points, max_voxels, NDim);
}
}

void dynamic_voxelize_forward(const at::Tensor &points,
Expand Down
43 changes: 39 additions & 4 deletions mmcv/ops/voxelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def forward(ctx,
voxel_size,
coors_range,
max_points=35,
max_voxels=20000):
max_voxels=20000,
deterministic=True):
"""Convert kitti points(N, >=3) to voxels.
Args:
Expand All @@ -34,6 +35,16 @@ def forward(ctx,
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
Default: 20000.
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
Returns:
tuple[torch.Tensor]: tuple[torch.Tensor]: A tuple contains three
Expand Down Expand Up @@ -69,7 +80,8 @@ def forward(ctx,
voxel_num,
max_points=max_points,
max_voxels=max_voxels,
NDim=3)
NDim=3,
deterministic=deterministic)
# select the valid voxels
voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num]
Expand Down Expand Up @@ -102,7 +114,27 @@ def __init__(self,
voxel_size,
point_cloud_range,
max_num_points,
max_voxels=20000):
max_voxels=20000,
deterministic=True):
"""
Args:
voxel_size (list): list [x, y, z] size of three dimension
point_cloud_range (list):
[x_min, y_min, z_min, x_max, y_max, z_max]
max_num_points (int): max number of points per voxel
max_voxels (tuple or int): max number of voxels in
(training, testing) time
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
"""
super().__init__()

self.voxel_size = voxel_size
Expand All @@ -112,6 +144,7 @@ def __init__(self,
self.max_voxels = max_voxels
else:
self.max_voxels = _pair(max_voxels)
self.deterministic = deterministic

point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
Expand All @@ -132,13 +165,15 @@ def forward(self, input):
max_voxels = self.max_voxels[1]

return voxelization(input, self.voxel_size, self.point_cloud_range,
self.max_num_points, max_voxels)
self.max_num_points, max_voxels,
self.deterministic)

def __repr__(self):
s = self.__class__.__name__ + '('
s += 'voxel_size=' + str(self.voxel_size)
s += ', point_cloud_range=' + str(self.point_cloud_range)
s += ', max_num_points=' + str(self.max_num_points)
s += ', max_voxels=' + str(self.max_voxels)
s += ', deterministic=' + str(self.deterministic)
s += ')'
return s
Loading

0 comments on commit b5d550f

Please sign in to comment.