Skip to content

Commit 7d16352

Browse files
Merge pull request #32 from nicolas-chaulet/debug
Fix bug with negative index accessing random memory
2 parents c5cbbae + 7bb9aa3 commit 7d16352

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

cpu/src/ball_query.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ at::Tensor degree(at::Tensor row, int64_t num_nodes)
5757
{
5858
auto zero = at::zeros(num_nodes, row.options());
5959
auto one = at::ones(row.size(0), row.options());
60-
return zero.scatter_add_(0, row, one);
60+
auto out = zero.scatter_add_(0, row, one);
61+
return out;
6162
}
6263

6364
std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tensor query,
@@ -79,9 +80,13 @@ std::pair<at::Tensor, at::Tensor> batch_ball_query(at::Tensor support, at::Tenso
7980
auto options_dist = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCPU);
8081

8182
int max_count = 0;
82-
auto batch_access = query_batch.accessor<int64_t, 1>();
83+
auto q_batch_access = query_batch.accessor<int64_t, 1>();
84+
auto s_batch_access = support_batch.accessor<int64_t, 1>();
85+
86+
auto batch_size = q_batch_access[query_batch.size(0) - 1] + 1;
87+
TORCH_CHECK(batch_size == (s_batch_access[support_batch.size(0) - 1] + 1),
88+
"Both batches need to have the same number of samples.")
8389

84-
auto batch_size = batch_access[-1] + 1;
8590
query_batch = degree(query_batch, batch_size);
8691
query_batch = at::cat({at::zeros(1, query_batch.options()), query_batch.cumsum(0)}, 0);
8792
support_batch = degree(support_batch, batch_size);

test/test_ballquerry.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import unittest
22
import torch
3-
from torch_points_kernels import ball_query
43
import numpy.testing as npt
54
import numpy as np
65
from sklearn.neighbors import KDTree
6+
import os
7+
import sys
8+
9+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
10+
sys.path.insert(0, ROOT)
711

8-
from . import run_if_cuda
12+
from test import run_if_cuda
13+
from torch_points_kernels import ball_query
914

1015

1116
class TestBall(unittest.TestCase):
@@ -76,10 +81,10 @@ def test_simple_gpu(self):
7681
npt.assert_array_almost_equal(dist2, dist2_answer)
7782

7883
def test_simple_cpu(self):
79-
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [0.1, 0, 0]]).to(torch.float)
84+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
8085
y = torch.tensor([[0, 0, 0]]).to(torch.float)
8186

82-
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
87+
batch_x = torch.from_numpy(np.asarray([0, 0, 0, 0])).long()
8388
batch_y = torch.from_numpy(np.asarray([0])).long()
8489

8590
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
@@ -93,6 +98,17 @@ def test_simple_cpu(self):
9398
npt.assert_array_almost_equal(idx, idx_answer)
9499
npt.assert_array_almost_equal(dist2, dist2_answer)
95100

101+
102+
def test_breaks(self):
103+
x = torch.tensor([[10, 0, 0], [0.1, 0, 0], [10, 0, 0], [10.1, 0, 0]]).to(torch.float)
104+
y = torch.tensor([[0, 0, 0]]).to(torch.float)
105+
106+
batch_x = torch.from_numpy(np.asarray([0, 0, 1, 1])).long()
107+
batch_y = torch.from_numpy(np.asarray([0])).long()
108+
109+
with self.assertRaises(RuntimeError):
110+
idx, dist2 = ball_query(1.0, 2, x, y, mode="PARTIAL_DENSE", batch_x=batch_x, batch_y=batch_y)
111+
96112
def test_random_cpu(self):
97113
a = torch.randn(100, 3).to(torch.float)
98114
b = torch.randn(50, 3).to(torch.float)

0 commit comments

Comments
 (0)