Skip to content

Commit

Permalink
dataset read aligned bbox
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuziyi616 committed May 7, 2021
1 parent 0a7abe7 commit bebbc2b
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 349 deletions.
27 changes: 9 additions & 18 deletions configs/_base_/datasets/scannet-3d-18class.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,16 @@
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_bbox_3d=True,
with_label_3d=True,
with_mask_3d=True,
with_seg_3d=True),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39),
max_cat_id=40),
dict(
type='GlobalAlignment', rotation_axis=2,
ignore_index=len(class_names)),
dict(type='IndoorPointSample', num_points=40000),
dict(
type='RandomFlip3D',
Expand All @@ -54,9 +52,7 @@
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(
type='GlobalAlignment', rotation_axis=2,
ignore_index=len(class_names)),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
Expand Down Expand Up @@ -93,16 +89,11 @@
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=True,
with_seg_3d=True),
dict(type='PointSegClassMapping', valid_cat_ids=valid_class_ids),
dict(
type='GlobalAlignment',
rotation_axis=2,
ignore_index=len(class_names),
extract_bbox=True),
with_bbox_3d=True,
with_label_3d=True,
with_mask_3d=False,
with_seg_3d=False),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
84 changes: 2 additions & 82 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,28 +296,20 @@ def __repr__(self):
@PIPELINES.register_module()
class GlobalAlignment(object):
"""Apply global alignment to 3D scene points by rotation and translation.
Extract 3D bboxes from the aligned points and instance mask if provided.
Args:
rotation_axis (int): Rotation axis for points and bboxes rotation.
ignore_index (int): Label index for which we won't extract bboxes.
extract_bbox (bool): Whether extract new ground-truth bboxes after \
alignment. This requires instance and semantic mask inputs.
Defaults to False.
Note:
This function should be called after PointSegClassMapping in pipeline.
We do not record the applied rotation and translation as in \
GlobalRotScaleTrans. Because usually, we do not need to reverse \
the alignment step.
For example, ScanNet 3D detection task uses aligned ground-truth \
bounding boxes for evaluation.
"""

def __init__(self, rotation_axis, ignore_index, extract_bbox=False):
def __init__(self, rotation_axis):
self.rotation_axis = rotation_axis
self.ignore_index = ignore_index
self.extract_bbox = extract_bbox

def _trans_points(self, input_dict, trans_factor):
"""Private function to translate points.
Expand Down Expand Up @@ -357,74 +349,6 @@ def _check_rot_mat(self, rot_mat):
is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all()
assert is_valid, f'invalid rotation matrix {rot_mat}'

def _bbox_from_points(self, points):
"""Get the bounding box of a set of points.
Args:
points (np.ndarray): A set of points belonging to one instance.
Returns:
np.ndarray: A bounding box of input points. We use origin as \
(0.5, 0.5, 0.5) without yaw.
"""
xmin = np.min(points[:, 0])
ymin = np.min(points[:, 1])
zmin = np.min(points[:, 2])
xmax = np.max(points[:, 0])
ymax = np.max(points[:, 1])
zmax = np.max(points[:, 2])
bbox = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2,
(zmin + zmax) / 2, xmax - xmin, ymax - ymin,
zmax - zmin])
return bbox

def _extract_bboxes(self, input_dict):
"""Extract bounding boxes from points, semantic mask and instance mask.
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after extracting bboxes, keys in \
input_dict['bbox3d_fields'] are updated in the dict.
"""
# TODO: this function is only used in ScanNet-Det pipeline currently
# TODO: we only extract gt_bboxes_3d which is DepthInstance3DBoxes
from mmdet3d.core.bbox import DepthInstance3DBoxes

assert 'pts_instance_mask' in input_dict.keys(), \
'instance mask is not provided in GlobalAlignment'
assert 'pts_semantic_mask' in input_dict.keys(), \
'semantic mask is not provided in GlobalAlignment'

coords = input_dict['points'].coord.numpy()
inst_mask = input_dict['pts_instance_mask']
sem_mask = input_dict['pts_semantic_mask']

# select points from valid categories where we want to extract bboxes
valid_cat_mask = (sem_mask != self.ignore_index)
inst_ids = np.unique(inst_mask[valid_cat_mask]) # ids of valid insts
instance_bboxes = np.zeros((inst_ids.shape[0], 7))
inst_id2cat_id = {
inst_id: sem_mask[inst_mask == inst_id][0]
for inst_id in inst_ids
}
for bbox_idx, inst_id in enumerate(inst_ids):
cat_id = inst_id2cat_id[inst_id]
inst_coords = coords[inst_mask == inst_id]
bbox = self._bbox_from_points(inst_coords)
instance_bboxes[bbox_idx, :6] = bbox
instance_bboxes[bbox_idx, 6] = cat_id

if 'gt_bboxes_3d' not in input_dict['bbox3d_fields']:
input_dict['bbox3d_fields'].append('gt_bboxes_3d')
input_dict['gt_bboxes_3d'] = DepthInstance3DBoxes(
instance_bboxes[:, :6],
box_dim=6,
with_yaw=False,
origin=(0.5, 0.5, 0.5))
input_dict['gt_labels_3d'] = instance_bboxes[:, 6].astype(np.long)

def __call__(self, input_dict):
"""Call function to shuffle points.
Expand All @@ -447,16 +371,12 @@ def __call__(self, input_dict):
self._check_rot_mat(rot_mat)
self._rot_points(input_dict, rot_mat)
self._trans_points(input_dict, trans_vec)
if self.extract_bbox:
self._extract_bboxes(input_dict)

return input_dict

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(rotation_axis={self.rotation_axis},'
repr_str += f' ignore_index={self.ignore_index},'
repr_str += f' extract_bbox={self.extract_bbox})'
repr_str += f'(rotation_axis={self.rotation_axis})'
return repr_str


Expand Down
108 changes: 8 additions & 100 deletions mmdet3d/datasets/scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,89 +154,6 @@ def _get_axis_align_matrix(info):
'use new pre-process scripts to re-generate ScanNet data')
return np.eye(4).astype(np.float32)

def evaluate(self,
results,
metric=None,
iou_thr=(0.25, 0.5),
logger=None,
show=False,
out_dir=None,
pipeline=None):
"""Evaluate.
Evaluation in indoor protocol.
Since ScanNet detection data pipeline re-computes ground-truth boxes,
we can't directly use gt_bboxes from self.data_infos.
Args:
results (list[dict]): List of results.
metric (str | list[str]): Metrics to be evaluated.
iou_thr (list[float]): AP IoU thresholds.
show (bool): Whether to visualize.
Default: False.
out_dir (str): Path to save the visualization results.
Default: None.
pipeline (list[dict], optional): raw data loading for showing.
Default: None.
Returns:
dict: Evaluation results.
"""
from mmdet3d.core.evaluation import indoor_eval
assert isinstance(
results, list), f'Expect results to be list, got {type(results)}.'
assert len(results) > 0, 'Expect length of results > 0.'
assert len(results) == len(self.data_infos)
assert isinstance(
results[0], dict
), f'Expect elements in results to be dict, got {type(results[0])}.'
# load gt_bboxes via pipeline
pipeline = self._get_pipeline(pipeline)
gt_bboxes = [
self._extract_data(
i, pipeline, ['gt_bboxes_3d', 'gt_labels_3d'], load_annos=True)
for i in range(len(self.data_infos))
]
gt_annos = [self._build_annos(*gt_bbox) for gt_bbox in gt_bboxes]
label2cat = {i: cat_id for i, cat_id in enumerate(self.CLASSES)}
ret_dict = indoor_eval(
gt_annos,
results,
iou_thr,
label2cat,
logger=logger,
box_type_3d=self.box_type_3d,
box_mode_3d=self.box_mode_3d)
if show:
self.show(results, out_dir, pipeline=pipeline)

return ret_dict

@staticmethod
def _build_annos(gt_bboxes, gt_labels):
"""Transform gt bboxes and labels into self.data_infos['annos'] format.
Args:
gt_bboxes (:obj:`BaseInstance3DBoxes`): \
3D bounding boxes in Depth coordinate
gt_labels (torch.Tensor): Labels of boxes.
Returns:
dict: annotations including the following keys
- gt_boxes_upright_depth (np.ndarray): 3D bounding boxes.
- class (np.ndarray): Labels of boxes.
- gt_num (int): Number of boxes.
"""
bbox = gt_bboxes.tensor.numpy()[:, :6].copy() # drop yaw dimension
bbox[..., 2] += bbox[..., 5] / 2 # bottom center to gravity center
anno = {
'gt_boxes_upright_depth': bbox,
'class': gt_labels.numpy(),
'gt_num': gt_labels.shape[0]
}
return anno

def _build_default_pipeline(self):
"""Build the default pipeline for this dataset."""
pipeline = [
Expand All @@ -248,19 +165,11 @@ def _build_default_pipeline(self):
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=True,
with_seg_3d=True),
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28,
33, 34, 36, 39)),
dict(
type='GlobalAlignment',
rotation_axis=2,
ignore_index=len(self.CLASSES),
extract_bbox=True),
with_bbox_3d=True,
with_label_3d=True,
with_mask_3d=False,
with_seg_3d=False),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=self.CLASSES,
Expand All @@ -287,10 +196,9 @@ def show(self, results, out_dir, show=True, pipeline=None):
data_info = self.data_infos[i]
pts_path = data_info['pts_path']
file_name = osp.split(pts_path)[-1].split('.')[0]
points, gt_bboxes = self._extract_data(
i, pipeline, ['points', 'gt_bboxes_3d'], load_annos=True)
points = points.numpy()
gt_bboxes = gt_bboxes.tensor.numpy()
points = self._extract_data(
i, pipeline, 'points', load_annos=True).numpy()
gt_bboxes = self.get_ann_info(i)['gt_bboxes_3d'].tensor.numpy()
pred_bboxes = result['boxes_3d'].tensor.numpy()
show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name,
show)
Expand Down
Binary file modified tests/data/scannet/scannet_infos.pkl
Binary file not shown.
Loading

0 comments on commit bebbc2b

Please sign in to comment.