Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix a bug that may cause compilation failure of dynamic voxelization when using GPUs with compute capability lower than 6.x #326

Merged
merged 3 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mmdet3d/ops/voxel/scatter_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(ctx, feats, coors, reduce_type='max'):
ctx.reduce_type = reduce_type
ctx.save_for_backward(feats, voxel_feats, point2voxel_map,
voxel_points_count)
ctx.mark_non_differentiable(voxel_coors)
return voxel_feats, voxel_coors

@staticmethod
Expand Down
69 changes: 40 additions & 29 deletions mmdet3d/ops/voxel/src/scatter_points_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#include "voxelization.h"
#include <ATen/cuda/Exceptions.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/types.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>

typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
Expand Down Expand Up @@ -66,7 +70,6 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
old = atomicCAS(address_as_ull, assumed,
__double_as_longlong(val + __longlong_as_double(assumed)));
} while (assumed != old);
return __longlong_as_double(old);
#else
atomicAdd(address, val);
#endif
Expand Down Expand Up @@ -103,7 +106,7 @@ __global__ void coors_map_init_kernel(const int64_t *coors_id,
} else {
coors_map[0] = 0;
}
return;
continue;
}
auto left = coors_id[coors_id_argsort[x - 1]];
coors_map[x] = (left < here) ? 1 : 0;
Expand All @@ -121,7 +124,7 @@ feats_reduce_kernel(const T *feats, const T_int *coors, int32_t *coors_map,
x += gridDim.x * blockDim.x) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1)
return;
continue;

const T_int *coors_offset = coors + x * NDim;
T_int *out_coors_offset = out_coors + reduce_to * NDim;
Expand Down Expand Up @@ -155,7 +158,7 @@ __global__ void add_reduce_traceback_grad_kernel(
x += gridDim.x * blockDim.x) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) {
return;
continue;
}

const int input_offset = x * num_feats;
Expand Down Expand Up @@ -188,7 +191,7 @@ __global__ void max_reduce_traceback_scatter_idx_kernel(
const T *feats_offset = feats + input_offset;

if (reduce_to == -1) {
return;
continue;
}

const int reduced_offset = reduce_to * num_feats;
Expand Down Expand Up @@ -224,9 +227,9 @@ max_reduce_scatter_grad_kernel(T *grad_feats, const T *grad_reduced_feats,

namespace voxelization {

std::vector<torch::Tensor>
dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
const torch::Tensor &coors,
std::vector<at::Tensor>
dynamic_point_to_voxel_forward_gpu(const at::Tensor &feats,
const at::Tensor &coors,
const reduce_t reduce_type) {
CHECK_INPUT(feats);
CHECK_INPUT(coors);
Expand All @@ -235,17 +238,18 @@ dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
const int num_input = feats.size(0);
const int num_feats = feats.size(1);

auto coors_id = torch::empty({num_input}, coors.options().dtype(torch::kI64));
auto coors_id = at::empty({num_input}, coors.options().dtype(torch::kInt64));
auto coor_space_dim = coors.max_values(0) + 1;
auto coors_map_sorted =
torch::empty({num_input}, coors.options().dtype(torch::kI32));
at::empty({num_input}, coors.options().dtype(torch::kInt32));
auto coors_map =
torch::empty({num_input}, coors.options().dtype(torch::kI32));
auto num_coors = at::zeros({1}, coors.options().dtype(torch::kI32));
at::empty({num_input}, coors.options().dtype(torch::kInt32));
auto num_coors = at::zeros({1}, coors.options().dtype(torch::kInt32));

AT_DISPATCH_INTEGRAL_TYPES(
coors.scalar_type(), "coors_id_kernel", ([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
coors_id_kernel<<<blocks, threads>>>(
coors.data_ptr<scalar_t>(), coor_space_dim.data_ptr<scalar_t>(),
Expand All @@ -257,24 +261,25 @@ dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,

AT_DISPATCH_INTEGRAL_TYPES(
coors_id_argsort.scalar_type(), "coors_map_init_kernel", ([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
coors_map_init_kernel<<<blocks, threads>>>(
coors_id.data_ptr<int64_t>(), coors_id_argsort.data_ptr<scalar_t>(),
coors_map_sorted.data_ptr<int32_t>(), num_input);
}));
AT_CUDA_CHECK(cudaGetLastError());

coors_map_sorted = coors_map_sorted.cumsum(0, torch::kI32);
coors_map_sorted = coors_map_sorted.cumsum(0, torch::kInt32);
coors_map.index_put_(coors_id_argsort, coors_map_sorted);

const int num_coors_cpu =
coors_map_sorted[-1].cpu().data_ptr<int32_t>()[0] + 1;
auto out_coors = torch::empty({num_coors_cpu, NDim}, coors.options());
auto out_coors = at::empty({num_coors_cpu, NDim}, coors.options());
auto reduced_feats =
torch::empty({num_coors_cpu, num_feats}, feats.options());
at::empty({num_coors_cpu, num_feats}, feats.options());
auto reduce_count =
torch::zeros({num_coors_cpu}, coors.options().dtype(torch::kI32));
at::zeros({num_coors_cpu}, coors.options().dtype(torch::kInt32));

AT_DISPATCH_FLOATING_TYPES(
feats.scalar_type(), "feats_reduce_kernel", ([&] {
Expand All @@ -289,7 +294,8 @@ dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
reduced_feats.fill_(static_cast<F_t>(0));

dim3 blocks(
std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
feats_reduce_kernel<<<blocks, threads>>>(
feats.data_ptr<F_t>(), coors.data_ptr<I_t>(),
Expand All @@ -308,9 +314,9 @@ dynamic_point_to_voxel_forward_gpu(const torch::Tensor &feats,
}

void dynamic_point_to_voxel_backward_gpu(
torch::Tensor &grad_feats, const torch::Tensor &grad_reduced_feats,
const torch::Tensor &feats, const torch::Tensor &reduced_feats,
const torch::Tensor &coors_map, const torch::Tensor &reduce_count,
at::Tensor &grad_feats, const at::Tensor &grad_reduced_feats,
const at::Tensor &feats, const at::Tensor &reduced_feats,
const at::Tensor &coors_map, const at::Tensor &reduce_count,
const reduce_t reduce_type) {
CHECK_INPUT(grad_feats);
CHECK_INPUT(grad_reduced_feats);
Expand All @@ -330,7 +336,9 @@ void dynamic_point_to_voxel_backward_gpu(
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 blocks
(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
add_reduce_traceback_grad_kernel<<<blocks, threads>>>(
grad_feats.data_ptr<scalar_t>(),
Expand All @@ -340,12 +348,14 @@ void dynamic_point_to_voxel_backward_gpu(
}));
AT_CUDA_CHECK(cudaGetLastError());
} else {
auto reduce_from = torch::full({num_reduced, num_feats}, num_input,
coors_map.options().dtype(torch::kI32));
auto reduce_from = at::full({num_reduced, num_feats}, num_input,
coors_map.options().dtype(torch::kInt32));
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(std::min(DIVUP(num_input, threadsPerBlock), maxGridDim));
dim3 blocks
(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
max_reduce_traceback_scatter_idx_kernel<<<blocks, threads>>>(
feats.data_ptr<scalar_t>(), reduced_feats.data_ptr<scalar_t>(),
Expand All @@ -358,7 +368,8 @@ void dynamic_point_to_voxel_backward_gpu(
grad_reduced_feats.scalar_type(),
"max_reduce_traceback_scatter_idx_kernel", ([&] {
dim3 blocks(
std::min(DIVUP(num_reduced, threadsPerBlock), maxGridDim));
std::min(at::cuda::ATenCeilDiv(num_reduced, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
max_reduce_scatter_grad_kernel<<<blocks, threads>>>(
grad_feats.data_ptr<scalar_t>(),
Expand Down
3 changes: 1 addition & 2 deletions mmdet3d/ops/voxel/src/voxelization.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
#pragma once
#include <torch/extension.h>

typedef enum { SUM, MEAN, MAX } reduce_t;
#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
typedef enum { SUM = 0, MEAN = 1, MAX = 2 } reduce_t;

namespace voxelization {

Expand Down