Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support LaserMix augmentation #2302

Merged
merged 5 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmdet3d/datasets/seg3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def prepare_data(self, idx: int) -> dict:
if not self.test_mode:
data_info = self.get_data_info(idx)
# Pass the dataset to the pipeline during training to support mixed
# data augmentation, such as polarmix.
# data augmentation, such as polarmix and lasermix.
data_info['dataset'] = self
return self.pipeline(data_info)
else:
Expand Down
6 changes: 3 additions & 3 deletions mmdet3d/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .transforms_3d import (AffineResize, BackgroundPointsFilter,
GlobalAlignment, GlobalRotScaleTrans,
IndoorPatchPointSample, IndoorPointSample,
MultiViewWrapper, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample,
LaserMix, MultiViewWrapper, ObjectNameFilter,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PhotoMetricDistortion3D, PointSample, PointShuffle,
PointsRangeFilter, PolarMix, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomResize3D,
Expand All @@ -30,5 +30,5 @@
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
'LidarDet3DInferencerLoader', 'PolarMix'
'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix'
]
145 changes: 145 additions & 0 deletions mmdet3d/datasets/transforms/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,3 +2521,148 @@ def __repr__(self) -> str:
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str


@TRANSFORMS.register_module()
class LaserMix(BaseTransform):
"""LaserMix data augmentation.
The lasermix transform steps are as follows:
1. Another random point cloud is picked by dataset.
2. Divide the point cloud into several regions according to pitch
angles and combine the areas crossly.
Required Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
- dataset (:obj:`BaseDataset`)
Modified Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
Args:
num_areas (List[int]): A list of area numbers will be divided into.
pitch_angles (Sequence[float]): Pitch angles used to divide areas.
pre_transform (Sequence[dict], optional): Sequence of transform object
or config dict to be composed. Defaults to None.
prob (float): The transformation probability. Defaults to 1.0.
"""

def __init__(self,
num_areas: List[int],
pitch_angles: Sequence[float],
pre_transform: Optional[Sequence[dict]] = None,
prob: float = 1.0) -> None:
assert is_list_of(num_areas, int), \
'num_areas should be a list of int.'
self.num_areas = num_areas

assert len(pitch_angles) == 2, \
'The length of pitch_angles should be 2, ' \
f'but got {len(pitch_angles)}.'
assert pitch_angles[1] > pitch_angles[0], \
'pitch_angles[1] should be larger than pitch_angles[0].'
self.pitch_angles = pitch_angles

self.prob = prob
if pre_transform is None:
self.pre_transform = None
else:
self.pre_transform = Compose(pre_transform)

def laser_mix_transform(self, input_dict: dict, mix_results: dict) -> dict:
"""LaserMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
mix_results (dict): Mixed dict picked from dataset.
Returns:
dict: output dict after transformation.
"""
mix_points = mix_results['points']
mix_pts_semantic_mask = mix_results['pts_semantic_mask']

points = input_dict['points']
pts_semantic_mask = input_dict['pts_semantic_mask']

rho = torch.sqrt(points.coord[:, 0]**2 + points.coord[:, 1]**2)
pitch = torch.atan2(points.coord[:, 2], rho)
pitch = torch.clip(pitch, self.pitch_angles[0] + 1e-5,
self.pitch_angles[1] - 1e-5)

mix_rho = torch.sqrt(mix_points.coord[:, 0]**2 +
mix_points.coord[:, 1]**2)
mix_pitch = torch.atan2(mix_points.coord[:, 2], mix_rho)
mix_pitch = torch.clip(mix_pitch, self.pitch_angles[0] + 1e-5,
self.pitch_angles[1] - 1e-5)

num_areas = np.random.choice(self.num_areas, size=1)[0]
angle_list = np.linspace(self.pitch_angles[1], self.pitch_angles[0],
num_areas + 1)
out_points = []
out_pts_semantic_mask = []
for i in range(num_areas):
# convert angle to radian
start_angle = angle_list[i + 1] / 180 * np.pi
end_angle = angle_list[i] / 180 * np.pi
if i % 2 == 0: # pick from original point cloud
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
idx = (pitch > start_angle) & (pitch <= end_angle)
out_points.append(points[idx])
out_pts_semantic_mask.append(pts_semantic_mask[idx.numpy()])
else: # pickle from mixed point cloud
idx = (mix_pitch > start_angle) & (mix_pitch <= end_angle)
out_points.append(mix_points[idx])
out_pts_semantic_mask.append(
mix_pts_semantic_mask[idx.numpy()])
out_points = points.cat(out_points)
out_pts_semantic_mask = np.concatenate(out_pts_semantic_mask, axis=0)
input_dict['points'] = out_points
input_dict['pts_semantic_mask'] = out_pts_semantic_mask
return input_dict

def transform(self, input_dict: dict) -> dict:
"""LaserMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: output dict after transformation.
"""
if np.random.rand() > self.prob:
return input_dict

assert 'dataset' in input_dict, \
'`dataset` is needed to pass through LaserMix, while not found.'
dataset = input_dict['dataset']

# get index of other point cloud
index = np.random.randint(0, len(dataset))

mix_results = dataset.get_data_info(index)

if self.pre_transform is not None:
# pre_transform may also require dataset
mix_results.update({'dataset': dataset})
# before lasermix need to go through
# the necessary pre_transform
mix_results = self.pre_transform(mix_results)
mix_results.pop('dataset')

input_dict = self.laser_mix_transform(input_dict, mix_results)

return input_dict

def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_areas={self.num_areas}, '
repr_str += f'pitch_angles={self.pitch_angles}, '
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str
127 changes: 126 additions & 1 deletion tests/test_datasets/test_transforms/test_transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
SemanticKITTIDataset)
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, LaserMix, PolarMix
from mmdet3d.structures import LiDARPoints
from mmdet3d.testing import create_data_info_after_loading
from mmdet3d.utils import register_all_modules
Expand Down Expand Up @@ -222,3 +222,128 @@ def test_transform(self):
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])


class TestLaserMix(unittest.TestCase):

def setUp(self):
self.pre_transform = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32'),
dict(type='PointSegClassMapping'),
]
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
seg_label_mapping = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
max_label = 259
self.dataset = SemanticKITTIDataset(
'./tests/data/semantickitti/',
'semantickitti_infos.pkl',
metainfo=dict(
classes=classes,
palette=palette,
seg_label_mapping=seg_label_mapping,
max_label=max_label),
data_prefix=dict(
pts='sequences/00/velodyne',
pts_semantic_mask='sequences/00/labels'),
pipeline=[],
modality=dict(use_lidar=True, use_camera=False))
points = np.random.random((100, 4))
self.results = {
'points': LiDARPoints(points, points_dim=4),
'pts_semantic_mask': np.random.randint(0, 20, (100, )),
'dataset': self.dataset
}

def test_transform(self):
# test assertion for invalid num_areas
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=3, pitch_angles=[-20, 0])

with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3.0, 4.0], pitch_angles=[-20, 0])

# test assertion for invalid pitch_angles
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3, 4], pitch_angles=[-20])

with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3, 4], pitch_angles=[0, -20])

transform = LaserMix(
num_areas=[3, 4, 5, 6],
pitch_angles=[-20, 0],
pre_transform=self.pre_transform)
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])