Skip to content

Commit

Permalink
[Feature] Support LaserMix augmentation (#2302)
Browse files Browse the repository at this point in the history
* add lasermix

* add prob

* update description

* update
  • Loading branch information
Xiangxu-0103 authored Feb 28, 2023
1 parent 7beabbd commit a5627bf
Show file tree
Hide file tree
Showing 4 changed files with 275 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mmdet3d/datasets/seg3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def prepare_data(self, idx: int) -> dict:
if not self.test_mode:
data_info = self.get_data_info(idx)
# Pass the dataset to the pipeline during training to support mixed
# data augmentation, such as polarmix.
# data augmentation, such as polarmix and lasermix.
data_info['dataset'] = self
return self.pipeline(data_info)
else:
Expand Down
6 changes: 3 additions & 3 deletions mmdet3d/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from .transforms_3d import (AffineResize, BackgroundPointsFilter,
GlobalAlignment, GlobalRotScaleTrans,
IndoorPatchPointSample, IndoorPointSample,
MultiViewWrapper, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample,
LaserMix, MultiViewWrapper, ObjectNameFilter,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PhotoMetricDistortion3D, PointSample, PointShuffle,
PointsRangeFilter, PolarMix, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints, RandomResize3D,
Expand All @@ -30,5 +30,5 @@
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
'LidarDet3DInferencerLoader', 'PolarMix'
'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix'
]
145 changes: 145 additions & 0 deletions mmdet3d/datasets/transforms/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2521,3 +2521,148 @@ def __repr__(self) -> str:
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str


@TRANSFORMS.register_module()
class LaserMix(BaseTransform):
"""LaserMix data augmentation.
The lasermix transform steps are as follows:
1. Another random point cloud is picked by dataset.
2. Divide the point cloud into several regions according to pitch
angles and combine the areas crossly.
Required Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
- dataset (:obj:`BaseDataset`)
Modified Keys:
- points (:obj:`BasePoints`)
- pts_semantic_mask (np.int64)
Args:
num_areas (List[int]): A list of area numbers will be divided into.
pitch_angles (Sequence[float]): Pitch angles used to divide areas.
pre_transform (Sequence[dict], optional): Sequence of transform object
or config dict to be composed. Defaults to None.
prob (float): The transformation probability. Defaults to 1.0.
"""

def __init__(self,
num_areas: List[int],
pitch_angles: Sequence[float],
pre_transform: Optional[Sequence[dict]] = None,
prob: float = 1.0) -> None:
assert is_list_of(num_areas, int), \
'num_areas should be a list of int.'
self.num_areas = num_areas

assert len(pitch_angles) == 2, \
'The length of pitch_angles should be 2, ' \
f'but got {len(pitch_angles)}.'
assert pitch_angles[1] > pitch_angles[0], \
'pitch_angles[1] should be larger than pitch_angles[0].'
self.pitch_angles = pitch_angles

self.prob = prob
if pre_transform is None:
self.pre_transform = None
else:
self.pre_transform = Compose(pre_transform)

def laser_mix_transform(self, input_dict: dict, mix_results: dict) -> dict:
"""LaserMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
mix_results (dict): Mixed dict picked from dataset.
Returns:
dict: output dict after transformation.
"""
mix_points = mix_results['points']
mix_pts_semantic_mask = mix_results['pts_semantic_mask']

points = input_dict['points']
pts_semantic_mask = input_dict['pts_semantic_mask']

rho = torch.sqrt(points.coord[:, 0]**2 + points.coord[:, 1]**2)
pitch = torch.atan2(points.coord[:, 2], rho)
pitch = torch.clip(pitch, self.pitch_angles[0] + 1e-5,
self.pitch_angles[1] - 1e-5)

mix_rho = torch.sqrt(mix_points.coord[:, 0]**2 +
mix_points.coord[:, 1]**2)
mix_pitch = torch.atan2(mix_points.coord[:, 2], mix_rho)
mix_pitch = torch.clip(mix_pitch, self.pitch_angles[0] + 1e-5,
self.pitch_angles[1] - 1e-5)

num_areas = np.random.choice(self.num_areas, size=1)[0]
angle_list = np.linspace(self.pitch_angles[1], self.pitch_angles[0],
num_areas + 1)
out_points = []
out_pts_semantic_mask = []
for i in range(num_areas):
# convert angle to radian
start_angle = angle_list[i + 1] / 180 * np.pi
end_angle = angle_list[i] / 180 * np.pi
if i % 2 == 0: # pick from original point cloud
idx = (pitch > start_angle) & (pitch <= end_angle)
out_points.append(points[idx])
out_pts_semantic_mask.append(pts_semantic_mask[idx.numpy()])
else: # pickle from mixed point cloud
idx = (mix_pitch > start_angle) & (mix_pitch <= end_angle)
out_points.append(mix_points[idx])
out_pts_semantic_mask.append(
mix_pts_semantic_mask[idx.numpy()])
out_points = points.cat(out_points)
out_pts_semantic_mask = np.concatenate(out_pts_semantic_mask, axis=0)
input_dict['points'] = out_points
input_dict['pts_semantic_mask'] = out_pts_semantic_mask
return input_dict

def transform(self, input_dict: dict) -> dict:
"""LaserMix transform function.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: output dict after transformation.
"""
if np.random.rand() > self.prob:
return input_dict

assert 'dataset' in input_dict, \
'`dataset` is needed to pass through LaserMix, while not found.'
dataset = input_dict['dataset']

# get index of other point cloud
index = np.random.randint(0, len(dataset))

mix_results = dataset.get_data_info(index)

if self.pre_transform is not None:
# pre_transform may also require dataset
mix_results.update({'dataset': dataset})
# before lasermix need to go through
# the necessary pre_transform
mix_results = self.pre_transform(mix_results)
mix_results.pop('dataset')

input_dict = self.laser_mix_transform(input_dict, mix_results)

return input_dict

def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_areas={self.num_areas}, '
repr_str += f'pitch_angles={self.pitch_angles}, '
repr_str += f'pre_transform={self.pre_transform}, '
repr_str += f'prob={self.prob})'
return repr_str
127 changes: 126 additions & 1 deletion tests/test_datasets/test_transforms/test_transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
SemanticKITTIDataset)
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, LaserMix, PolarMix
from mmdet3d.structures import LiDARPoints
from mmdet3d.testing import create_data_info_after_loading
from mmdet3d.utils import register_all_modules
Expand Down Expand Up @@ -222,3 +222,128 @@ def test_transform(self):
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])


class TestLaserMix(unittest.TestCase):

def setUp(self):
self.pre_transform = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32'),
dict(type='PointSegClassMapping'),
]
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
palette = [
[174, 199, 232],
[152, 223, 138],
[31, 119, 180],
[255, 187, 120],
[188, 189, 34],
[140, 86, 75],
[255, 152, 150],
[214, 39, 40],
[197, 176, 213],
[148, 103, 189],
[196, 156, 148],
[23, 190, 207],
[247, 182, 210],
[219, 219, 141],
[255, 127, 14],
[158, 218, 229],
[44, 160, 44],
[112, 128, 144],
[227, 119, 194],
[82, 84, 163],
]
seg_label_mapping = {
0: 0, # "unlabeled"
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
10: 1, # "car"
11: 2, # "bicycle"
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
15: 3, # "motorcycle"
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
18: 4, # "truck"
20: 5, # "other-vehicle"
30: 6, # "person"
31: 7, # "bicyclist"
32: 8, # "motorcyclist"
40: 9, # "road"
44: 10, # "parking"
48: 11, # "sidewalk"
49: 12, # "other-ground"
50: 13, # "building"
51: 14, # "fence"
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
60: 9, # "lane-marking" to "road" ---------------------mapped
70: 15, # "vegetation"
71: 16, # "trunk"
72: 17, # "terrain"
80: 18, # "pole"
81: 19, # "traffic-sign"
99: 0, # "other-object" to "unlabeled" ----------------mapped
252: 1, # "moving-car" to "car" ------------------------mapped
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
254: 6, # "moving-person" to "person" ------------------mapped
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
258: 4, # "moving-truck" to "truck" --------------------mapped
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
}
max_label = 259
self.dataset = SemanticKITTIDataset(
'./tests/data/semantickitti/',
'semantickitti_infos.pkl',
metainfo=dict(
classes=classes,
palette=palette,
seg_label_mapping=seg_label_mapping,
max_label=max_label),
data_prefix=dict(
pts='sequences/00/velodyne',
pts_semantic_mask='sequences/00/labels'),
pipeline=[],
modality=dict(use_lidar=True, use_camera=False))
points = np.random.random((100, 4))
self.results = {
'points': LiDARPoints(points, points_dim=4),
'pts_semantic_mask': np.random.randint(0, 20, (100, )),
'dataset': self.dataset
}

def test_transform(self):
# test assertion for invalid num_areas
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=3, pitch_angles=[-20, 0])

with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3.0, 4.0], pitch_angles=[-20, 0])

# test assertion for invalid pitch_angles
with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3, 4], pitch_angles=[-20])

with self.assertRaises(AssertionError):
transform = LaserMix(num_areas=[3, 4], pitch_angles=[0, -20])

transform = LaserMix(
num_areas=[3, 4, 5, 6],
pitch_angles=[-20, 0],
pre_transform=self.pre_transform)
results = transform.transform(copy.deepcopy(self.results))
self.assertTrue(results['points'].shape[0] ==
results['pts_semantic_mask'].shape[0])

0 comments on commit a5627bf

Please sign in to comment.