|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved.
|
2 | 2 | import random
|
3 | 3 | import warnings
|
4 |
| -from typing import List, Optional, Tuple, Union |
| 4 | +from typing import List, Optional, Sequence, Tuple, Union |
5 | 5 |
|
6 | 6 | import cv2
|
7 | 7 | import mmcv
|
8 | 8 | import numpy as np
|
| 9 | +import torch |
9 | 10 | from mmcv.transforms import BaseTransform, Compose, RandomResize, Resize
|
10 | 11 | from mmdet.datasets.transforms import (PhotoMetricDistortion, RandomCrop,
|
11 | 12 | RandomFlip)
|
12 |
| -from mmengine import is_tuple_of |
| 13 | +from mmengine import is_list_of, is_tuple_of |
13 | 14 |
|
14 | 15 | from mmdet3d.models.task_modules import VoxelGenerator
|
15 | 16 | from mmdet3d.registry import TRANSFORMS
|
@@ -2352,3 +2353,171 @@ def transform(self, input_dict: dict) -> dict:
|
2352 | 2353 | if len(input_dict[key]) == 0:
|
2353 | 2354 | input_dict.pop(key)
|
2354 | 2355 | return input_dict
|
| 2356 | + |
| 2357 | + |
| 2358 | +@TRANSFORMS.register_module() |
| 2359 | +class PolarMix(BaseTransform): |
| 2360 | + """PolarMix data augmentation. |
| 2361 | +
|
| 2362 | + The polarmix transform steps are as follows: |
| 2363 | +
|
| 2364 | + 1. Another random point cloud is picked by dataset. |
| 2365 | + 2. Exchange sectors of two point clouds that are cut with certain |
| 2366 | + azimuth angles. |
| 2367 | + 3. Cut point instances from picked point cloud, rotate them by multiple |
| 2368 | + azimuth angles, and paste the cut and rotated instances. |
| 2369 | +
|
| 2370 | + Required Keys: |
| 2371 | +
|
| 2372 | + - points (:obj:`BasePoints`) |
| 2373 | + - pts_semantic_mask (np.int64) |
| 2374 | + - dataset (:obj:`BaseDataset`) |
| 2375 | +
|
| 2376 | + Modified Keys: |
| 2377 | +
|
| 2378 | + - points (:obj:`BasePoints`) |
| 2379 | + - pts_semantic_mask (np.int64) |
| 2380 | +
|
| 2381 | + Args: |
| 2382 | + instance_classes (List[int]): Semantic masks which represent the |
| 2383 | + instance. |
| 2384 | + swap_ratio (float): Swap ratio of two point cloud. Defaults to 0.5. |
| 2385 | + rotate_paste_ratio (float): Rotate paste ratio. Defaults to 1.0. |
| 2386 | + pre_transform (Sequence[dict], optional): Sequence of transform object |
| 2387 | + or config dict to be composed. Defaults to None. |
| 2388 | + prob (float): The transformation probability. Defaults to 1.0. |
| 2389 | + """ |
| 2390 | + |
| 2391 | + def __init__(self, |
| 2392 | + instance_classes: List[int], |
| 2393 | + swap_ratio: float = 0.5, |
| 2394 | + rotate_paste_ratio: float = 1.0, |
| 2395 | + pre_transform: Optional[Sequence[dict]] = None, |
| 2396 | + prob: float = 1.0) -> None: |
| 2397 | + assert is_list_of(instance_classes, int), \ |
| 2398 | + 'instance_classes should be a list of int' |
| 2399 | + self.instance_classes = instance_classes |
| 2400 | + self.swap_ratio = swap_ratio |
| 2401 | + self.rotate_paste_ratio = rotate_paste_ratio |
| 2402 | + |
| 2403 | + self.prob = prob |
| 2404 | + if pre_transform is None: |
| 2405 | + self.pre_transform = None |
| 2406 | + else: |
| 2407 | + self.pre_transform = Compose(pre_transform) |
| 2408 | + |
| 2409 | + def polar_mix_transform(self, input_dict: dict, mix_results: dict) -> dict: |
| 2410 | + """PolarMix transform function. |
| 2411 | +
|
| 2412 | + Args: |
| 2413 | + input_dict (dict): Result dict from loading pipeline. |
| 2414 | + mix_results (dict): Mixed dict picked from dataset. |
| 2415 | +
|
| 2416 | + Returns: |
| 2417 | + dict: output dict after transformation. |
| 2418 | + """ |
| 2419 | + mix_points = mix_results['points'] |
| 2420 | + mix_pts_semantic_mask = mix_results['pts_semantic_mask'] |
| 2421 | + |
| 2422 | + points = input_dict['points'] |
| 2423 | + pts_semantic_mask = input_dict['pts_semantic_mask'] |
| 2424 | + |
| 2425 | + # 1. swap point cloud |
| 2426 | + if np.random.random() < self.swap_ratio: |
| 2427 | + start_angle = (np.random.random() - 1) * np.pi # -pi~0 |
| 2428 | + end_angle = start_angle + np.pi |
| 2429 | + # calculate horizontal angle for each point |
| 2430 | + yaw = -torch.atan2(points.coord[:, 1], points.coord[:, 0]) |
| 2431 | + mix_yaw = -torch.atan2(mix_points.coord[:, 1], mix_points.coord[:, |
| 2432 | + 0]) |
| 2433 | + |
| 2434 | + # select points in sector |
| 2435 | + idx = (yaw <= start_angle) | (yaw >= end_angle) |
| 2436 | + mix_idx = (mix_yaw > start_angle) & (mix_yaw < end_angle) |
| 2437 | + |
| 2438 | + # swap |
| 2439 | + points = points.cat([points[idx], mix_points[mix_idx]]) |
| 2440 | + pts_semantic_mask = np.concatenate( |
| 2441 | + (pts_semantic_mask[idx.numpy()], |
| 2442 | + mix_pts_semantic_mask[mix_idx.numpy()]), |
| 2443 | + axis=0) |
| 2444 | + |
| 2445 | + # 2. rotate-pasting |
| 2446 | + if np.random.random() < self.rotate_paste_ratio: |
| 2447 | + # extract instance points |
| 2448 | + instance_points, instance_pts_semantic_mask = [], [] |
| 2449 | + for instance_class in self.instance_classes: |
| 2450 | + mix_idx = mix_pts_semantic_mask == instance_class |
| 2451 | + instance_points.append(mix_points[mix_idx]) |
| 2452 | + instance_pts_semantic_mask.append( |
| 2453 | + mix_pts_semantic_mask[mix_idx]) |
| 2454 | + instance_points = mix_points.cat(instance_points) |
| 2455 | + instance_pts_semantic_mask = np.concatenate( |
| 2456 | + instance_pts_semantic_mask, axis=0) |
| 2457 | + |
| 2458 | + # rotate-copy |
| 2459 | + copy_points = [instance_points] |
| 2460 | + copy_pts_semantic_mask = [instance_pts_semantic_mask] |
| 2461 | + angle_list = [ |
| 2462 | + np.random.random() * np.pi * 2 / 3, |
| 2463 | + (np.random.random() + 1) * np.pi * 2 / 3 |
| 2464 | + ] |
| 2465 | + for angle in angle_list: |
| 2466 | + new_points = instance_points.clone() |
| 2467 | + new_points.rotate(angle) |
| 2468 | + copy_points.append(new_points) |
| 2469 | + copy_pts_semantic_mask.append(instance_pts_semantic_mask) |
| 2470 | + copy_points = instance_points.cat(copy_points) |
| 2471 | + copy_pts_semantic_mask = np.concatenate( |
| 2472 | + copy_pts_semantic_mask, axis=0) |
| 2473 | + |
| 2474 | + points = points.cat([points, copy_points]) |
| 2475 | + pts_semantic_mask = np.concatenate( |
| 2476 | + (pts_semantic_mask, copy_pts_semantic_mask), axis=0) |
| 2477 | + |
| 2478 | + input_dict['points'] = points |
| 2479 | + input_dict['pts_semantic_mask'] = pts_semantic_mask |
| 2480 | + return input_dict |
| 2481 | + |
| 2482 | + def transform(self, input_dict: dict) -> dict: |
| 2483 | + """PolarMix transform function. |
| 2484 | +
|
| 2485 | + Args: |
| 2486 | + input_dict (dict): Result dict from loading pipeline. |
| 2487 | +
|
| 2488 | + Returns: |
| 2489 | + dict: output dict after transformation. |
| 2490 | + """ |
| 2491 | + if np.random.rand() > self.prob: |
| 2492 | + return input_dict |
| 2493 | + |
| 2494 | + assert 'dataset' in input_dict, \ |
| 2495 | + '`dataset` is needed to pass through PolarMix, while not found.' |
| 2496 | + dataset = input_dict['dataset'] |
| 2497 | + |
| 2498 | + # get index of other point cloud |
| 2499 | + index = np.random.randint(0, len(dataset)) |
| 2500 | + |
| 2501 | + mix_results = dataset.get_data_info(index) |
| 2502 | + |
| 2503 | + if self.pre_transform is not None: |
| 2504 | + # pre_transform may also require dataset |
| 2505 | + mix_results.update({'dataset': dataset}) |
| 2506 | + # before polarmix need to go through |
| 2507 | + # the necessary pre_transform |
| 2508 | + mix_results = self.pre_transform(mix_results) |
| 2509 | + mix_results.pop('dataset') |
| 2510 | + |
| 2511 | + input_dict = self.polar_mix_transform(input_dict, mix_results) |
| 2512 | + |
| 2513 | + return input_dict |
| 2514 | + |
| 2515 | + def __repr__(self) -> str: |
| 2516 | + """str: Return a string that describes the module.""" |
| 2517 | + repr_str = self.__class__.__name__ |
| 2518 | + repr_str += f'(instance_classes={self.instance_classes}, ' |
| 2519 | + repr_str += f'swap_ratio={self.swap_ratio}, ' |
| 2520 | + repr_str += f'rotate_paste_ratio={self.rotate_paste_ratio}, ' |
| 2521 | + repr_str += f'pre_transform={self.pre_transform}, ' |
| 2522 | + repr_str += f'prob={self.prob})' |
| 2523 | + return repr_str |
0 commit comments