Skip to content

Commit

Permalink
[Enhance]: Add typehints for dataset transforms and fix potential bug…
Browse files Browse the repository at this point in the history
… for `PointSample` (#1875)

* update dataset transforms

* update dbsampler docstring and add typehints

* add type hints and fix potential point sample bug

* fix lint

* fix

* fix
  • Loading branch information
Xiangxu-0103 authored Oct 8, 2022
1 parent d8c9bc6 commit 5412046
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 118 deletions.
2 changes: 1 addition & 1 deletion mmdet3d/datasets/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __call__(self, data):
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
dict: Transformed data.
"""

for t in self.transforms:
Expand Down
86 changes: 46 additions & 40 deletions mmdet3d/datasets/transforms/dbsampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings
from typing import List, Optional

import mmengine
import numpy as np
Expand All @@ -16,18 +16,19 @@ class BatchSampler:
Args:
sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Default: None.
epoch (int, optional): Sampling epoch. Default: None.
shuffle (bool, optional): Whether to shuffle indices. Default: False.
drop_reminder (bool, optional): Drop reminder. Default: False.
name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices.
Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
"""

def __init__(self,
sampled_list,
name=None,
epoch=None,
shuffle=True,
drop_reminder=False):
sampled_list: List[dict],
name: Optional[str] = None,
epoch: Optional[int] = None,
shuffle: bool = True,
drop_reminder: bool = False) -> None:
self._sampled_list = sampled_list
self._indices = np.arange(len(sampled_list))
if shuffle:
Expand All @@ -40,7 +41,7 @@ def __init__(self,
self._epoch_counter = 0
self._drop_reminder = drop_reminder

def _sample(self, num):
def _sample(self, num: int) -> List[int]:
"""Sample specific number of ground truths and return indices.
Args:
Expand All @@ -57,15 +58,15 @@ def _sample(self, num):
self._idx += num
return ret

def _reset(self):
def _reset(self) -> None:
"""Reset the index of batchsampler to zero."""
assert self._name is not None
# print("reset", self._name)
if self._shuffle:
np.random.shuffle(self._indices)
self._idx = 0

def sample(self, num):
def sample(self, num: int) -> List[dict]:
"""Sample specific number of ground truths.
Args:
Expand All @@ -88,24 +89,28 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Default: None.
points_loader(dict, optional): Config of points loader. Default:
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict, optional): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""

def __init__(self,
info_path,
data_root,
rate,
prepare,
sample_groups,
classes=None,
points_loader=dict(
info_path: str,
data_root: str,
rate: float,
prepare: dict,
sample_groups: dict,
classes: Optional[List[str]] = None,
points_loader: dict = dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=[0, 1, 2, 3]),
file_client_args=dict(backend='disk')):
file_client_args: dict = dict(backend='disk')) -> None:
super().__init__()
self.data_root = data_root
self.info_path = info_path
Expand All @@ -118,18 +123,9 @@ def __init__(self,
self.file_client = mmengine.FileClient(**file_client_args)

# load data base infos
if hasattr(self.file_client, 'get_local_path'):
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmengine.load(
open(local_path, 'rb'), file_format='pkl')
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {info_path} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
db_infos = mmengine.load(info_path)
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmengine.load(open(local_path, 'rb'), file_format='pkl')

# filter database infos
from mmengine.logging import MMLogger
Expand Down Expand Up @@ -163,7 +159,7 @@ def __init__(self,
# TODO: No group_sampling currently

@staticmethod
def filter_by_difficulty(db_infos, removed_difficulty):
def filter_by_difficulty(db_infos: dict, removed_difficulty: list) -> dict:
"""Filter ground truths by difficulties.
Args:
Expand All @@ -182,7 +178,7 @@ def filter_by_difficulty(db_infos, removed_difficulty):
return new_db_infos

@staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict):
def filter_by_min_points(db_infos: dict, min_gt_points_dict: dict) -> dict:
"""Filter ground truths by number of points in the bbox.
Args:
Expand All @@ -203,12 +199,19 @@ def filter_by_min_points(db_infos, min_gt_points_dict):
db_infos[name] = filtered_infos
return db_infos

def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):
def sample_all(self,
gt_bboxes: np.ndarray,
gt_labels: np.ndarray,
img: Optional[np.ndarray] = None,
ground_plane: Optional[np.ndarray] = None) -> dict:
"""Sampling all categories of bboxes.
Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Ground truth labels of boxes.
img (np.ndarray, optional): Image array. Defaults to None.
ground_plane (np.ndarray, optional): Ground plane information.
Defaults to None.
Returns:
dict: Dict of sampled 'pseudo ground truths'.
Expand Down Expand Up @@ -301,7 +304,10 @@ def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):

return ret

def sample_class_v2(self, name, num, gt_bboxes):
def sample_class_v2(self,
name: str,
num: int,
gt_bboxes: np.ndarray) -> List[dict]:
"""Sampling specific categories of bounding boxes.
Args:
Expand Down
27 changes: 14 additions & 13 deletions mmdet3d/datasets/transforms/formating.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ class Pack3DDetInputs(BaseTransform):

def __init__(
self,
keys: dict,
meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'num_pts_feats', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')):
keys: tuple,
meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape',
'scale_factor', 'flip', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')) -> None:
self.keys = keys
self.meta_keys = meta_keys

Expand All @@ -99,7 +99,7 @@ def transform(self, results: Union[dict,
- img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info of
the sample.
the sample.
"""
# augtest
if isinstance(results, list):
Expand All @@ -116,7 +116,7 @@ def transform(self, results: Union[dict,
else:
raise NotImplementedError

def pack_single_results(self, results):
def pack_single_results(self, results: dict) -> dict:
"""Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing.
Expand All @@ -132,7 +132,7 @@ def pack_single_results(self, results):
- points
- img
- 'data_samples' (obj:`Det3DDataSample`): The annotation info
- 'data_samples' (:obj:`Det3DDataSample`): The annotation info
of the sample.
"""
# Format 3D data
Expand Down Expand Up @@ -220,6 +220,7 @@ def pack_single_results(self, results):
return packed_results

def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys})'
repr_str += f'(meta_keys={self.meta_keys})'
Expand Down
Loading

0 comments on commit 5412046

Please sign in to comment.