From 1b0ccd40ad06224c591d464fe985ba6f6270323c Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Fri, 4 Jun 2021 10:24:01 +0800 Subject: [PATCH 1/6] enhance IndoorPatchPointSample --- mmdet3d/datasets/pipelines/transforms_3d.py | 48 ++++++++++++------- .../test_pipelines/test_indoor_sample.py | 45 +++++++++++++++-- 2 files changed, 72 insertions(+), 21 deletions(-) diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index ff86f5ce02..c4c046938a 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -905,8 +905,6 @@ class IndoorPatchPointSample(object): num_points (int): Number of points to be sampled. block_size (float, optional): Size of a block to sample points from. Defaults to 1.5. - sample_rate (float, optional): Stride used in sliding patch generation. - Defaults to 1.0. ignore_index (int, optional): Label index that won't be used for the segmentation task. This is set in PointSegClassMapping as neg_cls. Defaults to None. @@ -914,21 +912,29 @@ class IndoorPatchPointSample(object): additional features. Defaults to False. num_try (int, optional): Number of times to try if the patch selected is invalid. Defaults to 10. + enlarge_size (float | None, optional): Enlarge the sampled patch to + [-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as + an augmentation. If None, set it as 0.01. Defaults to 0.2. + min_unique_num (int | None, optional): Minimum number of unique points + the sampled patch should contain. If None, use PointNet++'s method + to judge uniqueness. Defaults to None. """ def __init__(self, num_points, block_size=1.5, - sample_rate=1.0, ignore_index=None, use_normalized_coord=False, - num_try=10): + num_try=10, + enlarge_size=0.2, + min_unique_num=None): self.num_points = num_points self.block_size = block_size - self.sample_rate = sample_rate self.ignore_index = ignore_index self.use_normalized_coord = use_normalized_coord self.num_try = num_try + self.enlarge_size = enlarge_size if enlarge_size is not None else 0.01 + self.min_unique_num = min_unique_num def _input_generation(self, coords, patch_center, coord_max, attributes, attribute_dims, point_type): @@ -997,7 +1003,7 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): coord_max = np.amax(coords, axis=0) coord_min = np.amin(coords, axis=0) - for i in range(self.num_try): + for _ in range(self.num_try): # random sample a point as patch center cur_center = coords[np.random.choice(coords.shape[0])] @@ -1009,7 +1015,8 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): cur_max[2] = coord_max[2] cur_min[2] = coord_min[2] cur_choice = np.sum( - (coords >= (cur_min - 0.2)) * (coords <= (cur_max + 0.2)), + (coords >= (cur_min - self.enlarge_size)) * + (coords <= (cur_max + self.enlarge_size)), axis=1) == 3 if not cur_choice.any(): # no points in this patch @@ -1024,14 +1031,20 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): (cur_coords >= (cur_min - 0.01)) * (cur_coords <= (cur_max + 0.01)), axis=1) == 3 - # not sure if 31, 31, 62 are just some big values used to transform - # coords from 3d array to 1d and then check their uniqueness - # this is used in all the ScanNet code following PointNet++ - vidx = np.ceil((cur_coords[mask, :] - cur_min) / - (cur_max - cur_min) * np.array([31.0, 31.0, 62.0])) - vidx = np.unique(vidx[:, 0] * 31.0 * 62.0 + vidx[:, 1] * 62.0 + - vidx[:, 2]) - flag1 = len(vidx) / 31.0 / 31.0 / 62.0 >= 0.02 + + if self.min_unique_num is None: + # use PointNet++'s method as default + # [31, 31, 62] are just some big values used to transform + # coords from 3d array to 1d and then check their uniqueness + # this is used in all the ScanNet code following PointNet++ + vidx = np.ceil( + (cur_coords[mask, :] - cur_min) / (cur_max - cur_min) * + np.array([31.0, 31.0, 62.0])) + vidx = np.unique(vidx[:, 0] * 31.0 * 62.0 + vidx[:, 1] * 62.0 + + vidx[:, 2]) + flag1 = len(vidx) / 31.0 / 31.0 / 62.0 >= 0.02 + else: + flag1 = mask.sum() >= self.min_unique_num # selected patch should contain enough annotated points if self.ignore_index is None: @@ -1088,10 +1101,11 @@ def __repr__(self): repr_str = self.__class__.__name__ repr_str += f'(num_points={self.num_points},' repr_str += f' block_size={self.block_size},' - repr_str += f' sample_rate={self.sample_rate},' repr_str += f' ignore_index={self.ignore_index},' repr_str += f' use_normalized_coord={self.use_normalized_coord},' - repr_str += f' num_try={self.num_try})' + repr_str += f' num_try={self.num_try},' + repr_str += f' enlarge_size={self.enlarge_size},' + repr_str += f' min_unique_num={self.min_unique_num})' return repr_str diff --git a/tests/test_data/test_pipelines/test_indoor_sample.py b/tests/test_data/test_pipelines/test_indoor_sample.py index 6d1675d8ef..8014f8f1a2 100644 --- a/tests/test_data/test_pipelines/test_indoor_sample.py +++ b/tests/test_data/test_pipelines/test_indoor_sample.py @@ -67,7 +67,7 @@ def test_indoor_sample(): def test_indoor_seg_sample(): # test the train time behavior of IndoorPatchPointSample np.random.seed(0) - scannet_patch_sample_points = IndoorPatchPointSample(5, 1.5, 1.0, 20, True) + scannet_patch_sample_points = IndoorPatchPointSample(5, 1.5, 20, True) scannet_seg_class_mapping = \ PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39), 40) @@ -109,15 +109,52 @@ def test_indoor_seg_sample(): repr_str = repr(scannet_patch_sample_points) expected_repr_str = 'IndoorPatchPointSample(num_points=5, ' \ 'block_size=1.5, ' \ - 'sample_rate=1.0, ' \ 'ignore_index=20, ' \ 'use_normalized_coord=True, ' \ - 'num_try=10)' + 'num_try=10, ' \ + 'enlarge_size=0.2, ' \ + 'min_unique_num=None)' assert repr_str == expected_repr_str + # when enlarge_size and min_unique_num are set + np.random.seed(0) + scannet_patch_sample_points = IndoorPatchPointSample( + 5, 1.0, 20, False, num_try=1000, enlarge_size=None, min_unique_num=5) + # this patch is within [0, 1] and has 5 unique points + # it should be selected + scannet_points = np.random.rand(5, 6) + scannet_points[0, :3] = np.array([0.5, 0.5, 0.5]) + # generate points smaller than `min_unique_num` in local patches + # they won't be sampled + for i in range(2, 11, 2): + scannet_points = np.concatenate( + [scannet_points, np.random.rand(4, 6) + i], axis=0) + scannet_results = dict( + points=DepthPoints( + scannet_points, points_dim=6, + attribute_dims=dict(color=[3, 4, 5])), + pts_semantic_mask=np.random.randint(0, 20, + (scannet_points.shape[0], ))) + scannet_results = scannet_patch_sample_points(scannet_results) + scannet_points_result = scannet_results['points'] + + # manually constructed sampled points + scannet_choices = np.array([2, 4, 3, 1, 0]) + scannet_center = np.array([0.56804454, 0.92559665, 0.07103606]) + scannet_center[2] = 0.0 + scannet_input_points = np.concatenate([ + scannet_points[scannet_choices, :3] - scannet_center, + scannet_points[scannet_choices, 3:], + ], 1) + + assert scannet_points_result.points_dim == 6 + assert scannet_points_result.attribute_dims == dict(color=[3, 4, 5]) + scannet_points_result = scannet_points_result.tensor.numpy() + assert np.allclose(scannet_input_points, scannet_points_result, atol=1e-6) + # test on S3DIS dataset np.random.seed(0) - s3dis_patch_sample_points = IndoorPatchPointSample(5, 1.0, 1.0, None, True) + s3dis_patch_sample_points = IndoorPatchPointSample(5, 1.0, None, True) s3dis_results = dict() s3dis_points = np.fromfile( './tests/data/s3dis/points/Area_1_office_2.bin', From dc26f797f113e7de0eee9d9b10f32222f5254e46 Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Fri, 4 Jun 2021 10:34:53 +0800 Subject: [PATCH 2/6] modify configs and unit test --- .../_base_/datasets/s3dis_seg-3d-13class.py | 5 +- .../_base_/datasets/scannet_seg-3d-20class.py | 5 +- ...16x2_cosine_250e_scannet_seg-3d-20class.py | 5 +- ...16x2_cosine_200e_scannet_seg-3d-20class.py | 5 +- .../test_datasets/test_s3dis_dataset.py | 46 ++++++++++++++++++- .../test_datasets/test_scannet_dataset.py | 11 +++-- .../test_pipelines/test_indoor_pipeline.py | 10 ++-- 7 files changed, 68 insertions(+), 19 deletions(-) diff --git a/configs/_base_/datasets/s3dis_seg-3d-13class.py b/configs/_base_/datasets/s3dis_seg-3d-13class.py index e2dcab8099..39bf5568e0 100644 --- a/configs/_base_/datasets/s3dis_seg-3d-13class.py +++ b/configs/_base_/datasets/s3dis_seg-3d-13class.py @@ -28,9 +28,10 @@ type='IndoorPatchPointSample', num_points=num_points, block_size=1.0, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=True), + use_normalized_coord=True, + enlarge_size=0.2, + min_unique_num=None), dict(type='NormalizePointsColor', color_mean=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) diff --git a/configs/_base_/datasets/scannet_seg-3d-20class.py b/configs/_base_/datasets/scannet_seg-3d-20class.py index 5d9b56f917..cf73b09c8a 100644 --- a/configs/_base_/datasets/scannet_seg-3d-20class.py +++ b/configs/_base_/datasets/scannet_seg-3d-20class.py @@ -29,9 +29,10 @@ type='IndoorPatchPointSample', num_points=num_points, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=False), + use_normalized_coord=False, + enlarge_size=0.2, + min_unique_num=None), dict(type='NormalizePointsColor', color_mean=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) diff --git a/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py b/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py index 9ba658623a..3f32796cc2 100644 --- a/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py +++ b/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py @@ -37,9 +37,10 @@ type='IndoorPatchPointSample', num_points=num_points, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=False), + use_normalized_coord=False, + enlarge_size=0.2, + min_unique_num=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) ] diff --git a/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py b/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py index f2266fc05a..8c05430c96 100644 --- a/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py +++ b/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py @@ -37,9 +37,10 @@ type='IndoorPatchPointSample', num_points=num_points, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=False), + use_normalized_coord=False, + enlarge_size=0.2, + min_unique_num=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) ] diff --git a/tests/test_data/test_datasets/test_s3dis_dataset.py b/tests/test_data/test_datasets/test_s3dis_dataset.py index 1c34eb7196..6ac1d68519 100644 --- a/tests/test_data/test_datasets/test_s3dis_dataset.py +++ b/tests/test_data/test_datasets/test_s3dis_dataset.py @@ -40,9 +40,10 @@ def test_seg_getitem(): type='IndoorPatchPointSample', num_points=5, block_size=1.0, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=True), + use_normalized_coord=True, + enlarge_size=0.2, + min_unique_num=None), dict(type='NormalizePointsColor', color_mean=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict( @@ -207,6 +208,47 @@ def test_seg_show(): mmcv.check_file_exist(gt_file_path) mmcv.check_file_exist(pred_file_path) tmp_dir.cleanup() + # test show with pipeline + tmp_dir = tempfile.TemporaryDirectory() + temp_dir = tmp_dir.name + class_names = ('ceiling', 'floor', 'wall', 'beam', 'column', 'window', + 'door', 'table', 'chair', 'sofa', 'bookcase', 'board', + 'clutter') + eval_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=False, + use_color=True, + load_dim=6, + use_dim=[0, 1, 2, 3, 4, 5]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=False, + with_seg_3d=True), + dict( + type='PointSegClassMapping', + valid_cat_ids=tuple(range(len(class_names))), + max_cat_id=13), + dict( + type='DefaultFormatBundle3D', + with_label=False, + class_names=class_names), + dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) + ] + s3dis_dataset.show(results, temp_dir, show=False, pipeline=eval_pipeline) + pts_file_path = osp.join(temp_dir, 'Area_1_office_2', + 'Area_1_office_2_points.obj') + gt_file_path = osp.join(temp_dir, 'Area_1_office_2', + 'Area_1_office_2_gt.obj') + pred_file_path = osp.join(temp_dir, 'Area_1_office_2', + 'Area_1_office_2_pred.obj') + mmcv.check_file_exist(pts_file_path) + mmcv.check_file_exist(gt_file_path) + mmcv.check_file_exist(pred_file_path) + tmp_dir.cleanup() def test_multi_areas(): diff --git a/tests/test_data/test_datasets/test_scannet_dataset.py b/tests/test_data/test_datasets/test_scannet_dataset.py index bf438d3816..b3823a5252 100644 --- a/tests/test_data/test_datasets/test_scannet_dataset.py +++ b/tests/test_data/test_datasets/test_scannet_dataset.py @@ -335,7 +335,6 @@ def test_seg_getitem(): type='IndoorPatchPointSample', num_points=5, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), use_normalized_coord=True), dict(type='NormalizePointsColor', color_mean=None), @@ -408,9 +407,10 @@ def test_seg_getitem(): type='IndoorPatchPointSample', num_points=5, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=False) + use_normalized_coord=False, + enlarge_size=0.2, + min_unique_num=None) scannet_dataset = ScanNetSegDataset( data_root=root_path, ann_file=ann_file, @@ -456,9 +456,10 @@ def test_seg_getitem(): type='IndoorPatchPointSample', num_points=5, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=False) + use_normalized_coord=False, + enlarge_size=0.2, + min_unique_num=None) new_pipelines.remove(new_pipelines[4]) scannet_dataset = ScanNetSegDataset( data_root=root_path, diff --git a/tests/test_data/test_pipelines/test_indoor_pipeline.py b/tests/test_data/test_pipelines/test_indoor_pipeline.py index 6e705e85f9..91f87a942a 100644 --- a/tests/test_data/test_pipelines/test_indoor_pipeline.py +++ b/tests/test_data/test_pipelines/test_indoor_pipeline.py @@ -142,9 +142,10 @@ def test_scannet_seg_pipeline(): type='IndoorPatchPointSample', num_points=5, block_size=1.5, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=True), + use_normalized_coord=True, + enlarge_size=0.2, + min_unique_num=None), dict(type='NormalizePointsColor', color_mean=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) @@ -212,9 +213,10 @@ def test_s3dis_seg_pipeline(): type='IndoorPatchPointSample', num_points=5, block_size=1.0, - sample_rate=1.0, ignore_index=len(class_names), - use_normalized_coord=True), + use_normalized_coord=True, + enlarge_size=0.2, + min_unique_num=None), dict(type='NormalizePointsColor', color_mean=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict(type='Collect3D', keys=['points', 'pts_semantic_mask']) From 21ede350dbfb9b62505fd30bc2174d97e7ac0bf9 Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Fri, 4 Jun 2021 10:44:22 +0800 Subject: [PATCH 3/6] add docs & comment --- docs/compatibility.md | 6 ++++++ mmdet3d/datasets/pipelines/transforms_3d.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/docs/compatibility.md b/docs/compatibility.md index 0ed8bb889e..805e3eb88d 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -2,6 +2,12 @@ This document provides detailed descriptions of the BC-breaking changes in MMDetection3D. +## MMDetection3D 0.15.0 + +### Enhance `IndoorPatchPointSample` transform + +We enhance the pipeline function `IndoorPatchPointSample` by adding more choices for patch selection and removing a useless parameter `sample_rate`. Please modify the code as well as the config files accordingly if you use this transform. + ## MMDetection3D 0.14.0 ### Dataset class for 3D segmentation task diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index c4c046938a..07bef2a3e8 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -918,6 +918,12 @@ class IndoorPatchPointSample(object): min_unique_num (int | None, optional): Minimum number of unique points the sampled patch should contain. If None, use PointNet++'s method to judge uniqueness. Defaults to None. + + Note: + This transform should only be used in the training process of point + cloud segmentation tasks. For the sliding patch generation and + inference process in testing, please refer to the `slide_inference` + function of `EncoderDecoder3D` class. """ def __init__(self, From 91e43e6b9ae1290494770207ec6a8adde64a4eb5 Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Fri, 4 Jun 2021 10:51:14 +0800 Subject: [PATCH 4/6] fix legacy bug --- ...et2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py | 3 +-- ...et2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py b/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py index 3f32796cc2..2cb7ee1857 100644 --- a/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py +++ b/configs/pointnet2/pointnet2_msg_xyz-only_16x2_cosine_250e_scannet_seg-3d-20class.py @@ -117,8 +117,7 @@ classes=class_names, test_mode=False, ignore_index=len(class_names), - scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy', - label_weight=data_root + 'seg_info/train_label_weight.npy'), + scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy'), val=dict( type=dataset_type, data_root=data_root, diff --git a/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py b/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py index 8c05430c96..9dff449c5f 100644 --- a/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py +++ b/configs/pointnet2/pointnet2_ssg_xyz-only_16x2_cosine_200e_scannet_seg-3d-20class.py @@ -117,8 +117,7 @@ classes=class_names, test_mode=False, ignore_index=len(class_names), - scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy', - label_weight=data_root + 'seg_info/train_label_weight.npy'), + scene_idxs=data_root + 'seg_info/train_resampled_scene_idxs.npy'), val=dict( type=dataset_type, data_root=data_root, From c483eb7e87d4e834f6c18cc5bd5507c46dd81076 Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Fri, 4 Jun 2021 10:56:43 +0800 Subject: [PATCH 5/6] minor fix --- tests/test_data/test_datasets/test_scannet_dataset.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_data/test_datasets/test_scannet_dataset.py b/tests/test_data/test_datasets/test_scannet_dataset.py index b3823a5252..10aac8a3dd 100644 --- a/tests/test_data/test_datasets/test_scannet_dataset.py +++ b/tests/test_data/test_datasets/test_scannet_dataset.py @@ -336,7 +336,9 @@ def test_seg_getitem(): num_points=5, block_size=1.5, ignore_index=len(class_names), - use_normalized_coord=True), + use_normalized_coord=True, + enlarge_size=0.2, + min_unique_num=None), dict(type='NormalizePointsColor', color_mean=None), dict(type='DefaultFormatBundle3D', class_names=class_names), dict( From fa2649b7792086ae39526e9a7ee87034da78a7ca Mon Sep 17 00:00:00 2001 From: Wuziyi616 Date: Sat, 12 Jun 2021 15:55:52 +0800 Subject: [PATCH 6/6] keep sample_rate and add warning --- docs/compatibility.md | 2 +- mmdet3d/datasets/pipelines/transforms_3d.py | 11 +++++++++++ .../test_data/test_pipelines/test_indoor_sample.py | 14 +++++++++++--- 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 805e3eb88d..471a770c49 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -6,7 +6,7 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet ### Enhance `IndoorPatchPointSample` transform -We enhance the pipeline function `IndoorPatchPointSample` by adding more choices for patch selection and removing a useless parameter `sample_rate`. Please modify the code as well as the config files accordingly if you use this transform. +We enhance the pipeline function `IndoorPatchPointSample` used in point cloud segmentation task by adding more choices for patch selection. Also, we plan to remove the unused parameter `sample_rate` in the future. Please modify the code as well as the config files accordingly if you use this transform. ## MMDetection3D 0.14.0 diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index 07bef2a3e8..5f7a59e87b 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -1,4 +1,5 @@ import numpy as np +import warnings from mmcv import is_tuple_of from mmcv.utils import build_from_cfg @@ -905,6 +906,10 @@ class IndoorPatchPointSample(object): num_points (int): Number of points to be sampled. block_size (float, optional): Size of a block to sample points from. Defaults to 1.5. + sample_rate (float, optional): Stride used in sliding patch generation. + This parameter is unused in `IndoorPatchPointSample` and thus has + been deprecated. We plan to remove it in the future. + Defaults to None. ignore_index (int, optional): Label index that won't be used for the segmentation task. This is set in PointSegClassMapping as neg_cls. Defaults to None. @@ -929,6 +934,7 @@ class IndoorPatchPointSample(object): def __init__(self, num_points, block_size=1.5, + sample_rate=None, ignore_index=None, use_normalized_coord=False, num_try=10, @@ -942,6 +948,11 @@ def __init__(self, self.enlarge_size = enlarge_size if enlarge_size is not None else 0.01 self.min_unique_num = min_unique_num + if sample_rate is not None: + warnings.warn( + "'sample_rate' has been deprecated and will be removed in " + 'the future. Please remove them from your code.') + def _input_generation(self, coords, patch_center, coord_max, attributes, attribute_dims, point_type): """Generating model input. diff --git a/tests/test_data/test_pipelines/test_indoor_sample.py b/tests/test_data/test_pipelines/test_indoor_sample.py index 8014f8f1a2..d529f7b10d 100644 --- a/tests/test_data/test_pipelines/test_indoor_sample.py +++ b/tests/test_data/test_pipelines/test_indoor_sample.py @@ -67,7 +67,8 @@ def test_indoor_sample(): def test_indoor_seg_sample(): # test the train time behavior of IndoorPatchPointSample np.random.seed(0) - scannet_patch_sample_points = IndoorPatchPointSample(5, 1.5, 20, True) + scannet_patch_sample_points = IndoorPatchPointSample( + 5, 1.5, ignore_index=20, use_normalized_coord=True) scannet_seg_class_mapping = \ PointSegClassMapping((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, 36, 39), 40) @@ -119,7 +120,13 @@ def test_indoor_seg_sample(): # when enlarge_size and min_unique_num are set np.random.seed(0) scannet_patch_sample_points = IndoorPatchPointSample( - 5, 1.0, 20, False, num_try=1000, enlarge_size=None, min_unique_num=5) + 5, + 1.0, + ignore_index=20, + use_normalized_coord=False, + num_try=1000, + enlarge_size=None, + min_unique_num=5) # this patch is within [0, 1] and has 5 unique points # it should be selected scannet_points = np.random.rand(5, 6) @@ -154,7 +161,8 @@ def test_indoor_seg_sample(): # test on S3DIS dataset np.random.seed(0) - s3dis_patch_sample_points = IndoorPatchPointSample(5, 1.0, None, True) + s3dis_patch_sample_points = IndoorPatchPointSample( + 5, 1.0, ignore_index=None, use_normalized_coord=True) s3dis_results = dict() s3dis_points = np.fromfile( './tests/data/s3dis/points/Area_1_office_2.bin',