Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] Move ScanNet point alignment from data pre-processing to pipeline #439

Merged
merged 12 commits into from
May 11, 2021
3 changes: 3 additions & 0 deletions configs/_base_/datasets/scannet-3d-18class.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
with_label_3d=True,
with_mask_3d=True,
with_seg_3d=True),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
Expand Down Expand Up @@ -49,6 +50,7 @@
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
Expand Down Expand Up @@ -82,6 +84,7 @@
shift_height=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='GlobalAlignment', rotation_axis=2),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
22 changes: 15 additions & 7 deletions data/scannet/batch_load_scannet_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def export_one_scan(scan_name,
scan_name + '_vh_clean_2.0.010000.segs.json')
# includes axisAlignment info for the train set scans.
meta_file = osp.join(scannet_dir, scan_name, f'{scan_name}.txt')
mesh_vertices, semantic_labels, instance_labels, instance_bboxes, \
instance2semantic = export(mesh_file, agg_file, seg_file,
meta_file, label_map_file, None, test_mode)
mesh_vertices, semantic_labels, instance_labels, unaligned_bboxes, \
aligned_bboxes, instance2semantic, axis_align_matrix = export(
mesh_file, agg_file, seg_file, meta_file, label_map_file, None,
test_mode)

if not test_mode:
mask = np.logical_not(np.in1d(semantic_labels, DONOTCARE_CLASS_IDS))
Expand All @@ -47,9 +48,12 @@ def export_one_scan(scan_name,
num_instances = len(np.unique(instance_labels))
print(f'Num of instances: {num_instances}')

bbox_mask = np.in1d(instance_bboxes[:, -1], OBJ_CLASS_IDS)
instance_bboxes = instance_bboxes[bbox_mask, :]
print(f'Num of care instances: {instance_bboxes.shape[0]}')
bbox_mask = np.in1d(unaligned_bboxes[:, -1], OBJ_CLASS_IDS)
unaligned_bboxes = unaligned_bboxes[bbox_mask, :]
bbox_mask = np.in1d(aligned_bboxes[:, -1], OBJ_CLASS_IDS)
aligned_bboxes = aligned_bboxes[bbox_mask, :]
assert unaligned_bboxes.shape[0] == aligned_bboxes.shape[0]
print(f'Num of care instances: {unaligned_bboxes.shape[0]}')

if max_num_point is not None:
max_num_point = int(max_num_point)
Expand All @@ -65,7 +69,11 @@ def export_one_scan(scan_name,
if not test_mode:
np.save(f'{output_filename_prefix}_sem_label.npy', semantic_labels)
np.save(f'{output_filename_prefix}_ins_label.npy', instance_labels)
np.save(f'{output_filename_prefix}_bbox.npy', instance_bboxes)
np.save(f'{output_filename_prefix}_unaligned_bbox.npy',
unaligned_bboxes)
np.save(f'{output_filename_prefix}_aligned_bbox.npy', aligned_bboxes)
np.save(f'{output_filename_prefix}_axis_align_matrix.npy',
axis_align_matrix)


def batch_export(max_num_point,
Expand Down
66 changes: 40 additions & 26 deletions data/scannet/load_scannet_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,29 @@ def read_segmentation(filename):
return seg_to_verts, num_verts


def extract_bbox(mesh_vertices, object_id_to_segs, object_id_to_label_id,
instance_ids):
num_instances = len(np.unique(list(object_id_to_segs.keys())))
instance_bboxes = np.zeros((num_instances, 7))
for obj_id in object_id_to_segs:
label_id = object_id_to_label_id[obj_id]
obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
if len(obj_pc) == 0:
continue
xmin = np.min(obj_pc[:, 0])
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
ymin = np.min(obj_pc[:, 1])
zmin = np.min(obj_pc[:, 2])
xmax = np.max(obj_pc[:, 0])
ymax = np.max(obj_pc[:, 1])
zmax = np.max(obj_pc[:, 2])
bbox = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2,
(zmin + zmax) / 2, xmax - xmin, ymax - ymin,
zmax - zmin, label_id])
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id - 1, :] = bbox
return instance_bboxes


def export(mesh_file,
agg_file,
seg_file,
Expand All @@ -69,7 +92,7 @@ def export(mesh_file,
label_map_file (str): Path of the label_map_file.
output_file (str): Path of the output folder.
Default: None.
test_mode (bool): Whether is generating training data without labels.
test_mode (bool): Whether is generating test data without labels.
Default: False.

It returns a tuple, which containts the the following things:
Expand All @@ -86,8 +109,7 @@ def export(mesh_file,

# Load scene axis alignment matrix
lines = open(meta_file).readlines()
# TODO: test set data doesn't have align_matrix!
# TODO: save align_matrix and move align step to pipeline in the future
# test set data doesn't have align_matrix
axis_align_matrix = np.eye(4)
for line in lines:
if 'axisAlignment' in line:
Expand All @@ -97,10 +119,13 @@ def export(mesh_file,
]
break
axis_align_matrix = np.array(axis_align_matrix).reshape((4, 4))

# perform global alignment of mesh vertices
pts = np.ones((mesh_vertices.shape[0], 4))
pts[:, 0:3] = mesh_vertices[:, 0:3]
pts = np.dot(pts, axis_align_matrix.transpose()) # Nx4
mesh_vertices[:, 0:3] = pts[:, 0:3]
aligned_mesh_vertices = np.concatenate([pts[:, 0:3], mesh_vertices[:, 3:]],
axis=1)

# Load semantic and instance labels
if not test_mode:
Expand All @@ -115,45 +140,34 @@ def export(mesh_file,
label_ids[verts] = label_id
instance_ids = np.zeros(
shape=(num_verts), dtype=np.uint32) # 0: unannotated
num_instances = len(np.unique(list(object_id_to_segs.keys())))
for object_id, segs in object_id_to_segs.items():
for seg in segs:
verts = seg_to_verts[seg]
instance_ids[verts] = object_id
if object_id not in object_id_to_label_id:
object_id_to_label_id[object_id] = label_ids[verts][0]
instance_bboxes = np.zeros((num_instances, 7))
for obj_id in object_id_to_segs:
label_id = object_id_to_label_id[obj_id]
obj_pc = mesh_vertices[instance_ids == obj_id, 0:3]
if len(obj_pc) == 0:
continue
xmin = np.min(obj_pc[:, 0])
ymin = np.min(obj_pc[:, 1])
zmin = np.min(obj_pc[:, 2])
xmax = np.max(obj_pc[:, 0])
ymax = np.max(obj_pc[:, 1])
zmax = np.max(obj_pc[:, 2])
bbox = np.array([(xmin + xmax) / 2, (ymin + ymax) / 2,
(zmin + zmax) / 2, xmax - xmin, ymax - ymin,
zmax - zmin, label_id])
# NOTE: this assumes obj_id is in 1,2,3,.,,,.NUM_INSTANCES
instance_bboxes[obj_id - 1, :] = bbox
unaligned_bboxes = extract_bbox(mesh_vertices, object_id_to_segs,
object_id_to_label_id, instance_ids)
aligned_bboxes = extract_bbox(aligned_mesh_vertices, object_id_to_segs,
object_id_to_label_id, instance_ids)
else:
label_ids = None
instance_ids = None
instance_bboxes = None
unaligned_bboxes = None
aligned_bboxes = None
object_id_to_label_id = None

if output_file is not None:
np.save(output_file + '_vert.npy', mesh_vertices)
if not test_mode:
np.save(output_file + '_sem_label.npy', label_ids)
np.save(output_file + '_ins_label.npy', instance_ids)
np.save(output_file + '_bbox.npy', instance_bboxes)
np.save(output_file + '_unaligned_bbox.npy', unaligned_bboxes)
np.save(output_file + '_aligned_bbox.npy', aligned_bboxes)
np.save(output_file + '_axis_align_matrix.npy', axis_align_matrix)

return mesh_vertices, label_ids, instance_ids, \
instance_bboxes, object_id_to_label_id
return mesh_vertices, label_ids, instance_ids, unaligned_bboxes, \
aligned_bboxes, object_id_to_label_id, axis_align_matrix


def main():
Expand Down
13 changes: 8 additions & 5 deletions mmdet3d/core/bbox/structures/base_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,15 @@ def corners(self):
pass

@abstractmethod
def rotate(self, angles, axis=0):
"""Calculate whether the points are in any of the boxes.
def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.

Args:
angles (float): Rotation angles.
axis (int): The axis to rotate the boxes.
angle (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.
"""
pass

Expand All @@ -144,7 +147,7 @@ def flip(self, bev_direction='horizontal'):
pass

def translate(self, trans_vector):
"""Calculate whether the points are in any of the boxes.
"""Translate boxes with the given translation vector.

Args:
trans_vector (torch.Tensor): Translation vector of size 1x3.
Expand Down
24 changes: 18 additions & 6 deletions mmdet3d/core/bbox/structures/cam_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,12 @@ def nearest_bev(self):
return bev_boxes

def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle.
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.

Args:
angle (float, torch.Tensor): Rotation angle.
angle (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.

Expand All @@ -183,10 +185,20 @@ def rotate(self, angle, points=None):
"""
if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle)
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, 0, -rot_sin], [0, 1, 0],
[rot_sin, 0, rot_cos]])
assert angle.shape == torch.Size([3, 3]) or angle.numel() == 1, \
f'invalid rotation angle shape {angle.shape}'

if angle.numel() == 1:
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, 0, -rot_sin],
[0, 1, 0],
[rot_sin, 0, rot_cos]])
else:
rot_mat_T = angle
rot_sin = rot_mat_T[2, 0]
rot_cos = rot_mat_T[0, 0]
angle = np.arctan2(rot_sin, rot_cos)

self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
self.tensor[:, 6] += angle
Expand Down
26 changes: 19 additions & 7 deletions mmdet3d/core/bbox/structures/depth_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def nearest_bev(self):
return bev_boxes

def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle.
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.

Args:
angle (float, torch.Tensor): Rotation angle.
angle (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.

Expand All @@ -130,11 +132,21 @@ def rotate(self, angle, points=None):
"""
if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle)
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0], [0, 0,
1]]).T
assert angle.shape == torch.Size([3, 3]) or angle.numel() == 1, \
f'invalid rotation angle shape {angle.shape}'

if angle.numel() == 1:
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0],
[0, 0, 1]]).T
else:
rot_mat_T = angle.T
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
rot_sin = rot_mat_T[0, 1]
rot_cos = rot_mat_T[0, 0]
angle = np.arctan2(rot_sin, rot_cos)

self.tensor[:, 0:3] = self.tensor[:, 0:3] @ rot_mat_T
if self.with_yaw:
self.tensor[:, 6] -= angle
Expand Down
24 changes: 18 additions & 6 deletions mmdet3d/core/bbox/structures/lidar_box3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,12 @@ def nearest_bev(self):
return bev_boxes

def rotate(self, angle, points=None):
"""Rotate boxes with points (optional) with the given angle.
"""Rotate boxes with points (optional) with the given angle or \
rotation matrix.

Args:
angle (float | torch.Tensor): Rotation angle.
angles (float | torch.Tensor | np.ndarray):
Rotation angle or rotation matrix.
points (torch.Tensor, numpy.ndarray, :obj:`BasePoints`, optional):
Points to rotate. Defaults to None.

Expand All @@ -128,10 +130,20 @@ def rotate(self, angle, points=None):
"""
if not isinstance(angle, torch.Tensor):
angle = self.tensor.new_tensor(angle)
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0], [0, 0, 1]])
assert angle.shape == torch.Size([3, 3]) or angle.numel() == 1, \
f'invalid rotation angle shape {angle.shape}'

if angle.numel() == 1:
rot_sin = torch.sin(angle)
rot_cos = torch.cos(angle)
rot_mat_T = self.tensor.new_tensor([[rot_cos, -rot_sin, 0],
[rot_sin, rot_cos, 0],
[0, 0, 1]])
else:
rot_mat_T = angle
rot_sin = rot_mat_T[1, 0]
rot_cos = rot_mat_T[0, 0]
angle = np.arctan2(rot_sin, rot_cos)

self.tensor[:, :3] = self.tensor[:, :3] @ rot_mat_T
self.tensor[:, 6] += angle
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/core/points/base_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def rotate(self, rotation, axis=None):
if not isinstance(rotation, torch.Tensor):
rotation = self.tensor.new_tensor(rotation)
assert rotation.shape == torch.Size([3, 3]) or \
rotation.numel() == 1
rotation.numel() == 1, f'invalid rotation shape {rotation.shape}'

if axis is None:
axis = self.rotation_axis
Expand Down
15 changes: 8 additions & 7 deletions mmdet3d/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset
from .nuscenes_mono_dataset import NuScenesMonoDataset
from .pipelines import (BackgroundPointsFilter, GlobalRotScaleTrans,
from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNoise, ObjectRangeFilter,
Expand All @@ -26,10 +27,10 @@
'DATASETS', 'build_dataset', 'CocoDataset', 'NuScenesDataset',
'NuScenesMonoDataset', 'LyftDataset', 'ObjectSample', 'RandomFlip3D',
'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter',
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile',
'NormalizePointsColor', 'IndoorPointSample', 'LoadAnnotations3D',
'SUNRGBDDataset', 'ScanNetDataset', 'ScanNetSegDataset', 'S3DISSegDataset',
'SemanticKITTIDataset', 'Custom3DDataset', 'Custom3DSegDataset',
'LoadPointsFromMultiSweeps', 'WaymoDataset', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'get_loading_pipeline'
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'S3DISSegDataset',
'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline'
]
12 changes: 6 additions & 6 deletions mmdet3d/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
LoadPointsFromMultiSweeps, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
from .transforms_3d import (BackgroundPointsFilter, GlobalRotScaleTrans,
IndoorPatchPointSample, IndoorPointSample,
ObjectNoise, ObjectRangeFilter, ObjectSample,
PointShuffle, PointsRangeFilter, RandomFlip3D,
VoxelBasedPointSampler)
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNoise, ObjectRangeFilter,
ObjectSample, PointShuffle, PointsRangeFilter,
RandomFlip3D, VoxelBasedPointSampler)

__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
Expand All @@ -19,6 +19,6 @@
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D'
]
Loading