Skip to content

Commit 90bae82

Browse files
authored
Merge pull request #21 from nicolas-chaulet/changetoken
Change token to -1 for partial dense
2 parents c03fe63 + e6190a5 commit 90bae82

File tree

5 files changed

+24
-17
lines changed

5 files changed

+24
-17
lines changed

cpu/src/neighbors.cpp

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,43 +209,35 @@ int batch_nanoflann_neighbors(vector<scalar_t>& queries, vector<scalar_t>& suppo
209209
}
210210
// how many neighbors do we keep
211211
if (max_num > 0)
212-
{
213212
max_count = max_num;
214-
}
215-
// Reserve the memory
216213

214+
const int token = -1;
217215
if (mode == 0)
218216
{
219217
neighbors_indices.resize(query_pcd.pts.size() * max_count);
220-
221218
dists.resize(query_pcd.pts.size() * max_count);
222219
i0 = 0;
223-
224220
b = 0;
225221

226222
for (auto& inds_dists : all_inds_dists)
227223
{ // Check if we changed batch
228-
229224
if (i0 == q_batches[b + 1] && b < (int)s_batches.size() - 1 &&
230225
b < (int)q_batches.size() - 1)
231-
{
232226
b++;
233-
}
234227

235228
for (int j = 0; j < max_count; j++)
236229
{
237-
if ((unsigned int)j < inds_dists.size())
230+
if ((size_t)j < inds_dists.size())
238231
{
239232
neighbors_indices[i0 * max_count + j] = inds_dists[j].first + s_batches[b];
240233
dists[i0 * max_count + j] = (float)inds_dists[j].second;
241234
}
242235
else
243236
{
244-
neighbors_indices[i0 * max_count + j] = supports.size() / 3;
237+
neighbors_indices[i0 * max_count + j] = token;
245238
dists[i0 * max_count + j] = -1;
246239
}
247240
}
248-
249241
i0++;
250242
}
251243
index.reset();

cpu/src/torch_nearest_neighbors.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
#include "compat.h"
44
#include "neighbors.cpp"
55
#include "neighbors.h"
6+
#include "utils.h"
67
#include <iostream>
78
#include <torch/extension.h>
89

910
std::pair<at::Tensor, at::Tensor> ball_query(at::Tensor support, at::Tensor query, float radius,
1011
int max_num, int mode)
1112
{
13+
CHECK_CONTIGUOUS(support);
14+
CHECK_CONTIGUOUS(query);
15+
1216
at::Tensor out;
1317
at::Tensor out_dists;
1418
std::vector<long> neighbors_indices(query.size(0), 0);
@@ -60,6 +64,11 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
6064
at::Tensor support_batch, at::Tensor query_batch,
6165
float radius, int max_num, int mode)
6266
{
67+
CHECK_CONTIGUOUS(support);
68+
CHECK_CONTIGUOUS(query);
69+
CHECK_CONTIGUOUS(support_batch);
70+
CHECK_CONTIGUOUS(query_batch);
71+
6372
at::Tensor idx;
6473

6574
at::Tensor dist;
@@ -115,6 +124,9 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
115124
std::pair<at::Tensor, at::Tensor> dense_ball_query(at::Tensor support, at::Tensor query,
116125
float radius, int max_num, int mode)
117126
{
127+
CHECK_CONTIGUOUS(support);
128+
CHECK_CONTIGUOUS(query);
129+
118130
int b = query.size(0);
119131
vector<at::Tensor> batch_idx;
120132
vector<at::Tensor> batch_dist;

cuda/src/ball_query.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ std::pair<at::Tensor, at::Tensor> ball_query_partial_dense(at::Tensor x, at::Ten
6464
CHECK_CUDA(batch_y);
6565
}
6666

67-
at::Tensor idx = torch::full({y.size(0), nsample}, x.size(0),
68-
at::device(y.device()).dtype(at::ScalarType::Long));
67+
at::Tensor idx =
68+
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Long));
6969

7070
at::Tensor dist =
7171
torch::full({y.size(0), nsample}, -1, at::device(y.device()).dtype(at::ScalarType::Float));

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444

4545
setup(
4646
name="torch_points",
47-
version="0.2.3",
47+
version="0.3.0",
4848
author="Nicolas Chaulet",
4949
packages=find_packages(),
5050
install_requires=requirements,

test/test_ballquerry.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_simple_gpu(self):
6363
idx = idx.detach().cpu().numpy()
6464
dist2 = dist2.detach().cpu().numpy()
6565

66-
idx_answer = np.asarray([[1, 4]])
66+
idx_answer = np.asarray([[1, -1]])
6767
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
6868

6969
npt.assert_array_almost_equal(idx, idx_answer)
@@ -81,7 +81,7 @@ def test_simple_cpu(self):
8181
idx = idx.detach().cpu().numpy()
8282
dist2 = dist2.detach().cpu().numpy()
8383

84-
idx_answer = np.asarray([[1, 4]])
84+
idx_answer = np.asarray([[1, -1]])
8585
dist2_answer = np.asarray([[0.0100, -1.0000]]).astype(np.float32)
8686

8787
npt.assert_array_almost_equal(idx, idx_answer)
@@ -95,6 +95,9 @@ def test_random_cpu(self):
9595
R = 1
9696

9797
idx, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
98+
idx1, dist = ball_query(R, 15, a, b, mode="PARTIAL_DENSE", batch_x=batch_a, batch_y=batch_b)
99+
torch.testing.assert_allclose(idx1, idx)
100+
98101
self.assertEqual(idx.shape[0], b.shape[0])
99102
self.assertEqual(dist.shape[0], b.shape[0])
100103
self.assertLessEqual(idx.max().item(), len(batch_a))
@@ -104,7 +107,7 @@ def test_random_cpu(self):
104107
idx3_sk = tree.query_radius(b.detach().numpy(), r=R)
105108
i = np.random.randint(len(batch_b))
106109
for p in idx[i].detach().numpy():
107-
if p < len(batch_a):
110+
if p >= 0 and p < len(batch_a):
108111
assert p in idx3_sk[i]
109112

110113

0 commit comments

Comments
 (0)