Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Efficient implementation of PointSegClassMapping #489

Merged
merged 5 commits into from
Apr 27, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions configs/_base_/datasets/s3dis_seg-3d-13class.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))),
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict(
type='IndoorPatchPointSample',
num_points=num_points,
Expand Down Expand Up @@ -65,7 +66,8 @@
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))),
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict(
type='DefaultFormatBundle3D',
with_label=False,
Expand Down
3 changes: 2 additions & 1 deletion configs/_base_/datasets/scannet-3d-18class.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)),
36, 39),
max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000),
dict(
type='RandomFlip3D',
Expand Down
6 changes: 4 additions & 2 deletions configs/_base_/datasets/scannet_seg-3d-20class.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)),
33, 34, 36, 39),
max_cat_id=40),
dict(
type='IndoorPatchPointSample',
num_points=num_points,
Expand Down Expand Up @@ -67,7 +68,8 @@
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)),
33, 34, 36, 39),
max_cat_id=40),
dict(
type='DefaultFormatBundle3D',
with_label=False,
Expand Down
6 changes: 4 additions & 2 deletions docs/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for
dict(
type='PointSegClassMapping', # Declare valid categories, refer to mmdet3d.datasets.pipelines.point_seg_class_mapping for more details
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39)),
36, 39), # all valid categories ids
max_cat_id=40), # max possible category id in input segmentation mask
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details
num_points=40000), # Number of points to be sampled
dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes
Expand Down Expand Up @@ -283,7 +284,8 @@ data = dict(
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000),
dict(
type='IndoorFlipData',
Expand Down
24 changes: 14 additions & 10 deletions mmdet3d/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,19 @@ class PointSegClassMapping(object):

Args:
valid_cat_ids (tuple[int]): A tuple of valid category.
max_cat_id (int): The max possible cat_id in input segmentation mask.
Defaults to 40.
"""

def __init__(self, valid_cat_ids):
def __init__(self, valid_cat_ids, max_cat_id=40):
self.valid_cat_ids = valid_cat_ids
self.max_cat_id = max_cat_id
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved

# build cat_id to class index mapping
neg_cls = len(valid_cat_ids)
self.cat_id2class = (np.ones(max_cat_id + 1) * neg_cls).astype(int)
for cls_idx, cat_id in enumerate(valid_cat_ids):
self.cat_id2class[cat_id] = cls_idx

def __call__(self, results):
"""Call function to map original semantic class to valid category ids.
Expand All @@ -256,22 +265,17 @@ def __call__(self, results):
"""
assert 'pts_semantic_mask' in results
pts_semantic_mask = results['pts_semantic_mask']
neg_cls = len(self.valid_cat_ids)

for i in range(pts_semantic_mask.shape[0]):
if pts_semantic_mask[i] in self.valid_cat_ids:
converted_id = self.valid_cat_ids.index(pts_semantic_mask[i])
pts_semantic_mask[i] = converted_id
else:
pts_semantic_mask[i] = neg_cls
converted_pts_sem_mask = self.cat_id2class[pts_semantic_mask]

results['pts_semantic_mask'] = pts_semantic_mask
results['pts_semantic_mask'] = converted_pts_sem_mask
return results

def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(valid_cat_ids={self.valid_cat_ids})'
repr_str += f'(valid_cat_ids={self.valid_cat_ids}, '
repr_str += f'max_cat_id={self.max_cat_id})'
return repr_str


Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/datasets/s3dis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def _build_default_pipeline(self):
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS),
valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=np.max(self.ALL_CLASS_IDS)),
dict(
type='DefaultFormatBundle3D',
with_label=False,
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/datasets/scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def _build_default_pipeline(self):
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=self.VALID_CLASS_IDS),
valid_cat_ids=self.VALID_CLASS_IDS,
max_cat_id=np.max(self.ALL_CLASS_IDS)),
dict(
type='DefaultFormatBundle3D',
with_label=False,
Expand Down
3 changes: 2 additions & 1 deletion tests/test_data/test_datasets/test_s3dis_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def test_seg_getitem():
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))),
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict(
type='IndoorPatchPointSample',
num_points=5,
Expand Down
9 changes: 6 additions & 3 deletions tests/test_data/test_datasets/test_scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,8 @@ def test_seg_getitem():
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
28, 33, 34, 36, 39),
max_cat_id=40),
dict(
type='IndoorPatchPointSample',
num_points=5,
Expand Down Expand Up @@ -542,7 +543,8 @@ def test_seg_evaluate():
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
Expand Down Expand Up @@ -606,7 +608,8 @@ def test_seg_show():
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'pts_semantic_mask'])
]
Expand Down
6 changes: 4 additions & 2 deletions tests/test_data/test_pipelines/test_indoor_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def test_scannet_seg_pipeline():
dict(
type='PointSegClassMapping',
valid_cat_ids=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39)),
28, 33, 34, 36, 39),
max_cat_id=40),
dict(
type='IndoorPatchPointSample',
num_points=5,
Expand Down Expand Up @@ -197,7 +198,8 @@ def test_s3dis_seg_pipeline():
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=tuple(range(len(class_names)))),
valid_cat_ids=tuple(range(len(class_names))),
max_cat_id=13),
dict(
type='IndoorPatchPointSample',
num_points=5,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_pipelines/test_indoor_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_indoor_seg_sample():
scannet_patch_sample_points = IndoorPatchPointSample(5, 1.5, 1.0, 20, True)
scannet_seg_class_mapping = \
PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16,
24, 28, 33, 34, 36, 39))
24, 28, 33, 34, 36, 39), 40)
scannet_results = dict()
scannet_points = np.fromfile(
'./tests/data/scannet/points/scene0000_00.bin',
Expand Down
33 changes: 31 additions & 2 deletions tests/test_data/test_pipelines/test_loadings/test_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def test_load_segmentation_mask():
# Convert class_id to label and assign ignore_index
scannet_seg_class_mapping = \
PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16,
24, 28, 33, 34, 36, 39))
24, 28, 33, 34, 36, 39), 40)
scannet_results = scannet_seg_class_mapping(scannet_results)
scannet_pts_semantic_mask = scannet_results['pts_semantic_mask']

Expand Down Expand Up @@ -250,7 +250,7 @@ def test_load_segmentation_mask():
assert s3dis_pts_semantic_mask.shape == (100, )

# Convert class_id to label and assign ignore_index
s3dis_seg_class_mapping = PointSegClassMapping(tuple(range(13)))
s3dis_seg_class_mapping = PointSegClassMapping(tuple(range(13)), 13)
s3dis_results = s3dis_seg_class_mapping(s3dis_results)
s3dis_pts_semantic_mask = s3dis_results['pts_semantic_mask']

Expand Down Expand Up @@ -288,6 +288,35 @@ def test_load_points_from_multi_sweeps():
assert points.shape == (403, 4)


def test_point_seg_class_mapping():
sem_mask = np.array([
16, 22, 2, 3, 7, 3, 16, 2, 16, 3, 1, 0, 6, 22, 3, 1, 2, 16, 1, 1, 1,
38, 7, 25, 16, 25, 3, 40, 38, 3, 33, 6, 16, 6, 16, 1, 38, 1, 1, 2, 8,
0, 18, 15, 0, 0, 40, 40, 1, 2, 3, 16, 33, 2, 2, 2, 7, 3, 14, 22, 4, 22,
15, 24, 2, 40, 3, 2, 8, 3, 1, 6, 40, 6, 0, 15, 4, 7, 6, 0, 1, 16, 14,
3, 0, 1, 1, 16, 38, 2, 15, 6, 4, 1, 16, 2, 3, 3, 3, 2
])
valid_cat_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)
point_seg_class_mapping = PointSegClassMapping(valid_cat_ids, 40)
input_dict = dict(pts_semantic_mask=sem_mask)
results = point_seg_class_mapping(input_dict)
mapped_sem_mask = results['pts_semantic_mask']
expected_sem_mask = np.array([
13, 20, 1, 2, 6, 2, 13, 1, 13, 2, 0, 20, 5, 20, 2, 0, 1, 13, 0, 0, 0,
20, 6, 20, 13, 20, 2, 20, 20, 2, 16, 5, 13, 5, 13, 0, 20, 0, 0, 1, 7,
20, 20, 20, 20, 20, 20, 20, 0, 1, 2, 13, 16, 1, 1, 1, 6, 2, 12, 20, 3,
20, 20, 14, 1, 20, 2, 1, 7, 2, 0, 5, 20, 5, 20, 20, 3, 6, 5, 20, 0, 13,
12, 2, 20, 0, 0, 13, 20, 1, 20, 5, 3, 0, 13, 1, 2, 2, 2, 1
])
repr_str = repr(point_seg_class_mapping)
expected_repr_str = f'PointSegClassMapping(valid_cat_ids={valid_cat_ids}'\
', max_cat_id=40)'

assert np.all(mapped_sem_mask == expected_sem_mask)
assert repr_str == expected_repr_str


def test_normalize_points_color():
coord = np.array([[68.137, 3.358, 2.516], [67.697, 3.55, 2.501],
[67.649, 3.76, 2.5], [66.414, 3.901, 2.459],
Expand Down