Skip to content

Commit

Permalink
[Features]Support mmdet3d (#103)
Browse files Browse the repository at this point in the history
* add mmdet3d code

* add code

* update code

* [log]This commit finish pointpillar export and evaluate on onnxruntime.The model is sample with nvidia repo model

* add tensorrt config

* fix config

* update

* support for tensorrt

* add config

* fix config`

* fix apis about torch2onnx

* update

* mmdet3d deploy version1.0

* map is ok

* fix code

* version1.0

* fix code

* fix visual

* fix bug

* tensorrt support success

* add docstring

* add docs

* fix docs

* fix comments

* fix comment

* fix comment

* fix openvino wrapper

* add unit test

* fix device about cpu

* fix comment

* fix show_result

* fix lint

* fix requirments

* remove ci about det3d

* fix ut

* add ut data

* support for new version pointpillars

* fix comment

* fix support_list

* fix comments

* fix config name
  • Loading branch information
VVsssssk authored and lvhan028 committed Apr 1, 2022
1 parent 5adbfc5 commit dea2410
Show file tree
Hide file tree
Showing 28 changed files with 1,467 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ jobs:
- name: Check docstring coverage
run: |
pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 95 mmdeploy
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 mmdeploy
- name: Check pylint score
run: |
pip install pylint
Expand Down
15 changes: 15 additions & 0 deletions configs/mmdet3d/voxel-detection/voxel-detection_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = ['./voxel-detection_static.py']

onnx_config = dict(
dynamic_axes={
'voxels': {
0: 'voxels_num',
},
'num_points': {
0: 'voxels_num',
},
'coors': {
0: 'voxels_num',
}
},
input_shape=None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./voxel-detection_dynamic.py', '../../_base_/backends/onnxruntime.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/openvino.py']

onnx_config = dict(input_shape=None)

backend_config = dict(model_inputs=[
dict(
opt_shapes=dict(
voxels=[5000, 32, 4], num_points=[5000], coors=[5000, 4]))
])
6 changes: 6 additions & 0 deletions configs/mmdet3d/voxel-detection/voxel-detection_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['../../_base_/onnx_config.py']
codebase_config = dict(
type='mmdet3d', task='VoxelDetection', model_type='end2end')
onnx_config = dict(
input_names=['voxels', 'num_points', 'coors'],
output_names=['scores', 'bbox_preds', 'dir_scores'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
voxels=dict(
min_shape=[2000, 32, 4],
opt_shape=[5000, 32, 4],
max_shape=[9000, 32, 4]),
num_points=dict(
min_shape=[2000], opt_shape=[5000], max_shape=[9000]),
coors=dict(
min_shape=[2000, 4],
opt_shape=[5000, 4],
max_shape=[9000, 4]),
))
])
43 changes: 43 additions & 0 deletions docs/en/codebases/mmdet3d.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
## MMDetection3d Support

MMDetection3d is a next-generation platform for general 3D object detection. It is a part of the [OpenMMLab](https://openmmlab.com/) project.

### MMDetection3d installation tutorial

Please refer to [getting_started.md](https://github.com/open-mmlab/mmdetection3d/blob/master/docs/en/getting_started.md) for installation.

### Example

```bash
python tools/deploy.py \
configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic.py \
${MMDET3D_DIR}/configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py \
checkpoints/point_pillars.pth \
${MMDET3D_DIR}/demo/data/kitti/kitti_000008.bin \
--work-dir \
work_dir \
--show \
--device \
cuda:0
```
### List of MMDetection3d models supported by MMDeploy

| Model | Task | OnnxRuntime | TensorRT | NCNN | PPLNN | OpenVINO | Model config |
| :----------------: | :------------------: | :---------: | :------: | :---: | :---: | :------: | :------------------------------------------------------------------------------------------------------: |
| PointPillars | VoxelDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |

### Reminder

Voxel detection onnx model excludes model.voxelize layer and model post process, and you can use python api to call these func.

Example:

```python
from mmdeploy.codebase.mmdet3d.deploy import VoxelDetectionModel
VoxelDetectionModel.voxelize(...)
VoxelDetectionModel.post_process(...)
```

### FAQs

None
1 change: 1 addition & 0 deletions docs/en/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ The table below lists the models that are guaranteed to be exportable to other b
| HRNet | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#hrnet-cvpr-2019) |
| MSPN | MMPose | N | Y | Y | Y | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#mspn-arxiv-2019) |
| LiteHRNet | MMPose | N | Y | Y | N | N | Y | [config](https://mmpose.readthedocs.io/en/latest/papers/backbones.html#litehrnet-cvpr-2021) |
| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) |

### Note

Expand Down
8 changes: 4 additions & 4 deletions mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Any, Optional, Union
from typing import Any, Optional, Tuple, Union

import mmcv
import torch
Expand All @@ -10,13 +10,13 @@
get_onnx_config, load_config)


def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
def torch2onnx_impl(model: torch.nn.Module, input: Union[torch.Tensor, Tuple],
deploy_cfg: Union[str, mmcv.Config], output_file: str):
"""Converting torch model to ONNX.
Args:
model (torch.nn.Module): Input pytorch model.
input (torch.Tensor): Input tensor used to convert model.
input (torch.Tensor | Tuple): Input tensor used to convert model.
deploy_cfg (str | mmcv.Config): Deployment config file or
Config object.
output_file (str): Output file to save ONNX model.
Expand Down Expand Up @@ -103,7 +103,7 @@ def torch2onnx(img: Any,

torch_model = task_processor.init_pytorch_model(model_checkpoint)
data, model_inputs = task_processor.create_input(img, input_shape)
if not isinstance(model_inputs, torch.Tensor):
if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1:
model_inputs = model_inputs[0]

torch2onnx_impl(
Expand Down
1 change: 0 additions & 1 deletion mmdeploy/apis/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
model = task_processor.init_backend_model(model)

model_inputs, _ = task_processor.create_input(img, input_shape)

with torch.no_grad():
result = task_processor.run_inference(model, model_inputs)[0]

Expand Down
6 changes: 5 additions & 1 deletion mmdeploy/backend/openvino/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ def __init__(self,
self.net = self.ie.read_network(ir_model_file, bin_path)
for input in self.net.input_info.values():
batch_size = input.input_data.shape[0]
assert batch_size == 1, 'Only batch 1 is supported.'
dims = len(input.input_data.shape)
# if input is a image, it has (B,C,H,W) channels,
# need batch_size==1
assert not dims == 4 or batch_size == 1, \
'Only batch 1 is supported.'
self.device = 'cpu'
self.sess = self.ie.load_network(
network=self.net, device_name=self.device.upper(), num_requests=1)
Expand Down
5 changes: 5 additions & 0 deletions mmdeploy/codebase/mmdet3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .deploy import MMDetection3d, VoxelDetection
from .models import * # noqa: F401,F403

__all__ = ['MMDetection3d', 'VoxelDetection']
6 changes: 6 additions & 0 deletions mmdeploy/codebase/mmdet3d/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .mmdetection3d import MMDetection3d
from .voxel_detection import VoxelDetection
from .voxel_detection_model import VoxelDetectionModel

__all__ = ['MMDetection3d', 'VoxelDetection', 'VoxelDetectionModel']
114 changes: 114 additions & 0 deletions mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

import mmcv
from mmcv.utils import Registry
from torch.utils.data import DataLoader, Dataset

from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, get_task_type


def __build_mmdet3d_task(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
device: str, registry: Registry) -> BaseTask:
task = get_task_type(deploy_cfg)
return registry.module_dict[task.value](model_cfg, deploy_cfg, device)


MMDET3D_TASK = Registry('mmdet3d_tasks', build_func=__build_mmdet3d_task)


@CODEBASE.register_module(Codebase.MMDET3D.value)
class MMDetection3d(MMCodebase):

task_registry = MMDET3D_TASK

def __init__(self):
super().__init__()

@staticmethod
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
device: str) -> BaseTask:
"""The interface to build the task processors of mmdet3d.
Args:
model_cfg (str | mmcv.Config): Model config file.
deploy_cfg (str | mmcv.Config): Deployment config file.
device (str): A string specifying device type.
Returns:
BaseTask: A task processor.
"""
return MMDET3D_TASK.build(model_cfg, deploy_cfg, device)

@staticmethod
def build_dataset(dataset_cfg: Union[str, mmcv.Config], *args,
**kwargs) -> Dataset:
"""Build dataset for detection3d.
Args:
dataset_cfg (str | mmcv.Config): The input dataset config.
Returns:
Dataset: A PyTorch dataset.
"""
from mmdet3d.datasets import build_dataset as build_dataset_mmdet3d

from mmdeploy.utils import load_config
dataset_cfg = load_config(dataset_cfg)[0]
data = dataset_cfg.data

dataset = build_dataset_mmdet3d(data.test)
return dataset

@staticmethod
def build_dataloader(dataset: Dataset,
samples_per_gpu: int,
workers_per_gpu: int,
num_gpus: int = 1,
dist: bool = False,
shuffle: bool = False,
seed: Optional[int] = None,
runner_type: str = 'EpochBasedRunner',
persistent_workers: bool = True,
**kwargs) -> DataLoader:
"""Build dataloader for detection3d.
Args:
dataset (Dataset): Input dataset.
samples_per_gpu (int): Number of training samples on each GPU, i.e.
,batch size of each GPU.
workers_per_gpu (int): How many subprocesses to use for data
loading for each GPU.
num_gpus (int): Number of GPUs. Only used in non-distributed
training.
dist (bool): Distributed training/test or not.
Defaults to `False`.
shuffle (bool): Whether to shuffle the data at every epoch.
Defaults to `False`.
seed (int): An integer set to be seed. Default is `None`.
runner_type (str): Type of runner. Default: `EpochBasedRunner`.
persistent_workers (bool): If True, the data loader will not
shutdown the worker processes after a dataset has been consumed
once. This allows to maintain the workers `Dataset` instances
alive. This argument is only valid when PyTorch>=1.7.0.
Default: False.
kwargs: Any other keyword argument to be used to initialize
DataLoader.
Returns:
DataLoader: A PyTorch dataloader.
"""
from mmdet3d.datasets import \
build_dataloader as build_dataloader_mmdet3d
return build_dataloader_mmdet3d(
dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=num_gpus,
dist=dist,
shuffle=shuffle,
seed=seed,
runner_type=runner_type,
persistent_workers=persistent_workers,
**kwargs)
Loading

0 comments on commit dea2410

Please sign in to comment.