Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Features]Support mmdet3d #103

Merged
merged 49 commits into from
Mar 10, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5d76b06
add mmdet3d code
Jan 19, 2022
e10f7b0
add code
Jan 19, 2022
2a015ba
update code
Jan 20, 2022
012643e
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Jan 20, 2022
5d3da0c
[log]This commit finish pointpillar export and evaluate on onnxruntim…
Jan 20, 2022
f104a37
add tensorrt config
Jan 21, 2022
84c6d2a
fix config
Jan 24, 2022
871e901
update
Jan 25, 2022
1227bb8
support for tensorrt
Jan 25, 2022
1197289
add config
Jan 25, 2022
1d262f5
fix config`
Jan 25, 2022
5c9972b
fix apis about torch2onnx
Jan 25, 2022
fc45c4d
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Jan 26, 2022
ef2340b
update
Jan 26, 2022
172b513
mmdet3d deploy version1.0
Feb 15, 2022
95795ac
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Feb 15, 2022
0d4a49c
map is ok
Feb 18, 2022
b121888
fix code
Feb 18, 2022
a2d5a27
version1.0
Feb 22, 2022
f5e9338
fix code
Feb 22, 2022
12d3a7c
fix visual
Feb 22, 2022
08bf6f2
fix bug
Feb 22, 2022
c0eab56
tensorrt support success
Feb 22, 2022
4bae5ca
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Feb 22, 2022
26e5a69
add docstring
Feb 24, 2022
2e8bc54
add docs
Feb 24, 2022
b63c7b7
fix docs
Feb 24, 2022
5332844
fix comments
Feb 25, 2022
a3a909f
fix comment
Feb 28, 2022
e4aaab0
fix comment
Feb 28, 2022
2ac4360
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Feb 28, 2022
43478ba
fix openvino wrapper
Feb 28, 2022
7390383
add unit test
Mar 2, 2022
bac4f90
fix device about cpu
Mar 2, 2022
a321c7c
fix comment
Mar 2, 2022
2219f99
fix show_result
Mar 2, 2022
28a3530
fix lint
Mar 2, 2022
e1297fa
fix requirments
Mar 2, 2022
ea72d65
remove ci about det3d
Mar 2, 2022
40e9e8e
fix ut
Mar 2, 2022
11ff68e
add ut data
Mar 2, 2022
2b84036
support for new version pointpillars
Mar 3, 2022
630f853
fix comment
Mar 4, 2022
f6d727a
Merge branch 'master' of https://github.com/open-mmlab/mmdeploy into …
Mar 7, 2022
e73fb9d
fix support_list
Mar 7, 2022
4f66162
fix comments
Mar 7, 2022
d164378
Merge branch 'dev-v0.4.0' of https://github.com/open-mmlab/mmdeploy i…
Mar 8, 2022
2f7923c
Merge branch 'dev-v0.4.0' of https://github.com/open-mmlab/mmdeploy i…
Mar 9, 2022
f5372a4
fix config name
Mar 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions configs/mmdet3d/voxel-detection/voxel-detection_dynamic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
_base_ = ['./voxel-detection_static.py']

onnx_config = dict(
dynamic_axes={
'voxels': {
0: 'voxels_num',
},
'num_points': {
0: 'voxels_num',
},
'coors': {
0: 'voxels_num',
}
},
input_shape=None)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
_base_ = [
'./voxel-detection_dynamic.py', '../../_base_/backends/onnxruntime.py'
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
_base_ = [
'./voxel-detection_static.py', '../../_base_/backends/onnxruntime.py'
]

onnx_config = dict(input_shape=None)
6 changes: 6 additions & 0 deletions configs/mmdet3d/voxel-detection/voxel-detection_static.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = ['../../_base_/onnx_config.py']
codebase_config = dict(
type='mmdet3d', task='VoxelDetection', model_type='end2end')
onnx_config = dict(
input_names=['voxels', 'num_points', 'coors'],
output_names=['bbox_preds', 'scores', 'dir_scores'])
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
model_inputs=[
dict(
input_shapes=dict(
voxels=dict(
min_shape=[4000, 32, 4],
opt_shape=[4000, 32, 4],
max_shape=[4000, 32, 4]),
num_points=dict(
min_shape=[4000], opt_shape=[4000], max_shape=[4000]),
coors=dict(
min_shape=[4000, 4],
opt_shape=[4000, 4],
max_shape=[4000, 4]),
))
])
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_pplnn_dynamic-512x1024.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
_base_ = ['./segmentation_dynamic.py', '../_base_/backends/pplnn.py']

onnx_config = dict(input_shape=None)
onnx_config = dict(input_shape=[512, 512])
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved

backend_config = dict(model_inputs=dict(opt_shape=[1, 3, 512, 1024]))
6 changes: 3 additions & 3 deletions mmdeploy/apis/pytorch2onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor,
opset=opset_version), torch.no_grad():
torch.onnx.export(
patched_model,
input,
tuple(input),
output_file,
export_params=onnx_cfg['export_params'],
input_names=input_names,
Expand Down Expand Up @@ -86,8 +86,8 @@ def torch2onnx(img: Any,

torch_model = task_processor.init_pytorch_model(model_checkpoint)
data, model_inputs = task_processor.create_input(img, input_shape)
if not isinstance(model_inputs, torch.Tensor):
model_inputs = model_inputs[0]
# if not isinstance(model_inputs, torch.Tensor):
# model_inputs = model_inputs[0]

torch2onnx_impl(
torch_model,
Expand Down
4 changes: 4 additions & 0 deletions mmdeploy/codebase/mmdet3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .deploy import MMDetection3d, VoxelDetection
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
from .models import * # noqa: F401,F403

__all__ = ['MMDetection3d', 'VoxelDetection']
4 changes: 4 additions & 0 deletions mmdeploy/codebase/mmdet3d/deploy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .mmdetection3d import MMDetection3d
RunningLeon marked this conversation as resolved.
Show resolved Hide resolved
from .voxel_detection import VoxelDetection

__all__ = ['MMDetection3d', 'VoxelDetection']
78 changes: 78 additions & 0 deletions mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

import mmcv
import torch
from mmcv.utils import Registry
from torch.utils.data import DataLoader, Dataset

from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase
from mmdeploy.utils import Codebase, get_task_type


def __build_mmdet3d_task(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
device: str, registry: Registry) -> BaseTask:
task = get_task_type(deploy_cfg)
return registry.module_dict[task.value](model_cfg, deploy_cfg, device)


MMDET3D_TASK = Registry('mmdet3d_tasks', build_func=__build_mmdet3d_task)


@CODEBASE.register_module(Codebase.MMDET3D.value)
class MMDetection3d(MMCodebase):
task_registry = MMDET3D_TASK

def __init__(self):
super().__init__()

@staticmethod
def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
device: str) -> BaseTask:
return MMDET3D_TASK.build(model_cfg, deploy_cfg, device)

@staticmethod
def build_dataset(dataset_cfg: Union[str, mmcv.Config], *args,
**kwargs) -> Dataset:

from mmdet3d.datasets import build_dataset as build_dataset_mmdet3d
from mmdeploy.utils import load_config
dataset_cfg = load_config(dataset_cfg)[0]
data = dataset_cfg.data

dataset = build_dataset_mmdet3d(data.test)
return dataset

@staticmethod
def build_dataloader(dataset: Dataset,
samples_per_gpu: int,
workers_per_gpu: int,
num_gpus: int = 1,
dist: bool = False,
shuffle: bool = False,
seed: Optional[int] = None,
drop_last: bool = False,
pin_memory: bool = True,
persistent_workers: bool = True,
**kwargs) -> DataLoader:
from mmdet3d.datasets import build_dataloader \
as build_dataloader_mmdet3d
return build_dataloader_mmdet3d(
dataset,
samples_per_gpu,
workers_per_gpu,
num_gpus=num_gpus,
dist=dist,
shuffle=shuffle,
seed=seed,
**kwargs)

@staticmethod
def single_gpu_test(model: torch.nn.Module,
data_loader: DataLoader,
show: bool = False,
out_dir: Optional[str] = None,
**kwargs):
from mmdet3d.apis import single_gpu_test
outputs = single_gpu_test(model, data_loader, show, out_dir, kwargs)
return outputs
173 changes: 173 additions & 0 deletions mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import mmcv
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset

from mmdeploy.codebase.base import BaseTask
from mmdeploy.codebase.mmdet3d.deploy.mmdetection3d import MMDET3D_TASK
from mmdeploy.utils import Task


def voxelize(points, model_cfg):
from mmdet3d.ops import Voxelization
voxel_layer = model_cfg.model['voxel_layer']
voxel_layer = Voxelization(**voxel_layer)
voxels, coors, num_points = [], [], []
for res in points:
res_voxels, res_coors, res_num_points = voxel_layer(res)
voxels.append(res_voxels)
coors.append(res_coors)
num_points.append(res_num_points)
voxels = torch.cat(voxels, dim=0)
num_points = torch.cat(num_points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
return voxels, num_points, coors_batch


class VoxelDetectionWrap(nn.Module):

def __init__(self, model):
super(VoxelDetectionWrap, self).__init__()
self.model = model

def forward(self, voxels, num_points, coors):
result = self.model(
voxel_input=[voxels, num_points, coors], img_metas=[0])
VVsssssk marked this conversation as resolved.
Show resolved Hide resolved
return result[0], result[1], result[2]


@MMDET3D_TASK.register_module(Task.VOXEL_DETECTION.value)
class VoxelDetection(BaseTask):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add docstring here


def __init__(self, model_cfg: mmcv.Config, deploy_cfg: mmcv.Config,
device: str):
super().__init__(model_cfg, deploy_cfg, device)

def init_backend_model(self,
model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module:
from .voxel_detection_model import build_voxel_detection_model
model = build_voxel_detection_model(
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model

def init_pytorch_model(self,
model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module:
from mmdet3d.apis import init_model
model = init_model(self.model_cfg, model_checkpoint, self.device)
model = VoxelDetectionWrap(model)
return model.eval()

def create_input(self,
pcds: Union[str, np.ndarray],
img_shape=None,
**kwargs) -> Tuple[Dict, torch.Tensor]:

from mmdet3d.datasets.pipelines import Compose
from mmcv.parallel import collate, scatter
from mmdet3d.core.bbox import get_box_type
if not isinstance(pcds, (list, tuple)):
pcds = [pcds]
cfg = self.model_cfg
test_pipeline = Compose(cfg.data.test.pipeline)
box_type_3d, box_mode_3d = get_box_type(cfg.data.test.box_type_3d)
data_list = []
for pcd in pcds:
data = dict(
pts_filename=pcd,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d,
# for ScanNet demo we need axis_align_matrix
ann_info=dict(axis_align_matrix=np.eye(4)),
sweeps=[],
# set timestamp = 0
timestamp=[0],
img_fields=[],
bbox3d_fields=[],
pts_mask_fields=[],
pts_seg_fields=[],
bbox_fields=[],
mask_fields=[],
seg_fields=[])
data = test_pipeline(data)
data_list.append(data)

result = collate(data_list, samples_per_gpu=len(pcds))
result['img_metas'] = [
img_metas.data[0] for img_metas in result['img_metas']
]
result['points'] = [point.data[0] for point in result['points']]
if self.device != 'cpu':
result = scatter(result, [self.device])[0]
else:
result['img_metas'] = result['img_metas'][0]
result['points'] = result['points'][0]
voxels_batch, num_points_batch, coors_batch = [], [], []
for point in result['points'][0]:
voxels, num_points, coors = voxelize([point], self.model_cfg)
voxels_batch.append(voxels)
num_points_batch.append(num_points)
coors_batch.append(coors)
result['voxels'] = voxels_batch
result['num_points'] = num_points_batch
result['coors'] = coors_batch
return result, [voxels_batch[0], num_points_batch[0], coors_batch[0]]

def visualize(self,
model: torch.nn.Module,
image: Union[str, np.ndarray],
result: list,
output_file: str,
window_name: str = '',
show_result: bool = False,
**kwargs):
print(result)

@staticmethod
def run_inference(model, model_inputs: Dict[str, torch.Tensor]):
batch_size = len(model_inputs['voxels'])
result = []
for i in range(batch_size):
voxels = model_inputs['voxels'][i]
num_points = model_inputs['num_points'][i]
coors = model_inputs['coors'][i]
result.append(model(voxels, num_points, coors))
return result

def get_tensor_from_input(self, input_data: Dict[str, Any],
**kwargs) -> torch.Tensor:
pass

def evaluate_outputs(model_cfg,
outputs: Sequence,
dataset: Dataset,
metrics: Optional[str] = None,
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
**kwargs):
pass

def get_model_name(self) -> str:
assert 'type' in self.model_cfg.model, 'model config contains no type'
name = self.model_cfg.model.type.lower()
return name

def get_partition_cfg(partition_type: str, **kwargs) -> Dict:
pass

def get_postprocess(self) -> Dict:
pass

def get_preprocess(self) -> Dict:
pass
Loading