Skip to content

Commit 93d3101

Browse files
Merge pull request #40 from nicolas-chaulet/region_growing
Basic region growing clustering
2 parents 8cc7b07 + 2f21fe2 commit 93d3101

File tree

10 files changed

+256
-18
lines changed

10 files changed

+256
-18
lines changed

.devcontainer/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,5 @@ RUN apt-get update \
2929
&& rm -rf /var/lib/apt/lists/*
3030

3131
RUN pip3 install -U pip
32-
RUN pip3 install torch numpy scikit-learn flake8 setuptools
32+
RUN pip3 install torch numpy scikit-learn flake8 setuptools numba
3333
RUN pip3 install torch_cluster torch_sparse torch_scatter torch_geometric

.github/workflows/deploy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- name: Install dependencies
1818
run: |
1919
python -m pip install --upgrade pip
20-
pip install torch numpy scikit-learn flake8 setuptools wheel twine
20+
pip install torch numpy scikit-learn flake8 setuptools wheel twine numba
2121
- name: Build package
2222
run: |
2323
python setup.py build_ext --inplace

.github/workflows/tests.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ jobs:
2222
- name: Install dependencies
2323
run: |
2424
python -m pip install --upgrade pip
25-
pip install numpy scikit-learn flake8 setuptools
25+
pip install numpy scikit-learn flake8 setuptools numba
2626
2727
- name: Install torch windows + linux
2828
if: ${{matrix.os != 'macos-latest'}}
2929
run: pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
3030
- name: Install torch macos
3131
if: ${{matrix.os == 'macos-latest'}}
32-
run: pip install torch
32+
run: pip install torch
3333

3434
- name: Build package
3535
run: |

CHANGELOG.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
# UNRELEASED
1+
# 0.6.5
2+
3+
## Additions
4+
- Clustering algorithm for [PointGroup](https://arxiv.org/pdf/2004.01658.pdf)
5+
6+
## Change
7+
- Force no ninja for the compilation
8+
9+
# 0.6.4
210

311
## Bug fix
412
- CPU version works for MacOS

benchmark/region_cluster.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import unittest
2+
import torch
3+
import os
4+
import sys
5+
import time
6+
import random
7+
8+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
9+
sys.path.insert(0, ROOT)
10+
11+
from torch_points_kernels.cluster import grow_proximity
12+
13+
torch.manual_seed(0)
14+
15+
num_points = 100000
16+
pos1 = torch.rand((num_points, 3))
17+
pos2 = torch.rand((num_points, 3)) + 2
18+
pos3 = torch.rand((num_points, 3)) + 4
19+
labels1 = torch.ones(num_points).long()
20+
labels2 = torch.ones(num_points).long()
21+
labels3 = torch.ones(num_points).long()
22+
pos = torch.cat([pos1, pos2, pos3], 0)
23+
label = torch.cat([labels1, labels2, labels3], 0)
24+
batch = torch.ones((3 * num_points)).long()
25+
cl = grow_proximity(pos, batch, radius=0.5)
26+
27+
28+
import cProfile, pstats, io
29+
from pstats import SortKey
30+
31+
pr = cProfile.Profile()
32+
pr.enable()
33+
t_start = time.perf_counter()
34+
grow_proximity(pos, batch)
35+
print(time.perf_counter() - t_start)
36+
pr.disable()
37+
s = io.StringIO()
38+
sortby = SortKey.CUMULATIVE
39+
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
40+
ps.print_stats()
41+
print(s.getvalue())

setup.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,19 @@ def get_ext_modules():
6060
return ext_modules
6161

6262

63+
class CustomBuildExtension(BuildExtension):
64+
def __init__(self, *args, **kwargs):
65+
super().__init__(*args, no_python_abi_suffix=True, use_ninja=False, **kwargs)
66+
67+
6368
def get_cmdclass():
64-
return {"build_ext": BuildExtension}
69+
return {"build_ext": CustomBuildExtension}
6570

6671

67-
requirements = ["torch>=1.1.0"]
72+
requirements = ["torch>=1.1.0", "numba"]
6873

6974
url = "https://github.com/nicolas-chaulet/torch-points-kernels"
70-
__version__ = "0.6.4"
75+
__version__ = "0.6.5"
7176
setup(
7277
name="torch-points-kernels",
7378
version=__version__,

test/test_cluster.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import unittest
2+
import torch
3+
import os
4+
import sys
5+
6+
ROOT = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..")
7+
sys.path.insert(0, ROOT)
8+
9+
from torch_points_kernels.cluster import grow_proximity, region_grow
10+
11+
12+
class TestGrow(unittest.TestCase):
13+
def setUp(self):
14+
self.pos = torch.tensor(
15+
[
16+
[0, 0, 0],
17+
[1, 0, 0],
18+
[2, 0, 0],
19+
[10, 0, 0],
20+
[0, 0, 0],
21+
[1, 0, 0],
22+
[2, 0, 0],
23+
[10, 0, 0],
24+
]
25+
)
26+
self.batch = torch.tensor([0, 0, 0, 0, 1, 1, 1, 1])
27+
self.labels = torch.tensor([0, 0, 1, 1, 0, 1, 1, 10])
28+
29+
def test_simple(self):
30+
clusters = grow_proximity(self.pos, self.batch, radius=2, min_cluster_size=1)
31+
self.assertEqual(clusters, [[0, 1, 2], [3], [4, 5, 6], [7]])
32+
33+
clusters = grow_proximity(self.pos, self.batch, radius=2, min_cluster_size=3)
34+
self.assertEqual(clusters, [[0, 1, 2], [4, 5, 6]])
35+
36+
def test_region_grow(self):
37+
clusters = region_grow(
38+
self.pos, self.labels, self.batch, radius=2, min_cluster_size=1
39+
)
40+
self.assertEqual(len(clusters[0]), 2)
41+
self.assertEqual(len(clusters[1]), 3)
42+
self.assertEqual(len(clusters[10]), 1)
43+
torch.testing.assert_allclose(clusters[0][0], torch.tensor([0, 1]))
44+
torch.testing.assert_allclose(clusters[0][1], torch.tensor([4]))
45+
torch.testing.assert_allclose(clusters[1][0], torch.tensor([2]))
46+
torch.testing.assert_allclose(clusters[1][1], torch.tensor([3]))
47+
torch.testing.assert_allclose(clusters[1][2], torch.tensor([5, 6]))
48+
torch.testing.assert_allclose(clusters[10][0], torch.tensor([7]))
49+
50+
51+
if __name__ == "__main__":
52+
unittest.main()

torch_points_kernels/__init__.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,13 @@
11
from .torchpoints import *
22
from .knn import knn
3+
from .cluster import region_grow
34

4-
__all__ = ["ball_query", "furthest_point_sample", "grouping_operation", "three_interpolate", "three_nn", "knn"]
5+
__all__ = [
6+
"ball_query",
7+
"furthest_point_sample",
8+
"grouping_operation",
9+
"three_interpolate",
10+
"three_nn",
11+
"knn",
12+
"region_grow",
13+
]

torch_points_kernels/cluster.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import torch
2+
from .torchpoints import ball_query_partial_dense
3+
import numpy as np
4+
import numba
5+
6+
7+
@numba.jit(nopython=True)
8+
def _grow_proximity_core(neighbours, min_cluster_size):
9+
num_points = int(neighbours.shape[0])
10+
visited = np.zeros((num_points,), dtype=numba.types.bool_)
11+
clusters = []
12+
for i in range(num_points):
13+
if visited[i]:
14+
continue
15+
16+
cluster = []
17+
queue = []
18+
visited[i] = True
19+
queue.append(i)
20+
cluster.append(i)
21+
while len(queue):
22+
k = queue.pop()
23+
k_neighbours = neighbours[k]
24+
for nei in k_neighbours:
25+
if nei.item() == -1:
26+
break
27+
28+
if not visited[nei]:
29+
visited[nei] = True
30+
queue.append(nei.item())
31+
cluster.append(nei.item())
32+
33+
if len(cluster) >= min_cluster_size:
34+
clusters.append(cluster)
35+
36+
return clusters
37+
38+
39+
def grow_proximity(pos, batch, nsample=16, radius=0.02, min_cluster_size=32):
40+
""" Grow based on proximity only
41+
Neighbour search is done on device while the cluster assignement is done on cpu"""
42+
assert pos.shape[0] == batch.shape[0]
43+
neighbours = (
44+
ball_query_partial_dense(radius, nsample, pos, pos, batch, batch)[0]
45+
.cpu()
46+
.numpy()
47+
)
48+
return _grow_proximity_core(neighbours, min_cluster_size)
49+
50+
51+
def region_grow(
52+
pos, labels, batch, ignore_labels=[], nsample=16, radius=0.02, min_cluster_size=32
53+
):
54+
""" Region growing clustering algorithm proposed in
55+
PointGroup: Dual-Set Point Grouping for 3D Instance Segmentation
56+
https://arxiv.org/pdf/2004.01658.pdf
57+
for instance segmentation
58+
59+
Parameters
60+
----------
61+
pos: torch.Tensor [N, 3]
62+
Location of the points
63+
labels: torch.Tensor [N,]
64+
labels of each point
65+
ignore_labels:
66+
Labels that should be ignored, no region growing will be performed on those
67+
nsample:
68+
maximum number of neighbours to consider
69+
radius:
70+
radius for the neighbour search
71+
min_cluster_size:
72+
Number of points above which a cluster is considered valid
73+
"""
74+
assert labels.dim() == 1
75+
assert pos.dim() == 2
76+
assert pos.shape[0] == labels.shape[0]
77+
78+
unique_labels = torch.unique(labels)
79+
clusters = {}
80+
ind = torch.arange(0, pos.shape[0])
81+
for l in unique_labels:
82+
if l in ignore_labels:
83+
continue
84+
85+
# Build clusters for a given label (ignore other points)
86+
label_mask = labels == l
87+
local_ind = ind[label_mask]
88+
label_clusters = grow_proximity(
89+
pos[label_mask, :],
90+
batch[label_mask],
91+
nsample=nsample,
92+
radius=radius,
93+
min_cluster_size=min_cluster_size,
94+
)
95+
96+
# Remap indices to original coordinates
97+
if len(label_clusters):
98+
remaped_clusters = []
99+
for cluster in label_clusters:
100+
cluster = torch.tensor(cluster).to(pos.device)
101+
remaped_clusters.append(local_ind[cluster])
102+
clusters[l.item()] = remaped_clusters
103+
104+
return clusters

torch_points_kernels/torchpoints.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ def furthest_point_sample(xyz, npoint):
3030
(B, npoint) tensor containing the set
3131
"""
3232
if npoint > xyz.shape[1]:
33-
raise ValueError("caanot sample %i points from an input set of %i points" % (npoint, xyz.shape[1]))
33+
raise ValueError(
34+
"caanot sample %i points from an input set of %i points"
35+
% (npoint, xyz.shape[1])
36+
)
3437
if xyz.is_cuda:
3538
return tpcuda.furthest_point_sampling(xyz, npoint)
3639
else:
@@ -99,9 +102,13 @@ def backward(ctx, grad_out):
99102
idx, weight, m = ctx.three_interpolate_for_backward
100103

101104
if grad_out.is_cuda:
102-
grad_features = tpcuda.three_interpolate_grad(grad_out.contiguous(), idx, weight, m)
105+
grad_features = tpcuda.three_interpolate_grad(
106+
grad_out.contiguous(), idx, weight, m
107+
)
103108
else:
104-
grad_features = tpcpu.knn_interpolate_grad(grad_out.contiguous(), idx, weight, m)
109+
grad_features = tpcpu.knn_interpolate_grad(
110+
grad_out.contiguous(), idx, weight, m
111+
)
105112

106113
return grad_features, None, None
107114

@@ -143,17 +150,23 @@ def grouping_operation(features, idx):
143150
all_idx = idx.reshape(idx.shape[0], -1)
144151
all_idx = all_idx.unsqueeze(1).repeat(1, features.shape[1], 1)
145152
grouped_features = features.gather(2, all_idx)
146-
return grouped_features.reshape(idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2])
153+
return grouped_features.reshape(
154+
idx.shape[0], features.shape[1], idx.shape[1], idx.shape[2]
155+
)
147156

148157

149-
def ball_query_dense(radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False):
158+
def ball_query_dense(
159+
radius, nsample, xyz, new_xyz, batch_xyz=None, batch_new_xyz=None, sort=False
160+
):
150161
# type: (Any, float, int, torch.Tensor, torch.Tensor) -> torch.Tensor
151162
if new_xyz.is_cuda:
152163
if sort:
153164
raise NotImplementedError("CUDA version does not sort the neighbors")
154165
ind, dist = tpcuda.ball_query_dense(new_xyz, xyz, radius, nsample)
155166
else:
156-
ind, dist = tpcpu.dense_ball_query(new_xyz, xyz, radius, nsample, mode=0, sorted=sort)
167+
ind, dist = tpcpu.dense_ball_query(
168+
new_xyz, xyz, radius, nsample, mode=0, sorted=sort
169+
)
157170
return ind, dist
158171

159172

@@ -162,9 +175,13 @@ def ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=False
162175
if x.is_cuda:
163176
if sort:
164177
raise NotImplementedError("CUDA version does not sort the neighbors")
165-
ind, dist = tpcuda.ball_query_partial_dense(x, y, batch_x, batch_y, radius, nsample)
178+
ind, dist = tpcuda.ball_query_partial_dense(
179+
x, y, batch_x, batch_y, radius, nsample
180+
)
166181
else:
167-
ind, dist = tpcpu.batch_ball_query(x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort)
182+
ind, dist = tpcpu.batch_ball_query(
183+
x, y, batch_x, batch_y, radius, nsample, mode=0, sorted=sort
184+
)
168185
return ind, dist
169186

170187

@@ -207,7 +224,9 @@ def ball_query(
207224
assert x.size(0) == batch_x.size(0)
208225
assert y.size(0) == batch_y.size(0)
209226
assert x.dim() == 2
210-
return ball_query_partial_dense(radius, nsample, x, y, batch_x, batch_y, sort=sort)
227+
return ball_query_partial_dense(
228+
radius, nsample, x, y, batch_x, batch_y, sort=sort
229+
)
211230

212231
elif mode.lower() == "dense":
213232
if (batch_x is not None) or (batch_y is not None):

0 commit comments

Comments
 (0)