Skip to content

Commit 55c3605

Browse files
Merge pull request #23 from nicolas-chaulet/longtype
Longtype
2 parents 6dafff9 + a0dd45b commit 55c3605

File tree

10 files changed

+91
-75
lines changed

10 files changed

+91
-75
lines changed

cpu/include/ball_query.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
#pragma once
22
#include <torch/extension.h>
33
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor query, at::Tensor support, float radius,
4-
int max_num, int mode);
4+
int max_num, int mode, bool sorted);
55

66
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor query, at::Tensor support,
77
at::Tensor query_batch, at::Tensor support_batch,
8-
float radius, int max_num, int mode);
8+
float radius, int max_num, int mode,
9+
bool sorted);
910

1011
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor query, at::Tensor support,
11-
float radius, int max_num, int mode);
12+
float radius, int max_num, int mode,
13+
bool sorted);

cpu/include/neighbors.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ using namespace std;
1010
template <typename scalar_t>
1111
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
1212
vector<long>& neighbors_indices, vector<float>& dists, float radius,
13-
int max_num, int mode);
13+
int max_num, int mode, bool sorted);
1414

1515
template <typename scalar_t>
1616
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
1717
vector<long>& q_batches, vector<long>& s_batches,
1818
vector<long>& neighbors_indices, vector<float>& dists, float radius,
19-
int max_num, int mode);
19+
int max_num, int mode, bool sorted);
2020

2121
template <typename scalar_t>
2222
void nanoflann_knn_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,

cpu/src/ball_query.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#include <torch/extension.h>
99

1010
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor query, float radius,
11-
int max_num, int mode)
11+
int max_num, int mode, bool sorted)
1212
{
1313
CHECK_CONTIGUOUS(support);
1414
CHECK_CONTIGUOUS(query);
@@ -31,7 +31,7 @@ std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor quer
3131
std::vector<scalar_t>(data_s, data_s + support.size(0) * support.size(1));
3232

3333
max_count = nanoflann_neighbors<scalar_t>(queries_stl, supports_stl, neighbors_indices,
34-
neighbors_dists, radius, max_num, mode);
34+
neighbors_dists, radius, max_num, mode, sorted);
3535
});
3636
auto neighbors_dists_ptr = neighbors_dists.data();
3737
long* neighbors_indices_ptr = neighbors_indices.data();
@@ -62,7 +62,7 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes)
6262

6363
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tensor query,
6464
at::Tensor support_batch, at::Tensor query_batch,
65-
float radius, int max_num, int mode)
65+
float radius, int max_num, int mode, bool sorted)
6666
{
6767
CHECK_CONTIGUOUS(support);
6868
CHECK_CONTIGUOUS(query);
@@ -97,9 +97,9 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
9797
std::vector<scalar_t> supports_stl(support.DATA_PTR<scalar_t>(),
9898
support.DATA_PTR<scalar_t>() + support.numel());
9999

100-
max_count = batch_nanoflann_neighbors<scalar_t>(queries_stl, supports_stl, query_batch_stl,
101-
support_batch_stl, neighbors_indices,
102-
neighbors_dists, radius, max_num, mode);
100+
max_count = batch_nanoflann_neighbors<scalar_t>(
101+
queries_stl, supports_stl, query_batch_stl, support_batch_stl, neighbors_indices,
102+
neighbors_dists, radius, max_num, mode, sorted);
103103
});
104104
auto neighbors_dists_ptr = neighbors_dists.data();
105105
long* neighbors_indices_ptr = neighbors_indices.data();
@@ -122,7 +122,7 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
122122
}
123123

124124
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tensor query,
125-
float radius, int max_num, int mode)
125+
float radius, int max_num, int mode, bool sorted)
126126
{
127127
CHECK_CONTIGUOUS(support);
128128
CHECK_CONTIGUOUS(query);
@@ -132,7 +132,7 @@ std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tenso
132132
vector<at::Tensor> batch_dist;
133133
for (int i = 0; i < b; i++)
134134
{
135-
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode);
135+
auto out_pair = ball_query(query[i], support[i], radius, max_num, mode, sorted);
136136
batch_idx.push_back(out_pair.first);
137137
batch_dist.push_back(out_pair.second);
138138
}

cpu/src/bindings.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
2828
"maximum number of neighbors found if mode = 0, if mode=1 return a "
2929
"tensor of size Num_edge x 2 and return a tensor containing the "
3030
"squared distance of the neighbors",
31-
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0);
31+
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false);
3232

3333
m.def("batch_ball_query", &batch_ball_query,
3434
"compute the radius search of a point cloud for each batch using "
@@ -53,7 +53,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
5353
"tensor of size Num_edge x 2 and return a tensor containing the "
5454
"squared distance of the neighbors",
5555
"support"_a, "querry"_a, "query_batch"_a, "support_batch"_a, "radius"_a, "max_num"_a = -1,
56-
"mode"_a = 0);
56+
"mode"_a = 0, "sorted"_a = false);
5757
m.def("dense_ball_query", &dense_ball_query,
5858
"compute the radius search of a batch of point cloud using nanoflann"
5959
"- support : a pytorch tensor of size B x N1 x 3, points where the "
@@ -69,5 +69,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
6969
"maximum number of neighbors found if mode = 0, if mode=1 return a "
7070
"tensor of size Num_edge x 2 and return a tensor containing the "
7171
"squared distance of the neighbors",
72-
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0);
72+
"support"_a, "querry"_a, "radius"_a, "max_num"_a = -1, "mode"_a = 0, "sorted"_a = false);
7373
}

cpu/src/neighbors.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@
22
// Taken from https://github.com/HuguesTHOMAS/KPConv
33

44
#include "neighbors.h"
5+
#include <random>
56

67
template <typename scalar_t>
78
int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
89
vector<long>& neighbors_indices, vector<float>& dists, float radius,
9-
int max_num, int mode)
10+
int max_num, int mode, bool sorted)
1011
{
1112
// Initiate variables
1213
// ******************
14+
std::random_device rd;
15+
std::mt19937 g(rd());
1316

1417
// square radius
15-
1618
const float search_radius = static_cast<float>(radius * radius);
1719

1820
// indices
@@ -47,7 +49,7 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
4749

4850
// Search params
4951
nanoflann::SearchParams search_params;
50-
search_params.sorted = true;
52+
search_params.sorted = sorted;
5153
std::vector<std::vector<std::pair<size_t, scalar_t>>> list_matches(pcd_query.pts.size());
5254

5355
for (auto& p0 : pcd_query.pts)
@@ -62,7 +64,11 @@ int nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
6264
if (nMatches == 0)
6365
list_matches[i0] = {std::make_pair(0, -1)};
6466
else
67+
{
68+
if (!sorted)
69+
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
6570
list_matches[i0] = ret_matches;
71+
}
6672
max_count = max(max_count, nMatches);
6773
i0++;
6874
}
@@ -132,10 +138,13 @@ template <typename scalar_t>
132138
int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& supports,
133139
vector<long>& q_batches, vector<long>& s_batches,
134140
vector<long>& neighbors_indices, vector<float>& dists, float radius,
135-
int max_num, int mode)
141+
int max_num, int mode, bool sorted)
136142
{
137143
// Initiate variables
138144
// ******************
145+
std::random_device rd;
146+
std::mt19937 g(rd());
147+
139148
// indices
140149
int i0 = 0;
141150

@@ -173,7 +182,7 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
173182
// ***********************
174183
// Search params
175184
nanoflann::SearchParams search_params;
176-
search_params.sorted = true;
185+
search_params.sorted = sorted;
177186
for (auto& p0 : query_pcd.pts)
178187
{
179188
// Check if we changed batch
@@ -192,16 +201,18 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
192201
index->buildIndex();
193202
}
194203

195-
// Initial guess of neighbors size
196-
197-
all_inds_dists[i0].reserve(max_count);
198-
// Find neighbors
199-
// std::cerr << p0.x << p0.y << p0.z<<std::endl;
204+
// Find neighboors
205+
std::vector<std::pair<size_t, scalar_t>> ret_matches;
206+
ret_matches.reserve(max_count);
200207
scalar_t query_pt[3] = {p0.x, p0.y, p0.z};
208+
size_t nMatches = index->radiusSearch(query_pt, r2, ret_matches, search_params);
201209

202-
size_t nMatches = index->radiusSearch(query_pt, r2, all_inds_dists[i0], search_params);
203-
// Update max count
210+
// Shuffle if needed
211+
if (!sorted)
212+
std::shuffle(ret_matches.begin(), ret_matches.end(), g);
213+
all_inds_dists[i0] = ret_matches;
204214

215+
// Update max count
205216
if (nMatches > (size_t)max_count)
206217
max_count = nMatches;
207218
// Increment query idx

cuda/src/ball_query.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "utils.h"
44

55
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
6-
const float* new_xyz, const float* xyz, int* idx,
6+
const float* new_xyz, const float* xyz, long* idx,
77
float* dist_out);
88

99
void query_ball_point_kernel_partial_wrapper(long batch_size, int size_x, int size_y, float radius,
@@ -25,15 +25,15 @@ std::pair<at::Tensor, at::Tensor> ball_query_dense(at::Tensor new_xyz, at::Tenso
2525
}
2626

2727
at::Tensor idx = torch::zeros({new_xyz.size(0), new_xyz.size(1), nsample},
28-
at::device(new_xyz.device()).dtype(at::ScalarType::Int));
28+
at::device(new_xyz.device()).dtype(at::ScalarType::Long));
2929
at::Tensor dist = torch::full({new_xyz.size(0), new_xyz.size(1), nsample}, -1,
3030
at::device(new_xyz.device()).dtype(at::ScalarType::Float));
3131

3232
if (new_xyz.type().is_cuda())
3333
{
3434
query_ball_point_kernel_dense_wrapper(
3535
xyz.size(0), xyz.size(1), new_xyz.size(1), radius, nsample, new_xyz.DATA_PTR<float>(),
36-
xyz.DATA_PTR<float>(), idx.DATA_PTR<int>(), dist.DATA_PTR<float>());
36+
xyz.DATA_PTR<float>(), idx.DATA_PTR<long>(), dist.DATA_PTR<float>());
3737
}
3838
else
3939
{

cuda/src/ball_query_gpu.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__global__ void query_ball_point_kernel_dense(int b, int n, int m, float radius, int nsample,
1010
const float* __restrict__ new_xyz,
1111
const float* __restrict__ xyz,
12-
int* __restrict__ idx_out,
12+
long* __restrict__ idx_out,
1313
float* __restrict__ dist_out)
1414
{
1515
int batch_index = blockIdx.x;
@@ -93,7 +93,7 @@ __global__ void query_ball_point_kernel_partial_dense(
9393
}
9494

9595
void query_ball_point_kernel_dense_wrapper(int b, int n, int m, float radius, int nsample,
96-
const float* new_xyz, const float* xyz, int* idx,float* dist_out)
96+
const float* new_xyz, const float* xyz, long* idx,float* dist_out)
9797
{
9898
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
9999
query_ball_point_kernel_dense<<<b, opt_n_threads(m), 0, stream>>>(b, n, m, radius, nsample,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
setup(
4646
name="torch_points",
47-
version="0.4.0",
47+
version="0.4.1",
4848
author="Nicolas Chaulet",
4949
packages=find_packages(),
5050
install_requires=requirements,

test/test_ballquerry.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,24 @@ def test_simple_gpu(self):
1414
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float).cuda()
1515
b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float).cuda()
1616
idx, dist = ball_query(1.01, 2, a, b)
17-
torch.testing.assert_allclose(idx.long().cpu(), torch.tensor([[[0, 1]], [[2, 2]]]))
17+
torch.testing.assert_allclose(idx.cpu(), torch.tensor([[[0, 1]], [[2, 2]]]))
1818
torch.testing.assert_allclose(dist.cpu(), torch.tensor([[[0, 1]], [[1, -1]]]).float())
1919

2020
def test_simple_cpu(self):
2121
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [2, 0, 0]], [[0, 0, 0], [1, 0, 0], [2, 0, 0]]]).to(torch.float)
2222
b = torch.tensor([[[0, 0, 0]], [[3, 0, 0]]]).to(torch.float)
23-
idx, dist = ball_query(1.01, 2, a, b)
24-
torch.testing.assert_allclose(idx.long(), torch.tensor([[[0, 1]], [[2, 2]]]))
23+
idx, dist = ball_query(1.01, 2, a, b, sort=True)
24+
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1]], [[2, 2]]]))
2525
torch.testing.assert_allclose(dist, torch.tensor([[[0, 1]], [[1, -1]]]).float())
2626

2727
a = torch.tensor([[[0, 0, 0], [1, 0, 0], [1, 1, 0]]]).to(torch.float)
28-
idx, dist = ball_query(1.01, 3, a, a)
29-
torch.testing.assert_allclose(idx.long(),torch.tensor([[[0, 1, 0],[1,0,2],[2,1,2]]]))
28+
idx, dist = ball_query(1.01, 3, a, a, sort=True)
29+
torch.testing.assert_allclose(idx, torch.tensor([[[0, 1, 0], [1, 0, 2], [2, 1, 2]]]))
3030

3131
@run_if_cuda
3232
def test_larger_gpu(self):
3333
a = torch.randn(32, 4096, 3).to(torch.float).cuda()
34-
idx,dist = ball_query(1, 64, a, a)
34+
idx, dist = ball_query(1, 64, a, a)
3535
self.assertGreaterEqual(idx.min(), 0)
3636

3737
@run_if_cuda
@@ -70,7 +70,7 @@ def test_simple_gpu(self):
7070
dist2 = dist2.detach().cpu().numpy()
7171

7272
idx_answer = np.asarray([[1, -1]])
73-
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
73+
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)
7474

7575
npt.assert_array_almost_equal(idx, idx_answer)
7676
npt.assert_array_almost_equal(dist2, dist2_answer)
@@ -88,7 +88,7 @@ def test_simple_cpu(self):
8888
dist2 = dist2.detach().cpu().numpy()
8989

9090
idx_answer = np.asarray([[1, -1]])
91-
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
91+
dist2_answer = np.asarray([[0.100, -1.0000]]).astype(np.float32)
9292

9393
npt.assert_array_almost_equal(idx, idx_answer)
9494
npt.assert_array_almost_equal(dist2, dist2_answer)
@@ -100,9 +100,13 @@ def test_random_cpu(self):
100100
batch_b = torch.tensor([0 for i in range(b.shape[0] // 2)] + [1 for i in range(b.shape[0] // 2, b.shape[0])])
101101
R = 1
102102

103-
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
104-
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
103+
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
104+
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=True)
105105
torch.testing.assert_allclose(idx1, idx)
106+
with self.assertRaises(AssertionError):
107+
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
108+
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b, sort=False)
109+
torch.testing.assert_allclose(idx1, idx)
106110

107111
self.assertEqual(idx.shape[0], b.shape[0])
108112
self.assertEqual(dist.shape[0], b.shape[0])

0 commit comments

Comments
 (0)