Skip to content

Commit 60d848b

Browse files
authored
[Feature] Support PolarMix augmentation (#2265)
* support polarmix * Update __init__.py * add UT * use `BasePoints` instead of numpy * Update transforms_3d.py * Update transforms_3d.py * Update test_transforms_3d.py * update docs * update polarmix without MultiImageMixDataset * add comments * fix UT * update docstring * fix yaw calculation * fix UT * refactor * update * update docs * fix typo * Update transforms_3d.py * update ut * fix typehint * add prob argument
1 parent 21de1af commit 60d848b

File tree

4 files changed

+316
-6
lines changed

4 files changed

+316
-6
lines changed

mmdet3d/datasets/seg3d_dataset.py

+18
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,24 @@ def parse_data_info(self, info: dict) -> dict:
283283

284284
return info
285285

286+
def prepare_data(self, idx: int) -> dict:
287+
"""Get data processed by ``self.pipeline``.
288+
289+
Args:
290+
idx (int): The index of ``data_info``.
291+
292+
Returns:
293+
dict: Results passed through ``self.pipeline``.
294+
"""
295+
if not self.test_mode:
296+
data_info = self.get_data_info(idx)
297+
# Pass the dataset to the pipeline during training to support mixed
298+
# data augmentation, such as polarmix.
299+
data_info['dataset'] = self
300+
return self.pipeline(data_info)
301+
else:
302+
return super().prepare_data(idx)
303+
286304
def get_scene_idxs(self, scene_idxs: Union[None, str,
287305
np.ndarray]) -> np.ndarray:
288306
"""Compute scene_idxs for data sampling.

mmdet3d/datasets/transforms/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
MultiViewWrapper, ObjectNameFilter, ObjectNoise,
1515
ObjectRangeFilter, ObjectSample,
1616
PhotoMetricDistortion3D, PointSample, PointShuffle,
17-
PointsRangeFilter, RandomDropPointsColor,
17+
PointsRangeFilter, PolarMix, RandomDropPointsColor,
1818
RandomFlip3D, RandomJitterPoints, RandomResize3D,
1919
RandomShiftScale, Resize3D, VoxelBasedPointSampler)
2020

@@ -30,5 +30,5 @@
3030
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
3131
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
3232
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
33-
'LidarDet3DInferencerLoader'
33+
'LidarDet3DInferencerLoader', 'PolarMix'
3434
]

mmdet3d/datasets/transforms/transforms_3d.py

+171-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import random
33
import warnings
4-
from typing import List, Optional, Tuple, Union
4+
from typing import List, Optional, Sequence, Tuple, Union
55

66
import cv2
77
import mmcv
88
import numpy as np
9+
import torch
910
from mmcv.transforms import BaseTransform, Compose, RandomResize, Resize
1011
from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop,
1112
RandomFlip)
12-
from mmengine import is_tuple_of
13+
from mmengine import is_list_of, is_tuple_of
1314

1415
from mmdet3d.models.task_modules import VoxelGenerator
1516
from mmdet3d.registry import TRANSFORMS
@@ -2352,3 +2353,171 @@ def transform(self, input_dict: dict) -> dict:
23522353
if len(input_dict[key]) == 0:
23532354
input_dict.pop(key)
23542355
return input_dict
2356+
2357+
2358+
@TRANSFORMS.register_module()
2359+
class PolarMix(BaseTransform):
2360+
"""PolarMix data augmentation.
2361+
2362+
The polarmix transform steps are as follows:
2363+
2364+
1. Another random point cloud is picked by dataset.
2365+
2. Exchange sectors of two point clouds that are cut with certain
2366+
azimuth angles.
2367+
3. Cut point instances from picked point cloud, rotate them by multiple
2368+
azimuth angles, and paste the cut and rotated instances.
2369+
2370+
Required Keys:
2371+
2372+
- points (:obj:`BasePoints`)
2373+
- pts_semantic_mask (np.int64)
2374+
- dataset (:obj:`BaseDataset`)
2375+
2376+
Modified Keys:
2377+
2378+
- points (:obj:`BasePoints`)
2379+
- pts_semantic_mask (np.int64)
2380+
2381+
Args:
2382+
instance_classes (List[int]): Semantic masks which represent the
2383+
instance.
2384+
swap_ratio (float): Swap ratio of two point cloud. Defaults to 0.5.
2385+
rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0.
2386+
pre_transform (Sequence[dict], optional): Sequence of transform object
2387+
or config dict to be composed. Defaults to None.
2388+
prob (float): The transformation probability. Defaults to 1.0.
2389+
"""
2390+
2391+
def __init__(self,
2392+
instance_classes: List[int],
2393+
swap_ratio: float = 0.5,
2394+
rotate_paste_ratio: float = 1.0,
2395+
pre_transform: Optional[Sequence[dict]] = None,
2396+
prob: float = 1.0) -> None:
2397+
assert is_list_of(instance_classes, int), \
2398+
'instance_classes should be a list of int'
2399+
self.instance_classes = instance_classes
2400+
self.swap_ratio = swap_ratio
2401+
self.rotate_paste_ratio = rotate_paste_ratio
2402+
2403+
self.prob = prob
2404+
if pre_transform is None:
2405+
self.pre_transform = None
2406+
else:
2407+
self.pre_transform = Compose(pre_transform)
2408+
2409+
def polar_mix_transform(self, input_dict: dict, mix_results: dict) -> dict:
2410+
"""PolarMix transform function.
2411+
2412+
Args:
2413+
input_dict (dict): Result dict from loading pipeline.
2414+
mix_results (dict): Mixed dict picked from dataset.
2415+
2416+
Returns:
2417+
dict: output dict after transformation.
2418+
"""
2419+
mix_points = mix_results['points']
2420+
mix_pts_semantic_mask = mix_results['pts_semantic_mask']
2421+
2422+
points = input_dict['points']
2423+
pts_semantic_mask = input_dict['pts_semantic_mask']
2424+
2425+
# 1. swap point cloud
2426+
if np.random.random() < self.swap_ratio:
2427+
start_angle = (np.random.random() - 1) * np.pi # -pi~0
2428+
end_angle = start_angle + np.pi
2429+
# calculate horizontal angle for each point
2430+
yaw = -torch.atan2(points.coord[:, 1], points.coord[:, 0])
2431+
mix_yaw = -torch.atan2(mix_points.coord[:, 1], mix_points.coord[:,
2432+
0])
2433+
2434+
# select points in sector
2435+
idx = (yaw <= start_angle) | (yaw >= end_angle)
2436+
mix_idx = (mix_yaw > start_angle) & (mix_yaw < end_angle)
2437+
2438+
# swap
2439+
points = points.cat([points[idx], mix_points[mix_idx]])
2440+
pts_semantic_mask = np.concatenate(
2441+
(pts_semantic_mask[idx.numpy()],
2442+
mix_pts_semantic_mask[mix_idx.numpy()]),
2443+
axis=0)
2444+
2445+
# 2. rotate-pasting
2446+
if np.random.random() < self.rotate_paste_ratio:
2447+
# extract instance points
2448+
instance_points, instance_pts_semantic_mask = [], []
2449+
for instance_class in self.instance_classes:
2450+
mix_idx = mix_pts_semantic_mask == instance_class
2451+
instance_points.append(mix_points[mix_idx])
2452+
instance_pts_semantic_mask.append(
2453+
mix_pts_semantic_mask[mix_idx])
2454+
instance_points = mix_points.cat(instance_points)
2455+
instance_pts_semantic_mask = np.concatenate(
2456+
instance_pts_semantic_mask, axis=0)
2457+
2458+
# rotate-copy
2459+
copy_points = [instance_points]
2460+
copy_pts_semantic_mask = [instance_pts_semantic_mask]
2461+
angle_list = [
2462+
np.random.random() * np.pi * 2 / 3,
2463+
(np.random.random() + 1) * np.pi * 2 / 3
2464+
]
2465+
for angle in angle_list:
2466+
new_points = instance_points.clone()
2467+
new_points.rotate(angle)
2468+
copy_points.append(new_points)
2469+
copy_pts_semantic_mask.append(instance_pts_semantic_mask)
2470+
copy_points = instance_points.cat(copy_points)
2471+
copy_pts_semantic_mask = np.concatenate(
2472+
copy_pts_semantic_mask, axis=0)
2473+
2474+
points = points.cat([points, copy_points])
2475+
pts_semantic_mask = np.concatenate(
2476+
(pts_semantic_mask, copy_pts_semantic_mask), axis=0)
2477+
2478+
input_dict['points'] = points
2479+
input_dict['pts_semantic_mask'] = pts_semantic_mask
2480+
return input_dict
2481+
2482+
def transform(self, input_dict: dict) -> dict:
2483+
"""PolarMix transform function.
2484+
2485+
Args:
2486+
input_dict (dict): Result dict from loading pipeline.
2487+
2488+
Returns:
2489+
dict: output dict after transformation.
2490+
"""
2491+
if np.random.rand() > self.prob:
2492+
return input_dict
2493+
2494+
assert 'dataset' in input_dict, \
2495+
'`dataset` is needed to pass through PolarMix, while not found.'
2496+
dataset = input_dict['dataset']
2497+
2498+
# get index of other point cloud
2499+
index = np.random.randint(0, len(dataset))
2500+
2501+
mix_results = dataset.get_data_info(index)
2502+
2503+
if self.pre_transform is not None:
2504+
# pre_transform may also require dataset
2505+
mix_results.update({'dataset': dataset})
2506+
# before polarmix need to go through
2507+
# the necessary pre_transform
2508+
mix_results = self.pre_transform(mix_results)
2509+
mix_results.pop('dataset')
2510+
2511+
input_dict = self.polar_mix_transform(input_dict, mix_results)
2512+
2513+
return input_dict
2514+
2515+
def __repr__(self) -> str:
2516+
"""str: Return a string that describes the module."""
2517+
repr_str = self.__class__.__name__
2518+
repr_str += f'(instance_classes={self.instance_classes}, '
2519+
repr_str += f'swap_ratio={self.swap_ratio}, '
2520+
repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, '
2521+
repr_str += f'pre_transform={self.pre_transform}, '
2522+
repr_str += f'prob={self.prob})'
2523+
return repr_str

tests/test_datasets/test_transforms/test_transforms_3d.py

+125-2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
import torch
77
from mmengine.testing import assert_allclose
88

9-
from mmdet3d.datasets import GlobalAlignment, RandomFlip3D
10-
from mmdet3d.datasets.transforms import GlobalRotScaleTrans
9+
from mmdet3d.datasets import (GlobalAlignment, RandomFlip3D,
10+
SemanticKITTIDataset)
11+
from mmdet3d.datasets.transforms import GlobalRotScaleTrans, PolarMix
12+
from mmdet3d.structures import LiDARPoints
1113
from mmdet3d.testing import create_data_info_after_loading
14+
from mmdet3d.utils import register_all_modules
15+
16+
register_all_modules()
1217

1318

1419
class TestGlobalRotScaleTrans(unittest.TestCase):
@@ -99,3 +104,121 @@ def test_global_alignment(self):
99104
# assert the rot metric
100105
with self.assertRaises(AssertionError):
101106
global_align_transform(data_info)
107+
108+
109+
class TestPolarMix(unittest.TestCase):
110+
111+
def setUp(self):
112+
self.pre_transform = [
113+
dict(
114+
type='LoadPointsFromFile',
115+
coord_type='LIDAR',
116+
load_dim=4,
117+
use_dim=4),
118+
dict(
119+
type='LoadAnnotations3D',
120+
with_bbox_3d=False,
121+
with_label_3d=False,
122+
with_mask_3d=False,
123+
with_seg_3d=True,
124+
seg_3d_dtype='np.int32'),
125+
dict(type='PointSegClassMapping'),
126+
]
127+
classes = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
128+
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
129+
'sidewalk', 'other-ground', 'building', 'fence',
130+
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
131+
palette = [
132+
[174, 199, 232],
133+
[152, 223, 138],
134+
[31, 119, 180],
135+
[255, 187, 120],
136+
[188, 189, 34],
137+
[140, 86, 75],
138+
[255, 152, 150],
139+
[214, 39, 40],
140+
[197, 176, 213],
141+
[148, 103, 189],
142+
[196, 156, 148],
143+
[23, 190, 207],
144+
[247, 182, 210],
145+
[219, 219, 141],
146+
[255, 127, 14],
147+
[158, 218, 229],
148+
[44, 160, 44],
149+
[112, 128, 144],
150+
[227, 119, 194],
151+
[82, 84, 163],
152+
]
153+
seg_label_mapping = {
154+
0: 0, # "unlabeled"
155+
1: 0, # "outlier" mapped to "unlabeled" --------------mapped
156+
10: 1, # "car"
157+
11: 2, # "bicycle"
158+
13: 5, # "bus" mapped to "other-vehicle" --------------mapped
159+
15: 3, # "motorcycle"
160+
16: 5, # "on-rails" mapped to "other-vehicle" ---------mapped
161+
18: 4, # "truck"
162+
20: 5, # "other-vehicle"
163+
30: 6, # "person"
164+
31: 7, # "bicyclist"
165+
32: 8, # "motorcyclist"
166+
40: 9, # "road"
167+
44: 10, # "parking"
168+
48: 11, # "sidewalk"
169+
49: 12, # "other-ground"
170+
50: 13, # "building"
171+
51: 14, # "fence"
172+
52: 0, # "other-structure" mapped to "unlabeled" ------mapped
173+
60: 9, # "lane-marking" to "road" ---------------------mapped
174+
70: 15, # "vegetation"
175+
71: 16, # "trunk"
176+
72: 17, # "terrain"
177+
80: 18, # "pole"
178+
81: 19, # "traffic-sign"
179+
99: 0, # "other-object" to "unlabeled" ----------------mapped
180+
252: 1, # "moving-car" to "car" ------------------------mapped
181+
253: 7, # "moving-bicyclist" to "bicyclist" ------------mapped
182+
254: 6, # "moving-person" to "person" ------------------mapped
183+
255: 8, # "moving-motorcyclist" to "motorcyclist" ------mapped
184+
256: 5, # "moving-on-rails" mapped to "other-vehic------mapped
185+
257: 5, # "moving-bus" mapped to "other-vehicle" -------mapped
186+
258: 4, # "moving-truck" to "truck" --------------------mapped
187+
259: 5 # "moving-other"-vehicle to "other-vehicle"-----mapped
188+
}
189+
max_label = 259
190+
self.dataset = SemanticKITTIDataset(
191+
'./tests/data/semantickitti/',
192+
'semantickitti_infos.pkl',
193+
metainfo=dict(
194+
classes=classes,
195+
palette=palette,
196+
seg_label_mapping=seg_label_mapping,
197+
max_label=max_label),
198+
data_prefix=dict(
199+
pts='sequences/00/velodyne',
200+
pts_semantic_mask='sequences/00/labels'),
201+
pipeline=[],
202+
modality=dict(use_lidar=True, use_camera=False))
203+
points = np.random.random((100, 4))
204+
self.results = {
205+
'points': LiDARPoints(points, points_dim=4),
206+
'pts_semantic_mask': np.random.randint(0, 20, (100, )),
207+
'dataset': self.dataset
208+
}
209+
210+
def test_transform(self):
211+
# test assertion for invalid instance_classes
212+
with self.assertRaises(AssertionError):
213+
transform = PolarMix(instance_classes=1)
214+
215+
with self.assertRaises(AssertionError):
216+
transform = PolarMix(instance_classes=[1.0, 2.0])
217+
218+
transform = PolarMix(
219+
instance_classes=[1, 2],
220+
swap_ratio=1.0,
221+
pre_transform=self.pre_transform)
222+
results = transform.transform(copy.deepcopy(self.results))
223+
self.assertTrue(results['points'].shape[0] ==
224+
results['pts_semantic_mask'].shape[0])

0 commit comments

Comments
 (0)