diff --git a/mmdet3d/datasets/__init__.py b/mmdet3d/datasets/__init__.py index f47335ebd9..61b34ec98b 100644 --- a/mmdet3d/datasets/__init__.py +++ b/mmdet3d/datasets/__init__.py @@ -11,6 +11,7 @@ ObjectSample, PointShuffle, PointsRangeFilter, RandomFlip3D, VoxelBasedPointSampler) from .scannet_dataset import ScanNetDataset +from .semantickitti_dataset import SemanticKITTIDataset from .sunrgbd_dataset import SUNRGBDDataset from .waymo_dataset import WaymoDataset @@ -21,7 +22,7 @@ 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample', - 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset', - 'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter', - 'VoxelBasedPointSampler' + 'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', + 'SemanticKITTIDataset', 'Custom3DDataset', 'LoadPointsFromMultiSweeps', + 'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler' ] diff --git a/mmdet3d/datasets/pipelines/loading.py b/mmdet3d/datasets/pipelines/loading.py index 85c12a3a05..4e9bfd2a53 100644 --- a/mmdet3d/datasets/pipelines/loading.py +++ b/mmdet3d/datasets/pipelines/loading.py @@ -419,6 +419,8 @@ class LoadAnnotations3D(LoadAnnotations): Defaults to False. poly2mask (bool, optional): Whether to convert polygon annotations to bitmasks. Defaults to True. + seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks. + Defaults to int64 file_client_args (dict): Config dict of file clients, refer to https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py for more details. @@ -434,6 +436,7 @@ def __init__(self, with_mask=False, with_seg=False, poly2mask=True, + seg_3d_dtype='int', file_client_args=dict(backend='disk')): super().__init__( with_bbox, @@ -446,6 +449,7 @@ def __init__(self, self.with_label_3d = with_label_3d self.with_mask_3d = with_mask_3d self.with_seg_3d = with_seg_3d + self.seg_3d_dtype = seg_3d_dtype def _load_bboxes_3d(self, results): """Private function to load 3D bounding box annotations. @@ -513,7 +517,8 @@ def _load_semantic_seg_3d(self, results): try: mask_bytes = self.file_client.get(pts_semantic_mask_path) # add .copy() to fix read-only bug - pts_semantic_mask = np.frombuffer(mask_bytes, dtype=np.int).copy() + pts_semantic_mask = np.frombuffer( + mask_bytes, dtype=self.seg_3d_dtype).copy() except ConnectionError: mmcv.check_file_exist(pts_semantic_mask_path) pts_semantic_mask = np.fromfile( diff --git a/mmdet3d/datasets/semantickitti_dataset.py b/mmdet3d/datasets/semantickitti_dataset.py new file mode 100644 index 0000000000..446ff35f89 --- /dev/null +++ b/mmdet3d/datasets/semantickitti_dataset.py @@ -0,0 +1,80 @@ +from os import path as osp + +from mmdet.datasets import DATASETS +from .custom_3d import Custom3DDataset + + +@DATASETS.register_module() +class SemanticKITTIDataset(Custom3DDataset): + r"""SemanticKITTI Dataset. + + This class serves as the API for experiments on the SemanticKITTI Dataset + Please refer to `_ + for data downloading + + Args: + data_root (str): Path of dataset root. + ann_file (str): Path of annotation file. + pipeline (list[dict], optional): Pipeline used for data processing. + Defaults to None. + classes (tuple[str], optional): Classes used in the dataset. + Defaults to None. + modality (dict, optional): Modality to specify the sensor data used + as input. Defaults to None. + box_type_3d (str, optional): NO 3D box for this dataset. + You can choose any type + Based on the `box_type_3d`, the dataset will encapsulate the box + to its original format then converted them to `box_type_3d`. + Defaults to 'LiDAR' in this dataset. Available options includes + + - 'LiDAR': Box in LiDAR coordinates. + - 'Depth': Box in depth coordinates, usually for indoor dataset. + - 'Camera': Box in camera coordinates. + filter_empty_gt (bool, optional): Whether to filter empty GT. + Defaults to True. + test_mode (bool, optional): Whether the dataset is in test mode. + Defaults to False. + """ + CLASSES = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', + 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', + 'sidewalk', 'other-ground', 'building', 'fence', 'vegetation', + 'trunck', 'terrian', 'pole', 'traffic-sign') + + def __init__(self, + data_root, + ann_file, + pipeline=None, + classes=None, + modality=None, + box_type_3d='Lidar', + filter_empty_gt=False, + test_mode=False): + super().__init__( + data_root=data_root, + ann_file=ann_file, + pipeline=pipeline, + classes=classes, + modality=modality, + box_type_3d=box_type_3d, + filter_empty_gt=filter_empty_gt, + test_mode=test_mode) + + def get_ann_info(self, index): + """Get annotation info according to the given index. + + Args: + index (int): Index of the annotation data to get. + + Returns: + dict: annotation information consists of the following keys: + + - pts_semantic_mask_path (str): Path of semantic masks. + """ + # Use index to get the annos, thus the evalhook could also use this api + info = self.data_infos[index] + + pts_semantic_mask_path = osp.join(self.data_root, + info['pts_semantic_mask_path']) + + anns_results = dict(pts_semantic_mask_path=pts_semantic_mask_path) + return anns_results diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 774491ce98..f0e7ae42e4 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -2,6 +2,7 @@ lyft_dataset_sdk networkx>=2.2,<2.3 # we may unlock the verion of numba in the future numba==0.48.0 +numpy<1.20.0 nuscenes-devkit plyfile scikit-image diff --git a/tests/data/semantickitti/semantickitti_infos.pkl b/tests/data/semantickitti/semantickitti_infos.pkl new file mode 100644 index 0000000000..b32ceb4ed2 Binary files /dev/null and b/tests/data/semantickitti/semantickitti_infos.pkl differ diff --git a/tests/data/semantickitti/sequences/00/labels/000000.label b/tests/data/semantickitti/sequences/00/labels/000000.label new file mode 100644 index 0000000000..f30abe9c27 Binary files /dev/null and b/tests/data/semantickitti/sequences/00/labels/000000.label differ diff --git a/tests/data/semantickitti/sequences/00/velodyne/000000.bin b/tests/data/semantickitti/sequences/00/velodyne/000000.bin new file mode 100644 index 0000000000..cf45a816db Binary files /dev/null and b/tests/data/semantickitti/sequences/00/velodyne/000000.bin differ diff --git a/tests/test_dataset/test_semantickitti_dataset.py b/tests/test_dataset/test_semantickitti_dataset.py new file mode 100644 index 0000000000..a6e31e63d9 --- /dev/null +++ b/tests/test_dataset/test_semantickitti_dataset.py @@ -0,0 +1,52 @@ +import numpy as np + +from mmdet3d.datasets import SemanticKITTIDataset + + +def test_getitem(): + np.random.seed(0) + root_path = './tests/data/semantickitti/' + ann_file = './tests/data/semantickitti/semantickitti_infos.pkl' + class_names = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus', + 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', + 'sidewalk', 'other-ground', 'building', 'fence', + 'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign') + pipelines = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + shift_height=True, + load_dim=4, + use_dim=[0, 1, 2]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_mask_3d=False, + with_seg_3d=True, + seg_3d_dtype=np.int32), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=1.0, + flip_ratio_bev_vertical=1.0), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[1.0, 1.0], + shift_height=True), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'points', + 'pts_semantic_mask', + ], + meta_keys=['file_name', 'sample_idx', 'pcd_rotation']), + ] + + semantickitti_dataset = SemanticKITTIDataset(root_path, ann_file, + pipelines) + data = semantickitti_dataset[0] + assert data['points']._data.shape[0] == data[ + 'pts_semantic_mask']._data.shape[0]