diff --git a/configs/_base_/models/h3dnet.py b/configs/_base_/models/h3dnet.py new file mode 100644 index 0000000000..1b34aca396 --- /dev/null +++ b/configs/_base_/models/h3dnet.py @@ -0,0 +1,332 @@ +primitive_z_cfg = dict( + type='PrimitiveHead', + num_dims=2, + num_classes=18, + primitive_mode='z', + upper_thresh=100.0, + surface_thresh=0.5, + vote_moudule_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=1, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + num_point=1024, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True), + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.4, 0.6], + reduction='mean', + loss_weight=30.0), + center_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=0.5, + loss_dst_weight=0.5), + semantic_reg_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=0.5, + loss_dst_weight=0.5), + semantic_cls_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + train_cfg=dict( + dist_thresh=0.2, + var_thresh=1e-2, + lower_thresh=1e-6, + num_point=100, + num_point_line=10, + line_thresh=0.2)) + +primitive_xy_cfg = dict( + type='PrimitiveHead', + num_dims=1, + num_classes=18, + primitive_mode='xy', + upper_thresh=100.0, + surface_thresh=0.5, + vote_moudule_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=1, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + num_point=1024, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True), + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.4, 0.6], + reduction='mean', + loss_weight=30.0), + center_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=0.5, + loss_dst_weight=0.5), + semantic_reg_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=0.5, + loss_dst_weight=0.5), + semantic_cls_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + train_cfg=dict( + dist_thresh=0.2, + var_thresh=1e-2, + lower_thresh=1e-6, + num_point=100, + num_point_line=10, + line_thresh=0.2)) + +primitive_line_cfg = dict( + type='PrimitiveHead', + num_dims=0, + num_classes=18, + primitive_mode='line', + upper_thresh=100.0, + surface_thresh=0.5, + vote_moudule_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=1, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + num_point=1024, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True), + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.4, 0.6], + reduction='mean', + loss_weight=30.0), + center_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=1.0, + loss_dst_weight=1.0), + semantic_reg_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='sum', + loss_src_weight=1.0, + loss_dst_weight=1.0), + semantic_cls_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=2.0), + train_cfg=dict( + dist_thresh=0.2, + var_thresh=1e-2, + lower_thresh=1e-6, + num_point=100, + num_point_line=10, + line_thresh=0.2)) + +proposal_module_cfg = dict( + suface_matching_cfg=dict( + num_point=256 * 6, + radius=0.5, + num_sample=32, + mlp_channels=[128 + 6, 128, 64, 32], + use_xyz=True, + normalize_xyz=True), + line_matching_cfg=dict( + num_point=256 * 12, + radius=0.5, + num_sample=32, + mlp_channels=[128 + 12, 128, 64, 32], + use_xyz=True, + normalize_xyz=True), + primitive_refine_channels=[128, 128, 128], + upper_thresh=100.0, + surface_thresh=0.5, + line_thresh=0.5, + train_cfg=dict( + far_threshold=0.6, + near_threshold=0.3, + mask_surface_threshold=0.3, + label_surface_threshold=0.3, + mask_line_threshold=0.3, + label_line_threshold=0.3), + cues_objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.3, 0.7], + reduction='mean', + loss_weight=5.0), + cues_semantic_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.3, 0.7], + reduction='mean', + loss_weight=5.0), + proposal_objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.2, 0.8], + reduction='none', + loss_weight=5.0), + primitive_center_loss=dict( + type='MSELoss', reduction='none', loss_weight=1.0)) + +model = dict( + type='H3DNet', + backbone=dict( + type='MultiBackbone', + num_streams=4, + suffixes=['net0', 'net1', 'net2', 'net3'], + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01), + act_cfg=dict(type='ReLU'), + backbones=dict( + type='PointNet2SASSG', + in_channels=4, + num_points=(2048, 1024, 512, 256), + radius=(0.2, 0.4, 0.8, 1.2), + num_samples=(64, 32, 16, 16), + sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), + (128, 128, 256)), + fp_channels=((256, 256), (256, 256)), + norm_cfg=dict(type='BN2d'), + pool_mod='max')), + rpn_head=dict( + type='VoteHead', + vote_moudule_cfg=dict( + in_channels=256, + vote_per_seed=1, + gt_per_seed=3, + conv_channels=(256, 256), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + norm_feats=True, + vote_loss=dict( + type='ChamferDistance', + mode='l1', + reduction='none', + loss_dst_weight=10.0)), + vote_aggregation_cfg=dict( + num_point=256, + radius=0.3, + num_sample=16, + mlp_channels=[256, 128, 128, 128], + use_xyz=True, + normalize_xyz=True), + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.2, 0.8], + reduction='sum', + loss_weight=5.0), + center_loss=dict( + type='ChamferDistance', + mode='l2', + reduction='sum', + loss_src_weight=10.0, + loss_dst_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + roi_head=dict( + type='H3DRoIHead', + primitive_list=[primitive_z_cfg, primitive_xy_cfg, primitive_line_cfg], + bbox_head=dict( + type='H3DBboxHead', + gt_per_seed=3, + num_proposal=256, + proposal_module_cfg=proposal_module_cfg, + feat_channels=(128, 128), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=dict( + type='CrossEntropyLoss', + class_weight=[0.2, 0.8], + reduction='sum', + loss_weight=5.0), + center_loss=dict( + type='ChamferDistance', + mode='l2', + reduction='sum', + loss_src_weight=10.0, + loss_dst_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=0.1), + size_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=0.1)))) + +# model training and testing settings +train_cfg = dict( + rpn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote'), + rpn_proposal=dict(use_nms=False), + rcnn=dict(pos_distance_thr=0.3, neg_distance_thr=0.6, sample_mod='vote')) + +test_cfg = dict( + rpn=dict( + sample_mod='seed', + nms_thr=0.25, + score_thr=0.05, + per_class_proposal=True, + use_nms=False), + rcnn=dict( + sample_mod='seed', + nms_thr=0.25, + score_thr=0.05, + per_class_proposal=True)) diff --git a/configs/h3dnet/README.md b/configs/h3dnet/README.md new file mode 100644 index 0000000000..0c084e0619 --- /dev/null +++ b/configs/h3dnet/README.md @@ -0,0 +1,19 @@ +# H3DNet: 3D Object Detection Using Hybrid Geometric Primitives + +## Introduction +We implement H3DNet and provide the result and checkpoints on ScanNet datasets. +``` +@inproceedings{zhang2020h3dnet, + author = {Zhang, Zaiwei and Sun, Bo and Yang, Haitao and Huang, Qixing}, + title = {H3DNet: 3D Object Detection Using Hybrid Geometric Primitives}, + booktitle = {Proceedings of the European Conference on Computer Vision}, + year = {2020} +} +``` + +## Results + +### ScanNet +| Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download | +| :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | +| [MultiBackbone](./h3dnet_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|[model](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238-2cea9c3a.pth) | [log](https://openmmlab.oss-accelerate.aliyuncs.com/mmdetection3d/v0.1.0_models/votenet/votenet_8x8_scannet-3d-18class/votenet_8x8_scannet-3d-18class_20200620_230238.log.json)| diff --git a/configs/h3dnet/h3dnet_scannet-3d-18class.py b/configs/h3dnet/h3dnet_scannet-3d-18class.py new file mode 100644 index 0000000000..c90abb7803 --- /dev/null +++ b/configs/h3dnet/h3dnet_scannet-3d-18class.py @@ -0,0 +1,70 @@ +_base_ = [ + '../_base_/datasets/scannet-3d-18class.py', '../_base_/models/h3dnet.py', + '../_base_/schedules/schedule_3x.py', '../_base_/default_runtime.py' +] + +# model settings +model = dict( + rpn_head=dict( + num_classes=18, + bbox_coder=dict( + type='PartialBinBasedBBoxCoder', + num_sizes=18, + num_dir_bins=24, + with_rot=False, + mean_sizes=[[0.76966727, 0.8116021, 0.92573744], + [1.876858, 1.8425595, 1.1931566], + [0.61328, 0.6148609, 0.7182701], + [1.3955007, 1.5121545, 0.83443564], + [0.97949594, 1.0675149, 0.6329687], + [0.531663, 0.5955577, 1.7500148], + [0.9624706, 0.72462326, 1.1481868], + [0.83221924, 1.0490936, 1.6875663], + [0.21132214, 0.4206159, 0.5372846], + [1.4440073, 1.8970833, 0.26985747], + [1.0294262, 1.4040797, 0.87554324], + [1.3766412, 0.65521795, 1.6813129], + [0.6650819, 0.71111923, 1.298853], + [0.41999173, 0.37906948, 1.7513971], + [0.59359556, 0.5912492, 0.73919016], + [0.50867593, 0.50656086, 0.30136237], + [1.1511526, 1.0546296, 0.49706793], + [0.47535285, 0.49249494, 0.5802117]])), + roi_head=dict( + bbox_head=dict( + num_classes=18, + bbox_coder=dict( + type='PartialBinBasedBBoxCoder', + num_sizes=18, + num_dir_bins=24, + with_rot=False, + mean_sizes=[[0.76966727, 0.8116021, 0.92573744], + [1.876858, 1.8425595, 1.1931566], + [0.61328, 0.6148609, 0.7182701], + [1.3955007, 1.5121545, 0.83443564], + [0.97949594, 1.0675149, 0.6329687], + [0.531663, 0.5955577, 1.7500148], + [0.9624706, 0.72462326, 1.1481868], + [0.83221924, 1.0490936, 1.6875663], + [0.21132214, 0.4206159, 0.5372846], + [1.4440073, 1.8970833, 0.26985747], + [1.0294262, 1.4040797, 0.87554324], + [1.3766412, 0.65521795, 1.6813129], + [0.6650819, 0.71111923, 1.298853], + [0.41999173, 0.37906948, 1.7513971], + [0.59359556, 0.5912492, 0.73919016], + [0.50867593, 0.50656086, 0.30136237], + [1.1511526, 1.0546296, 0.49706793], + [0.47535285, 0.49249494, 0.5802117]])))) + +data = dict(samples_per_gpu=3, workers_per_gpu=2) + +# optimizer +# yapf:disable +log_config = dict( + interval=30, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable diff --git a/mmdet3d/core/bbox/coders/partial_bin_based_bbox_coder.py b/mmdet3d/core/bbox/coders/partial_bin_based_bbox_coder.py index 724635383f..5edeaad115 100644 --- a/mmdet3d/core/bbox/coders/partial_bin_based_bbox_coder.py +++ b/mmdet3d/core/bbox/coders/partial_bin_based_bbox_coder.py @@ -55,7 +55,7 @@ def encode(self, gt_bboxes_3d, gt_labels_3d): return (center_target, size_class_target, size_res_target, dir_class_target, dir_res_target) - def decode(self, bbox_out): + def decode(self, bbox_out, suffix=''): """Decode predicted parts to bbox3d. Args: @@ -66,17 +66,18 @@ def decode(self, bbox_out): - dir_res: predicted bbox direction residual. - size_class: predicted bbox size class. - size_res: predicted bbox size residual. + suffix (str): Decode predictions with specific suffix. Returns: torch.Tensor: Decoded bbox3d with shape (batch, n, 7). """ - center = bbox_out['center'] + center = bbox_out['center' + suffix] batch_size, num_proposal = center.shape[:2] # decode heading angle if self.with_rot: - dir_class = torch.argmax(bbox_out['dir_class'], -1) - dir_res = torch.gather(bbox_out['dir_res'], 2, + dir_class = torch.argmax(bbox_out['dir_class' + suffix], -1) + dir_res = torch.gather(bbox_out['dir_res' + suffix], 2, dir_class.unsqueeze(-1)) dir_res.squeeze_(2) dir_angle = self.class2angle(dir_class, dir_res).reshape( @@ -85,8 +86,9 @@ def decode(self, bbox_out): dir_angle = center.new_zeros(batch_size, num_proposal, 1) # decode bbox size - size_class = torch.argmax(bbox_out['size_class'], -1, keepdim=True) - size_res = torch.gather(bbox_out['size_res'], 2, + size_class = torch.argmax( + bbox_out['size_class' + suffix], -1, keepdim=True) + size_res = torch.gather(bbox_out['size_res' + suffix], 2, size_class.unsqueeze(-1).repeat(1, 1, 1, 3)) mean_sizes = center.new_tensor(self.mean_sizes) size_base = torch.index_select(mean_sizes, 0, size_class.reshape(-1)) diff --git a/mmdet3d/core/bbox/structures/depth_box3d.py b/mmdet3d/core/bbox/structures/depth_box3d.py index 7e5cdbb75b..5d675cd411 100644 --- a/mmdet3d/core/bbox/structures/depth_box3d.py +++ b/mmdet3d/core/bbox/structures/depth_box3d.py @@ -251,3 +251,52 @@ def points_in_boxes(self, points): box_idxs_of_pts = points_in_boxes_batch(points_lidar, boxes_lidar) return box_idxs_of_pts.squeeze(0) + + def get_surface_line_center(self): + """Compute surface and line center of bounding boxes. + + Returns: + torch.Tensor: Surface and line center of bounding boxes. + """ + obj_size = self.dims + center = self.gravity_center + batch_size = center.shape[0] + + rot_sin = torch.sin(-self.yaw) + rot_cos = torch.cos(-self.yaw) + rot_mat_T = self.yaw.new_zeros(tuple(list(self.yaw.shape) + [3, 3])) + rot_mat_T[..., 0, 0] = rot_cos + rot_mat_T[..., 0, 1] = -rot_sin + rot_mat_T[..., 1, 0] = rot_sin + rot_mat_T[..., 1, 1] = rot_cos + rot_mat_T[..., 2, 2] = 1 + + # Get the object surface center + offset = obj_size.new_tensor([[0, 0, 1], [0, 0, -1], [0, 1, 0], + [0, -1, 0], [1, 0, 0], [-1, 0, 0]]) + offset = offset.view(1, 6, 3) / 2 + surface_3d = (offset * obj_size.view(batch_size, 1, 3).repeat( + 1, 6, 1)).transpose(0, 1).reshape(-1, 3) + + # Get the object line center + offset = obj_size.new_tensor([[1, 0, 1], [-1, 0, 1], [0, 1, 1], + [0, -1, 1], [1, 0, -1], [-1, 0, -1], + [0, 1, -1], [0, -1, -1], [1, 1, 0], + [1, -1, 0], [-1, 1, 0], [-1, -1, 0]]) + offset = offset.view(1, 12, 3) / 2 + + line_3d = (offset * + obj_size.view(batch_size, 1, 3).repeat(1, 12, 1)).transpose( + 0, 1).reshape(-1, 3) + + surface_rot = rot_mat_T.repeat(6, 1, 1) + surface_3d = torch.matmul( + surface_3d.unsqueeze(-2), surface_rot.transpose(2, 1)).squeeze(-2) + surface_center = center.repeat(6, 1) + surface_3d + + line_rot = rot_mat_T.repeat(12, 1, 1) + line_3d = torch.matmul( + line_3d.unsqueeze(-2), line_rot.transpose(2, 1)).squeeze(-2) + line_center = center.repeat(12, 1) + line_3d + + return surface_center, line_center diff --git a/mmdet3d/models/dense_heads/vote_head.py b/mmdet3d/models/dense_heads/vote_head.py index f9ecc063d1..c37917b86b 100644 --- a/mmdet3d/models/dense_heads/vote_head.py +++ b/mmdet3d/models/dense_heads/vote_head.py @@ -164,6 +164,7 @@ def forward(self, feat_dict, sample_mod): sample_indices) aggregated_points, features, aggregated_indices = vote_aggregation_ret results['aggregated_points'] = aggregated_points + results['aggregated_features'] = features results['aggregated_indices'] = aggregated_indices # 3. predict bbox and score @@ -183,7 +184,8 @@ def loss(self, pts_semantic_mask=None, pts_instance_mask=None, img_metas=None, - gt_bboxes_ignore=None): + gt_bboxes_ignore=None, + ret_target=False): """Compute loss. Args: @@ -199,6 +201,7 @@ def loss(self, img_metas (list[dict]): Contain pcd and img's meta info. gt_bboxes_ignore (None | list[torch.Tensor]): Specify which bounding. + ret_target (Bool): Return targets or not. Returns: dict: Losses of Votenet. @@ -283,6 +286,10 @@ def loss(self, dir_res_loss=dir_res_loss, size_class_loss=size_class_loss, size_res_loss=size_res_loss) + + if ret_target: + losses['targets'] = targets + return losses def get_targets(self, @@ -494,7 +501,12 @@ def get_targets_single(self, dir_class_targets, dir_res_targets, center_targets, mask_targets.long(), objectness_targets, objectness_masks) - def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): + def get_bboxes(self, + points, + bbox_preds, + input_metas, + rescale=False, + use_nms=True): """Generate bboxes from vote head predictions. Args: @@ -502,6 +514,8 @@ def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): bbox_preds (dict): Predictions from vote head. input_metas (list[dict]): Point cloud and image's meta info. rescale (bool): Whether to rescale bboxes. + use_nms (bool): Whether to apply NMS, skip nms postprocessing + while using vote head in rpn stage. Returns: list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. @@ -511,19 +525,23 @@ def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) bbox3d = self.bbox_coder.decode(bbox_preds) - batch_size = bbox3d.shape[0] - results = list() - for b in range(batch_size): - bbox_selected, score_selected, labels = self.multiclass_nms_single( - obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], - input_metas[b]) - bbox = input_metas[b]['box_type_3d']( - bbox_selected, - box_dim=bbox_selected.shape[-1], - with_yaw=self.bbox_coder.with_rot) - results.append((bbox, score_selected, labels)) - - return results + if use_nms: + batch_size = bbox3d.shape[0] + results = list() + for b in range(batch_size): + bbox_selected, score_selected, labels = \ + self.multiclass_nms_single(obj_scores[b], sem_scores[b], + bbox3d[b], points[b, ..., :3], + input_metas[b]) + bbox = input_metas[b]['box_type_3d']( + bbox_selected, + box_dim=bbox_selected.shape[-1], + with_yaw=self.bbox_coder.with_rot) + results.append((bbox, score_selected, labels)) + + return results + else: + return bbox3d def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, input_meta): diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index 8abf95f2eb..c0995a8b1d 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -1,5 +1,6 @@ from .base import Base3DDetector from .dynamic_voxelnet import DynamicVoxelNet +from .h3dnet import H3DNet from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_two_stage import MVXTwoStageDetector from .parta2 import PartA2 @@ -8,5 +9,5 @@ __all__ = [ 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', - 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet' + 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet' ] diff --git a/mmdet3d/models/detectors/h3dnet.py b/mmdet3d/models/detectors/h3dnet.py new file mode 100644 index 0000000000..831f7c3491 --- /dev/null +++ b/mmdet3d/models/detectors/h3dnet.py @@ -0,0 +1,173 @@ +import torch + +from mmdet3d.core import merge_aug_bboxes_3d +from mmdet.models import DETECTORS +from .two_stage import TwoStage3DDetector + + +@DETECTORS.register_module() +class H3DNet(TwoStage3DDetector): + r"""H3DNet model. + + Please refer to the `paper `_ + """ + + def __init__(self, + backbone, + neck=None, + rpn_head=None, + roi_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(H3DNet, self).__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained) + + def forward_train(self, + points, + img_metas, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + gt_bboxes_ignore=None): + """Forward of training. + + Args: + points (list[torch.Tensor]): Points of each batch. + img_metas (list): Image metas. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. + pts_semantic_mask (None | list[torch.Tensor]): point-wise semantic + label of each batch. + pts_instance_mask (None | list[torch.Tensor]): point-wise instance + label of each batch. + gt_bboxes_ignore (None | list[torch.Tensor]): Specify + which bounding. + + Returns: + dict: Losses. + """ + points_cat = torch.stack(points) + + feats_dict = self.extract_feat(points_cat) + feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]] + feats_dict['fp_features'] = [feats_dict['hd_feature']] + feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]] + + losses = dict() + if self.with_rpn: + rpn_outs = self.rpn_head(feats_dict, self.train_cfg.rpn.sample_mod) + feats_dict.update(rpn_outs) + + rpn_loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask, img_metas) + rpn_losses = self.rpn_head.loss( + rpn_outs, + *rpn_loss_inputs, + gt_bboxes_ignore=gt_bboxes_ignore, + ret_target=True) + feats_dict['targets'] = rpn_losses.pop('targets') + losses.update(rpn_losses) + + # Generate rpn proposals + proposal_cfg = self.train_cfg.get('rpn_proposal', + self.test_cfg.rpn) + proposal_inputs = (points, rpn_outs, img_metas) + proposal_list = self.rpn_head.get_bboxes( + *proposal_inputs, use_nms=proposal_cfg.use_nms) + feats_dict['proposal_list'] = proposal_list + else: + raise NotImplementedError + + roi_losses = self.roi_head.forward_train(feats_dict, img_metas, points, + gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, + pts_instance_mask, + gt_bboxes_ignore) + losses.update(roi_losses) + + return losses + + def simple_test(self, points, img_metas, imgs=None, rescale=False): + """Forward of testing. + + Args: + points (list[torch.Tensor]): Points of each sample. + img_metas (list): Image metas. + rescale (bool): Whether to rescale results. + + Returns: + list: Predicted 3d boxes. + """ + points_cat = torch.stack(points) + + feats_dict = self.extract_feat(points_cat) + feats_dict['fp_xyz'] = [feats_dict['fp_xyz_net0'][-1]] + feats_dict['fp_features'] = [feats_dict['hd_feature']] + feats_dict['fp_indices'] = [feats_dict['fp_indices_net0'][-1]] + + if self.with_rpn: + proposal_cfg = self.test_cfg.rpn + rpn_outs = self.rpn_head(feats_dict, proposal_cfg.sample_mod) + feats_dict.update(rpn_outs) + # Generate rpn proposals + proposal_list = self.rpn_head.get_bboxes( + points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms) + feats_dict['proposal_list'] = proposal_list + else: + raise NotImplementedError + + return self.roi_head.simple_test( + feats_dict, img_metas, points_cat, rescale=rescale) + + def aug_test(self, points, img_metas, imgs=None, rescale=False): + """Test with augmentation.""" + points_cat = [torch.stack(pts) for pts in points] + feats_dict = self.extract_feats(points_cat, img_metas) + for feat_dict in feats_dict: + feat_dict['fp_xyz'] = [feat_dict['fp_xyz_net0'][-1]] + feat_dict['fp_features'] = [feat_dict['hd_feature']] + feat_dict['fp_indices'] = [feat_dict['fp_indices_net0'][-1]] + + # only support aug_test for one sample + aug_bboxes = [] + for feat_dict, pts_cat, img_meta in zip(feats_dict, points_cat, + img_metas): + if self.with_rpn: + proposal_cfg = self.test_cfg.rpn + rpn_outs = self.rpn_head(feat_dict, proposal_cfg.sample_mod) + feat_dict.update(rpn_outs) + # Generate rpn proposals + proposal_list = self.rpn_head.get_bboxes( + points, rpn_outs, img_metas, use_nms=proposal_cfg.use_nms) + feat_dict['proposal_list'] = proposal_list + else: + raise NotImplementedError + + bbox_results = self.roi_head.simple_test( + feat_dict, + self.test_cfg.rcnn.sample_mod, + img_meta, + pts_cat, + rescale=rescale) + aug_bboxes.append(bbox_results) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, + self.bbox_head.test_cfg) + + return merged_bboxes + + def extract_feats(self, points, img_metas): + """Extract features of multiple samples.""" + return [ + self.extract_feat(pts, img_meta) + for pts, img_meta in zip(points, img_metas) + ] diff --git a/mmdet3d/models/roi_heads/__init__.py b/mmdet3d/models/roi_heads/__init__.py index be1f856e35..c93a3e3481 100644 --- a/mmdet3d/models/roi_heads/__init__.py +++ b/mmdet3d/models/roi_heads/__init__.py @@ -1,10 +1,12 @@ from .base_3droi_head import Base3DRoIHead from .bbox_heads import PartA2BboxHead -from .mask_heads import PointwiseSemanticHead +from .h3d_roi_head import H3DRoIHead +from .mask_heads import PointwiseSemanticHead, PrimitiveHead from .part_aggregation_roi_head import PartAggregationROIHead from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor __all__ = [ 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', - 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor' + 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor', + 'H3DRoIHead', 'PrimitiveHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/__init__.py b/mmdet3d/models/roi_heads/bbox_heads/__init__.py index 0da15dbda6..0256706a0c 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet3d/models/roi_heads/bbox_heads/__init__.py @@ -2,9 +2,11 @@ DoubleConvFCBBoxHead, Shared2FCBBoxHead, Shared4Conv1FCBBoxHead) +from .h3d_bbox_head import H3DBboxHead from .parta2_bbox_head import PartA2BboxHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', - 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead' + 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'H3DBboxHead', + 'PartA2BboxHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/h3d_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/h3d_bbox_head.py new file mode 100644 index 0000000000..d28242440a --- /dev/null +++ b/mmdet3d/models/roi_heads/bbox_heads/h3d_bbox_head.py @@ -0,0 +1,931 @@ +import torch +from mmcv.cnn import ConvModule +from torch import nn as nn +from torch.nn import functional as F + +from mmdet3d.core.bbox import DepthInstance3DBoxes +from mmdet3d.core.post_processing import aligned_3d_nms +from mmdet3d.models.builder import build_loss +from mmdet3d.models.losses import chamfer_distance +from mmdet3d.ops import PointSAModule +from mmdet.core import build_bbox_coder, multi_apply +from mmdet.models import HEADS + + +@HEADS.register_module() +class H3DBboxHead(nn.Module): + r"""Bbox head of `H3DNet `_. + + Args: + num_classes (int): The number of classes. + suface_matching_cfg (dict): Config for suface primitive matching. + line_matching_cfg (dict): Config for line primitive matching. + bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and + decoding boxes. + train_cfg (dict): Config for training. + test_cfg (dict): Config for testing. + gt_per_seed (int): Number of ground truth votes generated + from each seed point. + num_proposal (int): Number of proposal votes generated. + feat_channels (tuple[int]): Convolution channels of + prediction layer. + primitive_feat_refine_streams (int): The number of mlps to + refine primitive feature. + primitive_refine_channels (tuple[int]): Convolution channels of + prediction layer. + upper_thresh (float): Threshold for line matching. + surface_thresh (float): Threshold for suface matching. + line_thresh (float): Threshold for line matching. + conv_cfg (dict): Config of convolution in prediction layer. + norm_cfg (dict): Config of BN in prediction layer. + objectness_loss (dict): Config of objectness loss. + center_loss (dict): Config of center loss. + dir_class_loss (dict): Config of direction classification loss. + dir_res_loss (dict): Config of direction residual regression loss. + size_class_loss (dict): Config of size classification loss. + size_res_loss (dict): Config of size residual regression loss. + semantic_loss (dict): Config of point-wise semantic segmentation loss. + cues_objectness_loss (dict): Config of cues objectness loss. + cues_semantic_loss (dict): Config of cues semantic loss. + proposal_objectness_loss (dict): Config of proposal objectness + loss. + primitive_center_loss (dict): Config of primitive center regression + loss. + """ + + def __init__(self, + num_classes, + suface_matching_cfg, + line_matching_cfg, + bbox_coder, + train_cfg=None, + test_cfg=None, + proposal_module_cfg=None, + gt_per_seed=1, + num_proposal=256, + feat_channels=(128, 128), + primitive_feat_refine_streams=2, + primitive_refine_channels=[128, 128, 128], + upper_thresh=100.0, + surface_thresh=0.5, + line_thresh=0.5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + objectness_loss=None, + center_loss=None, + dir_class_loss=None, + dir_res_loss=None, + size_class_loss=None, + size_res_loss=None, + semantic_loss=None, + cues_objectness_loss=None, + cues_semantic_loss=None, + proposal_objectness_loss=None, + primitive_center_loss=None): + super(H3DBboxHead, self).__init__() + self.num_classes = num_classes + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.gt_per_seed = gt_per_seed + self.num_proposal = num_proposal + self.with_angle = bbox_coder['with_rot'] + self.upper_thresh = upper_thresh + self.surface_thresh = surface_thresh + self.line_thresh = line_thresh + + self.objectness_loss = build_loss(objectness_loss) + self.center_loss = build_loss(center_loss) + self.dir_class_loss = build_loss(dir_class_loss) + self.dir_res_loss = build_loss(dir_res_loss) + self.size_class_loss = build_loss(size_class_loss) + self.size_res_loss = build_loss(size_res_loss) + self.semantic_loss = build_loss(semantic_loss) + + self.bbox_coder = build_bbox_coder(bbox_coder) + self.num_sizes = self.bbox_coder.num_sizes + self.num_dir_bins = self.bbox_coder.num_dir_bins + + self.cues_objectness_loss = build_loss(cues_objectness_loss) + self.cues_semantic_loss = build_loss(cues_semantic_loss) + self.proposal_objectness_loss = build_loss(proposal_objectness_loss) + self.primitive_center_loss = build_loss(primitive_center_loss) + + assert suface_matching_cfg['mlp_channels'][-1] == \ + line_matching_cfg['mlp_channels'][-1] + + # surface center matching + self.surface_center_matcher = PointSAModule(**suface_matching_cfg) + # line center matching + self.line_center_matcher = PointSAModule(**line_matching_cfg) + + # Compute the matching scores + matching_feat_dims = suface_matching_cfg['mlp_channels'][-1] + self.matching_conv = ConvModule( + matching_feat_dims, + matching_feat_dims, + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=True) + self.matching_pred = nn.Conv1d(matching_feat_dims, 2, 1) + + # Compute the semantic matching scores + self.semantic_matching_conv = ConvModule( + matching_feat_dims, + matching_feat_dims, + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=True) + self.semantic_matching_pred = nn.Conv1d(matching_feat_dims, 2, 1) + + # Surface feature aggregation + self.surface_feats_aggregation = list() + for k in range(primitive_feat_refine_streams): + self.surface_feats_aggregation.append( + ConvModule( + matching_feat_dims, + matching_feat_dims, + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=True)) + self.surface_feats_aggregation = nn.Sequential( + *self.surface_feats_aggregation) + + # Line feature aggregation + self.line_feats_aggregation = list() + for k in range(primitive_feat_refine_streams): + self.line_feats_aggregation.append( + ConvModule( + matching_feat_dims, + matching_feat_dims, + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=True)) + self.line_feats_aggregation = nn.Sequential( + *self.line_feats_aggregation) + + # surface center(6) + line center(12) + prev_channel = 18 * matching_feat_dims + self.bbox_pred = nn.ModuleList() + for k in range(len(primitive_refine_channels)): + self.bbox_pred.append( + ConvModule( + prev_channel, + primitive_refine_channels[k], + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + bias=True, + inplace=False)) + prev_channel = primitive_refine_channels[k] + + # Final object detection + # Objectness scores (2), center residual (3), + # heading class+residual (num_heading_bin*2), size class + + # residual(num_size_cluster*4) + conv_out_channel = (2 + 3 + bbox_coder['num_dir_bins'] * 2 + + bbox_coder['num_sizes'] * 4 + self.num_classes) + self.bbox_pred.append(nn.Conv1d(prev_channel, conv_out_channel, 1)) + + def init_weights(self, pretrained=None): + """Initialize the weights in detector. + + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + pass + + def forward(self, feats_dict, sample_mod): + """Forward pass. + + Args: + feats_dict (dict): Feature dict from backbone. + sample_mod (str): Sample mode for vote aggregation layer. + valid modes are "vote", "seed" and "random". + + Returns: + dict: Predictions of vote head. + """ + ret_dict = {} + aggregated_points = feats_dict['aggregated_points'] + original_feature = feats_dict['aggregated_features'] + batch_size = original_feature.shape[0] + object_proposal = original_feature.shape[2] + + # Extract surface center, features and semantic predictions + z_center = feats_dict['pred_z_center'] + xy_center = feats_dict['pred_xy_center'] + z_semantic = feats_dict['sem_cls_scores_z'] + xy_semantic = feats_dict['sem_cls_scores_xy'] + z_feature = feats_dict['aggregated_features_z'] + xy_feature = feats_dict['aggregated_features_xy'] + # Extract line points and features + line_center = feats_dict['pred_line_center'] + line_feature = feats_dict['aggregated_features_line'] + + surface_center_pred = torch.cat((z_center, xy_center), dim=1) + ret_dict['surface_center_pred'] = surface_center_pred + ret_dict['surface_sem_pred'] = torch.cat((z_semantic, xy_semantic), + dim=1) + + # Extract the surface and line centers of rpn proposals + rpn_proposals = feats_dict['proposal_list'] + rpn_proposals_bbox = DepthInstance3DBoxes( + rpn_proposals.reshape(-1, 7).clone(), + box_dim=rpn_proposals.shape[-1], + with_yaw=self.with_angle, + origin=(0.5, 0.5, 0.5)) + + obj_surface_center, obj_line_center = \ + rpn_proposals_bbox.get_surface_line_center() + obj_surface_center = obj_surface_center.reshape( + batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3) + obj_line_center = obj_line_center.reshape(batch_size, -1, 12, + 3).transpose(1, 2).reshape( + batch_size, -1, 3) + ret_dict['surface_center_object'] = obj_surface_center + ret_dict['line_center_object'] = obj_line_center + + # aggregate primitive z and xy features to rpn proposals + surface_center_feature_pred = torch.cat((z_feature, xy_feature), dim=2) + surface_center_feature_pred = torch.cat( + (surface_center_feature_pred.new_zeros( + (batch_size, 6, surface_center_feature_pred.shape[2])), + surface_center_feature_pred), + dim=1) + + surface_xyz, surface_features, _ = self.surface_center_matcher( + surface_center_pred, + surface_center_feature_pred, + target_xyz=obj_surface_center) + + # aggregate primitive line features to rpn proposals + line_feature = torch.cat((line_feature.new_zeros( + (batch_size, 12, line_feature.shape[2])), line_feature), + dim=1) + line_xyz, line_features, _ = self.line_center_matcher( + line_center, line_feature, target_xyz=obj_line_center) + + # combine the surface and line features + combine_features = torch.cat((surface_features, line_features), dim=2) + + matching_features = self.matching_conv(combine_features) + matching_score = self.matching_pred(matching_features) + ret_dict['matching_score'] = matching_score.transpose(2, 1) + + semantic_matching_features = self.semantic_matching_conv( + combine_features) + semantic_matching_score = self.semantic_matching_pred( + semantic_matching_features) + ret_dict['semantic_matching_score'] = \ + semantic_matching_score.transpose(2, 1) + + surface_features = self.surface_feats_aggregation(surface_features) + line_features = self.line_feats_aggregation(line_features) + + # Combine all surface and line features + surface_features = surface_features.view(batch_size, -1, + object_proposal) + line_features = line_features.view(batch_size, -1, object_proposal) + + combine_feature = torch.cat((surface_features, line_features), dim=1) + + # Final bbox predictions + bbox_predictions = self.bbox_pred[0](combine_feature) + bbox_predictions += original_feature + for conv_module in self.bbox_pred[1:]: + bbox_predictions = conv_module(bbox_predictions) + + refine_decode_res = self.bbox_coder.split_pred(bbox_predictions, + aggregated_points) + for key in refine_decode_res.keys(): + ret_dict[key + '_optimized'] = refine_decode_res[key] + return ret_dict + + def loss(self, + bbox_preds, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + img_metas=None, + rpn_targets=None, + gt_bboxes_ignore=None): + """Compute loss. + + Args: + bbox_preds (dict): Predictions from forward of h3d bbox head. + points (list[torch.Tensor]): Input points. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each sample. + gt_labels_3d (list[torch.Tensor]): Labels of each sample. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise + semantic mask. + pts_instance_mask (None | list[torch.Tensor]): Point-wise + instance mask. + img_metas (list[dict]): Contain pcd and img's meta info. + rpn_targets (Tuple) : Targets generated by rpn head. + gt_bboxes_ignore (None | list[torch.Tensor]): Specify + which bounding. + + Returns: + dict: Losses of H3dnet. + """ + (vote_targets, vote_target_masks, size_class_targets, size_res_targets, + dir_class_targets, dir_res_targets, center_targets, mask_targets, + valid_gt_masks, objectness_targets, objectness_weights, + box_loss_weights, valid_gt_weights) = rpn_targets + + losses = {} + + # calculate refined proposal loss + refined_proposal_loss = self.get_proposal_stage_loss( + bbox_preds, + size_class_targets, + size_res_targets, + dir_class_targets, + dir_res_targets, + center_targets, + mask_targets, + objectness_targets, + objectness_weights, + box_loss_weights, + valid_gt_weights, + suffix='_optimized') + for key in refined_proposal_loss.keys(): + losses[key + '_optimized'] = refined_proposal_loss[key] + + bbox3d_optimized = self.bbox_coder.decode( + bbox_preds, suffix='_optimized') + + targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask, + bbox_preds) + + (cues_objectness_label, cues_sem_label, proposal_objectness_label, + cues_mask, cues_match_mask, proposal_objectness_mask, + cues_matching_label, obj_surface_line_center) = targets + + # match scores for each geometric primitive + objectness_scores = bbox_preds['matching_score'] + # match scores for the semantics of primitives + objectness_scores_sem = bbox_preds['semantic_matching_score'] + + primitive_objectness_loss = self.cues_objectness_loss( + objectness_scores.transpose(2, 1), + cues_objectness_label, + weight=cues_mask, + avg_factor=cues_mask.sum() + 1e-6) + + primitive_sem_loss = self.cues_semantic_loss( + objectness_scores_sem.transpose(2, 1), + cues_sem_label, + weight=cues_mask, + avg_factor=cues_mask.sum() + 1e-6) + + objectness_scores = bbox_preds['obj_scores_optimized'] + objectness_loss_refine = self.proposal_objectness_loss( + objectness_scores.transpose(2, 1), proposal_objectness_label) + primitive_matching_loss = (objectness_loss_refine * + cues_match_mask).sum() / ( + cues_match_mask.sum() + 1e-6) * 0.5 + primitive_sem_matching_loss = ( + objectness_loss_refine * proposal_objectness_mask).sum() / ( + proposal_objectness_mask.sum() + 1e-6) * 0.5 + + # Get the object surface center here + batch_size, object_proposal = bbox3d_optimized.shape[:2] + refined_bbox = DepthInstance3DBoxes( + bbox3d_optimized.reshape(-1, 7).clone(), + box_dim=bbox3d_optimized.shape[-1], + with_yaw=self.with_angle, + origin=(0.5, 0.5, 0.5)) + + pred_obj_surface_center, pred_obj_line_center = \ + refined_bbox.get_surface_line_center() + pred_obj_surface_center = pred_obj_surface_center.reshape( + batch_size, -1, 6, 3).transpose(1, 2).reshape(batch_size, -1, 3) + pred_obj_line_center = pred_obj_line_center.reshape( + batch_size, -1, 12, 3).transpose(1, 2).reshape(batch_size, -1, 3) + pred_surface_line_center = torch.cat( + (pred_obj_surface_center, pred_obj_line_center), 1) + + square_dist = self.primitive_center_loss(pred_surface_line_center, + obj_surface_line_center) + + match_dist = torch.sqrt(square_dist.sum(dim=-1) + 1e-6) + primitive_centroid_reg_loss = torch.sum( + match_dist * cues_matching_label) / ( + cues_matching_label.sum() + 1e-6) + + refined_loss = dict( + primitive_objectness_loss=primitive_objectness_loss, + primitive_sem_loss=primitive_sem_loss, + primitive_matching_loss=primitive_matching_loss, + primitive_sem_matching_loss=primitive_sem_matching_loss, + primitive_centroid_reg_loss=primitive_centroid_reg_loss) + + losses.update(refined_loss) + + return losses + + def get_bboxes(self, + points, + bbox_preds, + input_metas, + rescale=False, + suffix=''): + """Generate bboxes from vote head predictions. + + Args: + points (torch.Tensor): Input points. + bbox_preds (dict): Predictions from vote head. + input_metas (list[dict]): Point cloud and image's meta info. + rescale (bool): Whether to rescale bboxes. + + Returns: + list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. + """ + # decode boxes + obj_scores = F.softmax( + bbox_preds['obj_scores' + suffix], dim=-1)[..., -1] + + sem_scores = F.softmax(bbox_preds['sem_scores'], dim=-1) + + prediction_collection = {} + prediction_collection['center'] = bbox_preds['center' + suffix] + prediction_collection['dir_class'] = bbox_preds['dir_class'] + prediction_collection['dir_res'] = bbox_preds['dir_res' + suffix] + prediction_collection['size_class'] = bbox_preds['size_class'] + prediction_collection['size_res'] = bbox_preds['size_res' + suffix] + + bbox3d = self.bbox_coder.decode(prediction_collection) + + batch_size = bbox3d.shape[0] + results = list() + for b in range(batch_size): + bbox_selected, score_selected, labels = self.multiclass_nms_single( + obj_scores[b], sem_scores[b], bbox3d[b], points[b, ..., :3], + input_metas[b]) + bbox = input_metas[b]['box_type_3d']( + bbox_selected, + box_dim=bbox_selected.shape[-1], + with_yaw=self.bbox_coder.with_rot) + results.append((bbox, score_selected, labels)) + + return results + + def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, + input_meta): + """Multi-class nms in single batch. + + Args: + obj_scores (torch.Tensor): Objectness score of bounding boxes. + sem_scores (torch.Tensor): semantic class score of bounding boxes. + bbox (torch.Tensor): Predicted bounding boxes. + points (torch.Tensor): Input points. + input_meta (dict): Point cloud and image's meta info. + + Returns: + tuple[torch.Tensor]: Bounding boxes, scores and labels. + """ + bbox = input_meta['box_type_3d']( + bbox, + box_dim=bbox.shape[-1], + with_yaw=self.bbox_coder.with_rot, + origin=(0.5, 0.5, 0.5)) + box_indices = bbox.points_in_boxes(points) + + corner3d = bbox.corners + minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6))) + minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0] + minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0] + + nonempty_box_mask = box_indices.T.sum(1) > 5 + + bbox_classes = torch.argmax(sem_scores, -1) + nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask], + obj_scores[nonempty_box_mask], + bbox_classes[nonempty_box_mask], + self.test_cfg.nms_thr) + + # filter empty boxes and boxes with low score + scores_mask = (obj_scores > self.test_cfg.score_thr) + nonempty_box_inds = torch.nonzero(nonempty_box_mask).flatten() + nonempty_mask = torch.zeros_like(bbox_classes).scatter( + 0, nonempty_box_inds[nms_selected], 1) + selected = (nonempty_mask.bool() & scores_mask.bool()) + + if self.test_cfg.per_class_proposal: + bbox_selected, score_selected, labels = [], [], [] + for k in range(sem_scores.shape[-1]): + bbox_selected.append(bbox[selected].tensor) + score_selected.append(obj_scores[selected] * + sem_scores[selected][:, k]) + labels.append( + torch.zeros_like(bbox_classes[selected]).fill_(k)) + bbox_selected = torch.cat(bbox_selected, 0) + score_selected = torch.cat(score_selected, 0) + labels = torch.cat(labels, 0) + else: + bbox_selected = bbox[selected].tensor + score_selected = obj_scores[selected] + labels = bbox_classes[selected] + + return bbox_selected, score_selected, labels + + def get_proposal_stage_loss(self, + bbox_preds, + size_class_targets, + size_res_targets, + dir_class_targets, + dir_res_targets, + center_targets, + mask_targets, + objectness_targets, + objectness_weights, + box_loss_weights, + valid_gt_weights, + suffix=''): + """Compute loss for the aggregation module. + + Args: + bbox_preds (dict): Predictions from forward of vote head. + size_class_targets (torch.Tensor): Ground truth \ + size class of each prediction bounding box. + size_res_targets (torch.Tensor): Ground truth \ + size residual of each prediction bounding box. + dir_class_targets (torch.Tensor): Ground truth \ + direction class of each prediction bounding box. + dir_res_targets (torch.Tensor): Ground truth \ + direction residual of each prediction bounding box. + center_targets (torch.Tensor): Ground truth center \ + of each prediction bounding box. + mask_targets (torch.Tensor): Validation of each \ + prediction bounding box. + objectness_targets (torch.Tensor): Ground truth \ + objectness label of each prediction bounding box. + objectness_weights (torch.Tensor): Weights of objectness \ + loss for each prediction bounding box. + box_loss_weights (torch.Tensor): Weights of regression \ + loss for each prediction bounding box. + valid_gt_weights (torch.Tensor): Validation of each \ + ground truth bounding box. + + Returns: + dict: Losses of aggregation module. + """ + # calculate objectness loss + objectness_loss = self.objectness_loss( + bbox_preds['obj_scores' + suffix].transpose(2, 1), + objectness_targets, + weight=objectness_weights) + + # calculate center loss + source2target_loss, target2source_loss = self.center_loss( + bbox_preds['center' + suffix], + center_targets, + src_weight=box_loss_weights, + dst_weight=valid_gt_weights) + center_loss = source2target_loss + target2source_loss + + # calculate direction class loss + dir_class_loss = self.dir_class_loss( + bbox_preds['dir_class' + suffix].transpose(2, 1), + dir_class_targets, + weight=box_loss_weights) + + # calculate direction residual loss + batch_size, proposal_num = size_class_targets.shape[:2] + heading_label_one_hot = dir_class_targets.new_zeros( + (batch_size, proposal_num, self.num_dir_bins)) + heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), 1) + dir_res_norm = (bbox_preds['dir_res_norm' + suffix] * + heading_label_one_hot).sum(dim=-1) + dir_res_loss = self.dir_res_loss( + dir_res_norm, dir_res_targets, weight=box_loss_weights) + + # calculate size class loss + size_class_loss = self.size_class_loss( + bbox_preds['size_class' + suffix].transpose(2, 1), + size_class_targets, + weight=box_loss_weights) + + # calculate size residual loss + one_hot_size_targets = box_loss_weights.new_zeros( + (batch_size, proposal_num, self.num_sizes)) + one_hot_size_targets.scatter_(2, size_class_targets.unsqueeze(-1), 1) + one_hot_size_targets_expand = one_hot_size_targets.unsqueeze( + -1).repeat(1, 1, 1, 3) + size_residual_norm = (bbox_preds['size_res_norm' + suffix] * + one_hot_size_targets_expand).sum(dim=2) + box_loss_weights_expand = box_loss_weights.unsqueeze(-1).repeat( + 1, 1, 3) + size_res_loss = self.size_res_loss( + size_residual_norm, + size_res_targets, + weight=box_loss_weights_expand) + + # calculate semantic loss + semantic_loss = self.semantic_loss( + bbox_preds['sem_scores' + suffix].transpose(2, 1), + mask_targets, + weight=box_loss_weights) + + losses = dict( + objectness_loss=objectness_loss, + semantic_loss=semantic_loss, + center_loss=center_loss, + dir_class_loss=dir_class_loss, + dir_res_loss=dir_res_loss, + size_class_loss=size_class_loss, + size_res_loss=size_res_loss) + + return losses + + def get_targets(self, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + bbox_preds=None): + """Generate targets of proposal module. + + Args: + points (list[torch.Tensor]): Points of each batch. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): Labels of each batch. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic + label of each batch. + pts_instance_mask (None | list[torch.Tensor]): Point-wise instance + label of each batch. + bbox_preds (torch.Tensor): Bounding box predictions of vote head. + + Returns: + tuple[torch.Tensor]: Targets of proposal module. + """ + # find empty example + valid_gt_masks = list() + gt_num = list() + for index in range(len(gt_labels_3d)): + if len(gt_labels_3d[index]) == 0: + fake_box = gt_bboxes_3d[index].tensor.new_zeros( + 1, gt_bboxes_3d[index].tensor.shape[-1]) + gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) + gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) + valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) + gt_num.append(1) + else: + valid_gt_masks.append(gt_labels_3d[index].new_ones( + gt_labels_3d[index].shape)) + gt_num.append(gt_labels_3d[index].shape[0]) + + if pts_semantic_mask is None: + pts_semantic_mask = [None for i in range(len(gt_labels_3d))] + pts_instance_mask = [None for i in range(len(gt_labels_3d))] + + aggregated_points = [ + bbox_preds['aggregated_points'][i] + for i in range(len(gt_labels_3d)) + ] + + surface_center_pred = [ + bbox_preds['surface_center_pred'][i] + for i in range(len(gt_labels_3d)) + ] + + line_center_pred = [ + bbox_preds['pred_line_center'][i] + for i in range(len(gt_labels_3d)) + ] + + surface_center_object = [ + bbox_preds['surface_center_object'][i] + for i in range(len(gt_labels_3d)) + ] + + line_center_object = [ + bbox_preds['line_center_object'][i] + for i in range(len(gt_labels_3d)) + ] + + surface_sem_pred = [ + bbox_preds['surface_sem_pred'][i] + for i in range(len(gt_labels_3d)) + ] + + line_sem_pred = [ + bbox_preds['sem_cls_scores_line'][i] + for i in range(len(gt_labels_3d)) + ] + + (cues_objectness_label, cues_sem_label, proposal_objectness_label, + cues_mask, cues_match_mask, proposal_objectness_mask, + cues_matching_label, obj_surface_line_center) = multi_apply( + self.get_targets_single, points, gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask, aggregated_points, + surface_center_pred, line_center_pred, surface_center_object, + line_center_object, surface_sem_pred, line_sem_pred) + + cues_objectness_label = torch.stack(cues_objectness_label) + cues_sem_label = torch.stack(cues_sem_label) + proposal_objectness_label = torch.stack(proposal_objectness_label) + cues_mask = torch.stack(cues_mask) + cues_match_mask = torch.stack(cues_match_mask) + proposal_objectness_mask = torch.stack(proposal_objectness_mask) + cues_matching_label = torch.stack(cues_matching_label) + obj_surface_line_center = torch.stack(obj_surface_line_center) + + return (cues_objectness_label, cues_sem_label, + proposal_objectness_label, cues_mask, cues_match_mask, + proposal_objectness_mask, cues_matching_label, + obj_surface_line_center) + + def get_targets_single(self, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + aggregated_points=None, + pred_surface_center=None, + pred_line_center=None, + pred_obj_surface_center=None, + pred_obj_line_center=None, + pred_surface_sem=None, + pred_line_sem=None): + """Generate targets for primitive cues for single batch. + + Args: + points (torch.Tensor): Points of each batch. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \ + boxes of each batch. + gt_labels_3d (torch.Tensor): Labels of each batch. + pts_semantic_mask (None | torch.Tensor): Point-wise semantic + label of each batch. + pts_instance_mask (None | torch.Tensor): Point-wise instance + label of each batch. + aggregated_points (torch.Tensor): Aggregated points from + vote aggregation layer. + pred_surface_center (torch.Tensor): Prediction of surface center. + pred_line_center (torch.Tensor): Prediction of line center. + pred_obj_surface_center (torch.Tensor): Objectness prediction \ + of surface center. + pred_obj_line_center (torch.Tensor): Objectness prediction of \ + line center. + pred_surface_sem (torch.Tensor): Semantic prediction of \ + surface center. + pred_line_sem (torch.Tensor): Semantic prediction of line center. + Returns: + tuple[torch.Tensor]: Targets for primitive cues. + """ + device = points.device + gt_bboxes_3d = gt_bboxes_3d.to(device) + num_proposals = aggregated_points.shape[0] + gt_center = gt_bboxes_3d.gravity_center + + dist1, dist2, ind1, _ = chamfer_distance( + aggregated_points.unsqueeze(0), + gt_center.unsqueeze(0), + reduction='none') + # Set assignment + object_assignment = ind1.squeeze(0) + + # Generate objectness label and mask + # objectness_label: 1 if pred object center is within + # self.train_cfg['near_threshold'] of any GT object + # objectness_mask: 0 if pred object center is in gray + # zone (DONOTCARE), 1 otherwise + euclidean_dist1 = torch.sqrt(dist1.squeeze(0) + 1e-6) + proposal_objectness_label = euclidean_dist1.new_zeros( + num_proposals, dtype=torch.long) + proposal_objectness_mask = euclidean_dist1.new_zeros(num_proposals) + + gt_sem = gt_labels_3d[object_assignment] + + obj_surface_center, obj_line_center = \ + gt_bboxes_3d.get_surface_line_center() + obj_surface_center = obj_surface_center.reshape(-1, 6, + 3).transpose(0, 1) + obj_line_center = obj_line_center.reshape(-1, 12, 3).transpose(0, 1) + obj_surface_center = obj_surface_center[:, object_assignment].reshape( + 1, -1, 3) + obj_line_center = obj_line_center[:, + object_assignment].reshape(1, -1, 3) + + surface_sem = torch.argmax(pred_surface_sem, dim=1).float() + line_sem = torch.argmax(pred_line_sem, dim=1).float() + + dist_surface, _, surface_ind, _ = chamfer_distance( + obj_surface_center, + pred_surface_center.unsqueeze(0), + reduction='none') + dist_line, _, line_ind, _ = chamfer_distance( + obj_line_center, pred_line_center.unsqueeze(0), reduction='none') + + surface_sel = pred_surface_center[surface_ind.squeeze(0)] + line_sel = pred_line_center[line_ind.squeeze(0)] + surface_sel_sem = surface_sem[surface_ind.squeeze(0)] + line_sel_sem = line_sem[line_ind.squeeze(0)] + + surface_sel_sem_gt = gt_sem.repeat(6).float() + line_sel_sem_gt = gt_sem.repeat(12).float() + + euclidean_dist_surface = torch.sqrt(dist_surface.squeeze(0) + 1e-6) + euclidean_dist_line = torch.sqrt(dist_line.squeeze(0) + 1e-6) + objectness_label_surface = euclidean_dist_line.new_zeros( + num_proposals * 6, dtype=torch.long) + objectness_mask_surface = euclidean_dist_line.new_zeros(num_proposals * + 6) + objectness_label_line = euclidean_dist_line.new_zeros( + num_proposals * 12, dtype=torch.long) + objectness_mask_line = euclidean_dist_line.new_zeros(num_proposals * + 12) + objectness_label_surface_sem = euclidean_dist_line.new_zeros( + num_proposals * 6, dtype=torch.long) + objectness_label_line_sem = euclidean_dist_line.new_zeros( + num_proposals * 12, dtype=torch.long) + + euclidean_dist_obj_surface = torch.sqrt(( + (pred_obj_surface_center - surface_sel)**2).sum(dim=-1) + 1e-6) + euclidean_dist_obj_line = torch.sqrt( + torch.sum((pred_obj_line_center - line_sel)**2, dim=-1) + 1e-6) + + # Objectness score just with centers + proposal_objectness_label[ + euclidean_dist1 < self.train_cfg['near_threshold']] = 1 + proposal_objectness_mask[ + euclidean_dist1 < self.train_cfg['near_threshold']] = 1 + proposal_objectness_mask[ + euclidean_dist1 > self.train_cfg['far_threshold']] = 1 + + objectness_label_surface[ + (euclidean_dist_obj_surface < + self.train_cfg['label_surface_threshold']) * + (euclidean_dist_surface < + self.train_cfg['mask_surface_threshold'])] = 1 + objectness_label_surface_sem[ + (euclidean_dist_obj_surface < + self.train_cfg['label_surface_threshold']) * + (euclidean_dist_surface < self.train_cfg['mask_surface_threshold']) + * (surface_sel_sem == surface_sel_sem_gt)] = 1 + + objectness_label_line[ + (euclidean_dist_obj_line < self.train_cfg['label_line_threshold']) + * + (euclidean_dist_line < self.train_cfg['mask_line_threshold'])] = 1 + objectness_label_line_sem[ + (euclidean_dist_obj_line < self.train_cfg['label_line_threshold']) + * (euclidean_dist_line < self.train_cfg['mask_line_threshold']) * + (line_sel_sem == line_sel_sem_gt)] = 1 + + objectness_label_surface_obj = proposal_objectness_label.repeat(6) + objectness_mask_surface_obj = proposal_objectness_mask.repeat(6) + objectness_label_line_obj = proposal_objectness_label.repeat(12) + objectness_mask_line_obj = proposal_objectness_mask.repeat(12) + + objectness_mask_surface = objectness_mask_surface_obj + objectness_mask_line = objectness_mask_line_obj + + cues_objectness_label = torch.cat( + (objectness_label_surface, objectness_label_line), 0) + cues_sem_label = torch.cat( + (objectness_label_surface_sem, objectness_label_line_sem), 0) + cues_mask = torch.cat((objectness_mask_surface, objectness_mask_line), + 0) + + objectness_label_surface *= objectness_label_surface_obj + objectness_label_line *= objectness_label_line_obj + cues_matching_label = torch.cat( + (objectness_label_surface, objectness_label_line), 0) + + objectness_label_surface_sem *= objectness_label_surface_obj + objectness_label_line_sem *= objectness_label_line_obj + + cues_match_mask = (torch.sum( + cues_objectness_label.view(18, num_proposals), dim=0) >= + 1).float() + + obj_surface_line_center = torch.cat( + (obj_surface_center, obj_line_center), 1).squeeze(0) + + return (cues_objectness_label, cues_sem_label, + proposal_objectness_label, cues_mask, cues_match_mask, + proposal_objectness_mask, cues_matching_label, + obj_surface_line_center) diff --git a/mmdet3d/models/roi_heads/h3d_roi_head.py b/mmdet3d/models/roi_heads/h3d_roi_head.py new file mode 100644 index 0000000000..a2aada2dd7 --- /dev/null +++ b/mmdet3d/models/roi_heads/h3d_roi_head.py @@ -0,0 +1,158 @@ +from mmdet3d.core.bbox import bbox3d2result +from mmdet.models import HEADS +from ..builder import build_head +from .base_3droi_head import Base3DRoIHead + + +@HEADS.register_module() +class H3DRoIHead(Base3DRoIHead): + """H3D roi head for H3DNet. + + Args: + primitive_list (List): Configs of primitive heads. + bbox_head (ConfigDict): Config of bbox_head. + train_cfg (ConfigDict): Training config. + test_cfg (ConfigDict): Testing config. + """ + + def __init__(self, + primitive_list, + bbox_head=None, + train_cfg=None, + test_cfg=None): + super(H3DRoIHead, self).__init__( + bbox_head=bbox_head, train_cfg=train_cfg, test_cfg=test_cfg) + # Primitive module + assert len(primitive_list) == 3 + self.primitive_z = build_head(primitive_list[0]) + self.primitive_xy = build_head(primitive_list[1]) + self.primitive_line = build_head(primitive_list[2]) + + def init_weights(self, pretrained): + """Initialize weights, skip since ``H3DROIHead`` does not need to + initialize weights.""" + pass + + def init_mask_head(self): + """Initialize mask head, skip since ``H3DROIHead`` does not have + one.""" + pass + + def init_bbox_head(self, bbox_head): + """Initialize box head.""" + bbox_head['train_cfg'] = self.train_cfg + bbox_head['test_cfg'] = self.test_cfg + self.bbox_head = build_head(bbox_head) + + def init_assigner_sampler(self): + """Initialize assigner and sampler.""" + pass + + def forward_train(self, + feats_dict, + img_metas, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask, + pts_instance_mask, + gt_bboxes_ignore=None): + """Training forward function of PartAggregationROIHead. + + Args: + feats_dict (dict): Contains features from the first stage. + img_metas (list[dict]): Contain pcd and img's meta info. + points (list[torch.Tensor]): Input points. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each sample. + gt_labels_3d (list[torch.Tensor]): Labels of each sample. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise + semantic mask. + pts_instance_mask (None | list[torch.Tensor]): Point-wise + instance mask. + gt_bboxes_ignore (None | list[torch.Tensor]): Specify + which bounding. + + Returns: + dict: losses from each head. + """ + losses = dict() + + sample_mod = self.train_cfg.sample_mod + assert sample_mod in ['vote', 'seed', 'random'] + result_z = self.primitive_z(feats_dict, sample_mod) + feats_dict.update(result_z) + + result_xy = self.primitive_xy(feats_dict, sample_mod) + feats_dict.update(result_xy) + + result_line = self.primitive_line(feats_dict, sample_mod) + feats_dict.update(result_line) + + primitive_loss_inputs = (feats_dict, points, gt_bboxes_3d, + gt_labels_3d, pts_semantic_mask, + pts_instance_mask, img_metas, + gt_bboxes_ignore) + + loss_z = self.primitive_z.loss(*primitive_loss_inputs) + losses.update(loss_z) + + loss_xy = self.primitive_xy.loss(*primitive_loss_inputs) + losses.update(loss_xy) + + loss_line = self.primitive_line.loss(*primitive_loss_inputs) + losses.update(loss_line) + + targets = feats_dict.pop('targets') + + bbox_results = self.bbox_head(feats_dict, sample_mod) + + feats_dict.update(bbox_results) + bbox_loss = self.bbox_head.loss(feats_dict, points, gt_bboxes_3d, + gt_labels_3d, pts_semantic_mask, + pts_instance_mask, img_metas, targets, + gt_bboxes_ignore) + losses.update(bbox_loss) + + return losses + + def simple_test(self, feats_dict, img_metas, points, rescale=False): + """Simple testing forward function of PartAggregationROIHead. + + Note: + This function assumes that the batch size is 1 + + Args: + feats_dict (dict): Contains features from the first stage. + img_metas (list[dict]): Contain pcd and img's meta info. + points (torch.Tensor): Input points. + rescale (bool): Whether to rescale results. + + Returns: + dict: Bbox results of one frame. + """ + sample_mod = self.test.sample_mod + assert sample_mod in ['vote', 'seed', 'random'] + + result_z = self.primitive_z(feats_dict, sample_mod) + feats_dict.update(result_z) + + result_xy = self.primitive_xy(feats_dict, sample_mod) + feats_dict.update(result_xy) + + result_line = self.primitive_line(feats_dict, sample_mod) + feats_dict.update(result_line) + + bbox_preds = self.bbox_head(feats_dict, sample_mod) + feats_dict.update(bbox_preds) + bbox_list = self.bbox_head.get_bboxes( + points, + feats_dict, + img_metas, + rescale=rescale, + suffix='_optimized') + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results[0] diff --git a/tests/test_box3d.py b/tests/test_box3d.py index 4369ec8123..5076b4331d 100644 --- a/tests/test_box3d.py +++ b/tests/test_box3d.py @@ -1135,6 +1135,54 @@ def test_depth_boxes3d(): dtype=torch.int32) assert torch.all(box_idxs_of_pts == expected_idxs_of_pts) + # test get_surface_line_center + boxes = torch.tensor( + [[0.3294, 1.0359, 0.1171, 1.0822, 1.1247, 1.3721, 0.4916], + [-2.4630, -2.6324, -0.1616, 0.9202, 1.7896, 0.1992, 0.3185]]) + boxes = DepthInstance3DBoxes( + boxes, box_dim=boxes.shape[-1], with_yaw=True, origin=(0.5, 0.5, 0.5)) + surface_center, line_center = boxes.get_surface_line_center() + expected_surface_center = torch.tensor([[0.3294, 1.0359, 0.8031], + [-2.4630, -2.6324, -0.0620], + [0.3294, 1.0359, -0.5689], + [-2.4630, -2.6324, -0.2612], + [0.5949, 1.5317, 0.1171], + [-2.1828, -1.7826, -0.1616], + [0.0640, 0.5401, 0.1171], + [-2.7432, -3.4822, -0.1616], + [0.8064, 0.7805, 0.1171], + [-2.0260, -2.7765, -0.1616], + [-0.1476, 1.2913, 0.1171], + [-2.9000, -2.4883, -0.1616]]) + + expected_line_center = torch.tensor([[0.8064, 0.7805, 0.8031], + [-2.0260, -2.7765, -0.0620], + [-0.1476, 1.2913, 0.8031], + [-2.9000, -2.4883, -0.0620], + [0.5949, 1.5317, 0.8031], + [-2.1828, -1.7826, -0.0620], + [0.0640, 0.5401, 0.8031], + [-2.7432, -3.4822, -0.0620], + [0.8064, 0.7805, -0.5689], + [-2.0260, -2.7765, -0.2612], + [-0.1476, 1.2913, -0.5689], + [-2.9000, -2.4883, -0.2612], + [0.5949, 1.5317, -0.5689], + [-2.1828, -1.7826, -0.2612], + [0.0640, 0.5401, -0.5689], + [-2.7432, -3.4822, -0.2612], + [1.0719, 1.2762, 0.1171], + [-1.7458, -1.9267, -0.1616], + [0.5410, 0.2847, 0.1171], + [-2.3062, -3.6263, -0.1616], + [0.1178, 1.7871, 0.1171], + [-2.6198, -1.6385, -0.1616], + [-0.4131, 0.7956, 0.1171], + [-3.1802, -3.3381, -0.1616]]) + + assert torch.allclose(surface_center, expected_surface_center, atol=1e-04) + assert torch.allclose(line_center, expected_line_center, atol=1e-04) + def test_rotation_3d_in_axis(): points = torch.tensor([[[-0.4599, -0.0471, 0.0000], diff --git a/tests/test_heads.py b/tests/test_heads.py index 637ca5751f..28232f6396 100644 --- a/tests/test_heads.py +++ b/tests/test_heads.py @@ -569,3 +569,92 @@ def test_primitive_head(): with pytest.raises(AssertionError): primitive_head_cfg['vote_moudule_cfg']['in_channels'] = 'xyz' build_head(primitive_head_cfg) + + +def test_h3d_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + _setup_seed(0) + + h3d_head_cfg = _get_roi_head_cfg('h3dnet/h3dnet_8x8_scannet-3d-18class.py') + self = build_head(h3d_head_cfg).cuda() + + # prepare roi outputs + fp_xyz = [torch.rand([1, 1024, 3], dtype=torch.float32).cuda()] + hd_features = torch.rand([1, 256, 1024], dtype=torch.float32).cuda() + fp_indices = [torch.randint(0, 128, [1, 1024]).cuda()] + aggregated_points = torch.rand([1, 256, 3], dtype=torch.float32).cuda() + aggregated_features = torch.rand([1, 128, 256], dtype=torch.float32).cuda() + rpn_proposals = torch.cat([ + torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4 - 2, + torch.rand([1, 256, 3], dtype=torch.float32).cuda() * 4, + torch.zeros([1, 256, 1]).cuda() + ], + dim=-1) + + input_dict = dict( + fp_xyz_net0=fp_xyz, + hd_feature=hd_features, + aggregated_points=aggregated_points, + aggregated_features=aggregated_features, + seed_points=fp_xyz[0], + seed_indices=fp_indices[0], + rpn_proposals=rpn_proposals) + + # prepare gt label + from mmdet3d.core.bbox import DepthInstance3DBoxes + gt_bboxes_3d = [ + DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda()), + DepthInstance3DBoxes(torch.rand([4, 7], dtype=torch.float32).cuda()) + ] + gt_labels_3d = torch.randint(0, 18, [1, 4]).cuda() + gt_labels_3d = [gt_labels_3d[0]] + pts_semantic_mask = torch.randint(0, 19, [1, 1024]).cuda() + pts_semantic_mask = [pts_semantic_mask[0]] + pts_instance_mask = torch.randint(0, 4, [1, 1024]).cuda() + pts_instance_mask = [pts_instance_mask[0]] + points = torch.rand([1, 1024, 3], dtype=torch.float32).cuda() + + # prepare rpn targets + vote_targets = torch.rand([1, 1024, 9], dtype=torch.float32).cuda() + vote_target_masks = torch.rand([1, 1024], dtype=torch.float32).cuda() + size_class_targets = torch.rand([1, 256], + dtype=torch.float32).cuda().long() + size_res_targets = torch.rand([1, 256, 3], dtype=torch.float32).cuda() + dir_class_targets = torch.rand([1, 256], dtype=torch.float32).cuda().long() + dir_res_targets = torch.rand([1, 256], dtype=torch.float32).cuda() + center_targets = torch.rand([1, 4, 3], dtype=torch.float32).cuda() + mask_targets = torch.rand([1, 256], dtype=torch.float32).cuda().long() + valid_gt_masks = torch.rand([1, 4], dtype=torch.float32).cuda() + objectness_targets = torch.rand([1, 256], + dtype=torch.float32).cuda().long() + objectness_weights = torch.rand([1, 256], dtype=torch.float32).cuda() + box_loss_weights = torch.rand([1, 256], dtype=torch.float32).cuda() + valid_gt_weights = torch.rand([1, 4], dtype=torch.float32).cuda() + + targets = (vote_targets, vote_target_masks, size_class_targets, + size_res_targets, dir_class_targets, dir_res_targets, + center_targets, mask_targets, valid_gt_masks, + objectness_targets, objectness_weights, box_loss_weights, + valid_gt_weights) + + input_dict['targets'] = targets + + # train forward + ret_dict = self.forward_train( + input_dict, + 'vote', + points=points, + gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + pts_semantic_mask=pts_semantic_mask, + pts_instance_mask=pts_instance_mask, + img_metas=None) + + assert ret_dict['flag_loss_z'] >= 0 + assert ret_dict['vote_loss_z'] >= 0 + assert ret_dict['center_loss_z'] >= 0 + assert ret_dict['size_loss_z'] >= 0 + assert ret_dict['sem_loss_z'] >= 0 + assert ret_dict['objectness_loss_opt'] >= 0 + assert ret_dict['primitive_sem_matching_loss'] >= 0