Skip to content

Commit

Permalink
[Enhance] Support nuscenes demo (#353)
Browse files Browse the repository at this point in the history
* support nuscenes dataset in demo

* add convert_SyncBN in __init__

* fix meshlab visualization bug

* modify meshlab unittest

* add docstring

* add empty line in docstring

Co-authored-by: xiliu8006 <xiliu800@gmail.com>
  • Loading branch information
xiliu8006 and xiliu8006 authored Mar 26, 2021
1 parent 4eed122 commit 7684978
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 10 deletions.
5 changes: 3 additions & 2 deletions mmdet3d/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .inference import inference_detector, init_detector, show_result_meshlab
from .inference import (convert_SyncBN, inference_detector, init_detector,
show_result_meshlab)
from .test import single_gpu_test

__all__ = [
'inference_detector', 'init_detector', 'single_gpu_test',
'show_result_meshlab'
'show_result_meshlab', 'convert_SyncBN'
]
31 changes: 26 additions & 5 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@
from mmdet3d.models import build_detector


def convert_SyncBN(config):
"""Convert config's naiveSyncBN to BN.
Args:
config (str or :obj:`mmcv.Config`): Config file path or the config
object.
"""
if isinstance(config, dict):
for item in config:
if item == 'norm_cfg':
config[item]['type'] = config[item]['type']. \
replace('naiveSyncBN', 'BN')
else:
convert_SyncBN(config[item])


def init_detector(config, checkpoint=None, device='cuda:0'):
"""Initialize a detector from config file.
Expand All @@ -30,6 +46,7 @@ def init_detector(config, checkpoint=None, device='cuda:0'):
raise TypeError('config must be a filename or Config object, '
f'but got {type(config)}')
config.model.pretrained = None
convert_SyncBN(config.model)
config.model.train_cfg = None
model = build_detector(config.model, test_cfg=config.get('test_cfg'))
if checkpoint is not None:
Expand Down Expand Up @@ -64,6 +81,9 @@ def inference_detector(model, pcd):
pts_filename=pcd,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d,
sweeps=[],
# set timestamp = 0
timestamp=[0],
img_fields=[],
bbox3d_fields=[],
pts_mask_fields=[],
Expand Down Expand Up @@ -100,15 +120,16 @@ def show_result_meshlab(data, result, out_dir):

assert out_dir is not None, 'Expect out_dir, got none.'

pred_bboxes = result[0]['boxes_3d'].tensor.numpy()
if 'pts_bbox' in result[0].keys():
pred_bboxes = result[0]['pts_bbox']['boxes_3d'].tensor.numpy()
else:
pred_bboxes = result[0]['boxes_3d'].tensor.numpy()
# for now we convert points into depth mode
if data['img_metas'][0][0]['box_mode_3d'] != Box3DMode.DEPTH:
points = points[..., [1, 0, 2]]
points[..., 0] *= -1
pred_bboxes = Box3DMode.convert(pred_bboxes,
data['img_metas'][0][0]['box_mode_3d'],
Box3DMode.DEPTH)
pred_bboxes[..., 2] += pred_bboxes[..., 5] / 2
else:
pred_bboxes[..., 2] += pred_bboxes[..., 5] / 2
show_result(points, None, pred_bboxes, out_dir, file_name)
show_result(points, None, pred_bboxes, out_dir, file_name, show=False)
return out_dir, file_name
4 changes: 2 additions & 2 deletions mmdet3d/core/visualizer/show_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ def show_result(points, gt_bboxes, pred_bboxes, out_dir, filename, show=True):
filename (str): Filename of the current frame.
show (bool): Visualize the results online.
"""
from .open3d_vis import Visualizer

if show:
from .open3d_vis import Visualizer

vis = Visualizer(points)
if pred_bboxes is not None:
vis.add_bboxes(bbox3d=pred_bboxes)
Expand Down
50 changes: 49 additions & 1 deletion tests/test_runtime/test_apis.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import numpy as np
import os
import pytest
import tempfile
import torch
from mmcv.parallel import MMDataParallel
from os.path import dirname, exists, join

from mmdet3d.apis import inference_detector, init_detector, single_gpu_test
from mmdet3d.apis import (convert_SyncBN, inference_detector, init_detector,
show_result_meshlab, single_gpu_test)
from mmdet3d.core import Box3DMode
from mmdet3d.core.bbox import LiDARInstance3DBoxes
from mmdet3d.datasets import build_dataloader, build_dataset
from mmdet3d.models import build_detector

Expand Down Expand Up @@ -32,6 +38,48 @@ def _get_config_module(fname):
return config_mod


def test_convert_SyncBN():
cfg = _get_config_module(
'pointpillars/hv_pointpillars_fpn_sbn-all_4x8_2x_nus-3d.py')
model_cfg = cfg.model
convert_SyncBN(model_cfg)
assert model_cfg['pts_voxel_encoder']['norm_cfg']['type'] == 'BN1d'
assert model_cfg['pts_backbone']['norm_cfg']['type'] == 'BN2d'
assert model_cfg['pts_neck']['norm_cfg']['type'] == 'BN2d'


def test_show_result_meshlab():
pcd = 'tests/data/nuscenes/samples/LIDAR_TOP/n015-2018-08-02-17-16-37+' \
'0800__LIDAR_TOP__1533201470948018.pcd.bin'
box_3d = LiDARInstance3DBoxes(
torch.tensor(
[[8.7314, -1.8559, -1.5997, 0.4800, 1.2000, 1.8900, 0.0100]]))
labels_3d = torch.tensor([0])
scores_3d = torch.tensor([0.5])
points = np.random.rand(100, 4)
img_meta = dict(
pts_filename=pcd, boxes_3d=box_3d, box_mode_3d=Box3DMode.LIDAR)
data = dict(points=[[torch.tensor(points)]], img_metas=[[img_meta]])
result = [
dict(
pts_bbox=dict(
boxes_3d=box_3d, labels_3d=labels_3d, scores_3d=scores_3d))
]
temp_out_dir = tempfile.mkdtemp()
out_dir, file_name = show_result_meshlab(data, result, temp_out_dir)
expected_outfile_ply = file_name + '_pred.ply'
expected_outfile_obj = file_name + '_points.obj'
expected_outfile_ply_path = os.path.join(out_dir, file_name,
expected_outfile_ply)
expected_outfile_obj_path = os.path.join(out_dir, file_name,
expected_outfile_obj)
assert os.path.exists(expected_outfile_ply_path)
assert os.path.exists(expected_outfile_obj_path)
os.remove(expected_outfile_obj_path)
os.remove(expected_outfile_ply_path)
os.removedirs(os.path.join(temp_out_dir, file_name))


def test_inference_detector():
pcd = 'tests/data/kitti/training/velodyne_reduced/000000.bin'
detector_cfg = 'configs/pointpillars/hv_pointpillars_secfpn_' \
Expand Down

0 comments on commit 7684978

Please sign in to comment.