diff --git a/configs/_base_/models/groupfree3d.py b/configs/_base_/models/groupfree3d.py new file mode 100644 index 0000000000..077d049662 --- /dev/null +++ b/configs/_base_/models/groupfree3d.py @@ -0,0 +1,71 @@ +model = dict( + type='GroupFree3DNet', + backbone=dict( + type='PointNet2SASSG', + in_channels=3, + num_points=(2048, 1024, 512, 256), + radius=(0.2, 0.4, 0.8, 1.2), + num_samples=(64, 32, 16, 16), + sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256), + (128, 128, 256)), + fp_channels=((256, 256), (256, 288)), + norm_cfg=dict(type='BN2d'), + sa_cfg=dict( + type='PointSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=True)), + bbox_head=dict( + type='GroupFree3DHead', + in_channels=288, + num_decoder_layers=6, + num_proposal=256, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='GroupFree3DMHA', + embed_dims=288, + num_heads=8, + attn_drop=0.1, + dropout_layer=dict(type='Dropout', drop_prob=0.1)), + ffn_cfgs=dict( + embed_dims=288, + feedforward_channels=2048, + ffn_drop=0.1, + act_cfg=dict(type='ReLU', inplace=True)), + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', + 'norm')), + pred_layer_cfg=dict( + in_channels=288, shared_conv_channels=(288, 288), bias=True), + sampling_objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=8.0), + objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + center_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', beta=1.0, reduction='sum', loss_weight=10.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(sample_mod='kps'), + test_cfg=dict( + sample_mod='kps', + nms_thr=0.25, + score_thr=0.0, + per_class_proposal=True, + prediction_stages='last')) diff --git a/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-L12-O256.py b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-L12-O256.py new file mode 100644 index 0000000000..a540e79105 --- /dev/null +++ b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-L12-O256.py @@ -0,0 +1,206 @@ +_base_ = [ + '../_base_/datasets/scannet-3d-18class.py', + '../_base_/models/groupfree3d.py', '../_base_/schedules/schedule_3x.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + bbox_head=dict( + num_classes=18, + num_decoder_layers=12, + size_cls_agnostic=False, + bbox_coder=dict( + type='GroupFree3DBBoxCoder', + num_sizes=18, + num_dir_bins=1, + with_rot=False, + size_cls_agnostic=False, + mean_sizes=[[0.76966727, 0.8116021, 0.92573744], + [1.876858, 1.8425595, 1.1931566], + [0.61328, 0.6148609, 0.7182701], + [1.3955007, 1.5121545, 0.83443564], + [0.97949594, 1.0675149, 0.6329687], + [0.531663, 0.5955577, 1.7500148], + [0.9624706, 0.72462326, 1.1481868], + [0.83221924, 1.0490936, 1.6875663], + [0.21132214, 0.4206159, 0.5372846], + [1.4440073, 1.8970833, 0.26985747], + [1.0294262, 1.4040797, 0.87554324], + [1.3766412, 0.65521795, 1.6813129], + [0.6650819, 0.71111923, 1.298853], + [0.41999173, 0.37906948, 1.7513971], + [0.59359556, 0.5912492, 0.73919016], + [0.50867593, 0.50656086, 0.30136237], + [1.1511526, 1.0546296, 0.49706793], + [0.47535285, 0.49249494, 0.5802117]]), + sampling_objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=8.0), + objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + center_loss=dict( + type='SmoothL1Loss', beta=0.04, reduction='sum', loss_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=10.0 / 9.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + test_cfg=dict( + sample_mod='kps', + nms_thr=0.25, + score_thr=0.0, + per_class_proposal=True, + prediction_stages='last_three')) + +# dataset settings +dataset_type = 'ScanNetDataset' +data_root = './data/scannet/' +class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + with_mask_3d=True, + with_seg_3d=True), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='PointSegClassMapping', + valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, + 36, 39)), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[1.0, 1.0]), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'points', 'gt_bboxes_3d', 'gt_labels_3d', 'pts_semantic_mask', + 'pts_instance_mask' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_train.pkl', + pipeline=train_pipeline, + filter_empty_gt=False, + classes=class_names, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +# optimizer +lr = 0.006 +optimizer = dict( + lr=lr, + weight_decay=0.0005, + paramwise_cfg=dict( + custom_keys={ + 'bbox_head.decoder_layers': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_self_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_cross_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_query_proj': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_key_proj': dict(lr_mult=0.1, decay_mult=1.0) + })) + +optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) +lr_config = dict(policy='step', warmup=None, step=[280, 340]) + +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=400) +# yapf:disable +log_config = dict( + interval=30, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable diff --git a/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-L6-O256.py b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-L6-O256.py new file mode 100644 index 0000000000..016d6d723c --- /dev/null +++ b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-L6-O256.py @@ -0,0 +1,205 @@ +_base_ = [ + '../_base_/datasets/scannet-3d-18class.py', + '../_base_/models/groupfree3d.py', '../_base_/schedules/schedule_3x.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + bbox_head=dict( + num_classes=18, + size_cls_agnostic=False, + bbox_coder=dict( + type='GroupFree3DBBoxCoder', + num_sizes=18, + num_dir_bins=1, + with_rot=False, + size_cls_agnostic=False, + mean_sizes=[[0.76966727, 0.8116021, 0.92573744], + [1.876858, 1.8425595, 1.1931566], + [0.61328, 0.6148609, 0.7182701], + [1.3955007, 1.5121545, 0.83443564], + [0.97949594, 1.0675149, 0.6329687], + [0.531663, 0.5955577, 1.7500148], + [0.9624706, 0.72462326, 1.1481868], + [0.83221924, 1.0490936, 1.6875663], + [0.21132214, 0.4206159, 0.5372846], + [1.4440073, 1.8970833, 0.26985747], + [1.0294262, 1.4040797, 0.87554324], + [1.3766412, 0.65521795, 1.6813129], + [0.6650819, 0.71111923, 1.298853], + [0.41999173, 0.37906948, 1.7513971], + [0.59359556, 0.5912492, 0.73919016], + [0.50867593, 0.50656086, 0.30136237], + [1.1511526, 1.0546296, 0.49706793], + [0.47535285, 0.49249494, 0.5802117]]), + sampling_objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=8.0), + objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + center_loss=dict( + type='SmoothL1Loss', beta=0.04, reduction='sum', loss_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=10.0 / 9.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + test_cfg=dict( + sample_mod='kps', + nms_thr=0.25, + score_thr=0.0, + per_class_proposal=True, + prediction_stages='last_three')) + +# dataset settings +dataset_type = 'ScanNetDataset' +data_root = './data/scannet/' +class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + with_mask_3d=True, + with_seg_3d=True), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='PointSegClassMapping', + valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, + 36, 39)), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[1.0, 1.0]), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'points', 'gt_bboxes_3d', 'gt_labels_3d', 'pts_semantic_mask', + 'pts_instance_mask' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_train.pkl', + pipeline=train_pipeline, + filter_empty_gt=False, + classes=class_names, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +# optimizer +lr = 0.006 +optimizer = dict( + lr=lr, + weight_decay=0.0005, + paramwise_cfg=dict( + custom_keys={ + 'bbox_head.decoder_layers': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_self_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_cross_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_query_proj': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_key_proj': dict(lr_mult=0.1, decay_mult=1.0) + })) + +optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) +lr_config = dict(policy='step', warmup=None, step=[280, 340]) + +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=400) +# yapf:disable +log_config = dict( + interval=30, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable diff --git a/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O256.py b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O256.py new file mode 100644 index 0000000000..3482bfe253 --- /dev/null +++ b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O256.py @@ -0,0 +1,221 @@ +_base_ = [ + '../_base_/datasets/scannet-3d-18class.py', + '../_base_/models/groupfree3d.py', '../_base_/schedules/schedule_3x.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + backbone=dict( + type='PointNet2SASSG', + in_channels=3, + num_points=(2048, 1024, 512, 256), + radius=(0.2, 0.4, 0.8, 1.2), + num_samples=(64, 32, 16, 16), + sa_channels=((128, 128, 256), (256, 256, 512), (256, 256, 512), + (256, 256, 512)), + fp_channels=((512, 512), (512, 288)), + norm_cfg=dict(type='BN2d'), + sa_cfg=dict( + type='PointSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=True)), + bbox_head=dict( + num_classes=18, + num_decoder_layers=12, + size_cls_agnostic=False, + bbox_coder=dict( + type='GroupFree3DBBoxCoder', + num_sizes=18, + num_dir_bins=1, + with_rot=False, + size_cls_agnostic=False, + mean_sizes=[[0.76966727, 0.8116021, 0.92573744], + [1.876858, 1.8425595, 1.1931566], + [0.61328, 0.6148609, 0.7182701], + [1.3955007, 1.5121545, 0.83443564], + [0.97949594, 1.0675149, 0.6329687], + [0.531663, 0.5955577, 1.7500148], + [0.9624706, 0.72462326, 1.1481868], + [0.83221924, 1.0490936, 1.6875663], + [0.21132214, 0.4206159, 0.5372846], + [1.4440073, 1.8970833, 0.26985747], + [1.0294262, 1.4040797, 0.87554324], + [1.3766412, 0.65521795, 1.6813129], + [0.6650819, 0.71111923, 1.298853], + [0.41999173, 0.37906948, 1.7513971], + [0.59359556, 0.5912492, 0.73919016], + [0.50867593, 0.50656086, 0.30136237], + [1.1511526, 1.0546296, 0.49706793], + [0.47535285, 0.49249494, 0.5802117]]), + sampling_objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=8.0), + objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + center_loss=dict( + type='SmoothL1Loss', beta=0.04, reduction='sum', loss_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=10.0 / 9.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + test_cfg=dict( + sample_mod='kps', + nms_thr=0.25, + score_thr=0.0, + per_class_proposal=True, + prediction_stages='last_three')) + +# dataset settings +dataset_type = 'ScanNetDataset' +data_root = './data/scannet/' +class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + with_mask_3d=True, + with_seg_3d=True), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='PointSegClassMapping', + valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, + 36, 39)), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[1.0, 1.0]), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'points', 'gt_bboxes_3d', 'gt_labels_3d', 'pts_semantic_mask', + 'pts_instance_mask' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_train.pkl', + pipeline=train_pipeline, + filter_empty_gt=False, + classes=class_names, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +# optimizer +lr = 0.006 +optimizer = dict( + lr=lr, + weight_decay=0.0005, + paramwise_cfg=dict( + custom_keys={ + 'bbox_head.decoder_layers': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_self_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_cross_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_query_proj': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_key_proj': dict(lr_mult=0.1, decay_mult=1.0) + })) + +optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) +lr_config = dict(policy='step', warmup=None, step=[280, 340]) + +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=400) +# yapf:disable +log_config = dict( + interval=30, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable diff --git a/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O512.py b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O512.py new file mode 100644 index 0000000000..89d199a25b --- /dev/null +++ b/configs/groupfree3d/groupfree3d_8x4_scannet-3d-18class-w2x-L12-O512.py @@ -0,0 +1,222 @@ +_base_ = [ + '../_base_/datasets/scannet-3d-18class.py', + '../_base_/models/groupfree3d.py', '../_base_/schedules/schedule_3x.py', + '../_base_/default_runtime.py' +] + +# model settings +model = dict( + backbone=dict( + type='PointNet2SASSG', + in_channels=3, + num_points=(2048, 1024, 512, 256), + radius=(0.2, 0.4, 0.8, 1.2), + num_samples=(64, 32, 16, 16), + sa_channels=((128, 128, 256), (256, 256, 512), (256, 256, 512), + (256, 256, 512)), + fp_channels=((512, 512), (512, 288)), + norm_cfg=dict(type='BN2d'), + sa_cfg=dict( + type='PointSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=True)), + bbox_head=dict( + num_classes=18, + num_decoder_layers=12, + num_proposal=512, + size_cls_agnostic=False, + bbox_coder=dict( + type='GroupFree3DBBoxCoder', + num_sizes=18, + num_dir_bins=1, + with_rot=False, + size_cls_agnostic=False, + mean_sizes=[[0.76966727, 0.8116021, 0.92573744], + [1.876858, 1.8425595, 1.1931566], + [0.61328, 0.6148609, 0.7182701], + [1.3955007, 1.5121545, 0.83443564], + [0.97949594, 1.0675149, 0.6329687], + [0.531663, 0.5955577, 1.7500148], + [0.9624706, 0.72462326, 1.1481868], + [0.83221924, 1.0490936, 1.6875663], + [0.21132214, 0.4206159, 0.5372846], + [1.4440073, 1.8970833, 0.26985747], + [1.0294262, 1.4040797, 0.87554324], + [1.3766412, 0.65521795, 1.6813129], + [0.6650819, 0.71111923, 1.298853], + [0.41999173, 0.37906948, 1.7513971], + [0.59359556, 0.5912492, 0.73919016], + [0.50867593, 0.50656086, 0.30136237], + [1.1511526, 1.0546296, 0.49706793], + [0.47535285, 0.49249494, 0.5802117]]), + sampling_objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=8.0), + objectness_loss=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + center_loss=dict( + type='SmoothL1Loss', beta=0.04, reduction='sum', loss_weight=10.0), + dir_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + dir_res_loss=dict( + type='SmoothL1Loss', reduction='sum', loss_weight=10.0), + size_class_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0), + size_res_loss=dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=10.0 / 9.0), + semantic_loss=dict( + type='CrossEntropyLoss', reduction='sum', loss_weight=1.0)), + test_cfg=dict( + sample_mod='kps', + nms_thr=0.25, + score_thr=0.0, + per_class_proposal=True, + prediction_stages='last_three')) + +# dataset settings +dataset_type = 'ScanNetDataset' +data_root = './data/scannet/' +class_names = ('cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', + 'bookshelf', 'picture', 'counter', 'desk', 'curtain', + 'refrigerator', 'showercurtrain', 'toilet', 'sink', 'bathtub', + 'garbagebin') +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + with_mask_3d=True, + with_seg_3d=True), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='PointSegClassMapping', + valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34, + 36, 39)), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.087266, 0.087266], + scale_ratio_range=[1.0, 1.0]), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=[ + 'points', 'gt_bboxes_3d', 'gt_labels_3d', 'pts_semantic_mask', + 'pts_instance_mask' + ]) +] +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='DEPTH', + load_dim=6, + use_dim=[0, 1, 2]), + dict(type='GlobalAlignment', rotation_axis=2), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict( + type='RandomFlip3D', + sync_2d=False, + flip_ratio_bev_horizontal=0.5, + flip_ratio_bev_vertical=0.5), + dict(type='IndoorPointSample', num_points=50000), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=8, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=1, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_train.pkl', + pipeline=train_pipeline, + filter_empty_gt=False, + classes=class_names, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='Depth')), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'scannet_infos_val.pkl', + pipeline=test_pipeline, + classes=class_names, + test_mode=True, + box_type_3d='Depth')) + +# optimizer +lr = 0.006 +optimizer = dict( + lr=lr, + weight_decay=0.0005, + paramwise_cfg=dict( + custom_keys={ + 'bbox_head.decoder_layers': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_self_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_cross_posembeds': dict( + lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_query_proj': dict(lr_mult=0.1, decay_mult=1.0), + 'bbox_head.decoder_key_proj': dict(lr_mult=0.1, decay_mult=1.0) + })) + +optimizer_config = dict(grad_clip=dict(max_norm=0.1, norm_type=2)) +lr_config = dict(policy='step', warmup=None, step=[280, 340]) + +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=400) +# yapf:disable +log_config = dict( + interval=30, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable diff --git a/mmdet3d/core/bbox/coders/__init__.py b/mmdet3d/core/bbox/coders/__init__.py index 4d2c93d806..d2c2dae1a4 100644 --- a/mmdet3d/core/bbox/coders/__init__.py +++ b/mmdet3d/core/bbox/coders/__init__.py @@ -2,9 +2,10 @@ from .anchor_free_bbox_coder import AnchorFreeBBoxCoder from .centerpoint_bbox_coders import CenterPointBBoxCoder from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder +from .groupfree3d_bbox_coder import GroupFree3DBBoxCoder from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder __all__ = [ 'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder', - 'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder' + 'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder', 'GroupFree3DBBoxCoder' ] diff --git a/mmdet3d/core/bbox/coders/groupfree3d_bbox_coder.py b/mmdet3d/core/bbox/coders/groupfree3d_bbox_coder.py new file mode 100644 index 0000000000..0732af0754 --- /dev/null +++ b/mmdet3d/core/bbox/coders/groupfree3d_bbox_coder.py @@ -0,0 +1,189 @@ +import numpy as np +import torch + +from mmdet.core.bbox.builder import BBOX_CODERS +from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder + + +@BBOX_CODERS.register_module() +class GroupFree3DBBoxCoder(PartialBinBasedBBoxCoder): + """Modified partial bin based bbox coder for GroupFree3D. + + Args: + num_dir_bins (int): Number of bins to encode direction angle. + num_sizes (int): Number of size clusters. + mean_sizes (list[list[int]]): Mean size of bboxes in each class. + with_rot (bool): Whether the bbox is with rotation. Defaults to True. + size_cls_agnostic (bool): Whether the predicted size is class-agnostic. + Defaults to True. + """ + + def __init__(self, + num_dir_bins, + num_sizes, + mean_sizes, + with_rot=True, + size_cls_agnostic=True): + super(GroupFree3DBBoxCoder, self).__init__( + num_dir_bins=num_dir_bins, + num_sizes=num_sizes, + mean_sizes=mean_sizes, + with_rot=with_rot) + self.size_cls_agnostic = size_cls_agnostic + + def encode(self, gt_bboxes_3d, gt_labels_3d): + """Encode ground truth to prediction targets. + + Args: + gt_bboxes_3d (BaseInstance3DBoxes): Ground truth bboxes \ + with shape (n, 7). + gt_labels_3d (torch.Tensor): Ground truth classes. + + Returns: + tuple: Targets of center, size and direction. + """ + # generate center target + center_target = gt_bboxes_3d.gravity_center + + # generate bbox size target + size_target = gt_bboxes_3d.dims + size_class_target = gt_labels_3d + size_res_target = gt_bboxes_3d.dims - gt_bboxes_3d.tensor.new_tensor( + self.mean_sizes)[size_class_target] + + # generate dir target + box_num = gt_labels_3d.shape[0] + if self.with_rot: + (dir_class_target, + dir_res_target) = self.angle2class(gt_bboxes_3d.yaw) + else: + dir_class_target = gt_labels_3d.new_zeros(box_num) + dir_res_target = gt_bboxes_3d.tensor.new_zeros(box_num) + + return (center_target, size_target, size_class_target, size_res_target, + dir_class_target, dir_res_target) + + def decode(self, bbox_out, prefix=''): + """Decode predicted parts to bbox3d. + + Args: + bbox_out (dict): Predictions from model, should contain keys below. + + - center: predicted bottom center of bboxes. + - dir_class: predicted bbox direction class. + - dir_res: predicted bbox direction residual. + - size_class: predicted bbox size class. + - size_res: predicted bbox size residual. + - size: predicted class-agnostic bbox size + prefix (str): Decode predictions with specific prefix. + Defaults to ''. + + Returns: + torch.Tensor: Decoded bbox3d with shape (batch, n, 7). + """ + center = bbox_out[f'{prefix}center'] + batch_size, num_proposal = center.shape[:2] + + # decode heading angle + if self.with_rot: + dir_class = torch.argmax(bbox_out[f'{prefix}dir_class'], -1) + dir_res = torch.gather(bbox_out[f'{prefix}dir_res'], 2, + dir_class.unsqueeze(-1)) + dir_res.squeeze_(2) + dir_angle = self.class2angle(dir_class, dir_res).reshape( + batch_size, num_proposal, 1) + else: + dir_angle = center.new_zeros(batch_size, num_proposal, 1) + + # decode bbox size + if self.size_cls_agnostic: + bbox_size = bbox_out[f'{prefix}size'].reshape( + batch_size, num_proposal, 3) + else: + size_class = torch.argmax( + bbox_out[f'{prefix}size_class'], -1, keepdim=True) + size_res = torch.gather( + bbox_out[f'{prefix}size_res'], 2, + size_class.unsqueeze(-1).repeat(1, 1, 1, 3)) + mean_sizes = center.new_tensor(self.mean_sizes) + size_base = torch.index_select(mean_sizes, 0, + size_class.reshape(-1)) + bbox_size = size_base.reshape(batch_size, num_proposal, + -1) + size_res.squeeze(2) + + bbox3d = torch.cat([center, bbox_size, dir_angle], dim=-1) + return bbox3d + + def split_pred(self, cls_preds, reg_preds, base_xyz, prefix=''): + """Split predicted features to specific parts. + + Args: + cls_preds (torch.Tensor): Class predicted features to split. + reg_preds (torch.Tensor): Regression predicted features to split. + base_xyz (torch.Tensor): Coordinates of points. + prefix (str): Decode predictions with specific prefix. + Defaults to ''. + + Returns: + dict[str, torch.Tensor]: Split results. + """ + results = {} + start, end = 0, 0 + + cls_preds_trans = cls_preds.transpose(2, 1) + reg_preds_trans = reg_preds.transpose(2, 1) + + # decode center + end += 3 + # (batch_size, num_proposal, 3) + results[f'{prefix}center_residual'] = \ + reg_preds_trans[..., start:end].contiguous() + results[f'{prefix}center'] = base_xyz + \ + reg_preds_trans[..., start:end].contiguous() + start = end + + # decode direction + end += self.num_dir_bins + results[f'{prefix}dir_class'] = \ + reg_preds_trans[..., start:end].contiguous() + start = end + + end += self.num_dir_bins + dir_res_norm = reg_preds_trans[..., start:end].contiguous() + start = end + + results[f'{prefix}dir_res_norm'] = dir_res_norm + results[f'{prefix}dir_res'] = dir_res_norm * ( + np.pi / self.num_dir_bins) + + # decode size + if self.size_cls_agnostic: + end += 3 + results[f'{prefix}size'] = \ + reg_preds_trans[..., start:end].contiguous() + else: + end += self.num_sizes + results[f'{prefix}size_class'] = reg_preds_trans[ + ..., start:end].contiguous() + start = end + + end += self.num_sizes * 3 + size_res_norm = reg_preds_trans[..., start:end] + batch_size, num_proposal = reg_preds_trans.shape[:2] + size_res_norm = size_res_norm.view( + [batch_size, num_proposal, self.num_sizes, 3]) + start = end + + results[f'{prefix}size_res_norm'] = size_res_norm.contiguous() + mean_sizes = reg_preds.new_tensor(self.mean_sizes) + results[f'{prefix}size_res'] = ( + size_res_norm * mean_sizes.unsqueeze(0).unsqueeze(0)) + + # decode objectness score + # Group-Free-3D objectness output shape (batch, proposal, 1) + results[f'{prefix}obj_scores'] = cls_preds_trans[..., :1].contiguous() + + # decode semantic score + results[f'{prefix}sem_scores'] = cls_preds_trans[..., 1:].contiguous() + + return results diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 1f69c5910c..b814d02d9b 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -5,6 +5,7 @@ from .centerpoint_head import CenterHead from .fcos_mono3d_head import FCOSMono3DHead from .free_anchor3d_head import FreeAnchor3DHead +from .groupfree3d_head import GroupFree3DHead from .parta2_rpn_head import PartA2RPNHead from .shape_aware_head import ShapeAwareHead from .ssd_3d_head import SSD3DHead @@ -13,5 +14,6 @@ __all__ = [ 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead', - 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead' + 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead', + 'GroupFree3DHead' ] diff --git a/mmdet3d/models/dense_heads/groupfree3d_head.py b/mmdet3d/models/dense_heads/groupfree3d_head.py new file mode 100644 index 0000000000..87fd9598a1 --- /dev/null +++ b/mmdet3d/models/dense_heads/groupfree3d_head.py @@ -0,0 +1,992 @@ +import copy +import numpy as np +import torch +from mmcv import ConfigDict +from mmcv.cnn import ConvModule +from mmcv.cnn.bricks.transformer import (build_positional_encoding, + build_transformer_layer) +from mmcv.runner import force_fp32 +from torch import nn as nn +from torch.nn import functional as F + +from mmdet3d.core.post_processing import aligned_3d_nms +from mmdet3d.models.builder import build_loss +from mmdet3d.ops import Points_Sampler, gather_points +from mmdet.core import build_bbox_coder, multi_apply +from mmdet.models import HEADS +from .base_conv_bbox_head import BaseConvBboxHead + +EPS = 1e-6 + + +class PointsObjClsModule(nn.Module): + """object candidate point prediction from seed point features. + + Args: + in_channel (int): number of channels of seed point features. + num_convs (int): number of conv layers. + Default: 3. + conv_cfg (dict): Config of convolution. + Default: dict(type='Conv1d'). + norm_cfg (dict): Config of normalization. + Default: dict(type='BN1d'). + act_cfg (dict): Config of activation. + Default: dict(type='ReLU'). + """ + + def __init__(self, + in_channel, + num_convs=3, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU')): + super().__init__() + conv_channels = [in_channel for _ in range(num_convs - 1)] + conv_channels.append(1) + + self.mlp = nn.Sequential() + prev_channels = in_channel + for i in range(num_convs): + self.mlp.add_module( + f'layer{i}', + ConvModule( + prev_channels, + conv_channels[i], + 1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg if i < num_convs - 1 else None, + act_cfg=act_cfg if i < num_convs - 1 else None, + bias=True, + inplace=True)) + prev_channels = conv_channels[i] + + def forward(self, seed_features): + """Forward pass. + + Args: + seed_features (torch.Tensor): seed features, dims: + (batch_size, feature_dim, num_seed) + + Returns: + torch.Tensor: objectness logits, dim: + (batch_size, 1, num_seed) + """ + return self.mlp(seed_features) + + +class GeneralSamplingModule(nn.Module): + """Sampling Points. + + Sampling points with given index. + """ + + def forward(self, xyz, features, sample_inds): + """Forward pass. + + Args: + xyz: (B, N, 3) the coordinates of the features. + features (Tensor): (B, C, N) features to sample. + sample_inds (Tensor): (B, M) the given index, + where M is the number of points. + + Returns: + Tensor: (B, M, 3) coordinates of sampled features + Tensor: (B, C, M) the sampled features. + Tensor: (B, M) the given index. + """ + xyz_t = xyz.transpose(1, 2).contiguous() + new_xyz = gather_points(xyz_t, sample_inds).transpose(1, + 2).contiguous() + new_features = gather_points(features, sample_inds).contiguous() + + return new_xyz, new_features, sample_inds + + +@HEADS.register_module() +class GroupFree3DHead(nn.Module): + r"""Bbox head of `Group-Free 3D `_. + + Args: + num_classes (int): The number of class. + in_channels (int): The dims of input features from backbone. + bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and + decoding boxes. + num_decoder_layers (int): The number of transformer decoder layers. + transformerlayers (dict): Config for transformer decoder. + train_cfg (dict): Config for training. + test_cfg (dict): Config for testing. + num_proposal (int): The number of initial sampling candidates. + pred_layer_cfg (dict): Config of classfication and regression + prediction layers. + size_cls_agnostic (bool): Whether the predicted size is class-agnostic. + gt_per_seed (int): the number of candidate instance each point belongs + to. + sampling_objectness_loss (dict): Config of initial sampling + objectness loss. + objectness_loss (dict): Config of objectness loss. + center_loss (dict): Config of center loss. + dir_class_loss (dict): Config of direction classification loss. + dir_res_loss (dict): Config of direction residual regression loss. + size_class_loss (dict): Config of size classification loss. + size_res_loss (dict): Config of size residual regression loss. + size_reg_loss (dict): Config of class-agnostic size regression loss. + semantic_loss (dict): Config of point-wise semantic segmentation loss. + """ + + def __init__(self, + num_classes, + in_channels, + bbox_coder, + num_decoder_layers, + transformerlayers, + decoder_self_posembeds=dict( + type='ConvBNPositionalEncoding', + input_channel=6, + num_pos_feats=288), + decoder_cross_posembeds=dict( + type='ConvBNPositionalEncoding', + input_channel=3, + num_pos_feats=288), + train_cfg=None, + test_cfg=None, + num_proposal=128, + pred_layer_cfg=None, + size_cls_agnostic=True, + gt_per_seed=3, + sampling_objectness_loss=None, + objectness_loss=None, + center_loss=None, + dir_class_loss=None, + dir_res_loss=None, + size_class_loss=None, + size_res_loss=None, + size_reg_loss=None, + semantic_loss=None): + super(GroupFree3DHead, self).__init__() + self.num_classes = num_classes + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.num_proposal = num_proposal + self.in_channels = in_channels + self.num_decoder_layers = num_decoder_layers + self.size_cls_agnostic = size_cls_agnostic + self.gt_per_seed = gt_per_seed + + # Transformer decoder layers + if isinstance(transformerlayers, ConfigDict): + transformerlayers = [ + copy.deepcopy(transformerlayers) + for _ in range(num_decoder_layers) + ] + else: + assert isinstance(transformerlayers, list) and \ + len(transformerlayers) == num_decoder_layers + self.decoder_layers = nn.ModuleList() + for i in range(self.num_decoder_layers): + self.decoder_layers.append( + build_transformer_layer(transformerlayers[i])) + self.embed_dims = self.decoder_layers[0].embed_dims + assert self.embed_dims == decoder_self_posembeds['num_pos_feats'] + assert self.embed_dims == decoder_cross_posembeds['num_pos_feats'] + + # bbox_coder + self.bbox_coder = build_bbox_coder(bbox_coder) + self.num_sizes = self.bbox_coder.num_sizes + self.num_dir_bins = self.bbox_coder.num_dir_bins + + # Initial object candidate sampling + self.gsample_module = GeneralSamplingModule() + self.fps_module = Points_Sampler([self.num_proposal]) + self.points_obj_cls = PointsObjClsModule(self.in_channels) + + self.fp16_enabled = False + + # initial candidate prediction + self.conv_pred = BaseConvBboxHead( + **pred_layer_cfg, + num_cls_out_channels=self._get_cls_out_channels(), + num_reg_out_channels=self._get_reg_out_channels()) + + # query proj and key proj + self.decoder_query_proj = nn.Conv1d( + self.embed_dims, self.embed_dims, kernel_size=1) + self.decoder_key_proj = nn.Conv1d( + self.embed_dims, self.embed_dims, kernel_size=1) + + # query position embed + self.decoder_self_posembeds = nn.ModuleList() + for _ in range(self.num_decoder_layers): + self.decoder_self_posembeds.append( + build_positional_encoding(decoder_self_posembeds)) + # key position embed + self.decoder_cross_posembeds = nn.ModuleList() + for _ in range(self.num_decoder_layers): + self.decoder_cross_posembeds.append( + build_positional_encoding(decoder_cross_posembeds)) + + # Prediction Head + self.prediction_heads = nn.ModuleList() + for i in range(self.num_decoder_layers): + self.prediction_heads.append( + BaseConvBboxHead( + **pred_layer_cfg, + num_cls_out_channels=self._get_cls_out_channels(), + num_reg_out_channels=self._get_reg_out_channels())) + + self.sampling_objectness_loss = build_loss(sampling_objectness_loss) + self.objectness_loss = build_loss(objectness_loss) + self.center_loss = build_loss(center_loss) + self.dir_res_loss = build_loss(dir_res_loss) + self.dir_class_loss = build_loss(dir_class_loss) + self.semantic_loss = build_loss(semantic_loss) + if self.size_cls_agnostic: + self.size_reg_loss = build_loss(size_reg_loss) + else: + self.size_res_loss = build_loss(size_res_loss) + self.size_class_loss = build_loss(size_class_loss) + + def init_weights(self): + """Initialize weights of transformer decoder in GroupFree3DHead.""" + # initialize transformer + for m in self.decoder_layers.parameters(): + if m.dim() > 1: + nn.init.xavier_uniform_(m) + + for m in self.decoder_self_posembeds.parameters(): + if m.dim() > 1: + nn.init.xavier_uniform_(m) + + for m in self.decoder_cross_posembeds.parameters(): + if m.dim() > 1: + nn.init.xavier_uniform_(m) + + def _get_cls_out_channels(self): + """Return the channel number of classification outputs.""" + # Class numbers (k) + objectness (1) + return self.num_classes + 1 + + def _get_reg_out_channels(self): + """Return the channel number of regression outputs.""" + # center residual (3), + # heading class+residual (num_dir_bins*2), + # size class+residual(num_sizes*4 or 3) + if self.size_cls_agnostic: + return 6 + self.num_dir_bins * 2 + else: + return 3 + self.num_dir_bins * 2 + self.num_sizes * 4 + + def _extract_input(self, feat_dict): + """Extract inputs from features dictionary. + + Args: + feat_dict (dict): Feature dict from backbone. + + Returns: + torch.Tensor: Coordinates of input points. + torch.Tensor: Features of input points. + torch.Tensor: Indices of input points. + """ + + seed_points = feat_dict['fp_xyz'][-1] + seed_features = feat_dict['fp_features'][-1] + seed_indices = feat_dict['fp_indices'][-1] + + return seed_points, seed_features, seed_indices + + def forward(self, feat_dict, sample_mod): + """Forward pass. + + Note: + The forward of GroupFree3DHead is devided into 2 steps: + + 1. Initial object candidates sampling. + 2. Iterative object box prediction by transformer decoder. + + Args: + feat_dict (dict): Feature dict from backbone. + sample_mod (str): sample mode for initial candidates sampling. + + Returns: + results (dict): Predictions of GroupFree3D head. + """ + assert sample_mod in ['fps', 'kps'] + + seed_xyz, seed_features, seed_indices = self._extract_input(feat_dict) + + results = dict( + seed_points=seed_xyz, + seed_features=seed_features, + seed_indices=seed_indices) + + # 1. Initial object candidates sampling. + if sample_mod == 'fps': + sample_inds = self.fps_module(seed_xyz, seed_features) + elif sample_mod == 'kps': + points_obj_cls_logits = self.points_obj_cls( + seed_features) # (batch_size, 1, num_seed) + points_obj_cls_scores = points_obj_cls_logits.sigmoid().squeeze(1) + sample_inds = torch.topk(points_obj_cls_scores, + self.num_proposal)[1].int() + results['seeds_obj_cls_logits'] = points_obj_cls_logits + else: + raise NotImplementedError( + f'Sample mode {sample_mod} is not supported!') + + candidate_xyz, candidate_features, sample_inds = self.gsample_module( + seed_xyz, seed_features, sample_inds) + + results['query_points_xyz'] = candidate_xyz # (B, M, 3) + results['query_points_feature'] = candidate_features # (B, C, M) + results['query_points_sample_inds'] = sample_inds.long() # (B, M) + + prefix = 'proposal.' + cls_predictions, reg_predictions = self.conv_pred(candidate_features) + decode_res = self.bbox_coder.split_pred(cls_predictions, + reg_predictions, candidate_xyz, + prefix) + + results.update(decode_res) + bbox3d = self.bbox_coder.decode(results, prefix) + + # 2. Iterative object box prediction by transformer decoder. + base_bbox3d = bbox3d[:, :, :6].detach().clone() + + query = self.decoder_query_proj(candidate_features).permute(2, 0, 1) + key = self.decoder_key_proj(seed_features).permute(2, 0, 1) + value = key + + # transformer decoder + results['num_decoder_layers'] = 0 + for i in range(self.num_decoder_layers): + prefix = f's{i}.' + + query_pos = self.decoder_self_posembeds[i](base_bbox3d).permute( + 2, 0, 1) + key_pos = self.decoder_cross_posembeds[i](seed_xyz).permute( + 2, 0, 1) + + query = self.decoder_layers[i]( + query, key, value, query_pos=query_pos, + key_pos=key_pos).permute(1, 2, 0) + + results[f'{prefix}query'] = query + + cls_predictions, reg_predictions = self.prediction_heads[i](query) + decode_res = self.bbox_coder.split_pred(cls_predictions, + reg_predictions, + candidate_xyz, prefix) + # TODO: should save bbox3d instead of decode_res? + results.update(decode_res) + + bbox3d = self.bbox_coder.decode(results, prefix) + results[f'{prefix}bbox3d'] = bbox3d + base_bbox3d = bbox3d[:, :, :6].detach().clone() + query = query.permute(2, 0, 1) + + results['num_decoder_layers'] += 1 + + return results + + @force_fp32(apply_to=('bbox_preds', )) + def loss(self, + bbox_preds, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + img_metas=None, + gt_bboxes_ignore=None, + ret_target=False): + """Compute loss. + + Args: + bbox_preds (dict): Predictions from forward of vote head. + points (list[torch.Tensor]): Input points. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each sample. + gt_labels_3d (list[torch.Tensor]): Labels of each sample. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise + semantic mask. + pts_instance_mask (None | list[torch.Tensor]): Point-wise + instance mask. + img_metas (list[dict]): Contain pcd and img's meta info. + gt_bboxes_ignore (None | list[torch.Tensor]): Specify + which bounding. + ret_target (Bool): Return targets or not. + + Returns: + dict: Losses of GroupFree3D. + """ + targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask, + bbox_preds) + (sampling_targets, sampling_weights, assigned_size_targets, + size_class_targets, size_res_targets, dir_class_targets, + dir_res_targets, center_targets, assigned_center_targets, + mask_targets, valid_gt_masks, objectness_targets, objectness_weights, + box_loss_weights, valid_gt_weights) = targets + + batch_size, proposal_num = size_class_targets.shape[:2] + + losses = dict() + + # calculate objectness classification loss + sampling_obj_score = bbox_preds['seeds_obj_cls_logits'].reshape(-1, 1) + sampling_objectness_loss = self.sampling_objectness_loss( + sampling_obj_score, + 1 - sampling_targets.reshape(-1), + sampling_weights.reshape(-1), + avg_factor=batch_size) + losses['sampling_objectness_loss'] = sampling_objectness_loss + + prefixes = ['proposal.'] + [ + f's{i}.' for i in range(bbox_preds['num_decoder_layers']) + ] + num_stages = len(prefixes) + for prefix in prefixes: + + # calculate objectness loss + obj_score = bbox_preds[f'{prefix}obj_scores'].transpose(2, 1) + objectness_loss = self.objectness_loss( + obj_score.reshape(-1, 1), + 1 - objectness_targets.reshape(-1), + objectness_weights.reshape(-1), + avg_factor=batch_size) + losses[f'{prefix}objectness_loss'] = objectness_loss / num_stages + + # calculate center loss + box_loss_weights_expand = box_loss_weights.unsqueeze(-1).expand( + -1, -1, 3) + center_loss = self.center_loss( + bbox_preds[f'{prefix}center'], + assigned_center_targets, + weight=box_loss_weights_expand) + losses[f'{prefix}center_loss'] = center_loss / num_stages + + # calculate direction class loss + dir_class_loss = self.dir_class_loss( + bbox_preds[f'{prefix}dir_class'].transpose(2, 1), + dir_class_targets, + weight=box_loss_weights) + losses[f'{prefix}dir_class_loss'] = dir_class_loss / num_stages + + # calculate direction residual loss + heading_label_one_hot = size_class_targets.new_zeros( + (batch_size, proposal_num, self.num_dir_bins)) + heading_label_one_hot.scatter_(2, dir_class_targets.unsqueeze(-1), + 1) + dir_res_norm = torch.sum( + bbox_preds[f'{prefix}dir_res_norm'] * heading_label_one_hot, + -1) + dir_res_loss = self.dir_res_loss( + dir_res_norm, dir_res_targets, weight=box_loss_weights) + losses[f'{prefix}dir_res_loss'] = dir_res_loss / num_stages + + if self.size_cls_agnostic: + # calculate class-agnostic size loss + size_reg_loss = self.size_reg_loss( + bbox_preds[f'{prefix}size'], + assigned_size_targets, + weight=box_loss_weights_expand) + losses[f'{prefix}size_reg_loss'] = size_reg_loss / num_stages + + else: + # calculate size class loss + size_class_loss = self.size_class_loss( + bbox_preds[f'{prefix}size_class'].transpose(2, 1), + size_class_targets, + weight=box_loss_weights) + losses[ + f'{prefix}size_class_loss'] = size_class_loss / num_stages + + # calculate size residual loss + one_hot_size_targets = size_class_targets.new_zeros( + (batch_size, proposal_num, self.num_sizes)) + one_hot_size_targets.scatter_(2, + size_class_targets.unsqueeze(-1), + 1) + one_hot_size_targets_expand = one_hot_size_targets.unsqueeze( + -1).expand(-1, -1, -1, 3).contiguous() + size_residual_norm = torch.sum( + bbox_preds[f'{prefix}size_res_norm'] * + one_hot_size_targets_expand, 2) + box_loss_weights_expand = box_loss_weights.unsqueeze( + -1).expand(-1, -1, 3) + size_res_loss = self.size_res_loss( + size_residual_norm, + size_res_targets, + weight=box_loss_weights_expand) + losses[f'{prefix}size_res_loss'] = size_res_loss / num_stages + + # calculate semantic loss + semantic_loss = self.semantic_loss( + bbox_preds[f'{prefix}sem_scores'].transpose(2, 1), + mask_targets, + weight=box_loss_weights) + losses[f'{prefix}semantic_loss'] = semantic_loss / num_stages + + if ret_target: + losses['targets'] = targets + + return losses + + def get_targets(self, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + bbox_preds=None, + max_gt_num=64): + """Generate targets of GroupFree3D head. + + Args: + points (list[torch.Tensor]): Points of each batch. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth \ + bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): Labels of each batch. + pts_semantic_mask (None | list[torch.Tensor]): Point-wise semantic + label of each batch. + pts_instance_mask (None | list[torch.Tensor]): Point-wise instance + label of each batch. + bbox_preds (torch.Tensor): Bounding box predictions of vote head. + max_gt_num (int): Max number of GTs for single batch. + + Returns: + tuple[torch.Tensor]: Targets of GroupFree3D head. + """ + # find empty example + valid_gt_masks = list() + gt_num = list() + for index in range(len(gt_labels_3d)): + if len(gt_labels_3d[index]) == 0: + fake_box = gt_bboxes_3d[index].tensor.new_zeros( + 1, gt_bboxes_3d[index].tensor.shape[-1]) + gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) + gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) + valid_gt_masks.append(gt_labels_3d[index].new_zeros(1)) + gt_num.append(1) + else: + valid_gt_masks.append(gt_labels_3d[index].new_ones( + gt_labels_3d[index].shape)) + gt_num.append(gt_labels_3d[index].shape[0]) + # max_gt_num = max(gt_num) + + max_gt_nums = [max_gt_num for _ in range(len(gt_labels_3d))] + + if pts_semantic_mask is None: + pts_semantic_mask = [None for i in range(len(gt_labels_3d))] + pts_instance_mask = [None for i in range(len(gt_labels_3d))] + + seed_points = [ + bbox_preds['seed_points'][i] for i in range(len(gt_labels_3d)) + ] + + seed_indices = [ + bbox_preds['seed_indices'][i] for i in range(len(gt_labels_3d)) + ] + + candidate_indices = [ + bbox_preds['query_points_sample_inds'][i] + for i in range(len(gt_labels_3d)) + ] + + (sampling_targets, assigned_size_targets, size_class_targets, + size_res_targets, dir_class_targets, dir_res_targets, center_targets, + assigned_center_targets, mask_targets, objectness_targets, + objectness_masks) = multi_apply(self.get_targets_single, points, + gt_bboxes_3d, gt_labels_3d, + pts_semantic_mask, pts_instance_mask, + max_gt_nums, seed_points, + seed_indices, candidate_indices) + + # pad targets as original code of GroupFree3D. + for index in range(len(gt_labels_3d)): + pad_num = max_gt_num - gt_labels_3d[index].shape[0] + valid_gt_masks[index] = F.pad(valid_gt_masks[index], (0, pad_num)) + + sampling_targets = torch.stack(sampling_targets) + sampling_weights = (sampling_targets >= 0).float() + sampling_normalizer = sampling_weights.sum(dim=1, keepdim=True).float() + sampling_weights /= sampling_normalizer.clamp(min=1.0) + + assigned_size_targets = torch.stack(assigned_size_targets) + center_targets = torch.stack(center_targets) + valid_gt_masks = torch.stack(valid_gt_masks) + + assigned_center_targets = torch.stack(assigned_center_targets) + objectness_targets = torch.stack(objectness_targets) + + objectness_weights = torch.stack(objectness_masks) + cls_normalizer = objectness_weights.sum(dim=1, keepdim=True).float() + objectness_weights /= cls_normalizer.clamp(min=1.0) + + box_loss_weights = objectness_targets.float() / ( + objectness_targets.sum().float() + EPS) + + valid_gt_weights = valid_gt_masks.float() / ( + valid_gt_masks.sum().float() + EPS) + + dir_class_targets = torch.stack(dir_class_targets) + dir_res_targets = torch.stack(dir_res_targets) + size_class_targets = torch.stack(size_class_targets) + size_res_targets = torch.stack(size_res_targets) + mask_targets = torch.stack(mask_targets) + + return (sampling_targets, sampling_weights, assigned_size_targets, + size_class_targets, size_res_targets, dir_class_targets, + dir_res_targets, center_targets, assigned_center_targets, + mask_targets, valid_gt_masks, objectness_targets, + objectness_weights, box_loss_weights, valid_gt_weights) + + def get_targets_single(self, + points, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + max_gt_nums=None, + seed_points=None, + seed_indices=None, + candidate_indices=None, + seed_points_obj_topk=4): + """Generate targets of GroupFree3D head for single batch. + + Args: + points (torch.Tensor): Points of each batch. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth \ + boxes of each batch. + gt_labels_3d (torch.Tensor): Labels of each batch. + pts_semantic_mask (None | torch.Tensor): Point-wise semantic + label of each batch. + pts_instance_mask (None | torch.Tensor): Point-wise instance + label of each batch. + max_gt_nums (int): Max number of GTs for single batch. + seed_points (torch.Tensor): Coordinates of seed points. + seed_indices (torch.Tensor): Indices of seed points. + candidate_indices (torch.Tensor): Indices of object candidates. + seed_points_obj_topk (int): k value of k-Closest Points Sampling. + + Returns: + tuple[torch.Tensor]: Targets of GroupFree3D head. + """ + + assert self.bbox_coder.with_rot or pts_semantic_mask is not None + + gt_bboxes_3d = gt_bboxes_3d.to(points.device) + + # generate center, dir, size target + (center_targets, size_targets, size_class_targets, size_res_targets, + dir_class_targets, + dir_res_targets) = self.bbox_coder.encode(gt_bboxes_3d, gt_labels_3d) + + # pad targets as original code of GroupFree3D + pad_num = max_gt_nums - gt_labels_3d.shape[0] + box_label_mask = points.new_zeros([max_gt_nums]) + box_label_mask[:gt_labels_3d.shape[0]] = 1 + + gt_bboxes_pad = F.pad(gt_bboxes_3d.tensor, (0, 0, 0, pad_num)) + gt_bboxes_pad[gt_labels_3d.shape[0]:, 0:3] += 1000 + gt_bboxes_3d = gt_bboxes_3d.new_box(gt_bboxes_pad) + + gt_labels_3d = F.pad(gt_labels_3d, (0, pad_num)) + + center_targets = F.pad(center_targets, (0, 0, 0, pad_num), value=1000) + size_targets = F.pad(size_targets, (0, 0, 0, pad_num)) + size_class_targets = F.pad(size_class_targets, (0, pad_num)) + size_res_targets = F.pad(size_res_targets, (0, 0, 0, pad_num)) + dir_class_targets = F.pad(dir_class_targets, (0, pad_num)) + dir_res_targets = F.pad(dir_res_targets, (0, pad_num)) + + # 0. generate pts_instance_label and pts_obj_mask + num_points = points.shape[0] + pts_obj_mask = points.new_zeros([num_points], dtype=torch.long) + pts_instance_label = points.new_zeros([num_points], + dtype=torch.long) - 1 + + if self.bbox_coder.with_rot: + vote_targets = points.new_zeros([num_points, 4 * self.gt_per_seed]) + vote_target_idx = points.new_zeros([num_points], dtype=torch.long) + box_indices_all = gt_bboxes_3d.points_in_boxes(points) + for i in range(gt_labels_3d.shape[0]): + box_indices = box_indices_all[:, i] + indices = torch.nonzero( + box_indices, as_tuple=False).squeeze(-1) + selected_points = points[indices] + pts_obj_mask[indices] = 1 + vote_targets_tmp = vote_targets[indices] + votes = gt_bboxes_3d.gravity_center[i].unsqueeze( + 0) - selected_points[:, :3] + + for j in range(self.gt_per_seed): + column_indices = torch.nonzero( + vote_target_idx[indices] == j, + as_tuple=False).squeeze(-1) + vote_targets_tmp[column_indices, + int(j * 3):int(j * 3 + + 3)] = votes[column_indices] + vote_targets_tmp[column_indices, + j + 3 * self.gt_per_seed] = i + if j == 0: + vote_targets_tmp[ + column_indices, :3 * + self.gt_per_seed] = votes[column_indices].repeat( + 1, self.gt_per_seed) + vote_targets_tmp[column_indices, + 3 * self.gt_per_seed:] = i + + vote_targets[indices] = vote_targets_tmp + vote_target_idx[indices] = torch.clamp( + vote_target_idx[indices] + 1, max=2) + + dist = points.new_zeros([num_points, self.gt_per_seed]) + 1000 + for j in range(self.gt_per_seed): + dist[:, j] = (vote_targets[:, 3 * j:3 * j + 3]**2).sum(-1) + + instance_indices = torch.argmin( + dist, dim=-1).unsqueeze(-1) + 3 * self.gt_per_seed + instance_lable = torch.gather(vote_targets, 1, + instance_indices).squeeze(-1) + pts_instance_label = instance_lable.long() + pts_instance_label[pts_obj_mask == 0] = -1 + + elif pts_semantic_mask is not None: + for i in torch.unique(pts_instance_mask): + indices = torch.nonzero( + pts_instance_mask == i, as_tuple=False).squeeze(-1) + + if pts_semantic_mask[indices[0]] < self.num_classes: + selected_points = points[indices, :3] + center = 0.5 * ( + selected_points.min(0)[0] + selected_points.max(0)[0]) + + delta_xyz = center - center_targets + instance_lable = torch.argmin((delta_xyz**2).sum(-1)) + pts_instance_label[indices] = instance_lable + pts_obj_mask[indices] = 1 + + else: + raise NotImplementedError + + # 1. generate objectness targets in sampling head + gt_num = gt_labels_3d.shape[0] + num_seed = seed_points.shape[0] + num_candidate = candidate_indices.shape[0] + + object_assignment = torch.gather(pts_instance_label, 0, seed_indices) + # set background points to the last gt bbox as original code + object_assignment[object_assignment < 0] = gt_num - 1 + object_assignment_one_hot = gt_bboxes_3d.tensor.new_zeros( + (num_seed, gt_num)) + object_assignment_one_hot.scatter_(1, object_assignment.unsqueeze(-1), + 1) # (num_seed, gt_num) + + delta_xyz = seed_points.unsqueeze( + 1) - gt_bboxes_3d.gravity_center.unsqueeze( + 0) # (num_seed, gt_num, 3) + delta_xyz = delta_xyz / (gt_bboxes_3d.dims.unsqueeze(0) + EPS) + + new_dist = torch.sum(delta_xyz**2, dim=-1) + euclidean_dist1 = torch.sqrt(new_dist + EPS) + euclidean_dist1 = euclidean_dist1 * object_assignment_one_hot + 100 * ( + 1 - object_assignment_one_hot) + # (gt_num, num_seed) + euclidean_dist1 = euclidean_dist1.permute(1, 0) + + # gt_num x topk + topk_inds = torch.topk( + euclidean_dist1, + seed_points_obj_topk, + largest=False)[1] * box_label_mask[:, None] + \ + (box_label_mask[:, None] - 1) + topk_inds = topk_inds.long() + topk_inds = topk_inds.view(-1).contiguous() + + sampling_targets = torch.zeros( + num_seed + 1, dtype=torch.long).to(points.device) + sampling_targets[topk_inds] = 1 + sampling_targets = sampling_targets[:num_seed] + # pts_instance_label + objectness_label_mask = torch.gather(pts_instance_label, 0, + seed_indices) # num_seed + sampling_targets[objectness_label_mask < 0] = 0 + + # 2. objectness target + seed_obj_gt = torch.gather(pts_obj_mask, 0, seed_indices) # num_seed + objectness_targets = torch.gather(seed_obj_gt, 0, + candidate_indices) # num_candidate + + # 3. box target + seed_instance_label = torch.gather(pts_instance_label, 0, + seed_indices) # num_seed + query_points_instance_label = torch.gather( + seed_instance_label, 0, candidate_indices) # num_candidate + + # Set assignment + # (num_candidate, ) with values in 0,1,...,gt_num-1 + assignment = query_points_instance_label + # set background points to the last gt bbox as original code + assignment[assignment < 0] = gt_num - 1 + assignment_expand = assignment.unsqueeze(1).expand(-1, 3) + + assigned_center_targets = center_targets[assignment] + assigned_size_targets = size_targets[assignment] + + dir_class_targets = dir_class_targets[assignment] + dir_res_targets = dir_res_targets[assignment] + dir_res_targets /= (np.pi / self.num_dir_bins) + + size_class_targets = size_class_targets[assignment] + size_res_targets = \ + torch.gather(size_res_targets, 0, assignment_expand) + one_hot_size_targets = gt_bboxes_3d.tensor.new_zeros( + (num_candidate, self.num_sizes)) + one_hot_size_targets.scatter_(1, size_class_targets.unsqueeze(-1), 1) + one_hot_size_targets = one_hot_size_targets.unsqueeze(-1).expand( + -1, -1, 3) # (num_candidate,num_size_cluster,3) + mean_sizes = size_res_targets.new_tensor( + self.bbox_coder.mean_sizes).unsqueeze(0) + pos_mean_sizes = torch.sum(one_hot_size_targets * mean_sizes, 1) + size_res_targets /= pos_mean_sizes + + mask_targets = gt_labels_3d[assignment].long() + + objectness_masks = points.new_ones((num_candidate)) + + return (sampling_targets, assigned_size_targets, size_class_targets, + size_res_targets, dir_class_targets, dir_res_targets, + center_targets, assigned_center_targets, mask_targets, + objectness_targets, objectness_masks) + + def get_bboxes(self, + points, + bbox_preds, + input_metas, + rescale=False, + use_nms=True): + """Generate bboxes from GroupFree3D head predictions. + + Args: + points (torch.Tensor): Input points. + bbox_preds (dict): Predictions from GroupFree3D head. + input_metas (list[dict]): Point cloud and image's meta info. + rescale (bool): Whether to rescale bboxes. + use_nms (bool): Whether to apply NMS, skip nms postprocessing + while using GroupFree3D head in rpn stage. + + Returns: + list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. + """ + # support multi-stage predicitons + assert self.test_cfg['prediction_stages'] in \ + ['last', 'all', 'last_three'] + + prefixes = list() + if self.test_cfg['prediction_stages'] == 'last': + prefixes = [f'_{self.num_decoder_layers - 1}'] + elif self.test_cfg['prediction_stages'] == 'all': + prefixes = ['proposal.'] + \ + [f's{i}.' for i in range(self.num_decoder_layers)] + elif self.test_cfg['prediction_stages'] == 'last_three': + prefixes = [ + f's{i}.' for i in range(self.num_decoder_layers - + 3, self.num_decoder_layers) + ] + else: + raise NotImplementedError + + obj_scores = list() + sem_scores = list() + bbox3d = list() + for prefix in prefixes: + # decode boxes + obj_score = bbox_preds[f'{prefix}obj_scores'][..., -1].sigmoid() + sem_score = bbox_preds[f'{prefix}sem_scores'].softmax(-1) + bbox = self.bbox_coder.decode(bbox_preds, prefix) + obj_scores.append(obj_score) + sem_scores.append(sem_score) + bbox3d.append(bbox) + + obj_scores = torch.cat(obj_scores, dim=1) + sem_scores = torch.cat(sem_scores, dim=1) + bbox3d = torch.cat(bbox3d, dim=1) + + if use_nms: + batch_size = bbox3d.shape[0] + results = list() + for b in range(batch_size): + bbox_selected, score_selected, labels = \ + self.multiclass_nms_single(obj_scores[b], sem_scores[b], + bbox3d[b], points[b, ..., :3], + input_metas[b]) + bbox = input_metas[b]['box_type_3d']( + bbox_selected, + box_dim=bbox_selected.shape[-1], + with_yaw=self.bbox_coder.with_rot) + results.append((bbox, score_selected, labels)) + + return results + else: + return bbox3d + + def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, + input_meta): + """Multi-class nms in single batch. + + Args: + obj_scores (torch.Tensor): Objectness score of bounding boxes. + sem_scores (torch.Tensor): semantic class score of bounding boxes. + bbox (torch.Tensor): Predicted bounding boxes. + points (torch.Tensor): Input points. + input_meta (dict): Point cloud and image's meta info. + + Returns: + tuple[torch.Tensor]: Bounding boxes, scores and labels. + """ + bbox = input_meta['box_type_3d']( + bbox, + box_dim=bbox.shape[-1], + with_yaw=self.bbox_coder.with_rot, + origin=(0.5, 0.5, 0.5)) + box_indices = bbox.points_in_boxes(points) + + corner3d = bbox.corners + minmax_box3d = corner3d.new(torch.Size((corner3d.shape[0], 6))) + minmax_box3d[:, :3] = torch.min(corner3d, dim=1)[0] + minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0] + + nonempty_box_mask = box_indices.T.sum(1) > 5 + + bbox_classes = torch.argmax(sem_scores, -1) + nms_selected = aligned_3d_nms(minmax_box3d[nonempty_box_mask], + obj_scores[nonempty_box_mask], + bbox_classes[nonempty_box_mask], + self.test_cfg.nms_thr) + + # filter empty boxes and boxes with low score + scores_mask = (obj_scores > self.test_cfg.score_thr) + nonempty_box_inds = torch.nonzero( + nonempty_box_mask, as_tuple=False).flatten() + nonempty_mask = torch.zeros_like(bbox_classes).scatter( + 0, nonempty_box_inds[nms_selected], 1) + selected = (nonempty_mask.bool() & scores_mask.bool()) + + if self.test_cfg.per_class_proposal: + bbox_selected, score_selected, labels = [], [], [] + for k in range(sem_scores.shape[-1]): + bbox_selected.append(bbox[selected].tensor) + score_selected.append(obj_scores[selected] * + sem_scores[selected][:, k]) + labels.append( + torch.zeros_like(bbox_classes[selected]).fill_(k)) + bbox_selected = torch.cat(bbox_selected, 0) + score_selected = torch.cat(score_selected, 0) + labels = torch.cat(labels, 0) + else: + bbox_selected = bbox[selected].tensor + score_selected = obj_scores[selected] + labels = bbox_classes[selected] + + return bbox_selected, score_selected, labels diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index cd0aba794a..b295c0aa06 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -2,6 +2,7 @@ from .centerpoint import CenterPoint from .dynamic_voxelnet import DynamicVoxelNet from .fcos_mono3d import FCOSMono3D +from .groupfree3dnet import GroupFree3DNet from .h3dnet import H3DNet from .imvotenet import ImVoteNet from .imvoxelnet import ImVoxelNet @@ -17,5 +18,5 @@ 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', - 'FCOSMono3D', 'ImVoxelNet' + 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet' ] diff --git a/mmdet3d/models/detectors/groupfree3dnet.py b/mmdet3d/models/detectors/groupfree3dnet.py new file mode 100644 index 0000000000..e50aec5ecb --- /dev/null +++ b/mmdet3d/models/detectors/groupfree3dnet.py @@ -0,0 +1,104 @@ +import torch + +from mmdet3d.core import bbox3d2result, merge_aug_bboxes_3d +from mmdet.models import DETECTORS +from .single_stage import SingleStage3DDetector + + +@DETECTORS.register_module() +class GroupFree3DNet(SingleStage3DDetector): + """`Group-Free 3D `_.""" + + def __init__(self, + backbone, + bbox_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None): + super(GroupFree3DNet, self).__init__( + backbone=backbone, + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained) + + def forward_train(self, + points, + img_metas, + gt_bboxes_3d, + gt_labels_3d, + pts_semantic_mask=None, + pts_instance_mask=None, + gt_bboxes_ignore=None): + """Forward of training. + + Args: + points (list[torch.Tensor]): Points of each batch. + img_metas (list): Image metas. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. + pts_semantic_mask (None | list[torch.Tensor]): point-wise semantic + label of each batch. + pts_instance_mask (None | list[torch.Tensor]): point-wise instance + label of each batch. + gt_bboxes_ignore (None | list[torch.Tensor]): Specify + which bounding. + + Returns: + dict[str: torch.Tensor]: Losses. + """ + # TODO: refactor votenet series to reduce redundant codes. + points_cat = torch.stack(points) + + x = self.extract_feat(points_cat) + bbox_preds = self.bbox_head(x, self.train_cfg.sample_mod) + loss_inputs = (points, gt_bboxes_3d, gt_labels_3d, pts_semantic_mask, + pts_instance_mask, img_metas) + losses = self.bbox_head.loss( + bbox_preds, *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore) + return losses + + def simple_test(self, points, img_metas, imgs=None, rescale=False): + """Forward of testing. + + Args: + points (list[torch.Tensor]): Points of each sample. + img_metas (list): Image metas. + rescale (bool): Whether to rescale results. + Returns: + list: Predicted 3d boxes. + """ + points_cat = torch.stack(points) + + x = self.extract_feat(points_cat) + bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) + bbox_list = self.bbox_head.get_bboxes( + points_cat, bbox_preds, img_metas, rescale=rescale) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def aug_test(self, points, img_metas, imgs=None, rescale=False): + """Test with augmentation.""" + points_cat = [torch.stack(pts) for pts in points] + feats = self.extract_feats(points_cat, img_metas) + + # only support aug_test for one sample + aug_bboxes = [] + for x, pts_cat, img_meta in zip(feats, points_cat, img_metas): + bbox_preds = self.bbox_head(x, self.test_cfg.sample_mod) + bbox_list = self.bbox_head.get_bboxes( + pts_cat, bbox_preds, img_meta, rescale=rescale) + bbox_list = [ + dict(boxes_3d=bboxes, scores_3d=scores, labels_3d=labels) + for bboxes, scores, labels in bbox_list + ] + aug_bboxes.append(bbox_list[0]) + + # after merging, bboxes will be rescaled to the original image size + merged_bboxes = merge_aug_bboxes_3d(aug_bboxes, img_metas, + self.bbox_head.test_cfg) + + return [merged_bboxes] diff --git a/mmdet3d/models/model_utils/__init__.py b/mmdet3d/models/model_utils/__init__.py index b8276279b9..87f73d27a9 100644 --- a/mmdet3d/models/model_utils/__init__.py +++ b/mmdet3d/models/model_utils/__init__.py @@ -1,3 +1,4 @@ +from .transformer import GroupFree3DMHA from .vote_module import VoteModule -__all__ = ['VoteModule'] +__all__ = ['VoteModule', 'GroupFree3DMHA'] diff --git a/mmdet3d/models/model_utils/transformer.py b/mmdet3d/models/model_utils/transformer.py new file mode 100644 index 0000000000..2db8a2859d --- /dev/null +++ b/mmdet3d/models/model_utils/transformer.py @@ -0,0 +1,137 @@ +from mmcv.cnn.bricks.registry import ATTENTION +from mmcv.cnn.bricks.transformer import POSITIONAL_ENCODING, MultiheadAttention +from torch import nn as nn + + +@ATTENTION.register_module() +class GroupFree3DMHA(MultiheadAttention): + """A warpper for torch.nn.MultiheadAttention for GroupFree3D. + + This module implements MultiheadAttention with identity connection, + and positional encoding used in DETR is also passed as input. + + Args: + embed_dims (int): The embedding dimension. + num_heads (int): Parallel attention heads. Same as + `nn.MultiheadAttention`. + attn_drop (float): A Dropout layer on attn_output_weights. Default 0.0. + proj_drop (float): A Dropout layer. Default 0.0. + dropout_layer (obj:`ConfigDict`): The dropout_layer used + when adding the shortcut. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + """ + + def __init__(self, + embed_dims, + num_heads, + attn_drop=0., + proj_drop=0., + dropout_layer=dict(type='DropOut', drop_prob=0.), + init_cfg=None, + batch_first=False, + **kwargs): + super().__init__(embed_dims, num_heads, attn_drop, proj_drop, + dropout_layer, init_cfg, batch_first, **kwargs) + + def forward(self, + query, + key, + value, + identity, + query_pos=None, + key_pos=None, + attn_mask=None, + key_padding_mask=None, + **kwargs): + """Forward function for `GroupFree3DMHA`. + + **kwargs allow passing a more general data flow when combining + with other operations in `transformerlayer`. + + Args: + query (Tensor): The input query with shape [num_queries, bs, + embed_dims]. Same in `nn.MultiheadAttention.forward`. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims]. Same in `nn.MultiheadAttention.forward`. + If None, the ``query`` will be used. Defaults to None. + value (Tensor): The value tensor with same shape as `key`. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + If None, the `key` will be used. + identity (Tensor): This tensor, with the same shape as x, + will be used for the identity link. + If None, `x` will be used. Defaults to None. + query_pos (Tensor): The positional encoding for query, with + the same shape as `x`. If not None, it will + be added to `x` before forward function. Defaults to None. + key_pos (Tensor): The positional encoding for `key`, with the + same shape as `key`. Defaults to None. If not None, it will + be added to `key` before forward function. If None, and + `query_pos` has the same shape as `key`, then `query_pos` + will be used for `key_pos`. Defaults to None. + attn_mask (Tensor): ByteTensor mask with shape [num_queries, + num_keys]. Same in `nn.MultiheadAttention.forward`. + Defaults to None. + key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys]. + Same in `nn.MultiheadAttention.forward`. Defaults to None. + + Returns: + Tensor: forwarded results with shape [num_queries, bs, embed_dims]. + """ + + if hasattr(self, 'operation_name'): + if self.operation_name == 'self_attn': + value = value + query_pos + elif self.operation_name == 'cross_attn': + value = value + key_pos + else: + raise NotImplementedError( + f'{self.__class__.name} ' + f"can't be used as {self.operation_name}") + else: + value = value + query_pos + + return super(GroupFree3DMHA, self).forward( + query=query, + key=key, + value=value, + identity=identity, + query_pos=query_pos, + key_pos=key_pos, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + **kwargs) + + +@POSITIONAL_ENCODING.register_module() +class ConvBNPositionalEncoding(nn.Module): + """Absolute position embedding with Conv learning. + + Args: + input_channel (int): input features dim. + num_pos_feats (int): output position features dim. + Defaults to 288 to be consistent with seed features dim. + """ + + def __init__(self, input_channel, num_pos_feats=288): + super().__init__() + self.position_embedding_head = nn.Sequential( + nn.Conv1d(input_channel, num_pos_feats, kernel_size=1), + nn.BatchNorm1d(num_pos_feats), nn.ReLU(inplace=True), + nn.Conv1d(num_pos_feats, num_pos_feats, kernel_size=1)) + + def forward(self, xyz): + """Forward pass. + + Args: + xyz (Tensor): (B, N, 3) the coordinates to embed. + + Returns: + Tensor: (B, num_pos_feats, N) the embeded position features. + """ + xyz = xyz.permute(0, 2, 1) + position_embedding = self.position_embedding_head(xyz) + return position_embedding diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index bff2875e2a..828ecc83a9 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -379,6 +379,59 @@ def test_fcos3d(): assert attrs_3d.shape[0] >= 0 +def test_groupfree3dnet(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + _setup_seed(0) + groupfree3d_cfg = _get_detector_cfg( + 'groupfree3d/groupfree3d_8x8_scannet-3d-18class-L6-O256.py') + self = build_detector(groupfree3d_cfg).cuda() + + points_0 = torch.rand([50000, 3], device='cuda') + points_1 = torch.rand([50000, 3], device='cuda') + points = [points_0, points_1] + img_meta_0 = dict(box_type_3d=DepthInstance3DBoxes) + img_meta_1 = dict(box_type_3d=DepthInstance3DBoxes) + img_metas = [img_meta_0, img_meta_1] + gt_bbox_0 = DepthInstance3DBoxes(torch.rand([10, 7], device='cuda')) + gt_bbox_1 = DepthInstance3DBoxes(torch.rand([10, 7], device='cuda')) + gt_bboxes = [gt_bbox_0, gt_bbox_1] + gt_labels_0 = torch.randint(0, 18, [10], device='cuda') + gt_labels_1 = torch.randint(0, 18, [10], device='cuda') + gt_labels = [gt_labels_0, gt_labels_1] + pts_instance_mask_1 = torch.randint(0, 10, [50000], device='cuda') + pts_instance_mask_2 = torch.randint(0, 10, [50000], device='cuda') + pts_instance_mask = [pts_instance_mask_1, pts_instance_mask_2] + pts_semantic_mask_1 = torch.randint(0, 19, [50000], device='cuda') + pts_semantic_mask_2 = torch.randint(0, 19, [50000], device='cuda') + pts_semantic_mask = [pts_semantic_mask_1, pts_semantic_mask_2] + + # test forward_train + losses = self.forward_train(points, img_metas, gt_bboxes, gt_labels, + pts_semantic_mask, pts_instance_mask) + + assert losses['sampling_objectness_loss'] >= 0 + assert losses['s5.objectness_loss'] >= 0 + assert losses['s5.semantic_loss'] >= 0 + assert losses['s5.center_loss'] >= 0 + assert losses['s5.dir_class_loss'] >= 0 + assert losses['s5.dir_res_loss'] >= 0 + assert losses['s5.size_class_loss'] >= 0 + assert losses['s5.size_res_loss'] >= 0 + + # test simple_test + with torch.no_grad(): + results = self.simple_test(points, img_metas) + boxes_3d = results[0]['boxes_3d'] + scores_3d = results[0]['scores_3d'] + labels_3d = results[0]['labels_3d'] + assert boxes_3d.tensor.shape[0] >= 0 + assert boxes_3d.tensor.shape[1] == 7 + assert scores_3d.shape[0] >= 0 + assert labels_3d.shape[0] >= 0 + + def test_imvoxelnet(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') diff --git a/tests/test_models/test_heads/test_heads.py b/tests/test_models/test_heads/test_heads.py index 0f0c33f4b8..5bd82a0295 100644 --- a/tests/test_models/test_heads/test_heads.py +++ b/tests/test_models/test_heads/test_heads.py @@ -1114,3 +1114,110 @@ def test_fcos_mono3d_head(): assert results[0][1].shape == torch.Size([200]) assert results[0][2].shape == torch.Size([200]) assert results[0][3].shape == torch.Size([200]) + + +def test_groupfree3d_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + _setup_seed(0) + vote_head_cfg = _get_vote_head_cfg( + 'groupfree3d/groupfree3d_8x8_scannet-3d-18class-L6-O256.py') + self = build_head(vote_head_cfg).cuda() + + fp_xyz = [torch.rand([2, 256, 3], dtype=torch.float32).cuda()] + fp_features = [torch.rand([2, 288, 256], dtype=torch.float32).cuda()] + fp_indices = [torch.randint(0, 128, [2, 256]).cuda()] + + input_dict = dict( + fp_xyz=fp_xyz, fp_features=fp_features, fp_indices=fp_indices) + + # test forward + ret_dict = self(input_dict, 'kps') + assert ret_dict['seeds_obj_cls_logits'].shape == torch.Size([2, 1, 256]) + assert ret_dict['s5.center'].shape == torch.Size([2, 128, 3]) + assert ret_dict['s5.dir_class'].shape == torch.Size([2, 128, 1]) + assert ret_dict['s5.dir_res'].shape == torch.Size([2, 128, 1]) + assert ret_dict['s5.size_class'].shape == torch.Size([2, 128, 18]) + assert ret_dict['s5.size_res'].shape == torch.Size([2, 128, 18, 3]) + assert ret_dict['s5.obj_scores'].shape == torch.Size([2, 128, 1]) + assert ret_dict['s5.sem_scores'].shape == torch.Size([2, 128, 18]) + + # test losses + points = [torch.rand([50000, 4], device='cuda') for i in range(2)] + gt_bbox1 = torch.rand([10, 7], dtype=torch.float32).cuda() + gt_bbox2 = torch.rand([10, 7], dtype=torch.float32).cuda() + + gt_bbox1 = DepthInstance3DBoxes(gt_bbox1) + gt_bbox2 = DepthInstance3DBoxes(gt_bbox2) + gt_bboxes = [gt_bbox1, gt_bbox2] + + pts_instance_mask_1 = torch.randint(0, 10, [50000], device='cuda') + pts_instance_mask_2 = torch.randint(0, 10, [50000], device='cuda') + pts_instance_mask = [pts_instance_mask_1, pts_instance_mask_2] + + pts_semantic_mask_1 = torch.randint(0, 19, [50000], device='cuda') + pts_semantic_mask_2 = torch.randint(0, 19, [50000], device='cuda') + pts_semantic_mask = [pts_semantic_mask_1, pts_semantic_mask_2] + + labels_1 = torch.randint(0, 18, [10], device='cuda') + labels_2 = torch.randint(0, 18, [10], device='cuda') + gt_labels = [labels_1, labels_2] + + losses = self.loss(ret_dict, points, gt_bboxes, gt_labels, + pts_semantic_mask, pts_instance_mask) + + assert losses['s5.objectness_loss'] >= 0 + assert losses['s5.semantic_loss'] >= 0 + assert losses['s5.center_loss'] >= 0 + assert losses['s5.dir_class_loss'] >= 0 + assert losses['s5.dir_res_loss'] >= 0 + assert losses['s5.size_class_loss'] >= 0 + assert losses['s5.size_res_loss'] >= 0 + + # test multiclass_nms_single + obj_scores = torch.rand([256], device='cuda') + sem_scores = torch.rand([256, 18], device='cuda') + points = torch.rand([50000, 3], device='cuda') + bbox = torch.rand([256, 7], device='cuda') + input_meta = dict(box_type_3d=DepthInstance3DBoxes) + bbox_selected, score_selected, labels = \ + self.multiclass_nms_single(obj_scores, + sem_scores, + bbox, + points, + input_meta) + assert bbox_selected.shape[0] >= 0 + assert bbox_selected.shape[1] == 7 + assert score_selected.shape[0] >= 0 + assert labels.shape[0] >= 0 + + # test get_boxes + points = torch.rand([1, 50000, 3], device='cuda') + seed_points = torch.rand([1, 1024, 3], device='cuda') + seed_indices = torch.randint(0, 50000, [1, 1024], device='cuda') + obj_scores = torch.rand([1, 256, 1], device='cuda') + center = torch.rand([1, 256, 3], device='cuda') + dir_class = torch.rand([1, 256, 1], device='cuda') + dir_res_norm = torch.rand([1, 256, 1], device='cuda') + dir_res = torch.rand([1, 256, 1], device='cuda') + size_class = torch.rand([1, 256, 18], device='cuda') + size_res = torch.rand([1, 256, 18, 3], device='cuda') + sem_scores = torch.rand([1, 256, 18], device='cuda') + bbox_preds = dict() + bbox_preds['seed_points'] = seed_points + bbox_preds['seed_indices'] = seed_indices + bbox_preds['s5.obj_scores'] = obj_scores + bbox_preds['s5.center'] = center + bbox_preds['s5.dir_class'] = dir_class + bbox_preds['s5.dir_res_norm'] = dir_res_norm + bbox_preds['s5.dir_res'] = dir_res + bbox_preds['s5.size_class'] = size_class + bbox_preds['s5.size_res'] = size_res + bbox_preds['s5.sem_scores'] = sem_scores + + self.test_cfg['prediction_stages'] == 'last' + results = self.get_bboxes(points, bbox_preds, [input_meta]) + assert results[0][0].tensor.shape[0] >= 0 + assert results[0][0].tensor.shape[1] == 7 + assert results[0][1].shape[0] >= 0 + assert results[0][2].shape[0] >= 0