Skip to content

Commit

Permalink
[Fixed] modify vote_head to support 3dssd (#396)
Browse files Browse the repository at this point in the history
* modify vote_head to support 3dssd

* delete .keys()

* add 3dssd unittest

* 3dssd->ssd3d
  • Loading branch information
xiliu8006 authored Mar 31, 2021
1 parent 043eb99 commit 7c30072
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 25 deletions.
16 changes: 13 additions & 3 deletions mmdet3d/models/dense_heads/vote_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,19 @@ def _extract_input(self, feat_dict):
torch.Tensor: Features of input points.
torch.Tensor: Indices of input points.
"""
seed_points = feat_dict['seed_points']
seed_features = feat_dict['seed_features']
seed_indices = feat_dict['seed_indices']

# for imvotenet
if 'seed_points' in feat_dict and \
'seed_features' in feat_dict and \
'seed_indices' in feat_dict:
seed_points = feat_dict['seed_points']
seed_features = feat_dict['seed_features']
seed_indices = feat_dict['seed_indices']
# for votenet
else:
seed_points = feat_dict['fp_xyz'][-1]
seed_features = feat_dict['fp_features'][-1]
seed_indices = feat_dict['fp_indices'][-1]

return seed_points, seed_features, seed_indices

Expand Down
22 changes: 0 additions & 22 deletions mmdet3d/models/detectors/votenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,28 +22,6 @@ def __init__(self,
test_cfg=test_cfg,
pretrained=pretrained)

def extract_feat(self, points, img_metas=None):
"""Directly extract features from the backbone+neck.
Args:
points (torch.Tensor): Input points.
"""
x = self.backbone(points)
if self.with_neck:
x = self.neck(x)

seed_points = x['fp_xyz'][-1]
seed_features = x['fp_features'][-1]
seed_indices = x['fp_indices'][-1]

feat_dict = {
'seed_points': seed_points,
'seed_features': seed_features,
'seed_indices': seed_indices
}

return feat_dict

def forward_train(self,
points,
img_metas,
Expand Down
41 changes: 41 additions & 0 deletions tests/test_models/test_detectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,47 @@ def test_voxel_net():
assert labels_3d.shape == torch.Size([50])


def test_3dssd():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
_setup_seed(0)
ssd3d_cfg = _get_detector_cfg('3dssd/3dssd_kitti-3d-car.py')
self = build_detector(ssd3d_cfg).cuda()
points_0 = torch.rand([2000, 4], device='cuda')
points_1 = torch.rand([2000, 4], device='cuda')
points = [points_0, points_1]
img_meta_0 = dict(box_type_3d=DepthInstance3DBoxes)
img_meta_1 = dict(box_type_3d=DepthInstance3DBoxes)
img_metas = [img_meta_0, img_meta_1]
gt_bbox_0 = DepthInstance3DBoxes(torch.rand([10, 7], device='cuda'))
gt_bbox_1 = DepthInstance3DBoxes(torch.rand([10, 7], device='cuda'))
gt_bboxes = [gt_bbox_0, gt_bbox_1]
gt_labels_0 = torch.randint(0, 10, [10], device='cuda')
gt_labels_1 = torch.randint(0, 10, [10], device='cuda')
gt_labels = [gt_labels_0, gt_labels_1]

# test forward_train
losses = self.forward_train(points, img_metas, gt_bboxes, gt_labels)
assert losses['vote_loss'] >= 0
assert losses['objectness_loss'] >= 0
assert losses['semantic_loss'] >= 0
assert losses['center_loss'] >= 0
assert losses['dir_class_loss'] >= 0
assert losses['dir_res_loss'] >= 0
assert losses['size_class_loss'] >= 0
assert losses['size_res_loss'] >= 0

# test simple_test
results = self.simple_test(points, img_metas)
boxes_3d = results[0]['boxes_3d']
scores_3d = results[0]['scores_3d']
labels_3d = results[0]['labels_3d']
assert boxes_3d.tensor.shape[0] >= 0
assert boxes_3d.tensor.shape[1] == 7
assert scores_3d.shape[0] >= 0
assert labels_3d.shape[0] >= 0


def test_vote_net():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
Expand Down

0 comments on commit 7c30072

Please sign in to comment.