Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Sampling points based on distance metric #667

Merged
merged 14 commits into from
Aug 5, 2021
4 changes: 2 additions & 2 deletions configs/3dssd/3dssd_4x4_kitti-3d-car.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.9, 1.1]),
dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict(type='IndoorPointSample', num_points=16384),
dict(type='PointSample', num_points=16384),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
Expand All @@ -77,7 +77,7 @@
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='IndoorPointSample', num_points=16384),
dict(type='PointSample', num_points=16384),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
Expand All @@ -47,7 +47,7 @@
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
4 changes: 2 additions & 2 deletions configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
Expand Down Expand Up @@ -225,7 +225,7 @@
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
10 changes: 5 additions & 5 deletions docs/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39), # all valid categories ids
max_cat_id=40), # max possible category id in input segmentation mask
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details
dict(type='PointSample', # Sample points, refer to mmdet3d.datasets.pipelines.transforms_3d for more details
num_points=40000), # Number of points to be sampled
dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes
flip_ratio_yz=0.5, # Probability of being flipped along yz plane
Expand Down Expand Up @@ -232,7 +232,7 @@ test_pipeline = [ # Testing pipeline, refer to mmdet3d.datasets.pipelines for m
shift_height=True, # Whether to use shifted height
load_dim=6, # The dimension of the loaded points
use_dim=[0, 1, 2]), # Which dimensions of the points to be used
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details
dict(type='PointSample', # Sample points, refer to mmdet3d.datasets.pipelines.transforms_3d for more details
num_points=40000), # Number of points to be sampled
dict(
type='DefaultFormatBundle3D', # Default format bundle to gather data in the pipeline, refer to mmdet3d.datasets.pipelines.formating for more details
Expand Down Expand Up @@ -286,7 +286,7 @@ data = dict(
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000),
dict(type='PointSample', num_points=40000),
dict(
type='IndoorFlipData',
flip_ratio_yz=0.5,
Expand Down Expand Up @@ -325,7 +325,7 @@ data = dict(
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000),
dict(type='PointSample', num_points=40000),
dict(
type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
Expand All @@ -350,7 +350,7 @@ data = dict(
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000),
dict(type='PointSample', num_points=40000),
dict(
type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
Expand Down
6 changes: 6 additions & 0 deletions mmdet3d/core/points/base_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def __getitem__(self, item):
Nonzero elements in the vector will be selected.
4. `new_points = points[3:11, vector]`:
return a slice of points and attribute dims.
5. `new_points = points[4:12, 2]`:
return a slice of points with single attribute.
Note that the returned Points might share storage with this Points,
subject to Pytorch's indexing semantics.

Expand All @@ -303,6 +305,10 @@ def __getitem__(self, item):
item = list(item)
item[1] = list(range(start, stop, step))
item = tuple(item)
elif isinstance(item[1], int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add unit tests for this.

item = list(item)
item[1] = [item[1]]
item = tuple(item)
p = self.tensor[item[0], item[1]]

keep_dims = list(
Expand Down
20 changes: 12 additions & 8 deletions mmdet3d/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset
from .nuscenes_mono_dataset import NuScenesMonoDataset
# yapf: disable
from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler)
ObjectRangeFilter, ObjectSample, PointSample,
PointShuffle, PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints,
VoxelBasedPointSampler)
# yapf: enable
from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset
Expand All @@ -30,9 +33,10 @@
'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter',
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'S3DISSegDataset',
'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor', 'RandomJitterPoints', 'ObjectNameFilter'
'PointSample', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset',
'ScanNetDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset',
'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints',
'ObjectNameFilter'
]
17 changes: 9 additions & 8 deletions mmdet3d/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints,
VoxelBasedPointSampler)
ObjectRangeFilter, ObjectSample, PointSample,
PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler)

__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter',
'RandomDropPointsColor', 'RandomJitterPoints'
'PointSample', 'PointSegClassMapping', 'MultiScaleFlipAug3D',
'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints'
]
81 changes: 60 additions & 21 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,45 +818,60 @@ def __repr__(self):


@PIPELINES.register_module()
class IndoorPointSample(object):
"""Indoor point sample.
class PointSample(object):
"""Point sample.

Sampling data to a certain number.

Args:
name (str): Name of the dataset.
num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points.
"""

def __init__(self, num_points):
def __init__(self, num_points, sample_range=None):
self.num_points = num_points

def points_random_sampling(self,
points,
num_samples,
replace=None,
return_choices=False):
self.sample_range = sample_range

def _points_random_sampling(self,
points,
num_samples,
sample_range=None,
replace=None,
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
return_choices=False):
"""Points random sampling.

Sample points to a certain number.

Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled.
replace (bool): Whether the sample is with or without replacement.
Defaults to None.
return_choices (bool): Whether return choice. Defaults to False.

sample_range (float, optional): Indicating the range where the
points will be sampled.
Defaults to None.
wHao-Wu marked this conversation as resolved.
Show resolved Hide resolved
replace (bool, optional): Sampling with or without replacement.
Defaults to None.
return_choices (bool, optional): Whether return choice.
Defaults to False.
Returns:
tuple[np.ndarray] | np.ndarray:

- points (np.ndarray | :obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples.
"""
if replace is None:
replace = (points.shape[0] < num_samples)
choices = np.random.choice(
points.shape[0], num_samples, replace=replace)
point_range = range(len(points))
if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth > sample_range)[0]
near_inds = np.where(depth <= sample_range)[0]
point_range = near_inds
num_samples -= len(far_inds)
choices = np.random.choice(point_range, num_samples, replace=replace)
if sample_range is not None and not replace:
wHao-Wu marked this conversation as resolved.
Show resolved Hide resolved
choices = np.concatenate((far_inds, choices))
# Shuffle points after sampling
np.random.shuffle(choices)
if return_choices:
return points[choices], choices
else:
Expand All @@ -867,14 +882,19 @@ def __call__(self, results):

Args:
input_dict (dict): Result dict from loading pipeline.

Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
from mmdet3d.core.points import CameraPoints
points = results['points']
points, choices = self.points_random_sampling(
points, self.num_points, return_choices=True)
# Points in Camera coord can provide the depth information.
# TODO: Need to suport distance-based sampling for other coord system.
if self.sample_range is not None:
assert isinstance(points, CameraPoints), \
'Sampling based on distance is only appliable for CAMERA coord'
points, choices = self._points_random_sampling(
points, self.num_points, self.sample_range, return_choices=True)
results['points'] = points

pts_instance_mask = results.get('pts_instance_mask', None)
Expand All @@ -893,10 +913,29 @@ def __call__(self, results):
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points})'
repr_str += f'(num_points={self.num_points},'
repr_str += f' sample_range={self.sample_range})'

return repr_str


@PIPELINES.register_module()
class IndoorPointSample(PointSample):
"""Indoor point sample.

Sampling data to a certain number.
NOTE: IndoorPointSample is deprecated in favor of PointSample

Args:
num_points (int): Number of points to be sampled.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
'IndoorPointSample is deprecated in favor of PointSample')
super(IndoorPointSample, self).__init__(*args, **kwargs)


@PIPELINES.register_module()
class IndoorPatchPointSample(object):
r"""Indoor point sample within a patch. Modified from `PointNet++ <https://
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_datasets/test_scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_getitem():
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)),
dict(type='IndoorPointSample', num_points=5),
dict(type='PointSample', num_points=5),
dict(
type='RandomFlip3D',
sync_2d=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data/test_datasets/test_sunrgbd_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _generate_sunrgbd_dataset_config():
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=5),
dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
Expand Down Expand Up @@ -73,7 +73,7 @@ def _generate_sunrgbd_multi_modality_dataset_config():
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=5),
dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_multi_scale_flip_aug_3D():
'sync_2d': False,
'flip_ratio_bev_horizontal': 0.5
}, {
'type': 'IndoorPointSample',
'type': 'PointSample',
'num_points': 5
}, {
'type':
Expand Down
Loading