diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e45e0c9520..f98d172a63 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -30,7 +30,7 @@ jobs: - name: Check docstring coverage run: | pip install interrogate - interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 95 mmdeploy + interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 mmdeploy - name: Check pylint score run: | pip install pylint diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_dynamic.py b/configs/mmdet3d/voxel-detection/voxel-detection_dynamic.py new file mode 100644 index 0000000000..1a2402e03a --- /dev/null +++ b/configs/mmdet3d/voxel-detection/voxel-detection_dynamic.py @@ -0,0 +1,15 @@ +_base_ = ['./voxel-detection_static.py'] + +onnx_config = dict( + dynamic_axes={ + 'voxels': { + 0: 'voxels_num', + }, + 'num_points': { + 0: 'voxels_num', + }, + 'coors': { + 0: 'voxels_num', + } + }, + input_shape=None) diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py b/configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py new file mode 100644 index 0000000000..705d2c32e7 --- /dev/null +++ b/configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py @@ -0,0 +1,3 @@ +_base_ = [ + './voxel-detection_dynamic.py', '../../_base_/backends/onnxruntime.py' +] diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_openvino_dynamic.py b/configs/mmdet3d/voxel-detection/voxel-detection_openvino_dynamic.py new file mode 100644 index 0000000000..2cfc965763 --- /dev/null +++ b/configs/mmdet3d/voxel-detection/voxel-detection_openvino_dynamic.py @@ -0,0 +1,9 @@ +_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/openvino.py'] + +onnx_config = dict(input_shape=None) + +backend_config = dict(model_inputs=[ + dict( + opt_shapes=dict( + voxels=[5000, 32, 4], num_points=[5000], coors=[5000, 4])) +]) diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_static.py b/configs/mmdet3d/voxel-detection/voxel-detection_static.py new file mode 100644 index 0000000000..406c16513d --- /dev/null +++ b/configs/mmdet3d/voxel-detection/voxel-detection_static.py @@ -0,0 +1,6 @@ +_base_ = ['../../_base_/onnx_config.py'] +codebase_config = dict( + type='mmdet3d', task='VoxelDetection', model_type='end2end') +onnx_config = dict( + input_names=['voxels', 'num_points', 'coors'], + output_names=['scores', 'bbox_preds', 'dir_scores']) diff --git a/configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic-kitti.py b/configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic-kitti.py new file mode 100644 index 0000000000..4286e12c40 --- /dev/null +++ b/configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic-kitti.py @@ -0,0 +1,18 @@ +_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/tensorrt.py'] +backend_config = dict( + common_config=dict(max_workspace_size=1 << 30), + model_inputs=[ + dict( + input_shapes=dict( + voxels=dict( + min_shape=[2000, 32, 4], + opt_shape=[5000, 32, 4], + max_shape=[9000, 32, 4]), + num_points=dict( + min_shape=[2000], opt_shape=[5000], max_shape=[9000]), + coors=dict( + min_shape=[2000, 4], + opt_shape=[5000, 4], + max_shape=[9000, 4]), + )) + ]) diff --git a/docs/en/codebases/mmdet3d.md b/docs/en/codebases/mmdet3d.md new file mode 100644 index 0000000000..fdf1d4f5bc --- /dev/null +++ b/docs/en/codebases/mmdet3d.md @@ -0,0 +1,43 @@ +## MMDetection3d Support + +MMDetection3d is a next-generation platform for general 3D object detection. It is a part of the [OpenMMLab](https://openmmlab.com/) project. + +### MMDetection3d installation tutorial + +Please refer to [getting_started.md](https://github.com/open-mmlab/mmdetection3d/blob/master/docs/en/getting_started.md) for installation. + +### Example + +```bash +python tools/deploy.py \ + configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic.py \ + ${MMDET3D_DIR}/configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py \ + checkpoints/point_pillars.pth \ + ${MMDET3D_DIR}/demo/data/kitti/kitti_000008.bin \ + --work-dir \ + work_dir \ + --show \ + --device \ + cuda:0 +``` +### List of MMDetection3d models supported by MMDeploy + +| Model | Task | OnnxRuntime | TensorRT | NCNN | PPLNN | OpenVINO | Model config | +| :----------------: | :------------------: | :---------: | :------: | :---: | :---: | :------: | :------------------------------------------------------------------------------------------------------: | +| PointPillars | VoxelDetection | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | + +### Reminder + +Voxel detection onnx model excludes model.voxelize layer and model post process, and you can use python api to call these func. + +Example: + +```python +from mmdeploy.codebase.mmdet3d.deploy import VoxelDetectionModel +VoxelDetectionModel.voxelize(...) +VoxelDetectionModel.post_process(...) +``` + +### FAQs + +None diff --git a/docs/en/supported_models.md b/docs/en/supported_models.md index fa7cf4f4ea..edf51c6591 100644 --- a/docs/en/supported_models.md +++ b/docs/en/supported_models.md @@ -61,6 +61,8 @@ The table below lists the models that are guaranteed to be exportable to other b | DBNet | MMOCR | Y | Y | Y | Y | Y | Y | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textdet/dbnet) | | CRNN | MMOCR | Y | Y | Y | Y | Y | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/crnn) | | SAR | MMOCR | N | Y | N | N | N | N | [config](https://github.com/open-mmlab/mmocr/tree/main/configs/textrecog/sar) | +| PointPillars | MMDetection3d | ? | Y | Y | N | N | Y | [config](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/pointpillars) | + ### Note diff --git a/mmdeploy/apis/pytorch2onnx.py b/mmdeploy/apis/pytorch2onnx.py index 35683f4e16..f627c9a346 100644 --- a/mmdeploy/apis/pytorch2onnx.py +++ b/mmdeploy/apis/pytorch2onnx.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp -from typing import Any, Optional, Union +from typing import Any, Optional, Tuple, Union import mmcv import torch @@ -10,13 +10,13 @@ get_onnx_config, load_config) -def torch2onnx_impl(model: torch.nn.Module, input: torch.Tensor, +def torch2onnx_impl(model: torch.nn.Module, input: Union[torch.Tensor, Tuple], deploy_cfg: Union[str, mmcv.Config], output_file: str): """Converting torch model to ONNX. Args: model (torch.nn.Module): Input pytorch model. - input (torch.Tensor): Input tensor used to convert model. + input (torch.Tensor | Tuple): Input tensor used to convert model. deploy_cfg (str | mmcv.Config): Deployment config file or Config object. output_file (str): Output file to save ONNX model. @@ -86,7 +86,7 @@ def torch2onnx(img: Any, torch_model = task_processor.init_pytorch_model(model_checkpoint) data, model_inputs = task_processor.create_input(img, input_shape) - if not isinstance(model_inputs, torch.Tensor): + if not isinstance(model_inputs, torch.Tensor) and len(model_inputs) == 1: model_inputs = model_inputs[0] torch2onnx_impl( diff --git a/mmdeploy/apis/visualize.py b/mmdeploy/apis/visualize.py index 3388397e01..e33c1892a8 100644 --- a/mmdeploy/apis/visualize.py +++ b/mmdeploy/apis/visualize.py @@ -55,7 +55,6 @@ def visualize_model(model_cfg: Union[str, mmcv.Config], model = task_processor.init_backend_model(model) model_inputs, _ = task_processor.create_input(img, input_shape) - with torch.no_grad(): result = task_processor.run_inference(model, model_inputs)[0] diff --git a/mmdeploy/backend/openvino/wrapper.py b/mmdeploy/backend/openvino/wrapper.py index 589906f345..7a41db24ad 100644 --- a/mmdeploy/backend/openvino/wrapper.py +++ b/mmdeploy/backend/openvino/wrapper.py @@ -42,7 +42,11 @@ def __init__(self, self.net = self.ie.read_network(ir_model_file, bin_path) for input in self.net.input_info.values(): batch_size = input.input_data.shape[0] - assert batch_size == 1, 'Only batch 1 is supported.' + dims = len(input.input_data.shape) + # if input is a image, it has (B,C,H,W) channels, + # need batch_size==1 + assert not dims == 4 or batch_size == 1, \ + 'Only batch 1 is supported.' self.device = 'cpu' self.sess = self.ie.load_network( network=self.net, device_name=self.device.upper(), num_requests=1) diff --git a/mmdeploy/codebase/mmdet3d/__init__.py b/mmdeploy/codebase/mmdet3d/__init__.py new file mode 100644 index 0000000000..1974ef569c --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .deploy import MMDetection3d, VoxelDetection +from .models import * # noqa: F401,F403 + +__all__ = ['MMDetection3d', 'VoxelDetection'] diff --git a/mmdeploy/codebase/mmdet3d/deploy/__init__.py b/mmdeploy/codebase/mmdet3d/deploy/__init__.py new file mode 100644 index 0000000000..60ef615aca --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/deploy/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .mmdetection3d import MMDetection3d +from .voxel_detection import VoxelDetection +from .voxel_detection_model import VoxelDetectionModel + +__all__ = ['MMDetection3d', 'VoxelDetection', 'VoxelDetectionModel'] diff --git a/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py b/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py new file mode 100644 index 0000000000..01f9fbf28e --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/deploy/mmdetection3d.py @@ -0,0 +1,114 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional, Union + +import mmcv +from mmcv.utils import Registry +from torch.utils.data import DataLoader, Dataset + +from mmdeploy.codebase.base import CODEBASE, BaseTask, MMCodebase +from mmdeploy.utils import Codebase, get_task_type + + +def __build_mmdet3d_task(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, + device: str, registry: Registry) -> BaseTask: + task = get_task_type(deploy_cfg) + return registry.module_dict[task.value](model_cfg, deploy_cfg, device) + + +MMDET3D_TASK = Registry('mmdet3d_tasks', build_func=__build_mmdet3d_task) + + +@CODEBASE.register_module(Codebase.MMDET3D.value) +class MMDetection3d(MMCodebase): + + task_registry = MMDET3D_TASK + + def __init__(self): + super().__init__() + + @staticmethod + def build_task_processor(model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, + device: str) -> BaseTask: + """The interface to build the task processors of mmdet3d. + + Args: + model_cfg (str | mmcv.Config): Model config file. + deploy_cfg (str | mmcv.Config): Deployment config file. + device (str): A string specifying device type. + + Returns: + BaseTask: A task processor. + """ + return MMDET3D_TASK.build(model_cfg, deploy_cfg, device) + + @staticmethod + def build_dataset(dataset_cfg: Union[str, mmcv.Config], *args, + **kwargs) -> Dataset: + """Build dataset for detection3d. + + Args: + dataset_cfg (str | mmcv.Config): The input dataset config. + + Returns: + Dataset: A PyTorch dataset. + """ + from mmdet3d.datasets import build_dataset as build_dataset_mmdet3d + + from mmdeploy.utils import load_config + dataset_cfg = load_config(dataset_cfg)[0] + data = dataset_cfg.data + + dataset = build_dataset_mmdet3d(data.test) + return dataset + + @staticmethod + def build_dataloader(dataset: Dataset, + samples_per_gpu: int, + workers_per_gpu: int, + num_gpus: int = 1, + dist: bool = False, + shuffle: bool = False, + seed: Optional[int] = None, + runner_type: str = 'EpochBasedRunner', + persistent_workers: bool = True, + **kwargs) -> DataLoader: + """Build dataloader for detection3d. + + Args: + dataset (Dataset): Input dataset. + samples_per_gpu (int): Number of training samples on each GPU, i.e. + ,batch size of each GPU. + workers_per_gpu (int): How many subprocesses to use for data + loading for each GPU. + num_gpus (int): Number of GPUs. Only used in non-distributed + training. + dist (bool): Distributed training/test or not. + Defaults to `False`. + shuffle (bool): Whether to shuffle the data at every epoch. + Defaults to `False`. + seed (int): An integer set to be seed. Default is `None`. + runner_type (str): Type of runner. Default: `EpochBasedRunner`. + persistent_workers (bool): If True, the data loader will not + shutdown the worker processes after a dataset has been consumed + once. This allows to maintain the workers `Dataset` instances + alive. This argument is only valid when PyTorch>=1.7.0. + Default: False. + kwargs: Any other keyword argument to be used to initialize + DataLoader. + + Returns: + DataLoader: A PyTorch dataloader. + """ + from mmdet3d.datasets import \ + build_dataloader as build_dataloader_mmdet3d + return build_dataloader_mmdet3d( + dataset, + samples_per_gpu, + workers_per_gpu, + num_gpus=num_gpus, + dist=dist, + shuffle=shuffle, + seed=seed, + runner_type=runner_type, + persistent_workers=persistent_workers, + **kwargs) diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py new file mode 100644 index 0000000000..63eb87b7ab --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection.py @@ -0,0 +1,301 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import mmcv +import numpy as np +import torch +import torch.nn as nn +from mmcv.parallel import collate, scatter +from mmdet3d.core.bbox import get_box_type +from mmdet3d.datasets.pipelines import Compose +from torch.utils.data import DataLoader, Dataset + +from mmdeploy.codebase.base import BaseTask +from mmdeploy.codebase.mmdet3d.deploy.mmdetection3d import MMDET3D_TASK +from mmdeploy.utils import Task, get_root_logger, load_config +from .voxel_detection_model import VoxelDetectionModel + + +@MMDET3D_TASK.register_module(Task.VOXEL_DETECTION.value) +class VoxelDetection(BaseTask): + + def __init__(self, model_cfg: mmcv.Config, deploy_cfg: mmcv.Config, + device: str): + super().__init__(model_cfg, deploy_cfg, device) + + def init_backend_model(self, + model_files: Sequence[str] = None, + **kwargs) -> torch.nn.Module: + """Initialize backend model. + + Args: + model_files (Sequence[str]): Input model files. + + Returns: + nn.Module: An initialized backend model. + """ + from .voxel_detection_model import build_voxel_detection_model + model = build_voxel_detection_model( + model_files, self.model_cfg, self.deploy_cfg, device=self.device) + return model + + def init_pytorch_model(self, + model_checkpoint: Optional[str] = None, + cfg_options: Optional[Dict] = None, + **kwargs) -> torch.nn.Module: + """Initialize torch model. + + Args: + model_checkpoint (str): The checkpoint file of torch model, + defaults to `None`. + cfg_options (dict): Optional config key-pair parameters. + Returns: + nn.Module: An initialized torch model generated by other OpenMMLab + codebases. + """ + from mmdet3d.apis import init_model + device = self.device + model = init_model(self.model_cfg, model_checkpoint, device) + return model.eval() + + def create_input(self, pcd: str, *args) -> Tuple[Dict, torch.Tensor]: + """Create input for detector. + + Args: + pcd (str): Input pcd file path. + + Returns: + tuple: (data, input), meta information for the input pcd + and model input. + """ + data = VoxelDetection.read_pcd_file(pcd, self.model_cfg, self.device) + voxels, num_points, coors = VoxelDetectionModel.voxelize( + self.model_cfg, data['points'][0]) + return data, (voxels, num_points, coors) + + def visualize(self, + model: torch.nn.Module, + image: str, + result: list, + output_file: str, + window_name: str, + show_result: bool = False, + score_thr: float = 0.3): + """Visualize predictions of a model. + + Args: + model (nn.Module): Input model. + image (str): Pcd file to draw predictions on. + result (list): A list of predictions. + output_file (str): Output file to save result. + window_name (str): The name of visualization window. Defaults to + an empty string. + show_result (bool): Whether to show result in windows, defaults + to `False`. + score_thr (float): The score threshold to display the bbox. + Defaults to 0.3. + """ + from mmdet3d.apis import show_result_meshlab + data = VoxelDetection.read_pcd_file(image, self.model_cfg, self.device) + show_result_meshlab( + data, + result, + output_file, + score_thr, + show=show_result, + snapshot=1 - show_result, + task='det') + + @staticmethod + def read_pcd_file(pcd: str, model_cfg: Union[str, mmcv.Config], + device: str) -> Dict: + """Read data from pcd file and run test pipeline. + + Args: + pcd (str): Pcd file path. + model_cfg (str | mmcv.Config): The model config. + device (str): A string specifying device type. + + Returns: + dict: meta information for the input pcd. + """ + if isinstance(pcd, (list, tuple)): + pcd = pcd[0] + model_cfg = load_config(model_cfg)[0] + test_pipeline = Compose(model_cfg.data.test.pipeline) + box_type_3d, box_mode_3d = get_box_type( + model_cfg.data.test.box_type_3d) + data = dict( + pts_filename=pcd, + box_type_3d=box_type_3d, + box_mode_3d=box_mode_3d, + # for ScanNet demo we need axis_align_matrix + ann_info=dict(axis_align_matrix=np.eye(4)), + sweeps=[], + # set timestamp = 0 + timestamp=[0], + img_fields=[], + bbox3d_fields=[], + pts_mask_fields=[], + pts_seg_fields=[], + bbox_fields=[], + mask_fields=[], + seg_fields=[]) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + data['img_metas'] = [ + img_metas.data[0] for img_metas in data['img_metas'] + ] + data['points'] = [point.data[0] for point in data['points']] + if device != 'cpu': + data = scatter(data, [device])[0] + return data + + @staticmethod + def run_inference(model: nn.Module, + model_inputs: Dict[str, torch.Tensor]) -> List: + """Run inference once for a object detection model of mmdet3d. + + Args: + model (nn.Module): Input model. + model_inputs (dict): A dict containing model inputs tensor and + meta info. + + Returns: + list: The predictions of model inference. + """ + result = model( + return_loss=False, + points=model_inputs['points'], + img_metas=model_inputs['img_metas']) + return [result] + + @staticmethod + def evaluate_outputs(model_cfg, + outputs: Sequence, + dataset: Dataset, + metrics: Optional[str] = None, + out: Optional[str] = None, + metric_options: Optional[dict] = None, + format_only: bool = False, + log_file: Optional[str] = None): + if out: + logger = get_root_logger() + logger.info(f'\nwriting results to {out}') + mmcv.dump(outputs, out) + kwargs = {} if metric_options is None else metric_options + if format_only: + dataset.format_results(outputs, **kwargs) + if metrics: + eval_kwargs = model_cfg.get('evaluation', {}).copy() + # hard-code way to remove EvalHook args + for key in [ + 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', + 'rule' + ]: + eval_kwargs.pop(key, None) + eval_kwargs.pop(key, None) + eval_kwargs.update(dict(metric=metrics, **kwargs)) + dataset.evaluate(outputs, **eval_kwargs) + + def get_model_name(self) -> str: + """Get the model name. + + Return: + str: the name of the model. + """ + raise NotImplementedError + + def get_tensor_from_input(self, input_data: Dict[str, Any], + **kwargs) -> torch.Tensor: + """Get input tensor from input data. + + Args: + input_data (dict): Input data containing meta info and image + tensor. + Returns: + torch.Tensor: An image in `Tensor`. + """ + raise NotImplementedError + + def get_partition_cfg(partition_type: str, **kwargs) -> Dict: + """Get a certain partition config for mmdet. + + Args: + partition_type (str): A string specifying partition type. + + Returns: + dict: A dictionary of partition config. + """ + raise NotImplementedError + + def get_postprocess(self) -> Dict: + """Get the postprocess information for SDK. + + Return: + dict: Composed of the postprocess information. + """ + raise NotImplementedError + + def get_preprocess(self) -> Dict: + """Get the preprocess information for SDK. + + Return: + dict: Composed of the preprocess information. + """ + raise NotImplementedError + + def single_gpu_test(self, + model: nn.Module, + data_loader: DataLoader, + show: bool = False, + out_dir: Optional[str] = None, + **kwargs) -> List: + """Run test with single gpu. + + Args: + model (nn.Module): Input model from nn.Module. + data_loader (DataLoader): PyTorch data loader. + show (bool): Specifying whether to show plotted results. Defaults + to `False`. + out_dir (str): A directory to save results, defaults to `None`. + + Returns: + list: The prediction results. + """ + model.eval() + results = [] + dataset = data_loader.dataset + + prog_bar = mmcv.ProgressBar(len(dataset)) + for i, data in enumerate(data_loader): + with torch.no_grad(): + result = model(data['points'][0].data, + data['img_metas'][0].data, False) + if show: + # Visualize the results of MMDetection3D model + # 'show_results' is MMdetection3D visualization API + if out_dir is None: + model.module.show_result( + data, + result, + out_dir='', + file_name='', + show=show, + snapshot=False, + score_thr=0.3) + else: + model.module.show_result( + data, + result, + out_dir=out_dir, + file_name=f'model_output{i}', + show=show, + snapshot=True, + score_thr=0.3) + results.extend(result) + + batch_size = len(result) + for _ in range(batch_size): + prog_bar.update() + return results diff --git a/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py new file mode 100644 index 0000000000..f33b8b60c8 --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py @@ -0,0 +1,248 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Sequence, Union + +import mmcv +import torch +from mmcv.utils import Registry +from torch.nn import functional as F + +from mmdeploy.codebase.base import BaseBackendModel +from mmdeploy.utils import (Backend, get_backend, get_codebase_config, + get_root_logger, load_config) + + +def __build_backend_voxel_model(cls_name: str, registry: Registry, *args, + **kwargs): + return registry.module_dict[cls_name](*args, **kwargs) + + +__BACKEND_MODEL = mmcv.utils.Registry( + 'backend_voxel_detectors', build_func=__build_backend_voxel_model) + + +@__BACKEND_MODEL.register_module('end2end') +class VoxelDetectionModel(BaseBackendModel): + """End to end model for inference of 3d voxel detection. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files + (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + device (str): A string specifying device type. + model_cfg (str | mmcv.Config): The model config. + deploy_cfg (str|mmcv.Config): Deployment config file or loaded Config + object. + """ + + def __init__(self, + backend: Backend, + backend_files: Sequence[str], + device: str, + model_cfg: mmcv.Config, + deploy_cfg: Union[str, mmcv.Config] = None): + super().__init__(deploy_cfg=deploy_cfg) + self.deploy_cfg = deploy_cfg + self.model_cfg = model_cfg + self.device = device + self._init_wrapper( + backend=backend, backend_files=backend_files, device=device) + + def _init_wrapper(self, backend: Backend, backend_files: Sequence[str], + device: str): + """Initialize backend wrapper. + + Args: + backend (Backend): The backend enum, specifying backend type. + backend_files (Sequence[str]): Paths to all required backend files + (e.g. '.onnx' for ONNX Runtime, '.param' and '.bin' for ncnn). + device (str): A string specifying device type. + """ + output_names = self.output_names + self.wrapper = BaseBackendModel._build_wrapper( + backend=backend, + backend_files=backend_files, + device=device, + output_names=output_names, + deploy_cfg=self.deploy_cfg) + + def forward(self, + points: Sequence[torch.Tensor], + img_metas: Sequence[dict], + return_loss=False): + """Run forward inference. + + Args: + points (Sequence[torch.Tensor]): A list contains input pcd(s) + in [N, ndim] float tensor. points[:, :3] contain xyz points + and points[:, 3:] contain other information like reflectivity + img_metas (Sequence[dict]): A list of meta info for image(s). + return_loss (Bool): Consistent with the pytorch model. + Default = False. + + Returns: + list: A list contains predictions. + """ + result_list = [] + for i in range(len(img_metas)): + voxels, num_points, coors = VoxelDetectionModel.voxelize( + self.model_cfg, points[i]) + input_dict = { + 'voxels': voxels, + 'num_points': num_points, + 'coors': coors + } + outputs = self.wrapper(input_dict) + result = VoxelDetectionModel.post_process(self.model_cfg, outputs, + img_metas[i], + self.device)[0] + result_list.append(result) + return result_list + + def show_result(self, + data: Dict, + result: List, + out_dir: str, + file_name: str, + show=False, + snapshot=False, + **kwargs): + from mmcv.parallel import DataContainer as DC + from mmdet3d.core import show_result + if isinstance(data['points'][0], DC): + points = data['points'][0]._data[0][0].numpy() + elif mmcv.is_list_of(data['points'][0], torch.Tensor): + points = data['points'][0][0] + else: + ValueError(f"Unsupported data type {type(data['points'][0])} " + f'for visualization!') + pred_bboxes = result[0]['boxes_3d'] + pred_labels = result[0]['labels_3d'] + pred_bboxes = pred_bboxes.tensor.cpu().numpy() + show_result( + points, + None, + pred_bboxes, + out_dir, + file_name, + show=show, + snapshot=snapshot, + pred_labels=pred_labels) + + @staticmethod + def voxelize(model_cfg: Union[str, mmcv.Config], points: torch.Tensor): + """convert kitti points(N, >=3) to voxels. + + Args: + model_cfg (str | mmcv.Config): The model config. + points (torch.Tensor): [N, ndim] float tensor. points[:, :3] + contain xyz points and points[:, 3:] contain other information + like reflectivity. + + Returns: + voxels: [M, max_points, ndim] float tensor. only contain points + and returned when max_points != -1. + coordinates: [M, 3] int32 tensor, always returned. + num_points_per_voxel: [M] int32 tensor. Only returned when + max_points != -1. + """ + from mmcv.ops import Voxelization + model_cfg = load_config(model_cfg)[0] + if 'voxel_layer' in model_cfg.model.keys(): + voxel_layer = model_cfg.model['voxel_layer'] + elif 'pts_voxel_layer' in model_cfg.model.keys(): + voxel_layer = model_cfg.model['pts_voxel_layer'] + else: + raise + voxel_layer = Voxelization(**voxel_layer) + voxels, coors, num_points = [], [], [] + for res in points: + res_voxels, res_coors, res_num_points = voxel_layer(res) + voxels.append(res_voxels) + coors.append(res_coors) + num_points.append(res_num_points) + voxels = torch.cat(voxels, dim=0) + num_points = torch.cat(num_points, dim=0) + coors_batch = [] + for i, coor in enumerate(coors): + coor_pad = F.pad(coor, (1, 0), mode='constant', value=i) + coors_batch.append(coor_pad) + coors_batch = torch.cat(coors_batch, dim=0) + return voxels, num_points, coors_batch + + @staticmethod + def post_process(model_cfg: Union[str, mmcv.Config], + outs: torch.Tensor, + img_metas: Dict, + device: str, + rescale=False): + """model post process. + + Args: + model_cfg (str | mmcv.Config): The model config. + outs (torch.Tensor): Output of model's head. + img_metas(Dict): Meta info for pcd. + device (str): A string specifying device type. + rescale (list[torch.Tensor]): whether th rescale bbox. + Returns: + list: A list contains predictions, include bboxes, scores, labels. + """ + from mmdet3d.core import bbox3d2result + from mmdet3d.models.builder import build_head + model_cfg = load_config(model_cfg)[0] + head_cfg = dict(**model_cfg.model['bbox_head']) + head_cfg['train_cfg'] = None + head_cfg['test_cfg'] = model_cfg.model['test_cfg'] + head = build_head(head_cfg) + if device == 'cpu': + logger = get_root_logger() + logger.warning( + 'Don\'t suggest using CPU device. Post process can\'t support.' + ) + if torch.cuda.is_available(): + device = 'cuda' + else: + raise NotImplementedError( + 'Post process don\'t support device=cpu') + cls_scores = [outs['scores'].to(device)] + bbox_preds = [outs['bbox_preds'].to(device)] + dir_scores = [outs['dir_scores'].to(device)] + bbox_list = head.get_bboxes( + cls_scores, bbox_preds, dir_scores, img_metas, rescale=False) + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + +def build_voxel_detection_model(model_files: Sequence[str], + model_cfg: Union[str, mmcv.Config], + deploy_cfg: Union[str, + mmcv.Config], device: str): + """Build 3d voxel object detection model for different backends. + + Args: + model_files (Sequence[str]): Input model file(s). + model_cfg (str | mmcv.Config): Input model config file or Config + object. + deploy_cfg (str | mmcv.Config): Input deployment config file or + Config object. + device (str): Device to input model + + Returns: + VoxelDetectionModel: Detector for a configured backend. + """ + deploy_cfg, model_cfg = load_config(deploy_cfg, model_cfg) + + backend = get_backend(deploy_cfg) + model_type = get_codebase_config(deploy_cfg).get('model_type', 'end2end') + + backend_detector = __BACKEND_MODEL.build( + model_type, + backend=backend, + backend_files=model_files, + device=device, + model_cfg=model_cfg, + deploy_cfg=deploy_cfg) + + return backend_detector diff --git a/mmdeploy/codebase/mmdet3d/models/__init__.py b/mmdeploy/codebase/mmdet3d/models/__init__.py new file mode 100644 index 0000000000..47d3e08b57 --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/models/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .base import * # noqa: F401,F403 +from .pillar_encode import * # noqa: F401,F403 +from .pillar_scatter import * # noqa: F401,F403 +from .voxelnet import * # noqa: F401,F403 diff --git a/mmdeploy/codebase/mmdet3d/models/base.py b/mmdeploy/codebase/mmdet3d/models/base.py new file mode 100644 index 0000000000..4a1226e2b4 --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/models/base.py @@ -0,0 +1,23 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet3d.models.detectors.base.Base3DDetector.forward_test') +def base3ddetector__forward_test(ctx, + self, + voxels, + num_points, + coors, + img_metas=None, + img=None, + rescale=True): + """Rewrite this function to run simple_test directly.""" + return self.simple_test(voxels, num_points, coors, img_metas, img) + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet3d.models.detectors.base.Base3DDetector.forward') +def base3ddetector__forward(ctx, self, *args): + """Rewrite this function to run the model directly.""" + return self.forward_test(*args) diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_encode.py b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py new file mode 100644 index 0000000000..71a30647b7 --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/models/pillar_encode.py @@ -0,0 +1,65 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmdet3d.models.voxel_encoders.utils import get_paddings_indicator + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet3d.models.voxel_encoders.pillar_encoder.PillarFeatureNet.forward') +def pillar_encoder__forward(ctx, self, features, num_points, coors): + """Rewrite this func to optimize node. Modify the code at + _with_voxel_center and use slice instead of the original operation. + + Args: + features (torch.Tensor): Point features or raw points in shape + (N, M, C). + num_points (torch.Tensor): Number of points in each pillar. + coors (torch.Tensor): Coordinates of each voxel. + + Returns: + torch.Tensor: Features of pillars. + """ + features_ls = [features] + # Find distance of x, y, and z from cluster center + if self._with_cluster_center: + points_mean = features[:, :, :3].sum( + dim=1, keepdim=True) / num_points.type_as(features).view(-1, 1, 1) + f_cluster = features[:, :, :3] - points_mean + features_ls.append(f_cluster) + + # Find distance of x, y, and z from pillar center + device = features.device + if self._with_voxel_center: + if not self.legacy: + f_center = features[..., :3] - ( + coors * torch.tensor([1, self.vz, self.vy, self.vx]).to(device) + + + torch.tensor([1, self.z_offset, self.y_offset, self.x_offset + ]).to(device)).unsqueeze(1).flip(2)[..., :3] + else: + f_center = features[..., :3] - ( + coors * torch.tensor([1, self.vz, self.vy, self.vx]).to(device) + + + torch.tensor([1, self.z_offset, self.y_offset, self.x_offset + ]).to(device)).unsqueeze(1).flip(2)[..., :3] + features_ls[0] = torch.cat((f_center, features[..., 3:]), dim=-1) + features_ls.append(f_center) + + if self._with_distance: + points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True) + features_ls.append(points_dist) + + # Combine together feature decorations + features = torch.cat(features_ls, dim=-1) + # The feature decorations were calculated without regard to whether + # pillar was empty. Need to ensure that + # empty pillars remain set to zeros. + voxel_count = features.shape[1] + mask = get_paddings_indicator(num_points, voxel_count, axis=0) + mask = torch.unsqueeze(mask, -1).type_as(features) + features *= mask + for pfn in self.pfn_layers: + features = pfn(features, num_points) + + return features.squeeze(1) diff --git a/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py new file mode 100644 index 0000000000..7056e3d481 --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/models/pillar_scatter.py @@ -0,0 +1,36 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet3d.models.middle_encoders.pillar_scatter.' + 'PointPillarsScatter.forward_batch', ) +def pointpillarsscatter__forward(ctx, + self, + voxel_features, + coors, + batch_size=1): + """Scatter features of single sample. + + Args: + voxel_features (torch.Tensor): Voxel features from voxel encoder layer. + coors (torch.Tensor): Coordinates of each voxel. + The first column indicates the sample ID. + batch_size (int): Number of samples in the current batch. + """ + canvas = torch.zeros( + self.in_channels, + self.nx * self.ny, + dtype=voxel_features.dtype, + device=voxel_features.device) + + indices = coors[:, 2] * self.nx + coors[:, 3] + indices = indices.long() + voxels = voxel_features.t() + # Now scatter the blob back to the canvas. + canvas[:, indices] = voxels + # Undo the column stacking to final 4-dim tensor + canvas = canvas.view(1, self.in_channels, self.ny, self.nx) + return canvas diff --git a/mmdeploy/codebase/mmdet3d/models/voxelnet.py b/mmdeploy/codebase/mmdet3d/models/voxelnet.py new file mode 100644 index 0000000000..6a47d20cae --- /dev/null +++ b/mmdeploy/codebase/mmdet3d/models/voxelnet.py @@ -0,0 +1,58 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdeploy.core import FUNCTION_REWRITER + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet3d.models.detectors.voxelnet.VoxelNet.simple_test') +def voxelnet__simple_test(ctx, + self, + voxels, + num_points, + coors, + img_metas=None, + imgs=None, + rescale=False): + """Test function without augmentaiton. Rewrite this func to remove model + post process. + + Args: + voxels(torch.Tensor): Point features or raw points in shape (N, M, C). + num_points (torch.Tensor): Number of points in each pillar. + coors (torch.Tensor): Coordinates of each voxel. + input_metas (list[dict]): Contain pcd meta info. + + Returns: + List: Result of model. + """ + x = self.extract_feat(voxels, num_points, coors, img_metas) + bbox_preds, scores, dir_scores = self.bbox_head(x) + return bbox_preds, scores, dir_scores + + +@FUNCTION_REWRITER.register_rewriter( + 'mmdet3d.models.detectors.voxelnet.VoxelNet.extract_feat') +def voxelnet__extract_feat(ctx, + self, + voxels, + num_points, + coors, + img_metas=None): + """Extract features from points. Rewrite this func to remove voxelize op. + + Args: + voxels(torch.Tensor): Point features or raw points in shape (N, M, C). + num_points (torch.Tensor): Number of points in each pillar. + coors (torch.Tensor): Coordinates of each voxel. + input_metas (list[dict]): Contain pcd meta info. + + Returns: + torch.Tensor: Features from points. + """ + voxel_features = self.voxel_encoder(voxels, num_points, coors) + batch_size = coors[-1, 0] + 1 # refactor + assert batch_size == 1 + x = self.middle_encoder(voxel_features, coors, batch_size) + x = self.backbone(x) + if self.with_neck: + x = self.neck(x) + return x diff --git a/mmdeploy/utils/constants.py b/mmdeploy/utils/constants.py index c36e610fea..094f7b7ab5 100644 --- a/mmdeploy/utils/constants.py +++ b/mmdeploy/utils/constants.py @@ -24,6 +24,7 @@ class Task(AdvancedEnum): CLASSIFICATION = 'Classification' OBJECT_DETECTION = 'ObjectDetection' INSTANCE_SEGMENTATION = 'InstanceSegmentation' + VOXEL_DETECTION = 'VoxelDetection' POSE_DETECTION = 'PoseDetection' @@ -34,6 +35,7 @@ class Codebase(AdvancedEnum): MMCLS = 'mmcls' MMOCR = 'mmocr' MMEDIT = 'mmedit' + MMDET3D = 'mmdet3d' MMPOSE = 'mmpose' diff --git a/tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin b/tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin new file mode 100644 index 0000000000..24cefd327f Binary files /dev/null and b/tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin differ diff --git a/tests/test_codebase/test_mmdet3d/data/kitti/kitti_infos_val.pkl b/tests/test_codebase/test_mmdet3d/data/kitti/kitti_infos_val.pkl new file mode 100644 index 0000000000..f2acbd3dcc Binary files /dev/null and b/tests/test_codebase/test_mmdet3d/data/kitti/kitti_infos_val.pkl differ diff --git a/tests/test_codebase/test_mmdet3d/data/model_cfg.py b/tests/test_codebase/test_mmdet3d/data/model_cfg.py new file mode 100644 index 0000000000..b1934a2ca7 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/data/model_cfg.py @@ -0,0 +1,143 @@ +# Copyright (c) OpenMMLab. All rights reserved. +voxel_size = [0.16, 0.16, 4] + +model = dict( + type='VoxelNet', + voxel_layer=dict( + max_num_points=32, # max_points_per_voxel + point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1], + voxel_size=voxel_size, + max_voxels=(16000, 40000) # (training, testing) max_voxels + ), + voxel_encoder=dict( + type='PillarFeatureNet', + in_channels=4, + feat_channels=[64], + with_distance=False, + voxel_size=voxel_size, + point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1]), + middle_encoder=dict( + type='PointPillarsScatter', in_channels=64, output_shape=[496, 432]), + backbone=dict( + type='SECOND', + in_channels=64, + layer_nums=[3, 5, 5], + layer_strides=[2, 2, 2], + out_channels=[64, 128, 256]), + neck=dict( + type='SECONDFPN', + in_channels=[64, 128, 256], + upsample_strides=[1, 2, 4], + out_channels=[128, 128, 128]), + test_cfg=dict( + use_rotate_nms=True, + nms_across_levels=False, + nms_thr=0.01, + score_thr=0.1, + min_bbox_size=0, + nms_pre=100, + max_num=50), + bbox_head=dict( + type='Anchor3DHead', + num_classes=3, + in_channels=384, + feat_channels=384, + use_direction_classifier=True, + anchor_generator=dict( + type='AlignedAnchor3DRangeGenerator', + ranges=[ + [0, -39.68, -0.6, 69.12, 39.68, -0.6], + [0, -39.68, -0.6, 69.12, 39.68, -0.6], + [0, -39.68, -1.78, 69.12, 39.68, -1.78], + ], + sizes=[[0.6, 0.8, 1.73], [0.6, 1.76, 1.73], [1.6, 3.9, 1.56]], + rotations=[0, 1.57], + reshape_out=False), + diff_rad_by_sin=True, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), + loss_dir=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2))) +point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1] +# dataset settings +data_root = 'tests/test_codebase/test_mmdet3d/data/kitti/' +dataset_type = 'KittiDataset' +class_names = ['Pedestrian', 'Cyclist', 'Car'] +input_modality = dict(use_lidar=True, use_camera=False) +# PointPillars adopted a different sampling strategies among classes +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'kitti_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), + classes=class_names, + sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10)) +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='ObjectNoise', + num_try=100, + translation_std=[0.25, 0.25, 0.25], + global_rot_range=[0.0, 0.0], + rot_range=[-0.15707963267, 0.15707963267]), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] +data = dict( + train=dict( + dataset=dict( + pipeline=train_pipeline, classes=class_names, + box_type_3d='LiDAR')), + val=dict(pipeline=test_pipeline, classes=class_names, box_type_3d='LiDAR'), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'kitti_infos_val.pkl', + split='training', + pts_prefix='velodyne_reduced', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True, + box_type_3d='LiDAR')) diff --git a/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py new file mode 100644 index 0000000000..8b1b0039ed --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py @@ -0,0 +1,109 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import mmcv +import numpy as np +import pytest +import torch + +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase, Task +from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs + +try: + import_codebase(Codebase.MMDET3D) +except ImportError: + pytest.skip( + f'{Codebase.MMDET3D} is not installed.', allow_module_level=True) + + +def get_pillar_encoder(): + from mmdet3d.models.voxel_encoders import PillarFeatureNet + model = PillarFeatureNet( + in_channels=4, + feat_channels=(64, ), + with_distance=False, + with_cluster_center=True, + with_voxel_center=True, + voxel_size=(0.2, 0.2, 4), + point_cloud_range=(0, -40, -3, 70.4, 40, 1), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + mode='max') + model.requires_grad_(False) + return model + + +def get_pointpillars_scatter(): + from mmdet3d.models.middle_encoders import PointPillarsScatter + model = PointPillarsScatter(in_channels=64, output_shape=(16, 16)) + model.requires_grad_(False) + return model + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_pillar_encoder(backend_type: Backend): + check_backend(backend_type, True) + model = get_pillar_encoder() + model.cpu().eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict( + input_shape=None, + input_names=['features', 'num_points', 'coors'], + output_names=['outputs']), + codebase_config=dict( + type=Codebase.MMDET3D.value, task=Task.VOXEL_DETECTION.value))) + features = torch.rand(3945, 32, 4) * 100 + num_points = torch.randint(0, 32, (3945, ), dtype=torch.int32) + coors = torch.randint(0, 10, (3945, 4), dtype=torch.int32) + model_outputs = [model.forward(features, num_points, coors)] + wrapped_model = WrapModel(model, 'forward') + rewrite_inputs = { + 'features': features, + 'num_points': num_points, + 'coors': coors + } + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + if isinstance(rewrite_outputs, dict): + rewrite_outputs = rewrite_outputs['output'] + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + if isinstance(rewrite_output, torch.Tensor): + rewrite_output = rewrite_output.cpu().numpy() + assert np.allclose( + model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03) + + +@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME]) +def test_pointpillars_scatter(backend_type: Backend): + check_backend(backend_type, True) + model = get_pointpillars_scatter() + model.cpu().eval() + + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=backend_type.value), + onnx_config=dict( + input_shape=None, + input_names=['voxel_features', 'coors'], + output_names=['outputs']), + codebase_config=dict( + type=Codebase.MMDET3D.value, task=Task.VOXEL_DETECTION.value))) + voxel_features = torch.rand(16 * 16, 64) * 100 + coors = torch.randint(0, 10, (16 * 16, 4), dtype=torch.int32) + model_outputs = [model.forward_batch(voxel_features, coors, 1)] + wrapped_model = WrapModel(model, 'forward_batch') + rewrite_inputs = {'voxel_features': voxel_features, 'coors': coors} + rewrite_outputs, is_backend_output = get_rewrite_outputs( + wrapped_model=wrapped_model, + model_inputs=rewrite_inputs, + deploy_cfg=deploy_cfg) + if isinstance(rewrite_outputs, dict): + rewrite_outputs = rewrite_outputs['output'] + for model_output, rewrite_output in zip(model_outputs, rewrite_outputs): + if isinstance(rewrite_output, torch.Tensor): + rewrite_output = rewrite_output.cpu().numpy() + assert np.allclose( + model_output.shape, rewrite_output.shape, rtol=1e-03, atol=1e-03) diff --git a/tests/test_codebase/test_mmdet3d/test_voxel_detection.py b/tests/test_codebase/test_mmdet3d/test_voxel_detection.py new file mode 100644 index 0000000000..aec5c5901c --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/test_voxel_detection.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +from tempfile import NamedTemporaryFile, TemporaryDirectory + +import mmcv +import pytest +import torch +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +import mmdeploy.backend.onnxruntime as ort_apis +from mmdeploy.apis import build_task_processor +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Codebase, load_config +from mmdeploy.utils.test import DummyModel, SwitchBackendWrapper + +try: + import_codebase(Codebase.MMDET3D) +except ImportError: + pytest.skip( + f'{Codebase.MMDET3D} is not installed.', allow_module_level=True) + +model_cfg_path = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' +pcd_path = 'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin' +model_cfg = load_config(model_cfg_path)[0] +deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type='onnxruntime'), + codebase_config=dict(type='mmdet3d', task='VoxelDetection'), + onnx_config=dict( + type='onnx', + export_params=True, + keep_initializers_as_inputs=False, + opset_version=11, + input_shape=None, + input_names=['voxels', 'num_points', 'coors'], + output_names=['scores', 'bbox_preds', 'dir_scores']))) +onnx_file = NamedTemporaryFile(suffix='.onnx').name +task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') + + +def test_init_pytorch_model(): + from mmdet3d.models import Base3DDetector + model = task_processor.init_pytorch_model(None) + assert isinstance(model, Base3DDetector) + + +@pytest.fixture +def backend_model(): + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + wrapper = SwitchBackendWrapper(ORTWrapper) + wrapper.set( + outputs={ + 'scores': torch.rand(1, 18, 32, 32), + 'bbox_preds': torch.rand(1, 42, 32, 32), + 'dir_scores': torch.rand(1, 12, 32, 32) + }) + + yield task_processor.init_backend_model(['']) + + wrapper.recover() + + +def test_init_backend_model(backend_model): + from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \ + VoxelDetectionModel + assert isinstance(backend_model, VoxelDetectionModel) + + +@pytest.mark.parametrize('device', ['cpu', 'cuda:0']) +def test_create_input(device): + if device == 'cuda:0' and not torch.cuda.is_available(): + pytest.skip('cuda is not available') + original_device = task_processor.device + task_processor.device = device + inputs = task_processor.create_input(pcd_path) + assert len(inputs) == 2 + task_processor.device = original_device + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +def test_run_inference(backend_model): + task_processor.device = 'cuda:0' + torch_model = task_processor.init_pytorch_model(None) + input_dict, _ = task_processor.create_input(pcd_path) + torch_results = task_processor.run_inference(torch_model, input_dict) + backend_results = task_processor.run_inference(backend_model, input_dict) + assert torch_results is not None + assert backend_results is not None + assert len(torch_results[0]) == len(backend_results[0]) + task_processor.device = 'cpu' + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +def test_visualize(): + task_processor.device = 'cuda:0' + input_dict, _ = task_processor.create_input(pcd_path) + torch_model = task_processor.init_pytorch_model(None) + results = task_processor.run_inference(torch_model, input_dict) + with TemporaryDirectory() as dir: + filename = dir + 'tmp.bin' + task_processor.visualize(torch_model, pcd_path, results[0], filename, + 'test', False) + assert os.path.exists(filename) + task_processor.device = 'cpu' + + +def test_build_dataset_and_dataloader(): + dataset = task_processor.build_dataset( + dataset_cfg=model_cfg, dataset_type='test') + assert isinstance(dataset, Dataset), 'Failed to build dataset' + dataloader = task_processor.build_dataloader(dataset, 1, 1) + assert isinstance(dataloader, DataLoader), 'Failed to build dataloader' + + +@pytest.mark.skipif( + reason='Only support GPU test', condition=not torch.cuda.is_available()) +def test_single_gpu_test_and_evaluate(): + from mmcv.parallel import MMDataParallel + task_processor.device = 'cuda:0' + + class DummyDataset(Dataset): + + def __getitem__(self, index): + return 0 + + def __len__(self): + return 0 + + def evaluate(self, *args, **kwargs): + return 0 + + def format_results(self, *args, **kwargs): + return 0 + + dataset = DummyDataset() + # Prepare dataloader + dataloader = DataLoader(dataset) + + # Prepare dummy model + model = DummyModel(outputs=[torch.rand([1, 10, 5]), torch.rand([1, 10])]) + model = MMDataParallel(model, device_ids=[0]) + # Run test + outputs = task_processor.single_gpu_test(model, dataloader) + assert isinstance(outputs, list) + output_file = NamedTemporaryFile(suffix='.pkl').name + task_processor.evaluate_outputs( + model_cfg, outputs, dataset, 'bbox', out=output_file, format_only=True) + task_processor.device = 'cpu' diff --git a/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py b/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py new file mode 100644 index 0000000000..1959996d93 --- /dev/null +++ b/tests/test_codebase/test_mmdet3d/test_voxel_detection_model.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp + +import mmcv +import pytest +import torch + +import mmdeploy.backend.onnxruntime as ort_apis +from mmdeploy.codebase import import_codebase +from mmdeploy.utils import Backend, Codebase +from mmdeploy.utils.test import SwitchBackendWrapper, backend_checker + +try: + import_codebase(Codebase.MMDET3D) +except ImportError: + pytest.skip( + f'{Codebase.MMDET3D} is not installed.', allow_module_level=True) +from mmdeploy.codebase.mmdet3d.deploy.voxel_detection import VoxelDetection + +pcd_path = 'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin' +model_cfg = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' + + +@backend_checker(Backend.ONNXRUNTIME) +class TestVoxelDetectionModel: + + @classmethod + def setup_class(cls): + # force add backend wrapper regardless of plugins + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + cls.wrapper = SwitchBackendWrapper(ORTWrapper) + cls.outputs = { + 'scores': torch.rand(1, 18, 32, 32), + 'bbox_preds': torch.rand(1, 42, 32, 32), + 'dir_scores': torch.rand(1, 12, 32, 32) + } + cls.wrapper.set(outputs=cls.outputs) + deploy_cfg = mmcv.Config({ + 'onnx_config': { + 'input_names': ['voxels', 'num_points', 'coors'], + 'output_names': ['scores', 'bbox_preds', 'dir_scores'] + } + }) + + from mmdeploy.utils import load_config + model_cfg_path = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' + model_cfg = load_config(model_cfg_path)[0] + from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \ + VoxelDetectionModel + cls.end2end_model = VoxelDetectionModel( + Backend.ONNXRUNTIME, [''], + device='cuda', + deploy_cfg=deploy_cfg, + model_cfg=model_cfg) + + @pytest.mark.skipif( + reason='Only support GPU test', + condition=not torch.cuda.is_available()) + def test_forward_and_show_result(self): + data = VoxelDetection.read_pcd_file(pcd_path, model_cfg, 'cuda') + results = self.end2end_model.forward(data['points'], data['img_metas']) + assert results is not None + from tempfile import TemporaryDirectory + with TemporaryDirectory() as dir: + self.end2end_model.show_result( + data, results, dir, 'backend_output.bin', show=False) + assert osp.exists(dir + '/backend_output.bin') + + +@backend_checker(Backend.ONNXRUNTIME) +def test_build_pose_detection_model(): + from mmdeploy.utils import load_config + model_cfg_path = 'tests/test_codebase/test_mmdet3d/data/model_cfg.py' + model_cfg = load_config(model_cfg_path)[0] + deploy_cfg = mmcv.Config( + dict( + backend_config=dict(type=Backend.ONNXRUNTIME.value), + onnx_config=dict( + output_names=['scores', 'bbox_preds', 'dir_scores']), + codebase_config=dict(type=Codebase.MMDET3D.value))) + + from mmdeploy.backend.onnxruntime import ORTWrapper + ort_apis.__dict__.update({'ORTWrapper': ORTWrapper}) + + # simplify backend inference + with SwitchBackendWrapper(ORTWrapper) as wrapper: + wrapper.set(model_cfg=model_cfg, deploy_cfg=deploy_cfg) + from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import ( + VoxelDetectionModel, build_voxel_detection_model) + voxeldetector = build_voxel_detection_model([''], model_cfg, + deploy_cfg, 'cpu') + assert isinstance(voxeldetector, VoxelDetectionModel)