Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support multi-modality visualization (demos and dataset show function) #405

Merged
merged 16 commits into from
Apr 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@
pipeline=train_pipeline,
modality=input_modality,
classes=class_names,
test_mode=False)),
test_mode=False,
box_type_3d='LiDAR')),
val=dict(
type=dataset_type,
data_root=data_root,
Expand All @@ -212,7 +213,8 @@
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True),
test_mode=True,
box_type_3d='LiDAR'),
test=dict(
type=dataset_type,
data_root=data_root,
Expand All @@ -222,7 +224,8 @@
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True))
test_mode=True,
box_type_3d='LiDAR'))
# Training settings
optimizer = dict(type='AdamW', lr=0.003, betas=(0.95, 0.99), weight_decay=0.01)
# max_norm=10 is better for SECOND
Expand Down
File renamed without changes.
Binary file added demo/data/kitti/kitti_000008.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/data/kitti/kitti_000008_infos.pkl
Binary file not shown.
File renamed without changes.
Binary file added demo/data/sunrgbd/sunrgbd_000017.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added demo/data/sunrgbd/sunrgbd_000017_infos.pkl
Binary file not shown.
32 changes: 32 additions & 0 deletions demo/multi_modality_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from argparse import ArgumentParser

from mmdet3d.apis import (inference_multi_modality_detector, init_detector,
show_result_meshlab)


def main():
parser = ArgumentParser()
parser.add_argument('pcd', help='Point cloud file')
parser.add_argument('image', help='image file')
parser.add_argument('ann', help='ann file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.0, help='bbox score threshold')
parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results')
args = parser.parse_args()

# build the model from a config file and a checkpoint file
model = init_detector(args.config, args.checkpoint, device=args.device)
# test a single image
result, data = inference_multi_modality_detector(model, args.pcd,
args.image, args.ann)
# show the results
show_result_meshlab(data, result, args.out_dir, args.score_thr)


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions demo/pcd_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def main():
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--score-thr', type=float, default=0.6, help='bbox score threshold')
'--score-thr', type=float, default=0.0, help='bbox score threshold')
parser.add_argument(
'--out-dir', type=str, default='demo', help='dir to save results')
args = parser.parse_args()
Expand All @@ -21,7 +21,7 @@ def main():
# test a single image
result, data = inference_detector(model, args.pcd)
# show the results
show_result_meshlab(data, result, args.out_dir)
show_result_meshlab(data, result, args.out_dir, args.score_thr)


if __name__ == '__main__':
Expand Down
53 changes: 53 additions & 0 deletions docs/0_demo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Demo

## Introduction

We provide scipts for multi-modality/single-modality and indoor/outdoor 3D detection demos. The pre-trained models can be downloaded from [model zoo](../docs/model_zoo.md). We provide pre-processed sample data from KITTI and SUN RGB-D dataset. You can use any other data following our pre-processing steps.

## Testing

### Single-modality demo

To test a 3D detector on point cloud data, simply run:

```shell
python demo/pcd_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}]
```

The visualization results including a point cloud and predicted 3D bounding boxes will be saved in ```demo/PCD_NAME```, which you can open using [MeshLab](http://www.meshlab.net/).

Example on KITTI data using [SECOND](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/second) model:

```shell
python demo/pcd_demo.py demo/data/kitti/kitti_000008.bin configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-car_20200620_230238-393f000c.pth
```

Example on SUN RGB-D data using [VoteNet](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/votenet) model:

```shell
python demo/pcd_demo.py demo/data/sunrgbd/sunrgbd_000017.bin configs/votenet/votenet_16x8_sunrgbd-3d-10class.py checkpoints/votenet_16x8_sunrgbd-3d-10class_20200620_230238-4483c0c0.pth
```

Remember to convert the VoteNet checkpoint if you are using mmdetection3d version >= 0.6.0. See its [README](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/votenet/README.md) for detailed instructions on how to convert the checkpoint.

### Multi-modality demo

To test a 3D detector on multi-modality data (typically point cloud and image), simply run:

```shell
python demo/multi_modality_demo.py ${PCD_FILE} ${IMAGE_FILE} ${ANNOTATION_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}]
```

where the ```ANNOTATION_FILE``` should provide the 3D to 2D projection matrix. The visualization results including a point cloud, an image, predicted 3D bounding boxes and their projection on the image will be saved in ```demo/PCD_NAME```.

Example on KITTI data using [MVX-Net](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/mvxnet) model:

```shell
python demo/multi_modality_demo.py demo/data/kitti/kitti_000008.bin demo/data/kitti/kitti_000008.png demo/data/kitti/kitti_000008_infos.pkl configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py checkpoints/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class_20200621_003904-10140f2d.pth
```

Example on SUN RGB-D data using [ImVoteNet](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/imvotenet) model:

```shell
python demo/multi_modality_demo.py demo/data/sunrgbd/sunrgbd_000017.bin demo/data/sunrgbd/sunrgbd_000017.jpg demo/data/sunrgbd/sunrgbd_000017_infos.pkl configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py checkpoints/imvotenet_stage2_16x8_sunrgbd-3d-10class_20210323_184021-d44dcb66.pth
```
27 changes: 24 additions & 3 deletions docs/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ PYTHONPATH="$(dirname $0)/..":$PYTHONPATH

### Point cloud demo

We provide a demo script to test a single sample. Pre-trained models can be downloaded from [model zoo](model_zoo.md)
We provide several demo scripts to test a single sample. Pre-trained models can be downloaded from [model zoo](model_zoo.md). To test a single-modality 3D detection on point cloud scenes:

```shell
python demo/pcd_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device ${GPU_ID}] [--score-thr ${SCORE_THR}] [--out-dir ${OUT_DIR}]
Expand All @@ -203,16 +203,18 @@ python demo/pcd_demo.py ${PCD_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} [--device
Examples:

```shell
python demo/pcd_demo.py demo/kitti_000008.bin configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-car_20200620_230238-393f000c.pth
python demo/pcd_demo.py demo/data/kitti/kitti_000008.bin configs/second/hv_second_secfpn_6x8_80e_kitti-3d-car.py checkpoints/hv_second_secfpn_6x8_80e_kitti-3d-car_20200620_230238-393f000c.pth
```

If you want to input a `ply` file, you can use the following function and convert it to `bin` format. Then you can use the converted `bin` file to generate demo.
Note that you need to install pandas and plyfile before using this script. This function can also be used for data preprocessing for training ```ply data```.

```python
import numpy as np
import pandas as pd
from plyfile import PlyData

def conver_ply(input_path, output_path):
def convert_ply(input_path, output_path):
plydata = PlyData.read(input_path) # read file
data = plydata.elements[0].data # read data
data_pd = pd.DataFrame(data) # convert to DataFrame
Expand All @@ -223,12 +225,31 @@ def conver_ply(input_path, output_path):
data_np[:, i] = data_pd[name]
data_np.astype(np.float32).tofile(output_path)
```

Examples:

```python
convert_ply('./test.ply', './test.bin')
```

If you have point clouds in other format (`off`, `obj`, etc.), you can use trimesh to convert them into `ply`.

```python
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
import trimesh

def to_ply(input_path, output_path, original_type):
mesh = trimesh.load(input_path, file_type=original_type) # read file
mesh.export(output_path, file_type='ply') # convert to ply
```

Examples:

```python
to_ply('./test.obj', './test.ply', 'obj')
```

More demos about single/multi-modality and indoor/outdoor 3D detection can be found in [demo](0_demo.md).

## High-level APIs for testing point clouds

### Synchronous interface
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Welcome to MMDetection3D's documentation!
:maxdepth: 2
:caption: Quick Run

0_demo.md
1_exist_data_model.md
2_new_data_model.md

Expand Down
6 changes: 4 additions & 2 deletions mmdet3d/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from .inference import (convert_SyncBN, inference_detector, init_detector,
from .inference import (convert_SyncBN, inference_detector,
inference_multi_modality_detector, init_detector,
show_result_meshlab)
from .test import single_gpu_test

__all__ = [
'inference_detector', 'init_detector', 'single_gpu_test',
'show_result_meshlab', 'convert_SyncBN'
'show_result_meshlab', 'convert_SyncBN',
'inference_multi_modality_detector'
]
147 changes: 140 additions & 7 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import mmcv
import numpy as np
import re
import torch
from copy import deepcopy
from mmcv.parallel import collate, scatter
from mmcv.runner import load_checkpoint
from os import path as osp

from mmdet3d.core import Box3DMode, show_result
from mmdet3d.core import (Box3DMode, DepthInstance3DBoxes,
LiDARInstance3DBoxes, show_multi_modality_result,
show_result)
from mmdet3d.core.bbox import get_box_type
from mmdet3d.datasets.pipelines import Compose
from mmdet3d.models import build_detector
Expand Down Expand Up @@ -106,13 +110,89 @@ def inference_detector(model, pcd):
return result, data


def show_result_meshlab(data, result, out_dir):
def inference_multi_modality_detector(model, pcd, image, ann_file):
"""Inference point cloud with the multimodality detector.

Args:
model (nn.Module): The loaded detector.
pcd (str): Point cloud files.
image (str): Image files.
ann_file (str): Annotation files.

Returns:
tuple: Predicted results and data from pipeline.
"""
cfg = model.cfg
device = next(model.parameters()).device # model device
# build the data pipeline
test_pipeline = deepcopy(cfg.data.test.pipeline)
test_pipeline = Compose(test_pipeline)
box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d)
# get data info containing calib
data_infos = mmcv.load(ann_file)
image_idx = int(re.findall(r'\d+', image)[-1]) # xxx/sunrgbd_000017.jpg
for x in data_infos:
if int(x['image']['image_idx']) != image_idx:
continue
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
info = x
break
data = dict(
pts_filename=pcd,
img_prefix=osp.dirname(image),
img_info=dict(filename=osp.basename(image)),
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d,
img_fields=[],
bbox3d_fields=[],
pts_mask_fields=[],
pts_seg_fields=[],
bbox_fields=[],
mask_fields=[],
seg_fields=[])

# depth map points to image conversion
if box_mode_3d == Box3DMode.DEPTH:
data.update(dict(calib=info['calib']))

data = test_pipeline(data)

# LiDAR to image conversion
if box_mode_3d == Box3DMode.LIDAR:
rect = info['calib']['R0_rect'].astype(np.float32)
Trv2c = info['calib']['Tr_velo_to_cam'].astype(np.float32)
P2 = info['calib']['P2'].astype(np.float32)
lidar2img = P2 @ rect @ Trv2c
data['img_metas'][0].data['lidar2img'] = lidar2img
elif box_mode_3d == Box3DMode.DEPTH:
data['calib'][0]['Rt'] = data['calib'][0]['Rt'].astype(np.float32)
data['calib'][0]['K'] = data['calib'][0]['K'].astype(np.float32)

data = collate([data], samples_per_gpu=1)
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device.index])[0]
else:
# this is a workaround to avoid the bug of MMDataParallel
data['img_metas'] = data['img_metas'][0].data
data['points'] = data['points'][0].data
data['img'] = data['img'][0].data
if box_mode_3d == Box3DMode.DEPTH:
data['calib'] = data['calib'][0].data

# forward the model
with torch.no_grad():
result = model(return_loss=False, rescale=True, **data)
return result, data


def show_result_meshlab(data, result, out_dir, score_thr=0.0):
"""Show result by meshlab.

Args:
data (dict): Contain data from pipeline.
result (dict): Predicted result from model.
out_dir (str): Directory to save visualized result.
score_thr (float): Minimum score of bboxes to be shown. Default: 0.0
"""
points = data['points'][0][0].cpu().numpy()
pts_filename = data['img_metas'][0][0]['pts_filename']
Expand All @@ -122,14 +202,67 @@ def show_result_meshlab(data, result, out_dir):

if 'pts_bbox' in result[0].keys():
pred_bboxes = result[0]['pts_bbox']['boxes_3d'].tensor.numpy()
pred_scores = result[0]['pts_bbox']['scores_3d'].numpy()
else:
pred_bboxes = result[0]['boxes_3d'].tensor.numpy()
pred_scores = result[0]['scores_3d'].numpy()

# filter out low score bboxes for visualization
if score_thr > 0:
inds = pred_scores > score_thr
pred_bboxes = pred_bboxes[inds]

# for now we convert points into depth mode
if data['img_metas'][0][0]['box_mode_3d'] != Box3DMode.DEPTH:
box_mode = data['img_metas'][0][0]['box_mode_3d']
if box_mode != Box3DMode.DEPTH:
points = points[..., [1, 0, 2]]
points[..., 0] *= -1
pred_bboxes = Box3DMode.convert(pred_bboxes,
data['img_metas'][0][0]['box_mode_3d'],
Box3DMode.DEPTH)
show_result(points, None, pred_bboxes, out_dir, file_name, show=False)
show_bboxes = Box3DMode.convert(pred_bboxes, box_mode, Box3DMode.DEPTH)
else:
show_bboxes = deepcopy(pred_bboxes)
show_result(points, None, show_bboxes, out_dir, file_name, show=False)

if 'img' not in data.keys():
return out_dir, file_name

# multi-modality visualization
# project 3D bbox to 2D image plane
if box_mode == Box3DMode.LIDAR:
if 'lidar2img' not in data['img_metas'][0][0]:
raise NotImplementedError(
'LiDAR to image transformation matrix is not provided')

show_bboxes = LiDARInstance3DBoxes(pred_bboxes, origin=(0.5, 0.5, 0))
img = mmcv.imread(data['img_metas'][0][0]['filename'])

show_multi_modality_result(
img,
None,
show_bboxes,
data['img_metas'][0][0]['lidar2img'],
out_dir,
file_name,
show=False)
elif box_mode == Box3DMode.DEPTH:
if 'calib' not in data.keys():
raise NotImplementedError(
'camera calibration information is not provided')

show_bboxes = DepthInstance3DBoxes(pred_bboxes, origin=(0.5, 0.5, 0))
img = mmcv.imread(data['img_metas'][0][0]['filename'])

show_multi_modality_result(
img,
None,
show_bboxes,
data['calib'][0],
out_dir,
file_name,
depth_bbox=True,
img_metas=data['img_metas'][0][0],
show=False)
Wuziyi616 marked this conversation as resolved.
Show resolved Hide resolved
else:
raise NotImplementedError(
f'visualization of {box_mode} bbox is not supported')

return out_dir, file_name
Loading