Skip to content

Commit

Permalink
[Feature] SemanticKITTI Dataset (#287)
Browse files Browse the repository at this point in the history
* add ini

* add semantickitti_dataset

* add test semantickitti_dataset

* delete last line in test_semmaticdataset

* add test data

* change_names

* load_labels dytype

* change_name

* numpy

* name

* dtype string

* minor issue-string

* seg_3d_dtype
  • Loading branch information
junhaozhang98 authored Feb 1, 2021
1 parent c556d27 commit b050172
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 4 deletions.
7 changes: 4 additions & 3 deletions mmdet3d/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)
from .scannet_dataset import ScanNetDataset
from .semantickitti_dataset import SemanticKITTIDataset
from .sunrgbd_dataset import SUNRGBDDataset
from .waymo_dataset import WaymoDataset

Expand All @@ -21,7 +22,7 @@
'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle',
'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'LoadPointsFromFile', 'NormalizePointsColor', 'IndoorPointSample',
'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset', 'Custom3DDataset',
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler'
'LoadAnnotations3D', 'SUNRGBDDataset', 'ScanNetDataset',
'SemanticKITTIDataset', 'Custom3DDataset', 'LoadPointsFromMultiSweeps',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler'
]
7 changes: 6 additions & 1 deletion mmdet3d/datasets/pipelines/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,8 @@ class LoadAnnotations3D(LoadAnnotations):
Defaults to False.
poly2mask (bool, optional): Whether to convert polygon annotations
to bitmasks. Defaults to True.
seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks.
Defaults to int64
file_client_args (dict): Config dict of file clients, refer to
https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py
for more details.
Expand All @@ -434,6 +436,7 @@ def __init__(self,
with_mask=False,
with_seg=False,
poly2mask=True,
seg_3d_dtype='int',
file_client_args=dict(backend='disk')):
super().__init__(
with_bbox,
Expand All @@ -446,6 +449,7 @@ def __init__(self,
self.with_label_3d = with_label_3d
self.with_mask_3d = with_mask_3d
self.with_seg_3d = with_seg_3d
self.seg_3d_dtype = seg_3d_dtype

def _load_bboxes_3d(self, results):
"""Private function to load 3D bounding box annotations.
Expand Down Expand Up @@ -513,7 +517,8 @@ def _load_semantic_seg_3d(self, results):
try:
mask_bytes = self.file_client.get(pts_semantic_mask_path)
# add .copy() to fix read-only bug
pts_semantic_mask = np.frombuffer(mask_bytes, dtype=np.int).copy()
pts_semantic_mask = np.frombuffer(
mask_bytes, dtype=self.seg_3d_dtype).copy()
except ConnectionError:
mmcv.check_file_exist(pts_semantic_mask_path)
pts_semantic_mask = np.fromfile(
Expand Down
80 changes: 80 additions & 0 deletions mmdet3d/datasets/semantickitti_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from os import path as osp

from mmdet.datasets import DATASETS
from .custom_3d import Custom3DDataset


@DATASETS.register_module()
class SemanticKITTIDataset(Custom3DDataset):
r"""SemanticKITTI Dataset.
This class serves as the API for experiments on the SemanticKITTI Dataset
Please refer to <http://www.semantic-kitti.org/dataset.html>`_
for data downloading
Args:
data_root (str): Path of dataset root.
ann_file (str): Path of annotation file.
pipeline (list[dict], optional): Pipeline used for data processing.
Defaults to None.
classes (tuple[str], optional): Classes used in the dataset.
Defaults to None.
modality (dict, optional): Modality to specify the sensor data used
as input. Defaults to None.
box_type_3d (str, optional): NO 3D box for this dataset.
You can choose any type
Based on the `box_type_3d`, the dataset will encapsulate the box
to its original format then converted them to `box_type_3d`.
Defaults to 'LiDAR' in this dataset. Available options includes
- 'LiDAR': Box in LiDAR coordinates.
- 'Depth': Box in depth coordinates, usually for indoor dataset.
- 'Camera': Box in camera coordinates.
filter_empty_gt (bool, optional): Whether to filter empty GT.
Defaults to True.
test_mode (bool, optional): Whether the dataset is in test mode.
Defaults to False.
"""
CLASSES = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence', 'vegetation',
'trunck', 'terrian', 'pole', 'traffic-sign')

def __init__(self,
data_root,
ann_file,
pipeline=None,
classes=None,
modality=None,
box_type_3d='Lidar',
filter_empty_gt=False,
test_mode=False):
super().__init__(
data_root=data_root,
ann_file=ann_file,
pipeline=pipeline,
classes=classes,
modality=modality,
box_type_3d=box_type_3d,
filter_empty_gt=filter_empty_gt,
test_mode=test_mode)

def get_ann_info(self, index):
"""Get annotation info according to the given index.
Args:
index (int): Index of the annotation data to get.
Returns:
dict: annotation information consists of the following keys:
- pts_semantic_mask_path (str): Path of semantic masks.
"""
# Use index to get the annos, thus the evalhook could also use this api
info = self.data_infos[index]

pts_semantic_mask_path = osp.join(self.data_root,
info['pts_semantic_mask_path'])

anns_results = dict(pts_semantic_mask_path=pts_semantic_mask_path)
return anns_results
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ lyft_dataset_sdk
networkx>=2.2,<2.3
# we may unlock the verion of numba in the future
numba==0.48.0
numpy<1.20.0
nuscenes-devkit
plyfile
scikit-image
Expand Down
Binary file added tests/data/semantickitti/semantickitti_infos.pkl
Binary file not shown.
Binary file not shown.
Binary file not shown.
52 changes: 52 additions & 0 deletions tests/test_dataset/test_semantickitti_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np

from mmdet3d.datasets import SemanticKITTIDataset


def test_getitem():
np.random.seed(0)
root_path = './tests/data/semantickitti/'
ann_file = './tests/data/semantickitti/semantickitti_infos.pkl'
class_names = ('unlabeled', 'car', 'bicycle', 'motorcycle', 'truck', 'bus',
'person', 'bicyclist', 'motorcyclist', 'road', 'parking',
'sidewalk', 'other-ground', 'building', 'fence',
'vegetation', 'trunck', 'terrian', 'pole', 'traffic-sign')
pipelines = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
shift_height=True,
load_dim=4,
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
seg_3d_dtype=np.int32),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.0,
flip_ratio_bev_vertical=1.0),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.087266, 0.087266],
scale_ratio_range=[1.0, 1.0],
shift_height=True),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
keys=[
'points',
'pts_semantic_mask',
],
meta_keys=['file_name', 'sample_idx', 'pcd_rotation']),
]

semantickitti_dataset = SemanticKITTIDataset(root_path, ann_file,
pipelines)
data = semantickitti_dataset[0]
assert data['points']._data.shape[0] == data[
'pts_semantic_mask']._data.shape[0]

0 comments on commit b050172

Please sign in to comment.