Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] faster but nondeterministic version of hard voxelization #904

Merged
merged 7 commits into from
Sep 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions mmdet3d/ops/voxel/src/voxelization.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ int hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
const int max_points, const int max_voxels,
const int NDim = 3);

int nondisterministic_hard_voxelize_gpu(const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3);

void dynamic_voxelize_gpu(const at::Tensor &points, at::Tensor &coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
Expand All @@ -53,12 +60,17 @@ inline int hard_voxelize(const at::Tensor &points, at::Tensor &voxels,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {
const int NDim = 3, const bool deterministic = true) {
if (points.device().is_cuda()) {
#ifdef WITH_CUDA
return hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel,
voxel_size, coors_range, max_points, max_voxels,
NDim);
if (deterministic) {
return hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel,
voxel_size, coors_range, max_points, max_voxels,
NDim);
}
return nondisterministic_hard_voxelize_gpu(points, voxels, coors, num_points_per_voxel,
voxel_size, coors_range, max_points, max_voxels,
NDim);
#else
AT_ERROR("Not compiled with GPU support");
#endif
Expand Down
157 changes: 157 additions & 0 deletions mmdet3d/ops/voxel/src/voxelization_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,53 @@ __global__ void determin_voxel_num(
}
}

__global__ void nondisterministic_get_assign_pos(
const int nthreads, const int32_t *coors_map, int32_t *pts_id,
int32_t *coors_count, int32_t *reduce_count, int32_t *coors_order) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
if (coors_idx > -1) {
int32_t coors_pts_pos = atomicAdd(&reduce_count[coors_idx], 1);
pts_id[thread_idx] = coors_pts_pos;
if (coors_pts_pos == 0) {
coors_order[coors_idx] = atomicAdd(coors_count, 1);
}
}
}
}

template<typename T>
__global__ void nondisterministic_assign_point_voxel(
const int nthreads, const T *points, const int32_t *coors_map,
const int32_t *pts_id, const int32_t *coors_in,
const int32_t *reduce_count, const int32_t *coors_order,
T *voxels, int32_t *coors, int32_t *pts_count, const int max_voxels,
const int max_points, const int num_features, const int NDim) {
CUDA_1D_KERNEL_LOOP(thread_idx, nthreads) {
int coors_idx = coors_map[thread_idx];
int coors_pts_pos = pts_id[thread_idx];
if (coors_idx > -1) {
int coors_pos = coors_order[coors_idx];
if (coors_pos < max_voxels && coors_pts_pos < max_points) {
auto voxels_offset =
voxels + (coors_pos * max_points + coors_pts_pos) * num_features;
auto points_offset = points + thread_idx * num_features;
for (int k = 0; k < num_features; k++) {
voxels_offset[k] = points_offset[k];
}
if (coors_pts_pos == 0) {
pts_count[coors_pos] = min(reduce_count[coors_idx], max_points);
auto coors_offset = coors + coors_pos * NDim;
auto coors_in_offset = coors_in + coors_idx * NDim;
for (int k = 0; k < NDim; k++) {
coors_offset[k] = coors_in_offset[k];
}
}
}
}
}
}

namespace voxelization {

int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels,
Expand Down Expand Up @@ -325,6 +372,116 @@ int hard_voxelize_gpu(const at::Tensor& points, at::Tensor& voxels,
return voxel_num_int;
}

int nondisterministic_hard_voxelize_gpu(
const at::Tensor &points, at::Tensor &voxels,
at::Tensor &coors, at::Tensor &num_points_per_voxel,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
const int max_points, const int max_voxels,
const int NDim = 3) {

CHECK_INPUT(points);

at::cuda::CUDAGuard device_guard(points.device());

const int num_points = points.size(0);
const int num_features = points.size(1);

if (num_points == 0)
return 0;

const float voxel_x = voxel_size[0];
const float voxel_y = voxel_size[1];
const float voxel_z = voxel_size[2];
const float coors_x_min = coors_range[0];
const float coors_y_min = coors_range[1];
const float coors_z_min = coors_range[2];
const float coors_x_max = coors_range[3];
const float coors_y_max = coors_range[4];
const float coors_z_max = coors_range[5];

const int grid_x = round((coors_x_max - coors_x_min) / voxel_x);
const int grid_y = round((coors_y_max - coors_y_min) / voxel_y);
const int grid_z = round((coors_z_max - coors_z_min) / voxel_z);

// map points to voxel coors
at::Tensor temp_coors =
at::zeros({num_points, NDim}, points.options().dtype(torch::kInt32));

dim3 grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 block(512);

// 1. link point to corresponding voxel coors
AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "hard_voxelize_kernel", ([&] {
dynamic_voxelize_kernel<scalar_t, int>
<<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
points.contiguous().data_ptr<scalar_t>(),
temp_coors.contiguous().data_ptr<int>(), voxel_x, voxel_y,
voxel_z, coors_x_min, coors_y_min, coors_z_min, coors_x_max,
coors_y_max, coors_z_max, grid_x, grid_y, grid_z, num_points,
num_features, NDim);
}));

at::Tensor coors_map;
at::Tensor coors_count;
at::Tensor coors_order;
at::Tensor reduce_count;
at::Tensor pts_id;

auto coors_clean = temp_coors.masked_fill(temp_coors.lt(0).any(-1, true), -1);

std::tie(temp_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, false);

if (temp_coors.index({0, 0}).lt(0).item<bool>()) {
// the first element of temp_coors is (-1,-1,-1) and should be removed
temp_coors = temp_coors.slice(0, 1);
coors_map = coors_map - 1;
}

int num_coors = temp_coors.size(0);
temp_coors = temp_coors.to(torch::kInt32);
coors_map = coors_map.to(torch::kInt32);

coors_count = coors_map.new_zeros(1);
coors_order = coors_map.new_empty(num_coors);
reduce_count = coors_map.new_zeros(num_coors);
pts_id = coors_map.new_zeros(num_points);

dim3 cp_grid(std::min(at::cuda::ATenCeilDiv(num_points, 512), 4096));
dim3 cp_block(512);
AT_DISPATCH_ALL_TYPES(points.scalar_type(), "get_assign_pos", ([&] {
nondisterministic_get_assign_pos<<<cp_grid, cp_block, 0,
at::cuda::getCurrentCUDAStream()>>>(
num_points,
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
coors_count.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>());
}));

AT_DISPATCH_ALL_TYPES(
points.scalar_type(), "assign_point_to_voxel", ([&] {
nondisterministic_assign_point_voxel<scalar_t>
<<<cp_grid, cp_block, 0, at::cuda::getCurrentCUDAStream()>>>(
num_points, points.contiguous().data_ptr<scalar_t>(),
coors_map.contiguous().data_ptr<int32_t>(),
pts_id.contiguous().data_ptr<int32_t>(),
temp_coors.contiguous().data_ptr<int32_t>(),
reduce_count.contiguous().data_ptr<int32_t>(),
coors_order.contiguous().data_ptr<int32_t>(),
voxels.contiguous().data_ptr<scalar_t>(),
coors.contiguous().data_ptr<int32_t>(),
num_points_per_voxel.contiguous().data_ptr<int32_t>(),
max_voxels, max_points,
num_features, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
return max_voxels < num_coors ? max_voxels : num_coors;
}

void dynamic_voxelize_gpu(const at::Tensor& points, at::Tensor& coors,
const std::vector<float> voxel_size,
const std::vector<float> coors_range,
Expand Down
34 changes: 30 additions & 4 deletions mmdet3d/ops/voxel/voxelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def forward(ctx,
voxel_size,
coors_range,
max_points=35,
max_voxels=20000):
max_voxels=20000,
deterministic=True):
"""convert kitti points(N, >=3) to voxels.

Args:
Expand All @@ -30,6 +31,16 @@ def forward(ctx,
max_voxels: int. indicate maximum voxels this function create.
for second, 20000 is a good choice. Users should shuffle points
before call this function because max_voxels may drop points.
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.

Returns:
voxels: [M, max_points, ndim] float tensor. only contain points
Expand All @@ -50,7 +61,8 @@ def forward(ctx,
size=(max_voxels, ), dtype=torch.int)
voxel_num = hard_voxelize(points, voxels, coors,
num_points_per_voxel, voxel_size,
coors_range, max_points, max_voxels, 3)
coors_range, max_points, max_voxels, 3,
deterministic)
# select the valid voxels
voxels_out = voxels[:voxel_num]
coors_out = coors[:voxel_num]
Expand All @@ -67,7 +79,8 @@ def __init__(self,
voxel_size,
point_cloud_range,
max_num_points,
max_voxels=20000):
max_voxels=20000,
deterministic=True):
super(Voxelization, self).__init__()
"""
Args:
Expand All @@ -77,6 +90,16 @@ def __init__(self,
max_num_points (int): max number of points per voxel
max_voxels (tuple or int): max number of voxels in
(training, testing) time
deterministic: bool. whether to invoke the non-deterministic
version of hard-voxelization implementations. non-deterministic
version is considerablly fast but is not deterministic. only
affects hard voxelization. default True. for more information
of this argument and the implementation insights, please refer
to the following links:
https://github.com/open-mmlab/mmdetection3d/issues/894
https://github.com/open-mmlab/mmdetection3d/pull/904
it is an experimental feature and we will appreciate it if
you could share with us the failing cases.
"""
self.voxel_size = voxel_size
self.point_cloud_range = point_cloud_range
Expand All @@ -85,6 +108,7 @@ def __init__(self,
self.max_voxels = max_voxels
else:
self.max_voxels = _pair(max_voxels)
self.deterministic = deterministic

point_cloud_range = torch.tensor(
point_cloud_range, dtype=torch.float32)
Expand All @@ -110,13 +134,15 @@ def forward(self, input):
max_voxels = self.max_voxels[1]

return voxelization(input, self.voxel_size, self.point_cloud_range,
self.max_num_points, max_voxels)
self.max_num_points, max_voxels,
self.deterministic)

def __repr__(self):
tmpstr = self.__class__.__name__ + '('
tmpstr += 'voxel_size=' + str(self.voxel_size)
tmpstr += ', point_cloud_range=' + str(self.point_cloud_range)
tmpstr += ', max_num_points=' + str(self.max_num_points)
tmpstr += ', max_voxels=' + str(self.max_voxels)
tmpstr += ', deterministic=' + str(self.deterministic)
tmpstr += ')'
return tmpstr
81 changes: 81 additions & 0 deletions tests/test_models/test_voxel_encoder/test_voxelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,84 @@ def test_voxelization():
assert np.all(
points[indices] == expected_coors[i][:num_points_current_voxel])
assert num_points_current_voxel == expected_num_points_per_voxel[i]


def test_voxelization_nondeterministic():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')

voxel_size = [0.5, 0.5, 0.5]
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
data_path = './tests/data/kitti/training/velodyne_reduced/000000.bin'
load_points_from_file = LoadPointsFromFile(
coord_type='LIDAR', load_dim=4, use_dim=4)
results = dict()
results['pts_filename'] = data_path
results = load_points_from_file(results)
points = results['points'].tensor.numpy()

points = torch.tensor(points)
max_num_points = -1
dynamic_voxelization = Voxelization(voxel_size, point_cloud_range,
max_num_points)

max_num_points = 10
max_voxels = 50
hard_voxelization = Voxelization(
voxel_size,
point_cloud_range,
max_num_points,
max_voxels,
deterministic=False)

# test hard_voxelization (non-deterministic version) on gpu
points = torch.tensor(points).contiguous().to(device='cuda:0')
voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()
voxels = voxels.cpu().detach().numpy().tolist()
num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy().tolist()

coors_all = dynamic_voxelization.forward(points)
coors_all = coors_all.cpu().detach().numpy().tolist()

coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])

assert len(coors_set) == len(coors)
assert len(coors_set - coors_all_set) == 0

points = points.cpu().detach().numpy().tolist()

coors_points_dict = {}
for c, ps in zip(coors_all, points):
if tuple(c) not in coors_points_dict:
coors_points_dict[tuple(c)] = set()
coors_points_dict[tuple(c)].add(tuple(ps))

for c, ps, n in zip(coors, voxels, num_points_per_voxel):
ideal_voxel_points_set = coors_points_dict[tuple(c)]
voxel_points_set = set([tuple(p) for p in ps[:n]])
assert len(voxel_points_set) == n
if n < max_num_points:
assert voxel_points_set == ideal_voxel_points_set
for p in ps[n:]:
assert max(p) == min(p) == 0
else:
assert len(voxel_points_set - ideal_voxel_points_set) == 0

# test hard_voxelization (non-deterministic version) on gpu
# with all input point in range
points = torch.tensor(points).contiguous().to(device='cuda:0')[:max_voxels]
coors_all = dynamic_voxelization.forward(points)
valid_mask = coors_all.ge(0).all(-1)
points = points[valid_mask]
coors_all = coors_all[valid_mask]
coors_all = coors_all.cpu().detach().numpy().tolist()

voxels, coors, num_points_per_voxel = hard_voxelization.forward(points)
coors = coors.cpu().detach().numpy().tolist()

coors_set = set([tuple(c) for c in coors])
coors_all_set = set([tuple(c) for c in coors_all])

assert len(coors_set) == len(coors) == len(coors_all_set)