Skip to content

Commit

Permalink
[Fix] fix dynamic_scatter 'invalid configuration argument error' trig…
Browse files Browse the repository at this point in the history
…gered by empty point input

- fix 'invalid configuration argument' error triggered by empty point input. test cases covering similar situations are added to test_dynamic_scatter.py as well.
trivial changes:
- switch to using torch::unique_dim to generate reduce mapping instead of calculating it from scratch.
  • Loading branch information
zhanggefan authored Apr 7, 2021
1 parent 1e7059f commit 3a5a201
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 117 deletions.
157 changes: 41 additions & 116 deletions mmdet3d/ops/voxel/src/scatter_points_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,71 +77,24 @@ __device__ __forceinline__ static void reduceAdd(double *address, double val) {
}
#endif

template <typename T_int>
__global__ void coors_id_kernel(const T_int *coors, const T_int *dim,
int64_t *coors_id, const int num_input,
const int NDim) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
const T_int *coor_x = coors + x * NDim;
auto coor_id = 0;
for (int i = 0; i < NDim && coor_id != -1; i++) {
coor_id *= dim[i];
auto t = static_cast<int64_t>(coor_x[i]);
coor_id = (t < 0) ? -1 : coor_id + t;
}
coors_id[x] = coor_id;
}
}

template <typename T_int>
__global__ void coors_map_init_kernel(const int64_t *coors_id,
const T_int *coors_id_argsort,
int32_t *coors_map, const int num_input) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
auto here = coors_id[coors_id_argsort[x]];
if (x == 0) {
if (here == -1) { // there is invalid points
coors_map[0] = -1;
} else {
coors_map[0] = 0;
}
continue;
}
auto left = coors_id[coors_id_argsort[x - 1]];
coors_map[x] = (left < here) ? 1 : 0;
}
}

template <typename T, typename T_int>
__global__ void feats_reduce_kernel(
const T *feats, const T_int *coors, int32_t *coors_map,
int32_t *reduce_count, // shall be 0 at initialization
T *reduced_feats, // shall be 0 at initialization
T_int *out_coors, const int num_input, const int num_feats, const int NDim,
const reduce_t reduce_type) {
template <typename T>
__global__ void
feats_reduce_kernel(const T *feats, const int32_t *coors_map,
T *reduced_feats, // shall be 0 at initialization
const int num_input, const int num_feats,
const reduce_t reduce_type) {
for (int x = blockIdx.x * blockDim.x + threadIdx.x; x < num_input;
x += gridDim.x * blockDim.x) {
int32_t reduce_to = coors_map[x];
if (reduce_to == -1) continue;

const T_int *coors_offset = coors + x * NDim;
T_int *out_coors_offset = out_coors + reduce_to * NDim;
for (int i = 0; i < NDim; i++) {
out_coors_offset[i] = coors_offset[i];
}

const T *feats_offset = feats + x * num_feats;
T *reduced_feats_offset = reduced_feats + reduce_to * num_feats;
if (reduce_type == reduce_t::MAX) {
for (int i = 0; i < num_feats; i++) {
reduceMax(&reduced_feats_offset[i], feats_offset[i]);
}
} else {
if (reduce_type == reduce_t::MEAN) {
atomicAdd(&reduce_count[reduce_to], static_cast<int32_t>(1));
}
for (int i = 0; i < num_feats; i++) {
reduceAdd(&reduced_feats_offset[i], feats_offset[i]);
}
Expand Down Expand Up @@ -233,78 +186,48 @@ std::vector<at::Tensor> dynamic_point_to_voxel_forward_gpu(
CHECK_INPUT(feats);
CHECK_INPUT(coors);

const int NDim = coors.size(1);
const int num_input = feats.size(0);
const int num_feats = feats.size(1);

auto coors_id = at::empty({num_input}, coors.options().dtype(torch::kInt64));
auto coor_space_dim = std::get<0>(coors.max(0)) + 1;
auto coors_map_sorted =
at::empty({num_input}, coors.options().dtype(torch::kInt32));
auto coors_map = at::empty({num_input}, coors.options().dtype(torch::kInt32));
auto num_coors = at::zeros({1}, coors.options().dtype(torch::kInt32));

AT_DISPATCH_INTEGRAL_TYPES(
coors.scalar_type(), "coors_id_kernel", ([&] {
dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
coors_id_kernel<<<blocks, threads>>>(
coors.data_ptr<scalar_t>(), coor_space_dim.data_ptr<scalar_t>(),
coors_id.data_ptr<int64_t>(), num_input, NDim);
}));
AT_CUDA_CHECK(cudaGetLastError());
if (num_input == 0)
return {feats.clone().detach(),
coors.clone().detach(),
coors.new_empty({0}, torch::kInt32),
coors.new_empty({0}, torch::kInt32)};

auto coors_id_argsort = coors_id.argsort();

AT_DISPATCH_INTEGRAL_TYPES(
coors_id_argsort.scalar_type(), "coors_map_init_kernel", ([&] {
dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
coors_map_init_kernel<<<blocks, threads>>>(
coors_id.data_ptr<int64_t>(), coors_id_argsort.data_ptr<scalar_t>(),
coors_map_sorted.data_ptr<int32_t>(), num_input);
}));
AT_CUDA_CHECK(cudaGetLastError());
at::Tensor out_coors;
at::Tensor coors_map;
at::Tensor reduce_count;

coors_map_sorted = coors_map_sorted.cumsum(0, torch::kInt32);
coors_map.index_put_({coors_id_argsort}, coors_map_sorted);
auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1);

const int num_coors_cpu =
coors_map_sorted[-1].cpu().data_ptr<int32_t>()[0] + 1;
auto out_coors = at::empty({num_coors_cpu, NDim}, coors.options());
auto reduced_feats = at::empty({num_coors_cpu, num_feats}, feats.options());
auto reduce_count =
at::zeros({num_coors_cpu}, coors.options().dtype(torch::kInt32));
std::tie(out_coors, coors_map, reduce_count) =
at::unique_dim(coors_clean, 0, true, true, true);

// the first element of out_coors is always (-1,-1,-1) and should be removed
out_coors = out_coors.slice(0, 1);
reduce_count = reduce_count.slice(0, 1).to(torch::kInt32);
coors_map = coors_map.to(torch::kInt32) - 1;

auto reduced_feats =
at::empty({out_coors.size(0), num_feats}, feats.options());

AT_DISPATCH_FLOATING_TYPES(
feats.scalar_type(), "feats_reduce_kernel", ([&] {
using F_t = scalar_t;
AT_DISPATCH_INTEGRAL_TYPES(
coors.scalar_type(), "feats_reduce_kernel", ([&] {
using I_t = scalar_t;

if (reduce_type == reduce_t::MAX)
reduced_feats.fill_(-std::numeric_limits<F_t>::infinity());
else
reduced_feats.fill_(static_cast<F_t>(0));

dim3 blocks(
std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
feats_reduce_kernel<<<blocks, threads>>>(
feats.data_ptr<F_t>(), coors.data_ptr<I_t>(),
coors_map.data_ptr<int32_t>(),
reduce_count.data_ptr<int32_t>(),
reduced_feats.data_ptr<F_t>(), out_coors.data_ptr<I_t>(),
num_input, num_feats, NDim, reduce_type);
if (reduce_type == reduce_t::MEAN)
reduced_feats /=
reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
}));
}));
if (reduce_type == reduce_t::MAX)
reduced_feats.fill_(-std::numeric_limits<scalar_t>::infinity());
else
reduced_feats.fill_(static_cast<scalar_t>(0));

dim3 blocks(std::min(at::cuda::ATenCeilDiv(num_input, threadsPerBlock),
maxGridDim));
dim3 threads(threadsPerBlock);
feats_reduce_kernel<<<blocks, threads>>>(
feats.data_ptr<scalar_t>(), coors_map.data_ptr<int32_t>(),
reduced_feats.data_ptr<scalar_t>(), num_input, num_feats, reduce_type);
if (reduce_type == reduce_t::MEAN)
reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype());
}));
AT_CUDA_CHECK(cudaGetLastError());

return {reduced_feats, out_coors, coors_map, reduce_count};
Expand All @@ -331,6 +254,8 @@ void dynamic_point_to_voxel_backward_gpu(at::Tensor &grad_feats,
grad_feats.fill_(0);
// copy voxel grad to points

if (num_input == 0 || num_reduced == 0) return;

if (reduce_type == reduce_t::MEAN || reduce_type == reduce_t::SUM) {
AT_DISPATCH_FLOATING_TYPES(
grad_reduced_feats.scalar_type(), "add_reduce_traceback_grad_kernel",
Expand Down
35 changes: 34 additions & 1 deletion tests/test_models/test_voxel_encoder/test_dynamic_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,46 @@ def test_dynamic_scatter():
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
coors = torch.randint(
low=-1, high=20, size=(200000, 3), dtype=torch.int32, device='cuda')
coors[coors.min(dim=-1).values < 0] = -1

dsmean = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], True)
dsmax = DynamicScatter([0.32, 0.32, 6],
[-74.88, -74.88, -2, 74.88, 74.88, 4], False)

# test empty input
empty_feats = torch.empty(size=(0, 3), dtype=torch.float32, device='cuda')
empty_coors = torch.empty(size=(0, 3), dtype=torch.int32, device='cuda')

empty_feats.requires_grad_()
empty_feats_out_mean, empty_coors_out_mean = dsmean(
empty_feats, empty_coors)
empty_feats_out_mean.sum().backward()
empty_feats_out_max, empty_coors_out_max = dsmax(empty_feats, empty_coors)
empty_feats_out_max.sum().backward()

assert empty_feats_out_mean.shape == empty_feats.shape
assert empty_feats_out_max.shape == empty_feats.shape
assert empty_coors_out_mean.shape == empty_coors.shape
assert empty_coors_out_max.shape == empty_coors.shape

# test empty reduced output
empty_o_feats = torch.rand(
size=(200000, 3), dtype=torch.float32, device='cuda') * 100 - 50
empty_o_coors = torch.randint(
low=-1, high=0, size=(200000, 3), dtype=torch.int32, device='cuda')

empty_o_feats.requires_grad_()
empty_o_feats_out_mean, empty_o_coors_out_mean = dsmean(
empty_o_feats, empty_o_coors)
empty_o_feats_out_mean.sum().backward()
assert (empty_o_feats.grad == 0).all()

empty_o_feats_out_max, empty_o_coors_out_max = dsmax(
empty_o_feats, empty_o_coors)
empty_o_feats_out_max.sum().backward()
assert (empty_o_feats.grad == 0).all()

# test non-empty input
ref_voxel_coors = coors.unique(dim=0, sorted=True)
ref_voxel_coors = ref_voxel_coors[ref_voxel_coors.min(dim=-1).values >= 0]
ref_voxel_feats_mean = []
Expand Down

0 comments on commit 3a5a201

Please sign in to comment.