diff --git a/configs/_base_/models/imvotenet_image.py b/configs/_base_/models/imvotenet_image.py new file mode 100644 index 0000000000..8ddfa8c877 --- /dev/null +++ b/configs/_base_/models/imvotenet_image.py @@ -0,0 +1,108 @@ +model = dict( + type='ImVoteNet', + img_backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe'), + img_neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + img_rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + scales=[8], + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0)), + img_roi_head=dict( + type='StandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32]), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=10, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + loss_cls=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=1.0))), + + # model training and testing settings + train_cfg=dict( + img_rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + img_rpn_proposal=dict( + nms_across_levels=False, + nms_pre=2000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + img_rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + test_cfg=dict( + img_rpn=dict( + nms_across_levels=False, + nms_pre=1000, + nms_post=1000, + max_num=1000, + nms_thr=0.7, + min_bbox_size=0), + img_rcnn=dict( + score_thr=0.05, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100))) diff --git a/configs/imvotenet/imvotenet_faster_rcnn_r50_fpn_2x4_sunrgbd-3d-10class.py b/configs/imvotenet/imvotenet_faster_rcnn_r50_fpn_2x4_sunrgbd-3d-10class.py new file mode 100644 index 0000000000..6743b1f86f --- /dev/null +++ b/configs/imvotenet/imvotenet_faster_rcnn_r50_fpn_2x4_sunrgbd-3d-10class.py @@ -0,0 +1,58 @@ +_base_ = [ + '../_base_/datasets/sunrgbd-3d-10class.py', '../_base_/default_runtime.py', + '../_base_/models/imvotenet_image.py' +] + +# use caffe img_norm +img_norm_cfg = dict( + mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='Resize', + img_scale=[(1333, 480), (1333, 504), (1333, 528), (1333, 552), + (1333, 576), (1333, 600)], + multiscale_mode='value', + keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.5), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(1333, 600), + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict(times=1, dataset=dict(pipeline=train_pipeline)), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) + +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001) +optimizer_config = dict(grad_clip=None) +lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=0.001, + step=[6]) +total_epochs = 8 + +load_from = 'http://download.openmmlab.com/mmdetection/v2.0/mask_rcnn/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco/mask_rcnn_r50_caffe_fpn_mstrain-poly_3x_coco_bbox_mAP-0.408__segm_mAP-0.37_20200504_163245-42aa3d00.pth' # noqa diff --git a/configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py b/configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py new file mode 100644 index 0000000000..409e09a56b --- /dev/null +++ b/configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py @@ -0,0 +1,242 @@ +_base_ = [ + '../_base_/datasets/sunrgbd-3d-10class.py', + '../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py', + '../_base_/models/imvotenet_image.py' +] + +class_names = ('bed', 'table', 'sofa', 'chair', 'toilet', 'desk', 'dresser', + 'night_stand', 'bookshelf', 'bathtub') + +# use caffe img_norm +img_norm_cfg = dict( + mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) + +model = dict( + pts_backbone=dict( + type='PointNet2SASSG', + in_channels=4, + num_points=(2048, 1024, 512, 256), + radius=(0.2, 0.4, 0.8, 1.2), + num_samples=(64, 32, 16, 16), + sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), + (128, 128, 256)), + fp_channels=((256, 256), (256, 256)), + norm_cfg=dict(type='BN2d'), + sa_cfg=dict( + type='PointSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=True)), + pts_bbox_heads=dict( + common=dict( + type='VoteHead', + num_classes=10, + bbox_coder=dict( + type='PartialBinBasedBBoxCoder', + num_sizes=10, + num_dir_bins=12, + with_rot=True, + mean_sizes=[[2.114256, 1.620300, 0.927272], + [0.791118, 1.279516, 0.718182], + [0.923508, 1.867419, 0.845495], + [0.591958, 0.552978, 0.827272], + [0.699104, 0.454178, 0.75625], + [0.69519, 1.346299, 0.736364], + [0.528526, 1.002642, 1.172878], + [0.500618, 0.632163, 0.683424], + [0.404671, 1.071108, 1.688889], + [0.76584, 1.398258, 0.472728]]), + pred_layer_cfg=dict( + in_channels=128, shared_conv_channels=(128, 128), bias=True), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.2, 0.8], + reduction='sum', + loss_weight=5.0), + center_loss=dict( + type='ChamferDistance', + mode='l2', + reduction='sum', + loss_src_weight=10.0, + loss_dst_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0 / 3.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + joint=dict( + vote_module_cfg=dict( + in_channels=512, + vote_per_seed=1, + gt_per_seed=3, + conv_channels=(512, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + type='PointSAModule', + num_point=256, + radius=0.3, + num_sample=16, + mlp_channels=[512, 128, 128, 128], + use_xyz=True, + normalize_xyz=True)), + pts=dict( + vote_module_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=3, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + type='PointSAModule', + num_point=256, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True)), + img=dict( + vote_module_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=3, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + type='PointSAModule', + num_point=256, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True)), + loss_weights=[0.4, 0.3, 0.3]), + img_mlp=dict( + in_channel=18, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU')), + fusion_layer=dict( + type='VoteFusion', + num_classes=len(class_names), + max_imvote_per_pixel=3), + num_sampled_seed=1024, + freeze_img_branch=True, + + # model training and testing settings + train_cfg=dict( + pts=dict( + pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote')), + test_cfg=dict( + img_rcnn=dict(score_thr=0.1), + pts=dict( + sample_mod='seed', + nms_thr=0.25, + score_thr=0.05, + per_class_proposal=True))) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=True, + load_dim=6, + use_dim=[0, 1, 2]), + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations3D'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='Resize', img_scale=(1333, 600), keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.0), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + ), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.523599, 0.523599], + scale_ratio_range=[0.85, 1.15], + shift_height=True), + dict(type='IndoorPointSample', num_points=20000), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'img', 'gt_bboxes', 'gt_labels', 'points', 'gt_bboxes_3d', + 'gt_labels_3d', 'calib' + ]) +] + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + shift_height=True, + load_dim=6, + use_dim=[0, 1, 2]), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 600), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip', flip_ratio=0.0), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + ), + dict(type='IndoorPointSample', num_points=20000), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['img', 'points', 'calib']) + ]), +] + +data = dict( + train=dict(dataset=dict(pipeline=train_pipeline)), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) + +load_from = None # TODO after we update model zoo diff --git a/mmdet3d/core/bbox/structures/coord_3d_mode.py b/mmdet3d/core/bbox/structures/coord_3d_mode.py index 06d00e8206..c107a75da6 100644 --- a/mmdet3d/core/bbox/structures/coord_3d_mode.py +++ b/mmdet3d/core/bbox/structures/coord_3d_mode.py @@ -182,7 +182,7 @@ def convert_point(point, src, dst, rt_mat=None): """Convert points from `src` mode to `dst` mode. Args: - box (tuple | list | np.dnarray | + point (tuple | list | np.dnarray | torch.Tensor | BasePoints): Can be a k-tuple, k-list or an Nxk array/tensor. src (:obj:`CoordMode`): The src Point mode. @@ -218,17 +218,25 @@ def convert_point(point, src, dst, rt_mat=None): arr = point.clone() # convert point from `src` mode to `dst` mode. - if rt_mat is not None: - if not isinstance(rt_mat, torch.Tensor): - rt_mat = arr.new_tensor(rt_mat) + # TODO: LIDAR + # only implemented provided Rt matrix in cam-depth conversion if src == Coord3DMode.LIDAR and dst == Coord3DMode.CAM: rt_mat = arr.new_tensor([[0, -1, 0], [0, 0, -1], [1, 0, 0]]) elif src == Coord3DMode.CAM and dst == Coord3DMode.LIDAR: rt_mat = arr.new_tensor([[0, 0, 1], [-1, 0, 0], [0, -1, 0]]) elif src == Coord3DMode.DEPTH and dst == Coord3DMode.CAM: - rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]]) + if rt_mat is None: + rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + else: + rt_mat = rt_mat.new_tensor( + [[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ \ + rt_mat.transpose(1, 0) elif src == Coord3DMode.CAM and dst == Coord3DMode.DEPTH: - rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) + if rt_mat is None: + rt_mat = arr.new_tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]]) + else: + rt_mat = rt_mat @ rt_mat.new_tensor([[1, 0, 0], [0, 0, 1], + [0, -1, 0]]) elif src == Coord3DMode.LIDAR and dst == Coord3DMode.DEPTH: rt_mat = arr.new_tensor([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) elif src == Coord3DMode.DEPTH and dst == Coord3DMode.LIDAR: @@ -245,7 +253,7 @@ def convert_point(point, src, dst, rt_mat=None): else: xyz = arr[:, :3] @ rt_mat.t() - remains = arr[..., 3:] + remains = arr[:, 3:] arr = torch.cat([xyz[:, :3], remains], dim=-1) # convert arr to the original type diff --git a/mmdet3d/core/bbox/structures/utils.py b/mmdet3d/core/bbox/structures/utils.py index 34b69bf13d..f1ac47fe36 100644 --- a/mmdet3d/core/bbox/structures/utils.py +++ b/mmdet3d/core/bbox/structures/utils.py @@ -122,7 +122,20 @@ def points_cam2img(points_3d, proj_mat): torch.Tensor: Points in image coordinates with shape [N, 2]. """ points_num = list(points_3d.shape)[:-1] + points_shape = np.concatenate([points_num, [1]], axis=0).tolist() + assert len(proj_mat.shape) == 2, f'The dimension of the projection'\ + f'matrix should be 2 instead of {len(proj_mat.shape)}.' + d1, d2 = proj_mat.shape[:2] + assert (d1 == 3 and d2 == 3) or (d1 == 3 and d2 == 4) or ( + d1 == 4 and d2 == 4), f'The shape of the projection matrix'\ + f' ({d1}*{d2}) is not supported.' + if d1 == 3: + proj_mat_expanded = torch.eye( + 4, device=proj_mat.device, dtype=proj_mat.dtype) + proj_mat_expanded[:d1, :d2] = proj_mat + proj_mat = proj_mat_expanded + # previous implementation use new_zeros, new_one yeilds better results points_4 = torch.cat( [points_3d, points_3d.new_ones(*points_shape)], dim=-1) diff --git a/mmdet3d/core/points/base_points.py b/mmdet3d/core/points/base_points.py index ac4a816ec3..52a0f74f06 100644 --- a/mmdet3d/core/points/base_points.py +++ b/mmdet3d/core/points/base_points.py @@ -10,14 +10,14 @@ class BasePoints(object): tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix. points_dim (int): Number of the dimension of a point. Each row is (x, y, z). Default to 3. - attribute_dims (dict): Dictinory to indicate the meaning of extra + attribute_dims (dict): Dictionary to indicate the meaning of extra dimension. Default to None. Attributes: tensor (torch.Tensor): Float matrix of N x points_dim. points_dim (int): Integer indicating the dimension of a point. Each row is (x, y, z, ...). - attribute_dims (bool): Dictinory to indicate the meaning of extra + attribute_dims (bool): Dictionary to indicate the meaning of extra dimension. Default to None. rotation_axis (int): Default rotation axis for points rotation. """ diff --git a/mmdet3d/core/points/cam_points.py b/mmdet3d/core/points/cam_points.py index d2ce420888..185680158b 100644 --- a/mmdet3d/core/points/cam_points.py +++ b/mmdet3d/core/points/cam_points.py @@ -8,14 +8,14 @@ class CameraPoints(BasePoints): tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix. points_dim (int): Number of the dimension of a point. Each row is (x, y, z). Default to 3. - attribute_dims (dict): Dictinory to indicate the meaning of extra + attribute_dims (dict): Dictionary to indicate the meaning of extra dimension. Default to None. Attributes: tensor (torch.Tensor): Float matrix of N x points_dim. points_dim (int): Integer indicating the dimension of a point. Each row is (x, y, z, ...). - attribute_dims (bool): Dictinory to indicate the meaning of extra + attribute_dims (bool): Dictionary to indicate the meaning of extra dimension. Default to None. rotation_axis (int): Default rotation axis for points rotation. """ diff --git a/mmdet3d/core/points/depth_points.py b/mmdet3d/core/points/depth_points.py index fe99ca838b..3d194a1e03 100644 --- a/mmdet3d/core/points/depth_points.py +++ b/mmdet3d/core/points/depth_points.py @@ -8,14 +8,14 @@ class DepthPoints(BasePoints): tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix. points_dim (int): Number of the dimension of a point. Each row is (x, y, z). Default to 3. - attribute_dims (dict): Dictinory to indicate the meaning of extra + attribute_dims (dict): Dictionary to indicate the meaning of extra dimension. Default to None. Attributes: tensor (torch.Tensor): Float matrix of N x points_dim. points_dim (int): Integer indicating the dimension of a point. Each row is (x, y, z, ...). - attribute_dims (bool): Dictinory to indicate the meaning of extra + attribute_dims (bool): Dictionary to indicate the meaning of extra dimension. Default to None. rotation_axis (int): Default rotation axis for points rotation. """ diff --git a/mmdet3d/core/points/lidar_points.py b/mmdet3d/core/points/lidar_points.py index 26132d34bf..f17323a3af 100644 --- a/mmdet3d/core/points/lidar_points.py +++ b/mmdet3d/core/points/lidar_points.py @@ -8,14 +8,14 @@ class LiDARPoints(BasePoints): tensor (torch.Tensor | np.ndarray | list): a N x points_dim matrix. points_dim (int): Number of the dimension of a point. Each row is (x, y, z). Default to 3. - attribute_dims (dict): Dictinory to indicate the meaning of extra + attribute_dims (dict): Dictionary to indicate the meaning of extra dimension. Default to None. Attributes: tensor (torch.Tensor): Float matrix of N x points_dim. points_dim (int): Integer indicating the dimension of a point. Each row is (x, y, z, ...). - attribute_dims (bool): Dictinory to indicate the meaning of extra + attribute_dims (bool): Dictionary to indicate the meaning of extra dimension. Default to None. rotation_axis (int): Default rotation axis for points rotation. """ diff --git a/mmdet3d/datasets/pipelines/formating.py b/mmdet3d/datasets/pipelines/formating.py index 4992e18175..f5b2def535 100644 --- a/mmdet3d/datasets/pipelines/formating.py +++ b/mmdet3d/datasets/pipelines/formating.py @@ -137,8 +137,8 @@ def __init__(self, 'pcd_horizontal_flip', 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', 'rect', 'Trv2c', 'P2', 'pcd_trans', 'sample_idx', - 'pcd_scale_factor', 'pcd_rotation', - 'pts_filename')): + 'pcd_scale_factor', 'pcd_rotation', 'pts_filename', + 'transformation_3d_flow')): self.keys = keys self.meta_keys = meta_keys diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index abdc5d867e..972615e35b 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -96,10 +96,15 @@ def __call__(self, input_dict): ) < self.flip_ratio_bev_vertical else False input_dict['pcd_vertical_flip'] = flip_vertical + if 'transformation_3d_flow' not in input_dict: + input_dict['transformation_3d_flow'] = [] + if input_dict['pcd_horizontal_flip']: self.random_flip_data_3d(input_dict, 'horizontal') + input_dict['transformation_3d_flow'].extend(['HF']) if input_dict['pcd_vertical_flip']: self.random_flip_data_3d(input_dict, 'vertical') + input_dict['transformation_3d_flow'].extend(['VF']) return input_dict def __repr__(self): @@ -405,6 +410,9 @@ def __call__(self, input_dict): 'pcd_scale_factor', 'pcd_trans' and keys in \ input_dict['bbox3d_fields'] are updated in the result dict. """ + if 'transformation_3d_flow' not in input_dict: + input_dict['transformation_3d_flow'] = [] + self._rot_bbox_points(input_dict) if 'pcd_scale_factor' not in input_dict: @@ -412,6 +420,8 @@ def __call__(self, input_dict): self._scale_bbox_points(input_dict) self._trans_bbox_points(input_dict) + + input_dict['transformation_3d_flow'].extend(['R', 'S', 'T']) return input_dict def __repr__(self): diff --git a/mmdet3d/datasets/sunrgbd_dataset.py b/mmdet3d/datasets/sunrgbd_dataset.py index f21fdeaa29..fd099faa6e 100644 --- a/mmdet3d/datasets/sunrgbd_dataset.py +++ b/mmdet3d/datasets/sunrgbd_dataset.py @@ -1,8 +1,10 @@ import numpy as np +from collections import OrderedDict from os import path as osp from mmdet3d.core import show_result from mmdet3d.core.bbox import DepthInstance3DBoxes +from mmdet.core import eval_map from mmdet.datasets import DATASETS from .custom_3d import Custom3DDataset @@ -59,6 +61,52 @@ def __init__(self, box_type_3d=box_type_3d, filter_empty_gt=filter_empty_gt, test_mode=test_mode) + if self.modality is None: + self.modality = dict(use_camera=True, use_lidar=True) + assert self.modality['use_camera'] or self.modality['use_lidar'] + + def get_data_info(self, index): + """Get data info according to the given index. + + Args: + index (int): Index of the sample data to get. + + Returns: + dict: Data information that will be passed to the data \ + preprocessing pipelines. It includes the following keys: + + - sample_idx (str): Sample index. + - pts_filename (str, optional): Filename of point clouds. + - file_name (str, optional): Filename of point clouds. + - img_prefix (str | None, optional): Prefix of image files. + - img_info (dict, optional): Image info. + - calib (dict, optional): Camera calibration info. + - ann_info (dict): Annotation info. + """ + info = self.data_infos[index] + sample_idx = info['point_cloud']['lidar_idx'] + assert info['point_cloud']['lidar_idx'] == info['image']['image_idx'] + input_dict = dict(sample_idx=sample_idx) + + if self.modality['use_lidar']: + pts_filename = osp.join(self.data_root, info['pts_path']) + input_dict['pts_filename'] = pts_filename + input_dict['file_name'] = pts_filename + + if self.modality['use_camera']: + img_filename = osp.join(self.data_root, + info['image']['image_path']) + input_dict['img_prefix'] = None + input_dict['img_info'] = dict(filename=img_filename) + calib = info['calib'] + input_dict['calib'] = calib + + if not self.test_mode: + annos = self.get_ann_info(index) + input_dict['ann_info'] = annos + if self.filter_empty_gt and len(annos['gt_bboxes_3d']) == 0: + return None + return input_dict def get_ann_info(self, index): """Get annotation info according to the given index. @@ -91,6 +139,15 @@ def get_ann_info(self, index): anns_results = dict( gt_bboxes_3d=gt_bboxes_3d, gt_labels_3d=gt_labels_3d) + + if self.modality['use_camera']: + if info['annos']['gt_num'] != 0: + gt_bboxes_2d = info['annos']['bbox'].astype(np.float32) + else: + gt_bboxes_2d = np.zeros((0, 4), dtype=np.float32) + anns_results['bboxes'] = gt_bboxes_2d + anns_results['labels'] = gt_labels_3d + return anns_results def show(self, results, out_dir, show=True): @@ -114,3 +171,33 @@ def show(self, results, out_dir, show=True): pred_bboxes = result['boxes_3d'].tensor.numpy() show_result(points, gt_bboxes, pred_bboxes, out_dir, file_name, show) + + def evaluate(self, + results, + metric=None, + iou_thr=(0.25, 0.5), + iou_thr_2d=(0.5, ), + logger=None, + show=False, + out_dir=None): + + # evaluate 3D detection performance + if isinstance(results[0], dict): + return super().evaluate(results, metric, iou_thr, logger, show, + out_dir) + # evaluate 2D detection performance + else: + eval_results = OrderedDict() + annotations = [self.get_ann_info(i) for i in range(len(self))] + iou_thr_2d = (iou_thr_2d) if isinstance(iou_thr_2d, + float) else iou_thr_2d + for iou_thr_2d_single in iou_thr_2d: + mean_ap, _ = eval_map( + results, + annotations, + scale_ranges=None, + iou_thr=iou_thr_2d_single, + dataset=self.CLASSES, + logger=logger) + eval_results['mAP_' + str(iou_thr_2d_single)] = mean_ap + return eval_results diff --git a/mmdet3d/models/dense_heads/vote_head.py b/mmdet3d/models/dense_heads/vote_head.py index 2edab4caef..7fd54cfdcd 100644 --- a/mmdet3d/models/dense_heads/vote_head.py +++ b/mmdet3d/models/dense_heads/vote_head.py @@ -119,9 +119,9 @@ def _extract_input(self, feat_dict): torch.Tensor: Features of input points. torch.Tensor: Indices of input points. """ - seed_points = feat_dict['fp_xyz'][-1] - seed_features = feat_dict['fp_features'][-1] - seed_indices = feat_dict['fp_indices'][-1] + seed_points = feat_dict['seed_points'] + seed_features = feat_dict['seed_features'] + seed_indices = feat_dict['seed_indices'] return seed_points, seed_features, seed_indices diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index 1ee43a9a20..827c15e373 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -2,6 +2,7 @@ from .centerpoint import CenterPoint from .dynamic_voxelnet import DynamicVoxelNet from .h3dnet import H3DNet +from .imvotenet import ImVoteNet from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_two_stage import MVXTwoStageDetector from .parta2 import PartA2 @@ -10,7 +11,16 @@ from .voxelnet import VoxelNet __all__ = [ - 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', - 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', - 'CenterPoint', 'SSD3DNet' + 'Base3DDetector', + 'VoxelNet', + 'DynamicVoxelNet', + 'MVXTwoStageDetector', + 'DynamicMVXFasterRCNN', + 'MVXFasterRCNN', + 'PartA2', + 'VoteNet', + 'H3DNet', + 'CenterPoint', + 'SSD3DNet', + 'ImVoteNet', ] diff --git a/mmdet3d/models/detectors/imvotenet.py b/mmdet3d/models/detectors/imvotenet.py new file mode 100644 index 0000000000..9d2c4e4175 --- /dev/null +++ b/mmdet3d/models/detectors/imvotenet.py @@ -0,0 +1,839 @@ +import numpy as np +import torch +from torch import nn as nn + +from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d +from mmdet3d.models.utils import MLP +from mmdet.models import DETECTORS +from .. import builder +from .base import Base3DDetector + + +def sample_valid_seeds(mask, num_sampled_seed=1024): + """Randomly sample seeds from all imvotes. + + Args: + mask (torch.Tensor): Bool tensor in shape ( + seed_num*max_imvote_per_pixel), indicates + whether this imvote corresponds to a 2D bbox. + num_sampled_seed (int): How many to sample from all imvotes. + + Returns: + torch.Tensor: Indices with shape (num_sampled_seed). + """ + device = mask.device + batch_size = mask.shape[0] + sample_inds = mask.new_zeros((batch_size, num_sampled_seed), + dtype=torch.int64) + for bidx in range(batch_size): + # return index of non zero elements + valid_inds = torch.nonzero(mask[bidx, :]).squeeze(-1) + if len(valid_inds) < num_sampled_seed: + # compute set t1 - t2 + t1 = torch.arange(num_sampled_seed, device=device) + t2 = valid_inds % num_sampled_seed + combined = torch.cat((t1, t2)) + uniques, counts = combined.unique(return_counts=True) + difference = uniques[counts == 1] + + rand_inds = torch.randperm( + len(difference), + device=device)[:num_sampled_seed - len(valid_inds)] + cur_sample_inds = difference[rand_inds] + cur_sample_inds = torch.cat((valid_inds, cur_sample_inds)) + else: + rand_inds = torch.randperm( + len(valid_inds), device=device)[:num_sampled_seed] + cur_sample_inds = valid_inds[rand_inds] + sample_inds[bidx, :] = cur_sample_inds + return sample_inds + + +@DETECTORS.register_module() +class ImVoteNet(Base3DDetector): + r"""`ImVoteNet `_ for 3D detection.""" + + def __init__(self, + pts_backbone=None, + pts_bbox_heads=None, + pts_neck=None, + img_backbone=None, + img_neck=None, + img_roi_head=None, + img_rpn_head=None, + img_mlp=None, + freeze_img_branch=False, + fusion_layer=None, + num_sampled_seed=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + + super(ImVoteNet, self).__init__() + + # point branch + if pts_backbone is not None: + self.pts_backbone = builder.build_backbone(pts_backbone) + if pts_neck is not None: + self.pts_neck = builder.build_neck(pts_neck) + if pts_bbox_heads is not None: + pts_bbox_head_common = pts_bbox_heads.common + pts_bbox_head_common.update( + train_cfg=train_cfg.pts if train_cfg is not None else None) + pts_bbox_head_common.update(test_cfg=test_cfg.pts) + pts_bbox_head_joint = pts_bbox_head_common.copy() + pts_bbox_head_joint.update(pts_bbox_heads.joint) + pts_bbox_head_pts = pts_bbox_head_common.copy() + pts_bbox_head_pts.update(pts_bbox_heads.pts) + pts_bbox_head_img = pts_bbox_head_common.copy() + pts_bbox_head_img.update(pts_bbox_heads.img) + + self.pts_bbox_head_joint = builder.build_head(pts_bbox_head_joint) + self.pts_bbox_head_pts = builder.build_head(pts_bbox_head_pts) + self.pts_bbox_head_img = builder.build_head(pts_bbox_head_img) + self.pts_bbox_heads = [ + self.pts_bbox_head_joint, self.pts_bbox_head_pts, + self.pts_bbox_head_img + ] + self.loss_weights = pts_bbox_heads.loss_weights + + # image branch + if img_backbone: + self.img_backbone = builder.build_backbone(img_backbone) + if img_neck is not None: + self.img_neck = builder.build_neck(img_neck) + if img_rpn_head is not None: + rpn_train_cfg = train_cfg.img_rpn if train_cfg \ + is not None else None + img_rpn_head_ = img_rpn_head.copy() + img_rpn_head_.update( + train_cfg=rpn_train_cfg, test_cfg=test_cfg.img_rpn) + self.img_rpn_head = builder.build_head(img_rpn_head_) + if img_roi_head is not None: + rcnn_train_cfg = train_cfg.img_rcnn if train_cfg \ + is not None else None + img_roi_head.update( + train_cfg=rcnn_train_cfg, test_cfg=test_cfg.img_rcnn) + self.img_roi_head = builder.build_head(img_roi_head) + + # fusion + if fusion_layer is not None: + self.fusion_layer = builder.build_fusion_layer(fusion_layer) + self.max_imvote_per_pixel = fusion_layer.max_imvote_per_pixel + + self.freeze_img_branch = freeze_img_branch + if freeze_img_branch: + self.freeze_img_branch_params() + + if img_mlp is not None: + self.img_mlp = MLP(**img_mlp) + + self.num_sampled_seed = num_sampled_seed + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.init_weights(pretrained=pretrained) + + def init_weights(self, pretrained=None): + """Initialize model weights.""" + super(ImVoteNet, self).init_weights(pretrained) + if pretrained is None: + img_pretrained = None + pts_pretrained = None + elif isinstance(pretrained, dict): + img_pretrained = pretrained.get('img', None) + pts_pretrained = pretrained.get('pts', None) + else: + raise ValueError( + f'pretrained should be a dict, got {type(pretrained)}') + if self.with_img_backbone: + self.img_backbone.init_weights(pretrained=img_pretrained) + if self.with_img_neck: + if isinstance(self.img_neck, nn.Sequential): + for m in self.img_neck: + m.init_weights() + else: + self.img_neck.init_weights() + + if self.with_img_roi_head: + self.img_roi_head.init_weights(img_pretrained) + if self.with_img_rpn: + self.img_rpn_head.init_weights() + if self.with_pts_backbone: + self.pts_backbone.init_weights(pretrained=pts_pretrained) + if self.with_pts_bbox: + self.pts_bbox_head.init_weights() + if self.with_pts_neck: + if isinstance(self.pts_neck, nn.Sequential): + for m in self.pts_neck: + m.init_weights() + else: + self.pts_neck.init_weights() + + def freeze_img_branch_params(self): + """Freeze all image branch parameters.""" + if self.with_img_bbox_head: + for param in self.img_bbox_head.parameters(): + param.requires_grad = False + if self.with_img_backbone: + for param in self.img_backbone.parameters(): + param.requires_grad = False + if self.with_img_neck: + for param in self.img_neck.parameters(): + param.requires_grad = False + if self.with_img_rpn: + for param in self.img_rpn_head.parameters(): + param.requires_grad = False + if self.with_img_roi_head: + for param in self.img_roi_head.parameters(): + param.requires_grad = False + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + """Overload in order to load img network ckpts into img branch.""" + module_names = ['backbone', 'neck', 'roi_head', 'rpn_head'] + for key in list(state_dict): + for module_name in module_names: + if key.startswith(module_name) and ('img_' + + key) not in state_dict: + state_dict['img_' + key] = state_dict.pop(key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) + + def train(self, mode=True): + """Overload in order to keep image branch modules in eval mode.""" + super(ImVoteNet, self).train(mode) + if self.freeze_img_branch: + if self.with_img_bbox_head: + self.img_bbox_head.eval() + if self.with_img_backbone: + self.img_backbone.eval() + if self.with_img_neck: + self.img_neck.eval() + if self.with_img_rpn: + self.img_rpn_head.eval() + if self.with_img_roi_head: + self.img_roi_head.eval() + + @property + def with_img_bbox(self): + """bool: Whether the detector has a 2D image box head.""" + return ((hasattr(self, 'img_roi_head') and self.img_roi_head.with_bbox) + or (hasattr(self, 'img_bbox_head') + and self.img_bbox_head is not None)) + + @property + def with_img_bbox_head(self): + """bool: Whether the detector has a 2D image box head (not roi).""" + return hasattr(self, + 'img_bbox_head') and self.img_bbox_head is not None + + @property + def with_img_backbone(self): + """bool: Whether the detector has a 2D image backbone.""" + return hasattr(self, 'img_backbone') and self.img_backbone is not None + + @property + def with_img_neck(self): + """bool: Whether the detector has a neck in image branch.""" + return hasattr(self, 'img_neck') and self.img_neck is not None + + @property + def with_img_rpn(self): + """bool: Whether the detector has a 2D RPN in image detector branch.""" + return hasattr(self, 'img_rpn_head') and self.img_rpn_head is not None + + @property + def with_img_roi_head(self): + """bool: Whether the detector has a RoI Head in image branch.""" + return hasattr(self, 'img_roi_head') and self.img_roi_head is not None + + @property + def with_pts_bbox(self): + """bool: Whether the detector has a 3D box head.""" + return hasattr(self, + 'pts_bbox_head') and self.pts_bbox_head is not None + + @property + def with_pts_backbone(self): + """bool: Whether the detector has a 3D backbone.""" + return hasattr(self, 'pts_backbone') and self.pts_backbone is not None + + @property + def with_pts_neck(self): + """bool: Whether the detector has a neck in 3D detector branch.""" + return hasattr(self, 'pts_neck') and self.pts_neck is not None + + def extract_feat(self, imgs): + """Just to inherit from abstract method.""" + pass + + def extract_img_feat(self, img): + """Directly extract features from the img backbone+neck.""" + x = self.img_backbone(img) + if self.with_img_neck: + x = self.img_neck(x) + return x + + def extract_img_feats(self, imgs): + """Extract features from multiple images. + + Args: + imgs (list[torch.Tensor]): A list of images. The images are + augmented from the same image but in different ways. + + Returns: + list[torch.Tensor]: Features of different images + """ + + assert isinstance(imgs, list) + return [self.extract_img_feat(img) for img in imgs] + + def extract_pts_feat(self, pts): + """Extract features of points.""" + x = self.pts_backbone(pts) + if self.with_pts_neck: + x = self.pts_neck(x) + + seed_points = x['fp_xyz'][-1] + seed_features = x['fp_features'][-1] + seed_indices = x['fp_indices'][-1] + + return (seed_points, seed_features, seed_indices) + + def extract_pts_feats(self, pts): + """Extract features of points from multiple samples.""" + assert isinstance(pts, list) + return [self.extract_pts_feat(pt) for pt in pts] + + @torch.no_grad() + def extract_bboxes_2d(self, + img, + img_metas, + train=True, + bboxes_2d=None, + **kwargs): + """Extract bounding boxes from 2d detector. + + Args: + img (torch.Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + img_metas (list[dict]): Image meta info. + train (bool): train-time or not. + bboxes_2d (list[torch.Tensor]): provided 2d bboxes, + not supported yet. + + Return: + list[torch.Tensor]: a list of processed 2d bounding boxes. + """ + if bboxes_2d is None: + x = self.extract_img_feat(img) + proposal_list = self.img_rpn_head.simple_test_rpn(x, img_metas) + rets = self.img_roi_head.simple_test( + x, proposal_list, img_metas, rescale=False) + + rets_processed = [] + for ret in rets: + tmp = np.concatenate(ret, axis=0) + sem_class = img.new_zeros((len(tmp))) + start = 0 + for i, bboxes in enumerate(ret): + sem_class[start:start + len(bboxes)] = i + start += len(bboxes) + ret = img.new_tensor(tmp) + + # append class index + ret = torch.cat([ret, sem_class[:, None]], dim=-1) + inds = torch.argsort(ret[:, 4], descending=True) + ret = ret.index_select(0, inds) + + # drop half bboxes during training for better generalization + if train: + rand_drop = torch.randperm(len(ret))[:(len(ret) + 1) // 2] + rand_drop = torch.sort(rand_drop)[0] + ret = ret[rand_drop] + + rets_processed.append(ret.float()) + return rets_processed + else: + rets_processed = [] + for ret in bboxes_2d: + if len(ret) > 0 and train: + rand_drop = torch.randperm(len(ret))[:(len(ret) + 1) // 2] + rand_drop = torch.sort(rand_drop)[0] + ret = ret[rand_drop] + rets_processed.append(ret.float()) + return rets_processed + + def forward_train(self, + points=None, + img=None, + img_metas=None, + gt_bboxes=None, + gt_labels=None, + gt_bboxes_ignore=None, + gt_masks=None, + proposals=None, + calib=None, + bboxes_2d=None, + gt_bboxes_3d=None, + gt_labels_3d=None, + pts_semantic_mask=None, + pts_instance_mask=None, + **kwargs): + """Forwarding of train for image branch pretrain or stage 2 train. + + Args: + points (list[torch.Tensor]): Points of each batch. + img (torch.Tensor): of shape (N, C, H, W) encoding input images. + Typically these should be mean centered and std scaled. + img_metas (list[dict]): list of image and point cloud meta info + dict. For example, keys include 'ori_shape', 'img_norm_cfg', + and 'transformation_3d_flow'. For details on the values of + the keys see `mmdet/datasets/pipelines/formatting.py:Collect`. + gt_bboxes (list[torch.Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (list[torch.Tensor]): class indices for each + 2d bounding box. + gt_bboxes_ignore (None | list[torch.Tensor]): specify which + 2d bounding boxes can be ignored when computing the loss. + gt_masks (None | torch.Tensor): true segmentation masks for each + 2d bbox, used if the architecture supports a segmentation task. + proposals: override rpn proposals (2d) with custom proposals. + Use when `with_rpn` is False. + calib (dict[str, torch.Tensor]): camera calibration matrices, + Rt and K. + bboxes_2d (list[torch.Tensor]): provided 2d bboxes, + not supported yet. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): 3d gt bboxes. + gt_labels_3d (list[torch.Tensor]): gt class labels for 3d bboxes. + pts_semantic_mask (None | list[torch.Tensor]): point-wise semantic + label of each batch. + pts_instance_mask (None | list[torch.Tensor]): point-wise instance + label of each batch. + + Returns: + dict[str, torch.Tensor]: a dictionary of loss components. + """ + if points is None: + x = self.extract_img_feat(img) + losses = dict() + + # RPN forward and loss + if self.with_img_rpn: + proposal_cfg = self.train_cfg.get('img_rpn_proposal', + self.test_cfg.img_rpn) + rpn_losses, proposal_list = self.img_rpn_head.forward_train( + x, + img_metas, + gt_bboxes, + gt_labels=None, + gt_bboxes_ignore=gt_bboxes_ignore, + proposal_cfg=proposal_cfg) + losses.update(rpn_losses) + else: + proposal_list = proposals + + roi_losses = self.img_roi_head.forward_train( + x, img_metas, proposal_list, gt_bboxes, gt_labels, + gt_bboxes_ignore, gt_masks, **kwargs) + losses.update(roi_losses) + return losses + else: + bboxes_2d = self.extract_bboxes_2d( + img, img_metas, bboxes_2d=bboxes_2d, **kwargs) + + points = torch.stack(points) + seeds_3d, seed_3d_features, seed_indices = \ + self.extract_pts_feat(points) + + img_features, masks = self.fusion_layer(img, bboxes_2d, seeds_3d, + img_metas, calib) + + inds = sample_valid_seeds(masks, self.num_sampled_seed) + batch_size, img_feat_size = img_features.shape[:2] + pts_feat_size = seed_3d_features.shape[1] + inds_img = inds.view(batch_size, 1, + -1).expand(-1, img_feat_size, -1) + img_features = img_features.gather(-1, inds_img) + inds = inds % inds.shape[1] + inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3) + seeds_3d = seeds_3d.gather(1, inds_seed_xyz) + inds_seed_feats = inds.view(batch_size, 1, + -1).expand(-1, pts_feat_size, -1) + seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats) + seed_indices = seed_indices.gather(1, inds) + + img_features = self.img_mlp(img_features) + fused_features = torch.cat([seed_3d_features, img_features], dim=1) + + feat_dict_joint = dict( + seed_points=seeds_3d, + seed_features=fused_features, + seed_indices=seed_indices) + feat_dict_pts = dict( + seed_points=seeds_3d, + seed_features=seed_3d_features, + seed_indices=seed_indices) + feat_dict_img = dict( + seed_points=seeds_3d, + seed_features=img_features, + seed_indices=seed_indices) + + loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask, img_metas) + bbox_preds_joints = self.pts_bbox_head_joint( + feat_dict_joint, self.train_cfg.pts.sample_mod) + bbox_preds_pts = self.pts_bbox_head_pts( + feat_dict_pts, self.train_cfg.pts.sample_mod) + bbox_preds_img = self.pts_bbox_head_img( + feat_dict_img, self.train_cfg.pts.sample_mod) + losses_towers = [] + losses_joint = self.pts_bbox_head_joint.loss( + bbox_preds_joints, + *loss_inputs, + gt_bboxes_ignore=gt_bboxes_ignore) + losses_pts = self.pts_bbox_head_pts.loss( + bbox_preds_pts, + *loss_inputs, + gt_bboxes_ignore=gt_bboxes_ignore) + losses_img = self.pts_bbox_head_img.loss( + bbox_preds_img, + *loss_inputs, + gt_bboxes_ignore=gt_bboxes_ignore) + losses_towers.append(losses_joint) + losses_towers.append(losses_pts) + losses_towers.append(losses_img) + combined_losses = dict() + for loss_term in losses_joint: + if 'loss' in loss_term: + combined_losses[loss_term] = 0 + for i in range(len(losses_towers)): + combined_losses[loss_term] += \ + losses_towers[i][loss_term] * \ + self.loss_weights[i] + else: + # only save the metric of the joint head + # if it is not a loss + combined_losses[loss_term] = \ + losses_towers[0][loss_term] + + return combined_losses + + def forward_test(self, + points=None, + img_metas=None, + img=None, + calib=None, + bboxes_2d=None, + **kwargs): + """Forwarding of test for image branch pretrain or stage 2 train. + + Args: + points (list[list[torch.Tensor]], optional): the outer + list indicates test-time augmentations and the inner + list contains all points in the batch, where each Tensor + should have a shape NxC. Defaults to None. + img_metas (list[list[dict]], optional): the outer list + indicates test-time augs (multiscale, flip, etc.) + and the inner list indicates images in a batch. + Defaults to None. + img (list[list[torch.Tensor]], optional): the outer + list indicates test-time augmentations and inner Tensor + should have a shape NxCxHxW, which contains all images + in the batch. Defaults to None. Defaults to None. + calibs (list[dict[str, torch.Tensor]], optional): camera + calibration matrices, Rt and K. + List indicates test-time augs. Defaults to None. + bboxes_2d (list[list[torch.Tensor]], optional): + Provided 2d bboxes, not supported yet. Defaults to None. + + Returns: + list[list[torch.Tensor]]|list[dict]: Predicted 2d or 3d boxes. + """ + if points is None: + for var, name in [(img, 'img'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError( + f'{name} must be a list, but got {type(var)}') + + num_augs = len(img) + if num_augs != len(img_metas): + raise ValueError(f'num of augmentations ({len(img)}) ' + f'!= num of image meta ({len(img_metas)})') + + if num_augs == 1: + # proposals (List[List[Tensor]]): the outer list indicates + # test-time augs (multiscale, flip, etc.) and the inner list + # indicates images in a batch. + # The Tensor should have a shape Px4, where P is the number of + # proposals. + if 'proposals' in kwargs: + kwargs['proposals'] = kwargs['proposals'][0] + return self.simple_test_img_only( + img=img[0], img_metas=img_metas[0], **kwargs) + else: + assert img[0].size(0) == 1, 'aug test does not support ' \ + 'inference with batch size ' \ + f'{img[0].size(0)}' + # TODO: support test augmentation for predefined proposals + assert 'proposals' not in kwargs + return self.aug_test_img_only( + img=img, img_metas=img_metas, **kwargs) + + else: + for var, name in [(points, 'points'), (img_metas, 'img_metas')]: + if not isinstance(var, list): + raise TypeError('{} must be a list, but got {}'.format( + name, type(var))) + + num_augs = len(points) + if num_augs != len(img_metas): + raise ValueError( + 'num of augmentations ({}) != num of image meta ({})'. + format(len(points), len(img_metas))) + + if num_augs == 1: + return self.simple_test( + points[0], + img_metas[0], + img[0], + calibs=calib[0], + bboxes_2d=bboxes_2d[0] if bboxes_2d is not None else None, + **kwargs) + else: + return self.aug_test(points, img_metas, img, calib, bboxes_2d, + **kwargs) + + def simple_test_img_only(self, + img, + img_metas, + proposals=None, + rescale=False): + """Test without augmentation, image network pretrain. May refer to + https://github.com/open- + mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py # + noqa. + + Args: + img (torch.Tensor): Should have a shape NxCxHxW, which contains + all images in the batch. + img_metas (list[dict]): + proposals (list[Tensor], optional): override rpn proposals + with custom proposals. Defaults to None. + rescale (bool, optional): Whether or not rescale bboxes to the + original shape of input image. Defaults to False. + + Returns: + list[list[torch.Tensor]]: Predicted 2d boxes. + """ + assert self.with_img_bbox, 'Img bbox head must be implemented.' + assert self.with_img_backbone, 'Img backbone must be implemented.' + assert self.with_img_rpn, 'Img rpn must be implemented.' + assert self.with_img_roi_head, 'Img roi head must be implemented.' + + x = self.extract_img_feat(img) + + if proposals is None: + proposal_list = self.img_rpn_head.simple_test_rpn(x, img_metas) + else: + proposal_list = proposals + + ret = self.img_roi_head.simple_test( + x, proposal_list, img_metas, rescale=rescale) + + return ret + + def simple_test(self, + points=None, + img_metas=None, + img=None, + calibs=None, + bboxes_2d=None, + rescale=False, + **kwargs): + """Test without augmentation, stage 2. + + Args: + points (list[torch.Tensor], optional): Elements in the list + should have a shape NxC, the list indicates all point-clouds + in the batch. Defaults to None. + img_metas (list[dict], optional): List indicates + images in a batch. Defaults to None. + img (torch.Tensor, optional): Should have a shape NxCxHxW, + which contains all images in the batch. Defaults to None. + calibs (dict[str, torch.Tensor], optional): camera + calibration matrices, Rt and K. Defaults to None. + bboxes_2d (list[torch.Tensor], optional): + Provided 2d bboxes, not supported yet. Defaults to None. + rescale (bool, optional): Whether or not rescale bboxes. + Defaults to False. + + Returns: + list[dict]: Predicted 3d boxes. + """ + bboxes_2d = self.extract_bboxes_2d( + img, img_metas, train=False, bboxes_2d=bboxes_2d, **kwargs) + + points = torch.stack(points) + seeds_3d, seed_3d_features, seed_indices = \ + self.extract_pts_feat(points) + + img_features, masks = self.fusion_layer(img, bboxes_2d, seeds_3d, + img_metas, calibs) + + inds = sample_valid_seeds(masks, self.num_sampled_seed) + batch_size, img_feat_size = img_features.shape[:2] + pts_feat_size = seed_3d_features.shape[1] + inds_img = inds.view(batch_size, 1, -1).expand(-1, img_feat_size, -1) + img_features = img_features.gather(-1, inds_img) + inds = inds % inds.shape[1] + inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3) + seeds_3d = seeds_3d.gather(1, inds_seed_xyz) + inds_seed_feats = inds.view(batch_size, 1, + -1).expand(-1, pts_feat_size, -1) + seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats) + seed_indices = seed_indices.gather(1, inds) + + img_features = self.img_mlp(img_features) + + fused_features = torch.cat([seed_3d_features, img_features], dim=1) + + feat_dict = dict( + seed_points=seeds_3d, + seed_features=fused_features, + seed_indices=seed_indices) + bbox_preds = self.pts_bbox_head_joint(feat_dict, + self.test_cfg.pts.sample_mod) + bbox_list = self.pts_bbox_head_joint.get_bboxes( + points, bbox_preds, img_metas, rescale=rescale) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def aug_test_img_only(self, img, img_metas, rescale=False): + """Test function with augmentation, image network pretrain. May refer + to https://github.com/open- + mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py # + noqa. + + Args: + img (list[list[torch.Tensor]], optional): the outer + list indicates test-time augmentations and inner Tensor + should have a shape NxCxHxW, which contains all images + in the batch. Defaults to None. Defaults to None. + img_metas (list[list[dict]], optional): the outer list + indicates test-time augs (multiscale, flip, etc.) + and the inner list indicates images in a batch. + Defaults to None. + rescale (bool, optional): Whether or not rescale bboxes to the + original shape of input image. If rescale is False, then + returned bboxes and masks will fit the scale of imgs[0]. + Defaults to None. + + Returns: + list[list[torch.Tensor]]: Predicted 2d boxes. + """ + assert self.with_img_bbox, 'Img bbox head must be implemented.' + assert self.with_img_backbone, 'Img backbone must be implemented.' + assert self.with_img_rpn, 'Img rpn must be implemented.' + assert self.with_img_roi_head, 'Img roi head must be implemented.' + + x = self.extract_img_feats(img) + proposal_list = self.img_rpn_head.aug_test_rpn(x, img_metas) + + return self.img_roi_head.aug_test( + x, proposal_list, img_metas, rescale=rescale) + + def aug_test(self, + points=None, + img_metas=None, + imgs=None, + calibs=None, + bboxes_2d=None, + rescale=False, + **kwargs): + """Test function with augmentation, stage 2. + + Args: + points (list[list[torch.Tensor]], optional): the outer + list indicates test-time augmentations and the inner + list contains all points in the batch, where each Tensor + should have a shape NxC. Defaults to None. + img_metas (list[list[dict]], optional): the outer list + indicates test-time augs (multiscale, flip, etc.) + and the inner list indicates images in a batch. + Defaults to None. + imgs (list[list[torch.Tensor]], optional): the outer + list indicates test-time augmentations and inner Tensor + should have a shape NxCxHxW, which contains all images + in the batch. Defaults to None. Defaults to None. + calibs (list[dict[str, torch.Tensor]], optional): camera + calibration matrices, Rt and K. + List indicates test-time augs. Defaults to None. + bboxes_2d (list[list[torch.Tensor]], optional): + Provided 2d bboxes, not supported yet. Defaults to None. + rescale (bool, optional): Whether or not rescale bboxes. + Defaults to False. + + Returns: + list[dict]: Predicted 3d boxes. + """ + points_cat = [torch.stack(pts) for pts in points] + feats = self.extract_pts_feats(points_cat, img_metas) + + # only support aug_test for one sample + aug_bboxes = [] + for x, pts_cat, img_meta, bbox_2d, img, calib in zip( + feats, points_cat, img_metas, bboxes_2d, imgs, calibs): + + bbox_2d = self.extract_bboxes_2d( + img, img_metas, train=False, bboxes_2d=bbox_2d, **kwargs) + + seeds_3d, seed_3d_features, seed_indices = x + + img_features, masks = self.fusion_layer(img, bbox_2d, seeds_3d, + img_metas, calib) + + inds = sample_valid_seeds(masks, self.num_sampled_seed) + batch_size, img_feat_size = img_features.shape[:2] + pts_feat_size = seed_3d_features.shape[1] + inds_img = inds.view(batch_size, 1, + -1).expand(-1, img_feat_size, -1) + img_features = img_features.gather(-1, inds_img) + inds = inds % inds.shape[1] + inds_seed_xyz = inds.view(batch_size, -1, 1).expand(-1, -1, 3) + seeds_3d = seeds_3d.gather(1, inds_seed_xyz) + inds_seed_feats = inds.view(batch_size, 1, + -1).expand(-1, pts_feat_size, -1) + seed_3d_features = seed_3d_features.gather(-1, inds_seed_feats) + seed_indices = seed_indices.gather(1, inds) + + img_features = self.img_mlp(img_features) + + fused_features = torch.cat([seed_3d_features, img_features], dim=1) + + feat_dict = dict( + seed_points=seeds_3d, + seed_features=fused_features, + seed_indices=seed_indices) + bbox_preds = self.pts_bbox_head_joint(feat_dict, + self.test_cfg.pts.sample_mod) + bbox_list = self.pts_bbox_head_joint.get_bboxes( + pts_cat, bbox_preds, img_metas, rescale=rescale) + + bbox_list = [ + dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) + for bboxes, scores, labels in bbox_list + ] + aug_bboxes.append(bbox_list[0]) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, + self.bbox_head.test_cfg) + + return [merged_bboxes] diff --git a/mmdet3d/models/detectors/votenet.py b/mmdet3d/models/detectors/votenet.py index d1e676736f..706a35fad2 100644 --- a/mmdet3d/models/detectors/votenet.py +++ b/mmdet3d/models/detectors/votenet.py @@ -7,10 +7,7 @@ @DETECTORS.register_module() class VoteNet(SingleStage3DDetector): - """VoteNet model. - - https://arxiv.org/pdf/1904.09664.pdf - """ + r"""`VoteNet `_ for 3D detection.""" def __init__(self, backbone, @@ -25,6 +22,28 @@ def __init__(self, test_cfg=test_cfg, pretrained=pretrained) + def extract_feat(self, points, img_metas=None): + """Directly extract features from the backbone+neck. + + Args: + points (torch.Tensor): Input points. + """ + x = self.backbone(points) + if self.with_neck: + x = self.neck(x) + + seed_points = x['fp_xyz'][-1] + seed_features = x['fp_features'][-1] + seed_indices = x['fp_indices'][-1] + + feat_dict = { + 'seed_points': seed_points, + 'seed_features': seed_features, + 'seed_indices': seed_indices + } + + return feat_dict + def forward_train(self, points, img_metas, diff --git a/mmdet3d/models/fusion_layers/__init__.py b/mmdet3d/models/fusion_layers/__init__.py index 93142ced2f..2ea683efa5 100644 --- a/mmdet3d/models/fusion_layers/__init__.py +++ b/mmdet3d/models/fusion_layers/__init__.py @@ -1,3 +1,9 @@ +from .coord_transform import (apply_3d_transformation, bbox_2d_transform, + coord_2d_transform) from .point_fusion import PointFusion +from .vote_fusion import VoteFusion -__all__ = ['PointFusion'] +__all__ = [ + 'PointFusion', 'VoteFusion', 'apply_3d_transformation', + 'bbox_2d_transform', 'coord_2d_transform' +] diff --git a/mmdet3d/models/fusion_layers/coord_transform.py b/mmdet3d/models/fusion_layers/coord_transform.py new file mode 100644 index 0000000000..b1805f6b53 --- /dev/null +++ b/mmdet3d/models/fusion_layers/coord_transform.py @@ -0,0 +1,214 @@ +import torch +from functools import partial + +from mmdet3d.core.points import get_points_type + + +def apply_3d_transformation(pcd, coords_type, img_meta, reverse=False): + """Apply transformation to input point cloud. + + Args: + pcd (torch.Tensor): The point cloud to be transformed. + coords_type (str): 'DEPTH' or 'CAMERA' or 'LIDAR' + img_meta(dict): Meta info regarding data transformation. + reverse (bool): Reversed transformation or not. + + Note: + The elements in img_meta['transformation_3d_flow']: + "T" stands for translation; + "S" stands for scale; + "R" stands for rotation; + "HF" stands for horizontal flip; + "VF" stands for vertical flip. + + Returns: + torch.Tensor: The transformed point cloud. + """ + + dtype = pcd.dtype + device = pcd.device + + pcd_rotate_mat = ( + torch.tensor(img_meta['pcd_rotation'], dtype=dtype, device=device) + if 'pcd_rotation' in img_meta else torch.eye( + 3, dtype=dtype, device=device)) + + pcd_scale_factor = ( + img_meta['pcd_scale_factor'] if 'pcd_scale_factor' in img_meta else 1.) + + pcd_trans_factor = ( + torch.tensor(img_meta['pcd_trans'], dtype=dtype, device=device) + if 'pcd_trans' in img_meta else torch.zeros( + (3), dtype=dtype, device=device)) + + pcd_horizontal_flip = img_meta[ + 'pcd_horizontal_flip'] if 'pcd_horizontal_flip' in \ + img_meta else False + + pcd_vertical_flip = img_meta[ + 'pcd_vertical_flip'] if 'pcd_vertical_flip' in \ + img_meta else False + + flow = img_meta['transformation_3d_flow'] \ + if 'transformation_3d_flow' in img_meta else [] + + pcd = pcd.clone() # prevent inplace modification + pcd = get_points_type(coords_type)(pcd) + + horizontal_flip_func = partial(pcd.flip, bev_direction='horizontal') \ + if pcd_horizontal_flip else lambda: None + vertical_flip_func = partial(pcd.flip, bev_direction='vertical') \ + if pcd_vertical_flip else lambda: None + if reverse: + scale_func = partial(pcd.scale, scale_factor=1.0 / pcd_scale_factor) + translate_func = partial(pcd.translate, trans_vector=-pcd_trans_factor) + # pcd_rotate_mat @ pcd_rotate_mat.inverse() is not + # exactly an identity matrix + # use angle to create the inverse rot matrix neither. + rotate_func = partial(pcd.rotate, rotation=pcd_rotate_mat.inverse()) + + # reverse the pipeline + flow = flow[::-1] + else: + scale_func = partial(pcd.scale, scale_factor=pcd_scale_factor) + translate_func = partial(pcd.translate, trans_vector=pcd_trans_factor) + rotate_func = partial(pcd.rotate, rotation=pcd_rotate_mat) + + flow_mapping = { + 'T': translate_func, + 'S': scale_func, + 'R': rotate_func, + 'HF': horizontal_flip_func, + 'VF': vertical_flip_func + } + for op in flow: + assert op in flow_mapping, f'This 3D data '\ + f'transformation op ({op}) is not supported' + func = flow_mapping[op] + func() + + return pcd.coord + + +def extract_2d_info(img_meta, tensor): + """Extract image augmentation information from img_meta. + + Args: + img_meta(dict): Meta info regarding data transformation. + tensor(torch.Tensor): Input tensor used to create new ones. + + Returns: + (int, int, int, int, torch.Tensor, bool, torch.Tensor): + The extracted information. + """ + img_shape = img_meta['img_shape'] + ori_shape = img_meta['ori_shape'] + img_h, img_w, _ = img_shape + ori_h, ori_w, _ = ori_shape + + img_scale_factor = ( + tensor.new_tensor(img_meta['scale_factor'][:2]) + if 'scale_factor' in img_meta else tensor.new_tensor([1.0, 1.0])) + img_flip = img_meta['flip'] if 'flip' in img_meta else False + img_crop_offset = ( + tensor.new_tensor(img_meta['img_crop_offset']) + if 'img_crop_offset' in img_meta else tensor.new_tensor([0.0, 0.0])) + + return (img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, + img_crop_offset) + + +def bbox_2d_transform(img_meta, bbox_2d, ori2new): + """Transform 2d bbox according to img_meta. + + Args: + img_meta(dict): Meta info regarding data transformation. + bbox_2d (torch.Tensor): Shape (..., >4) + The input 2d bboxes to transform. + ori2new (bool): Origin img coord system to new or not. + + Returns: + torch.Tensor: The transformed 2d bboxes. + """ + + img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, \ + img_crop_offset = extract_2d_info(img_meta, bbox_2d) + + bbox_2d_new = bbox_2d.clone() + + if ori2new: + bbox_2d_new[:, 0] = bbox_2d_new[:, 0] * img_scale_factor[0] + bbox_2d_new[:, 2] = bbox_2d_new[:, 2] * img_scale_factor[0] + bbox_2d_new[:, 1] = bbox_2d_new[:, 1] * img_scale_factor[1] + bbox_2d_new[:, 3] = bbox_2d_new[:, 3] * img_scale_factor[1] + + bbox_2d_new[:, 0] = bbox_2d_new[:, 0] + img_crop_offset[0] + bbox_2d_new[:, 2] = bbox_2d_new[:, 2] + img_crop_offset[0] + bbox_2d_new[:, 1] = bbox_2d_new[:, 1] + img_crop_offset[1] + bbox_2d_new[:, 3] = bbox_2d_new[:, 3] + img_crop_offset[1] + + if img_flip: + bbox_2d_r = img_w - bbox_2d_new[:, 0] + bbox_2d_l = img_w - bbox_2d_new[:, 2] + bbox_2d_new[:, 0] = bbox_2d_l + bbox_2d_new[:, 2] = bbox_2d_r + else: + if img_flip: + bbox_2d_r = img_w - bbox_2d_new[:, 0] + bbox_2d_l = img_w - bbox_2d_new[:, 2] + bbox_2d_new[:, 0] = bbox_2d_l + bbox_2d_new[:, 2] = bbox_2d_r + + bbox_2d_new[:, 0] = bbox_2d_new[:, 0] - img_crop_offset[0] + bbox_2d_new[:, 2] = bbox_2d_new[:, 2] - img_crop_offset[0] + bbox_2d_new[:, 1] = bbox_2d_new[:, 1] - img_crop_offset[1] + bbox_2d_new[:, 3] = bbox_2d_new[:, 3] - img_crop_offset[1] + + bbox_2d_new[:, 0] = bbox_2d_new[:, 0] / img_scale_factor[0] + bbox_2d_new[:, 2] = bbox_2d_new[:, 2] / img_scale_factor[0] + bbox_2d_new[:, 1] = bbox_2d_new[:, 1] / img_scale_factor[1] + bbox_2d_new[:, 3] = bbox_2d_new[:, 3] / img_scale_factor[1] + + return bbox_2d_new + + +def coord_2d_transform(img_meta, coord_2d, ori2new): + """Transform 2d pixel coordinates according to img_meta. + + Args: + img_meta(dict): Meta info regarding data transformation. + coord_2d (torch.Tensor): Shape (..., 2) + The input 2d coords to transform. + ori2new (bool): Origin img coord system to new or not. + + Returns: + torch.Tensor: The transformed 2d coordinates. + """ + + img_h, img_w, ori_h, ori_w, img_scale_factor, img_flip, \ + img_crop_offset = extract_2d_info(img_meta, coord_2d) + + coord_2d_new = coord_2d.clone() + + if ori2new: + # TODO here we assume this order of transformation + coord_2d_new[..., 0] = coord_2d_new[..., 0] * img_scale_factor[0] + coord_2d_new[..., 1] = coord_2d_new[..., 1] * img_scale_factor[1] + + coord_2d_new[..., 0] += img_crop_offset[0] + coord_2d_new[..., 1] += img_crop_offset[1] + + # flip uv coordinates and bbox + if img_flip: + coord_2d_new[..., 0] = img_w - coord_2d_new[..., 0] + else: + if img_flip: + coord_2d_new[..., 0] = img_w - coord_2d_new[..., 0] + + coord_2d_new[..., 0] -= img_crop_offset[0] + coord_2d_new[..., 1] -= img_crop_offset[1] + + coord_2d_new[..., 0] = coord_2d_new[..., 0] / img_scale_factor[0] + coord_2d_new[..., 1] = coord_2d_new[..., 1] / img_scale_factor[1] + + return coord_2d_new diff --git a/mmdet3d/models/fusion_layers/point_fusion.py b/mmdet3d/models/fusion_layers/point_fusion.py index c8a4b4b195..388a225f8a 100644 --- a/mmdet3d/models/fusion_layers/point_fusion.py +++ b/mmdet3d/models/fusion_layers/point_fusion.py @@ -4,18 +4,16 @@ from torch.nn import functional as F from ..registry import FUSION_LAYERS +from . import apply_3d_transformation def point_sample( + img_meta, img_features, points, lidar2img_rt, - pcd_rotate_mat, img_scale_factor, img_crop_offset, - pcd_trans_factor, - pcd_scale_factor, - pcd_flip, img_flip, img_pad_shape, img_shape, @@ -26,19 +24,14 @@ def point_sample( """Obtain image features using points. Args: + img_meta (dict): Meta info. img_features (torch.Tensor): 1 x C x H x W image features. points (torch.Tensor): Nx3 point cloud in LiDAR coordinates. lidar2img_rt (torch.Tensor): 4x4 transformation matrix. - pcd_rotate_mat (torch.Tensor): 3x3 rotation matrix of points - during augmentation. img_scale_factor (torch.Tensor): Scale factor with shape of \ (w_scale, h_scale). img_crop_offset (torch.Tensor): Crop offset used to crop \ image during data augmentation with shape of (w_offset, h_offset). - pcd_trans_factor ([type]): Translation of points in augmentation. - pcd_scale_factor (float): Scale factor of points during. - data augmentation - pcd_flip (bool): Whether the points are flipped. img_flip (bool): Whether the image is flipped. img_pad_shape (tuple[int]): int tuple indicates the h & w after padding, this is necessary to obtain features in feature map. @@ -54,19 +47,9 @@ def point_sample( Returns: torch.Tensor: NxC image features sampled by point coordinates. """ - # aug order: flip -> trans -> scale -> rot - # The transformation follows the augmentation order in data pipeline - if pcd_flip: - # if the points are flipped, flip them back first - points[:, 1] = -points[:, 1] - - points -= pcd_trans_factor - # the points should be scaled to the original scale in velo coordinate - points /= pcd_scale_factor - # the points should be rotated back - # pcd_rotate_mat @ pcd_rotate_mat.inverse() is not exactly an identity - # matrix, use angle to create the inverse rot matrix neither. - points = points @ pcd_rotate_mat.inverse() + + # apply transformation based on info in img_meta + points = apply_3d_transformation(points, 'LIDAR', img_meta, reverse=True) # project points from velo coordinate to camera coordinate num_points = points.shape[0] @@ -298,34 +281,21 @@ def sample_single(self, img_feats, pts, img_meta): Returns: torch.Tensor: Single level image features of each point. """ - pcd_scale_factor = ( - img_meta['pcd_scale_factor'] - if 'pcd_scale_factor' in img_meta.keys() else 1) - pcd_trans_factor = ( - pts.new_tensor(img_meta['pcd_trans']) - if 'pcd_trans' in img_meta.keys() else 0) - pcd_rotate_mat = ( - pts.new_tensor(img_meta['pcd_rotation']) if 'pcd_rotation' - in img_meta.keys() else torch.eye(3).type_as(pts).to(pts.device)) + # TODO: image transformation also extracted img_scale_factor = ( pts.new_tensor(img_meta['scale_factor'][:2]) if 'scale_factor' in img_meta.keys() else 1) - pcd_flip = img_meta['pcd_flip'] if 'pcd_flip' in img_meta.keys( - ) else False img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False img_crop_offset = ( pts.new_tensor(img_meta['img_crop_offset']) if 'img_crop_offset' in img_meta.keys() else 0) img_pts = point_sample( + img_meta, img_feats, pts, pts.new_tensor(img_meta['lidar2img']), - pcd_rotate_mat, img_scale_factor, img_crop_offset, - pcd_trans_factor, - pcd_scale_factor, - pcd_flip=pcd_flip, img_flip=img_flip, img_pad_shape=img_meta['input_shape'][:2], img_shape=img_meta['img_shape'][:2], diff --git a/mmdet3d/models/fusion_layers/vote_fusion.py b/mmdet3d/models/fusion_layers/vote_fusion.py new file mode 100644 index 0000000000..c26daa9f67 --- /dev/null +++ b/mmdet3d/models/fusion_layers/vote_fusion.py @@ -0,0 +1,212 @@ +import torch +from torch import nn as nn + +from mmdet3d.core.bbox import Coord3DMode, points_cam2img +from ..registry import FUSION_LAYERS +from . import apply_3d_transformation, bbox_2d_transform, coord_2d_transform + +EPS = 1e-6 + + +@FUSION_LAYERS.register_module() +class VoteFusion(nn.Module): + """Fuse 2d features from 3d seeds. + + Args: + num_classes (int): number of classes. + max_imvote_per_pixel (int): max number of imvotes. + """ + + def __init__(self, num_classes=10, max_imvote_per_pixel=3): + super(VoteFusion, self).__init__() + self.num_classes = num_classes + self.max_imvote_per_pixel = max_imvote_per_pixel + + def forward(self, imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas, + calibs): + """Forward function. + + Args: + imgs (list[torch.Tensor]): Image features. + bboxes_2d_rescaled (list[torch.Tensor]): 2D bboxes. + seeds_3d_depth (torch.Tensor): 3D seeds. + img_metas (list[dict]): Meta information of images. + calibs: Camera calibration information of the images. + + Returns: + torch.Tensor: Concatenated cues of each point. + torch.Tensor: Validity mask of each feature. + """ + img_features = [] + masks = [] + for i, data in enumerate( + zip(imgs, bboxes_2d_rescaled, seeds_3d_depth, img_metas)): + img, bbox_2d_rescaled, seed_3d_depth, img_meta = data + bbox_num = bbox_2d_rescaled.shape[0] + seed_num = seed_3d_depth.shape[0] + + img_shape = img_meta['img_shape'] + img_h, img_w, _ = img_shape + + # first reverse the data transformations + xyz_depth = apply_3d_transformation( + seed_3d_depth, 'DEPTH', img_meta, reverse=True) + + # then convert from depth coords to camera coords + xyz_cam = Coord3DMode.convert_point( + xyz_depth, + Coord3DMode.DEPTH, + Coord3DMode.CAM, + rt_mat=calibs['Rt'][i]) + + # project to 2d to get image coords (uv) + uv_origin = points_cam2img(xyz_cam, calibs['K'][i]) + uv_origin = (uv_origin - 1).round() + + # rescale 2d coordinates and bboxes + uv_rescaled = coord_2d_transform(img_meta, uv_origin, True) + bbox_2d_origin = bbox_2d_transform(img_meta, bbox_2d_rescaled, + False) + + if bbox_num == 0: + imvote_num = seed_num * self.max_imvote_per_pixel + + # use zero features + two_cues = torch.zeros((15, imvote_num), + device=seed_3d_depth.device) + mask_zero = torch.zeros( + imvote_num - seed_num, device=seed_3d_depth.device).bool() + mask_one = torch.ones( + seed_num, device=seed_3d_depth.device).bool() + mask = torch.cat([mask_one, mask_zero], dim=0) + else: + # expand bboxes and seeds + bbox_expanded = bbox_2d_origin.view(1, bbox_num, -1).expand( + seed_num, -1, -1) + seed_2d_expanded = uv_origin.view(seed_num, 1, + -1).expand(-1, bbox_num, -1) + seed_2d_expanded_x, seed_2d_expanded_y = \ + seed_2d_expanded.split(1, dim=-1) + + bbox_expanded_l, bbox_expanded_t, bbox_expanded_r, \ + bbox_expanded_b, bbox_expanded_conf, bbox_expanded_cls = \ + bbox_expanded.split(1, dim=-1) + bbox_expanded_midx = (bbox_expanded_l + bbox_expanded_r) / 2 + bbox_expanded_midy = (bbox_expanded_t + bbox_expanded_b) / 2 + + seed_2d_in_bbox_x = (seed_2d_expanded_x > bbox_expanded_l) * \ + (seed_2d_expanded_x < bbox_expanded_r) + seed_2d_in_bbox_y = (seed_2d_expanded_y > bbox_expanded_t) * \ + (seed_2d_expanded_y < bbox_expanded_b) + seed_2d_in_bbox = seed_2d_in_bbox_x * seed_2d_in_bbox_y + + # semantic cues, dim=class_num + sem_cue = torch.zeros_like(bbox_expanded_conf).expand( + -1, -1, self.num_classes) + sem_cue = sem_cue.scatter(-1, bbox_expanded_cls.long(), + bbox_expanded_conf) + + # bbox center - uv + delta_u = bbox_expanded_midx - seed_2d_expanded_x + delta_v = bbox_expanded_midy - seed_2d_expanded_y + + seed_3d_expanded = seed_3d_depth.view(seed_num, 1, -1).expand( + -1, bbox_num, -1) + + z_cam = xyz_cam[..., 2:3].view(seed_num, 1, + 1).expand(-1, bbox_num, -1) + + delta_u = delta_u * z_cam / calibs['K'][i, 0, 0] + delta_v = delta_v * z_cam / calibs['K'][i, 0, 0] + + imvote = torch.cat( + [delta_u, delta_v, + torch.zeros_like(delta_v)], dim=-1).view(-1, 3) + + # convert from camera coords to depth coords + imvote = Coord3DMode.convert_point( + imvote.view((-1, 3)), + Coord3DMode.CAM, + Coord3DMode.DEPTH, + rt_mat=calibs['Rt'][i]) + + # apply transformation to lifted imvotes + imvote = apply_3d_transformation( + imvote, 'DEPTH', img_meta, reverse=False) + + seed_3d_expanded = seed_3d_expanded.reshape(imvote.shape) + + # ray angle + ray_angle = seed_3d_expanded + imvote + ray_angle /= torch.sqrt(torch.sum(ray_angle**2, -1) + + EPS).unsqueeze(-1) + + # imvote lifted to 3d + xz = ray_angle[:, [0, 2]] / (ray_angle[:, [1]] + EPS) \ + * seed_3d_expanded[:, [1]] - seed_3d_expanded[:, [0, 2]] + + # geometric cues, dim=5 + geo_cue = torch.cat([xz, ray_angle], + dim=-1).view(seed_num, -1, 5) + + two_cues = torch.cat([geo_cue, sem_cue], dim=-1) + # mask to 0 if seed not in bbox + two_cues = two_cues * seed_2d_in_bbox.float() + + feature_size = two_cues.shape[-1] + # if bbox number is too small, append zeros + if bbox_num < self.max_imvote_per_pixel: + append_num = self.max_imvote_per_pixel - bbox_num + append_zeros = torch.zeros( + (seed_num, append_num, 1), + device=seed_2d_in_bbox.device).bool() + seed_2d_in_bbox = torch.cat( + [seed_2d_in_bbox, append_zeros], dim=1) + append_zeros = torch.zeros( + (seed_num, append_num, feature_size), + device=two_cues.device) + two_cues = torch.cat([two_cues, append_zeros], dim=1) + append_zeros = torch.zeros((seed_num, append_num, 1), + device=two_cues.device) + bbox_expanded_conf = torch.cat( + [bbox_expanded_conf, append_zeros], dim=1) + + # sort the valid seed-bbox pair according to confidence + pair_score = seed_2d_in_bbox.float() + bbox_expanded_conf + # and find the largests + mask, indices = pair_score.topk( + self.max_imvote_per_pixel, + dim=1, + largest=True, + sorted=True) + + indices_img = indices.expand(-1, -1, feature_size) + two_cues = two_cues.gather(dim=1, index=indices_img) + two_cues = two_cues.transpose(1, 0) + two_cues = two_cues.reshape(-1, feature_size).transpose( + 1, 0).contiguous() + + # since conf is ~ (0, 1), floor gives us validity + mask = mask.floor().int() + mask = mask.transpose(1, 0).reshape(-1).bool() + + # clear the padding + img = img[:, :img_shape[0], :img_shape[1]] + img_flatten = img.reshape(3, -1).float() + img_flatten /= 255. + + # take the normalized pixel value as texture cue + uv_flatten = uv_rescaled[:, 1].round() * \ + img_shape[1] + uv_rescaled[:, 0].round() + uv_expanded = uv_flatten.unsqueeze(0).expand(3, -1).long() + txt_cue = torch.gather(img_flatten, dim=-1, index=uv_expanded) + txt_cue = txt_cue.unsqueeze(1).expand(-1, + self.max_imvote_per_pixel, + -1).reshape(3, -1) + + # append texture cue + img_feature = torch.cat([two_cues, txt_cue], dim=0) + img_features.append(img_feature) + masks.append(mask) + + return torch.stack(img_features, 0), torch.stack(masks, 0) diff --git a/mmdet3d/models/utils/__init__.py b/mmdet3d/models/utils/__init__.py index 2206490be1..94aa1923e1 100644 --- a/mmdet3d/models/utils/__init__.py +++ b/mmdet3d/models/utils/__init__.py @@ -1,3 +1,4 @@ from .clip_sigmoid import clip_sigmoid +from .mlp import MLP -__all__ = ['clip_sigmoid'] +__all__ = ['clip_sigmoid', 'MLP'] diff --git a/mmdet3d/models/utils/mlp.py b/mmdet3d/models/utils/mlp.py new file mode 100644 index 0000000000..4bb91d8b33 --- /dev/null +++ b/mmdet3d/models/utils/mlp.py @@ -0,0 +1,48 @@ +from mmcv.cnn import ConvModule +from torch import nn as nn + + +class MLP(nn.Module): + """A simple MLP module. + + Pass features (B, C, N) through an MLP. + + Args: + in_channels (int): Number of channels of input features. + Default: 18. + conv_channels (tuple[int]): Out channels of the convolution. + Default: (256, 256). + conv_cfg (dict): Config of convolution. + Default: dict(type='Conv1d'). + norm_cfg (dict): Config of normalization. + Default: dict(type='BN1d'). + act_cfg (dict): Config of activation. + Default: dict(type='ReLU'). + """ + + def __init__(self, + in_channel=18, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU')): + super().__init__() + self.mlp = nn.Sequential() + prev_channels = in_channel + for i, conv_channel in enumerate(conv_channels): + self.mlp.add_module( + f'layer{i}', + ConvModule( + prev_channels, + conv_channels[i], + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + bias=True, + inplace=True)) + prev_channels = conv_channels[i] + + def forward(self, img_features): + return self.mlp(img_features) diff --git a/tests/test_models/test_fusion/test_fusion_coord_trans.py b/tests/test_models/test_fusion/test_fusion_coord_trans.py new file mode 100644 index 0000000000..88c1cae5c1 --- /dev/null +++ b/tests/test_models/test_fusion/test_fusion_coord_trans.py @@ -0,0 +1,136 @@ +"""Tests coords transformation in fusion modules. + +CommandLine: + pytest tests/test_models/test_fusion/test_fusion_coord_trans.py +""" + +import torch + +from mmdet3d.models.fusion_layers import apply_3d_transformation + + +def test_coords_transformation(): + """Test the transformation of 3d coords.""" + + # H+R+S+T, not reverse, depth + img_meta = { + 'pcd_scale_factor': + 1.2311e+00, + 'pcd_rotation': [[8.660254e-01, 0.5, 0], [-0.5, 8.660254e-01, 0], + [0, 0, 1.0e+00]], + 'pcd_trans': [1.111e-02, -8.88e-03, 0.0], + 'pcd_horizontal_flip': + True, + 'transformation_3d_flow': ['HF', 'R', 'S', 'T'] + } + + pcd = torch.tensor([[-5.2422e+00, -2.9757e-01, 4.0021e+01], + [-9.1435e-01, 2.6675e+01, -5.5950e+00], + [2.0089e-01, 5.8098e+00, -3.5409e+01], + [-1.9461e-01, 3.1309e+01, -1.0901e+00]]) + + pcd_transformed = apply_3d_transformation( + pcd, 'DEPTH', img_meta, reverse=False) + + expected_tensor = torch.tensor( + [[5.78332345e+00, 2.900697e+00, 4.92698531e+01], + [-1.5433839e+01, 2.8993850e+01, -6.8880045e+00], + [-3.77929405e+00, 6.061661e+00, -4.35920199e+01], + [-1.9053658e+01, 3.3491436e+01, -1.34202211e+00]]) + + assert torch.allclose(expected_tensor, pcd_transformed, 1e-4) + + # H+R+S+T, reverse, depth + img_meta = { + 'pcd_scale_factor': + 7.07106781e-01, + 'pcd_rotation': [[7.07106781e-01, 7.07106781e-01, 0.0], + [-7.07106781e-01, 7.07106781e-01, 0.0], + [0.0, 0.0, 1.0e+00]], + 'pcd_trans': [0.0, 0.0, 0.0], + 'pcd_horizontal_flip': + False, + 'transformation_3d_flow': ['HF', 'R', 'S', 'T'] + } + + pcd = torch.tensor([[-5.2422e+00, -2.9757e-01, 4.0021e+01], + [-9.1435e+01, 2.6675e+01, -5.5950e+00], + [6.061661e+00, -0.0, -1.0e+02]]) + + pcd_transformed = apply_3d_transformation( + pcd, 'DEPTH', img_meta, reverse=True) + + expected_tensor = torch.tensor( + [[-5.53977e+00, 4.94463e+00, 5.65982409e+01], + [-6.476e+01, 1.1811e+02, -7.91252488e+00], + [6.061661e+00, -6.061661e+00, -1.41421356e+02]]) + assert torch.allclose(expected_tensor, pcd_transformed, 1e-4) + + # H+R+S+T, not reverse, camera + img_meta = { + 'pcd_scale_factor': + 1.0 / 7.07106781e-01, + 'pcd_rotation': [[7.07106781e-01, 0.0, 7.07106781e-01], + [0.0, 1.0e+00, 0.0], + [-7.07106781e-01, 0.0, 7.07106781e-01]], + 'pcd_trans': [1.0e+00, -1.0e+00, 0.0], + 'pcd_horizontal_flip': + True, + 'transformation_3d_flow': ['HF', 'S', 'R', 'T'] + } + + pcd = torch.tensor([[-5.2422e+00, 4.0021e+01, -2.9757e-01], + [-9.1435e+01, -5.5950e+00, 2.6675e+01], + [6.061661e+00, -1.0e+02, -0.0]]) + + pcd_transformed = apply_3d_transformation( + pcd, 'CAMERA', img_meta, reverse=False) + + expected_tensor = torch.tensor( + [[6.53977e+00, 5.55982409e+01, 4.94463e+00], + [6.576e+01, -8.91252488e+00, 1.1811e+02], + [-5.061661e+00, -1.42421356e+02, -6.061661e+00]]) + + assert torch.allclose(expected_tensor, pcd_transformed, 1e-4) + + # V, reverse, camera + img_meta = {'pcd_vertical_flip': True, 'transformation_3d_flow': ['VF']} + + pcd_transformed = apply_3d_transformation( + pcd, 'CAMERA', img_meta, reverse=True) + + expected_tensor = torch.tensor([[-5.2422e+00, 4.0021e+01, 2.9757e-01], + [-9.1435e+01, -5.5950e+00, -2.6675e+01], + [6.061661e+00, -1.0e+02, 0.0]]) + + assert torch.allclose(expected_tensor, pcd_transformed, 1e-4) + + # V+H, not reverse, depth + img_meta = { + 'pcd_vertical_flip': True, + 'pcd_horizontal_flip': True, + 'transformation_3d_flow': ['VF', 'HF'] + } + + pcd_transformed = apply_3d_transformation( + pcd, 'DEPTH', img_meta, reverse=False) + + expected_tensor = torch.tensor([[5.2422e+00, -4.0021e+01, -2.9757e-01], + [9.1435e+01, 5.5950e+00, 2.6675e+01], + [-6.061661e+00, 1.0e+02, 0.0]]) + assert torch.allclose(expected_tensor, pcd_transformed, 1e-4) + + # V+H, reverse, lidar + img_meta = { + 'pcd_vertical_flip': True, + 'pcd_horizontal_flip': True, + 'transformation_3d_flow': ['VF', 'HF'] + } + + pcd_transformed = apply_3d_transformation( + pcd, 'LIDAR', img_meta, reverse=True) + + expected_tensor = torch.tensor([[5.2422e+00, -4.0021e+01, -2.9757e-01], + [9.1435e+01, 5.5950e+00, 2.6675e+01], + [-6.061661e+00, 1.0e+02, 0.0]]) + assert torch.allclose(expected_tensor, pcd_transformed, 1e-4) diff --git a/tests/test_models/test_fusion/test_point_fusion.py b/tests/test_models/test_fusion/test_point_fusion.py new file mode 100644 index 0000000000..31911b0ba1 --- /dev/null +++ b/tests/test_models/test_fusion/test_point_fusion.py @@ -0,0 +1,60 @@ +"""Tests the core function of point fusion. + +CommandLine: + pytest tests/test_models/test_fusion/test_point_fusion.py +""" + +import torch + +from mmdet3d.models.fusion_layers import PointFusion + + +def test_sample_single(): + # this function makes sure the rewriting of 3d coords transformation + # in point fusion does not change the original behaviour + lidar2img = torch.tensor( + [[6.0294e+02, -7.0791e+02, -1.2275e+01, -1.7094e+02], + [1.7678e+02, 8.8088e+00, -7.0794e+02, -1.0257e+02], + [9.9998e-01, -1.5283e-03, -5.2907e-03, -3.2757e-01], + [0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00]]) + + # all use default + img_meta = { + 'transformation_3d_flow': ['R', 'S', 'T', 'HF'], + 'input_shape': [370, 1224], + 'img_shape': [370, 1224], + 'lidar2img': lidar2img, + } + + # dummy parameters + fuse = PointFusion(1, 1, 1, 1) + img_feat = torch.arange(370 * 1224)[None, ...].view( + 370, 1224)[None, None, ...].float() / (370 * 1224) + pts = torch.tensor([[8.356, -4.312, -0.445], [11.777, -6.724, -0.564], + [6.453, 2.53, -1.612], [6.227, -3.839, -0.563]]) + out = fuse.sample_single(img_feat, pts, img_meta) + + expected_tensor = torch.tensor( + [0.5560822, 0.5476625, 0.9687978, 0.6241757]) + assert torch.allclose(expected_tensor, out, 1e-4) + + pcd_rotation = torch.tensor([[8.660254e-01, 0.5, 0], + [-0.5, 8.660254e-01, 0], [0, 0, 1.0e+00]]) + pcd_scale_factor = 1.111 + pcd_trans = torch.tensor([1.0, -1.0, 0.5]) + pts = pts @ pcd_rotation + pts *= pcd_scale_factor + pts += pcd_trans + pts[:, 1] = -pts[:, 1] + + # not use default + img_meta.update({ + 'pcd_scale_factor': pcd_scale_factor, + 'pcd_rotation': pcd_rotation, + 'pcd_trans': pcd_trans, + 'pcd_horizontal_flip': True + }) + out = fuse.sample_single(img_feat, pts, img_meta) + expected_tensor = torch.tensor( + [0.5560822, 0.5476625, 0.9687978, 0.6241757]) + assert torch.allclose(expected_tensor, out, 1e-4) diff --git a/tests/test_models/test_fusion/test_vote_fusion.py b/tests/test_models/test_fusion/test_vote_fusion.py new file mode 100644 index 0000000000..b249310421 --- /dev/null +++ b/tests/test_models/test_fusion/test_vote_fusion.py @@ -0,0 +1,321 @@ +"""Tests the core function of vote fusion. + +CommandLine: + pytest tests/test_models/test_fusion/test_vote_fusion.py +""" + +import torch + +from mmdet3d.models.fusion_layers import VoteFusion + + +def test_vote_fusion(): + img_meta = { + 'ori_shape': (530, 730, 3), + 'img_shape': (600, 826, 3), + 'pad_shape': (608, 832, 3), + 'scale_factor': + torch.tensor([1.1315, 1.1321, 1.1315, 1.1321]), + 'flip': + False, + 'pcd_horizontal_flip': + False, + 'pcd_vertical_flip': + False, + 'pcd_trans': + torch.tensor([0., 0., 0.]), + 'pcd_scale_factor': + 1.0308290128214932, + 'pcd_rotation': + torch.tensor([[0.9747, 0.2234, 0.0000], [-0.2234, 0.9747, 0.0000], + [0.0000, 0.0000, 1.0000]]), + 'transformation_3d_flow': ['HF', 'R', 'S', 'T'] + } + + calibs = { + 'Rt': + torch.tensor([[[0.979570, 0.047954, -0.195330], + [0.047954, 0.887470, 0.458370], + [0.195330, -0.458370, 0.867030]]]), + 'K': + torch.tensor([[[529.5000, 0.0000, 365.0000], + [0.0000, 529.5000, 265.0000], [0.0000, 0.0000, + 1.0000]]]) + } + + bboxes = torch.tensor([[[ + 5.4286e+02, 9.8283e+01, 6.1700e+02, 1.6742e+02, 9.7922e-01, 3.0000e+00 + ], [ + 4.2613e+02, 8.4646e+01, 4.9091e+02, 1.6237e+02, 9.7848e-01, 3.0000e+00 + ], [ + 2.5606e+02, 7.3244e+01, 3.7883e+02, 1.8471e+02, 9.7317e-01, 3.0000e+00 + ], [ + 6.0104e+02, 1.0648e+02, 6.6757e+02, 1.9216e+02, 8.4607e-01, 3.0000e+00 + ], [ + 2.2923e+02, 1.4984e+02, 7.0163e+02, 4.6537e+02, 3.5719e-01, 0.0000e+00 + ], [ + 2.5614e+02, 7.4965e+01, 3.3275e+02, 1.5908e+02, 2.8688e-01, 3.0000e+00 + ], [ + 9.8718e+00, 1.4142e+02, 2.0213e+02, 3.3878e+02, 1.0935e-01, 3.0000e+00 + ], [ + 6.1930e+02, 1.1768e+02, 6.8505e+02, 2.0318e+02, 1.0720e-01, 3.0000e+00 + ]]]) + + seeds_3d = torch.tensor([[[0.044544, 1.675476, -1.531831], + [2.500625, 7.238662, -0.737675], + [-0.600003, 4.827733, -0.084022], + [1.396212, 3.994484, -1.551180], + [-2.054746, 2.012759, -0.357472], + [-0.582477, 6.580470, -1.466052], + [1.313331, 5.722039, 0.123904], + [-1.107057, 3.450359, -1.043422], + [1.759746, 5.655951, -1.519564], + [-0.203003, 6.453243, 0.137703], + [-0.910429, 0.904407, -0.512307], + [0.434049, 3.032374, -0.763842], + [1.438146, 2.289263, -1.546332], + [0.575622, 5.041906, -0.891143], + [-1.675931, 1.417597, -1.588347]]]) + + imgs = torch.linspace( + -1, 1, steps=608 * 832).reshape(1, 608, 832).repeat(3, 1, 1)[None] + + expected_tensor1 = torch.tensor( + [[[ + 0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, + 0.000000e+00, 1.193706e-01, -0.000000e+00, -2.879214e-01, + -0.000000e+00, 0.000000e+00, 1.422463e-01, -6.474612e-01, + -0.000000e+00, 1.490057e-02, 0.000000e+00 + ], + [ + 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, + 0.000000e+00, -1.873745e+00, -0.000000e+00, 1.576240e-01, + 0.000000e+00, -0.000000e+00, -3.646177e-02, -7.751858e-01, + 0.000000e+00, 9.593642e-02, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, -6.263277e-02, 0.000000e+00, -3.646387e-01, + 0.000000e+00, 0.000000e+00, -5.875812e-01, -6.263450e-02, + 0.000000e+00, 1.149264e-01, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 8.899736e-01, 0.000000e+00, 9.019017e-01, + 0.000000e+00, 0.000000e+00, 6.917775e-01, 8.899733e-01, + 0.000000e+00, 9.812444e-01, 0.000000e+00 + ], + [ + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + -0.000000e+00, -4.516903e-01, -0.000000e+00, -2.315422e-01, + -0.000000e+00, -0.000000e+00, -4.197519e-01, -4.516906e-01, + -0.000000e+00, -1.547615e-01, -0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 3.571937e-01, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 3.571937e-01, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 9.731653e-01, + 0.000000e+00, 0.000000e+00, 1.093455e-01, 0.000000e+00, + 0.000000e+00, 8.460656e-01, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, + -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, + -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, + 2.540967e-03, -1.834944e-03, 1.032048e-03 + ], + [ + 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, + -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, + -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, + 2.540967e-03, -1.834944e-03, 1.032048e-03 + ], + [ + 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, + -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, + -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, + 2.540967e-03, -1.834944e-03, 1.032048e-03 + ]]]) + + expected_tensor2 = torch.tensor([[ + False, False, False, False, False, True, False, True, False, False, + True, True, False, True, False, False, False, False, False, False, + False, False, True, False, False, False, False, False, True, False, + False, False, False, False, False, False, False, False, False, False, + False, False, False, True, False + ]]) + + expected_tensor3 = torch.tensor( + [[[ + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, + -0.000000e+00, -0.000000e+00, 0.000000e+00, -0.000000e+00, + -0.000000e+00, 1.720988e-01, 0.000000e+00 + ], + [ + 0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, + -0.000000e+00, 0.000000e+00, -0.000000e+00, 0.000000e+00, + 0.000000e+00, -0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 4.824460e-02, 0.000000e+00 + ], + [ + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + -0.000000e+00, -0.000000e+00, -0.000000e+00, 0.000000e+00, + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + -0.000000e+00, 1.447314e-01, -0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 9.759269e-01, 0.000000e+00 + ], + [ + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + -0.000000e+00, -0.000000e+00, -0.000000e+00, -0.000000e+00, + -0.000000e+00, -1.631542e-01, -0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 1.072001e-01, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00, 0.000000e+00, + 0.000000e+00, 0.000000e+00, 0.000000e+00 + ], + [ + 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, + -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, + -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, + 2.540967e-03, -1.834944e-03, 1.032048e-03 + ], + [ + 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, + -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, + -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, + 2.540967e-03, -1.834944e-03, 1.032048e-03 + ], + [ + 2.316288e-03, -1.948284e-03, -3.694394e-03, 2.176163e-04, + -3.882605e-03, -1.901490e-03, -3.355042e-03, -1.774631e-03, + -6.981542e-04, -3.886823e-03, -1.302233e-03, -1.189933e-03, + 2.540967e-03, -1.834944e-03, 1.032048e-03 + ]]]) + + fusion = VoteFusion() + out1, out2 = fusion(imgs, bboxes, seeds_3d, [img_meta], calibs) + assert torch.allclose(expected_tensor1, out1[:, :, :15], 1e-3) + assert torch.allclose(expected_tensor2.float(), out2.float(), 1e-3) + assert torch.allclose(expected_tensor3, out1[:, :, 30:45], 1e-3) + + out1, out2 = fusion(imgs, bboxes[:, :2], seeds_3d, [img_meta], calibs) + out1 = out1[:, :15, 30:45] + out2 = out2[:, 30:45].float() + assert torch.allclose(torch.zeros_like(out1), out1, 1e-3) + assert torch.allclose(torch.zeros_like(out2), out2, 1e-3) diff --git a/tests/test_utils/test_coord_3d_mode.py b/tests/test_utils/test_coord_3d_mode.py index 00ba83ebbf..73d15ff902 100644 --- a/tests/test_utils/test_coord_3d_mode.py +++ b/tests/test_utils/test_coord_3d_mode.py @@ -62,21 +62,21 @@ def test_points_conversion(): convert_depth_points = cam_points.convert_to(Coord3DMode.DEPTH) expected_tensor = torch.tensor([[ - -5.2422e+00, -2.9757e-01, 4.0021e+01, 6.6660e-01, 1.9560e-01, + -5.2422e+00, 2.9757e-01, -4.0021e+01, 6.6660e-01, 1.9560e-01, 4.9740e-01, 9.4090e-01 ], [ - -2.6675e+01, 9.1435e-01, 5.5950e+00, + -2.6675e+01, -9.1435e-01, -5.5950e+00, 1.5020e-01, 3.7070e-01, 1.0860e-01, 6.2970e-01 ], [ - -5.8098e+00, -2.0089e-01, 3.5409e+01, + -5.8098e+00, 2.0089e-01, -3.5409e+01, 6.5650e-01, 6.2480e-01, 6.9540e-01, 2.5380e-01 ], [ - -3.1309e+01, 1.9461e-01, 1.0901e+00, + -3.1309e+01, -1.9461e-01, -1.0901e+00, 2.8030e-01, 2.5800e-02, 4.8960e-01, 3.2690e-01 ]]) @@ -157,21 +157,21 @@ def test_points_conversion(): convert_cam_points = depth_points.convert_to(Coord3DMode.CAM) expected_tensor = torch.tensor([[ - -5.2422e+00, 2.9757e-01, -4.0021e+01, 6.6660e-01, 1.9560e-01, + -5.2422e+00, -2.9757e-01, 4.0021e+01, 6.6660e-01, 1.9560e-01, 4.9740e-01, 9.4090e-01 ], [ - -2.6675e+01, -9.1435e-01, -5.5950e+00, + -2.6675e+01, 9.1435e-01, 5.5950e+00, 1.5020e-01, 3.7070e-01, 1.0860e-01, 6.2970e-01 ], [ - -5.8098e+00, 2.0089e-01, -3.5409e+01, + -5.8098e+00, -2.0089e-01, 3.5409e+01, 6.5650e-01, 6.2480e-01, 6.9540e-01, 2.5380e-01 ], [ - -3.1309e+01, -1.9461e-01, -1.0901e+00, + -3.1309e+01, 1.9461e-01, 1.0901e+00, 2.8030e-01, 2.5800e-02, 4.8960e-01, 3.2690e-01 ]]) @@ -182,6 +182,22 @@ def test_points_conversion(): assert torch.allclose(expected_tensor, convert_cam_points.tensor, 1e-4) assert torch.allclose(cam_point_tensor, convert_cam_points.tensor, 1e-4) + rt_mat_provided = torch.tensor([[0.99789, -0.012698, -0.063678], + [-0.012698, 0.92359, -0.38316], + [0.063678, 0.38316, 0.92148]]) + + depth_points_new = torch.cat([ + depth_points.tensor[:, :3] @ rt_mat_provided.t(), + depth_points.tensor[:, 3:] + ], + dim=1) + cam_point_tensor_new = Coord3DMode.convert_point( + depth_points_new, + Coord3DMode.DEPTH, + Coord3DMode.CAM, + rt_mat=rt_mat_provided) + assert torch.allclose(expected_tensor, cam_point_tensor_new, 1e-4) + convert_lidar_points = depth_points.convert_to(Coord3DMode.LIDAR) expected_tensor = torch.tensor([[ 4.0021e+01, 5.2422e+00, 2.9757e-01, 6.6660e-01, 1.9560e-01, 4.9740e-01, diff --git a/tools/data_converter/sunrgbd_data_utils.py b/tools/data_converter/sunrgbd_data_utils.py index 65187b8444..1149fab5d8 100644 --- a/tools/data_converter/sunrgbd_data_utils.py +++ b/tools/data_converter/sunrgbd_data_utils.py @@ -111,8 +111,9 @@ def get_calibration(self, idx): calib_filepath = osp.join(self.calib_dir, f'{idx:06d}.txt') lines = [line.rstrip() for line in open(calib_filepath)] Rt = np.array([float(x) for x in lines[0].split(' ')]) - Rt = np.reshape(Rt, (3, 3), order='F') + Rt = np.reshape(Rt, (3, 3), order='F').astype(np.float32) K = np.array([float(x) for x in lines[1].split(' ')]) + K = np.reshape(Rt, (3, 3), order='F').astype(np.float32) return K, Rt def get_label_objects(self, idx): @@ -155,8 +156,7 @@ def process_single_scene(sample_idx): osp.join(self.root_dir, 'points', f'{sample_idx:06d}.bin')) info['pts_path'] = osp.join('points', f'{sample_idx:06d}.bin') - img_name = osp.join(self.image_dir, f'{sample_idx:06d}') - img_path = osp.join(self.image_dir, img_name) + img_path = osp.join('image', f'{sample_idx:06d}.jpg') image_info = { 'image_idx': sample_idx, 'image_shape': self.get_image_shape(sample_idx),