Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RandomDropPointsColor transform #585

Merged
merged 3 commits into from
May 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mmdet3d/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)
from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset
Expand All @@ -32,5 +33,6 @@
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline'
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor'
]
6 changes: 4 additions & 2 deletions mmdet3d/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)

__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
Expand All @@ -20,5 +21,6 @@
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D'
'IndoorPatchPointSample', 'LoadImageFromFileMono3D',
'RandomDropPointsColor'
]
44 changes: 44 additions & 0 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,50 @@
from .data_augment_utils import noise_per_object_v3_


@PIPELINES.register_module()
class RandomDropPointsColor(object):
r"""Randomly set the color of points to all zeros.

Once this transform is executed, all the points' color will be dropped.
Refer to `PAConv <https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/
util/transform.py#L223>`_ for more details.

ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
Args:
drop_ratio (float): The probability of dropping point colors.
Defaults to 0.2.
"""

def __init__(self, drop_ratio=0.2):
assert isinstance(drop_ratio, (int, float)) and 0 <= drop_ratio <= 1, \
f'invalid drop_ratio value {drop_ratio}'
self.drop_ratio = drop_ratio

def __call__(self, input_dict):
"""Call function to drop point colors.

Args:
input_dict (dict): Result dict from loading pipeline.

Returns:
dict: Results after color dropping, \
'points' key is updated in the result dict.
"""
points = input_dict['points']
assert points.attribute_dims is not None and \
'color' in points.attribute_dims, \
'Expect points have color attribute'

if np.random.rand() < self.drop_ratio:
points.color = points.color * 0.0
return input_dict

def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(drop_ratio={self.drop_ratio})'
return repr_str


@PIPELINES.register_module()
class RandomFlip3D(RandomFlip):
"""Flip the points & bbox.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from mmdet3d.core.points import DepthPoints, LiDARPoints
from mmdet3d.datasets import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, ObjectNoise, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D,
PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D,
VoxelBasedPointSampler)


Expand Down Expand Up @@ -364,6 +365,41 @@ def test_global_rot_scale_trans():
atol=1e-6)


def test_random_drop_points_color():
# drop_ratio should be in [0, 1]
with pytest.raises(AssertionError):
random_drop_points_color = RandomDropPointsColor(drop_ratio=1.1)

# 100% drop
random_drop_points_color = RandomDropPointsColor(drop_ratio=1)

points = np.fromfile('tests/data/scannet/points/scene0000_00.bin',
np.float32).reshape(-1, 6)
depth_points = DepthPoints(
points.copy(), points_dim=6, attribute_dims=dict(color=[3, 4, 5]))

input_dict = dict(points=depth_points.clone())

input_dict = random_drop_points_color(input_dict)
trans_depth_points = input_dict['points']
trans_color = trans_depth_points.color
assert torch.all(trans_color == trans_color.new_zeros(trans_color.shape))

# 0% drop
random_drop_points_color = RandomDropPointsColor(drop_ratio=0)
input_dict = dict(points=depth_points.clone())

input_dict = random_drop_points_color(input_dict)
trans_depth_points = input_dict['points']
trans_color = trans_depth_points.color
assert torch.allclose(trans_color, depth_points.tensor[:, 3:6])

random_drop_points_color = RandomDropPointsColor(drop_ratio=0.5)
repr_str = repr(random_drop_points_color)
expected_repr_str = 'RandomDropPointsColor(drop_ratio=0.5)'
assert repr_str == expected_repr_str


def test_random_flip_3d():
random_flip_3d = RandomFlip3D(
flip_ratio_bev_horizontal=1.0, flip_ratio_bev_vertical=1.0)
Expand Down