Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add inferencer for lidar-based segmentation #2304

Merged
merged 12 commits into from
Mar 20, 2023
1 change: 1 addition & 0 deletions configs/pointnet2/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Models:
Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class_20210514_144009-24477ab1.pth

- Name: pointnet2_ssg_2xb16-cosine-50e_s3dis-seg
Alias: pointnet2-ssg_s3dis-seg
In Collection: PointNet++
Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py
Metadata:
Expand Down
8 changes: 4 additions & 4 deletions mmdet3d/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
inference_mono_3d_detector,
inference_multi_modality_detector, inference_segmentor,
init_model)
from .inferencers import (BaseDet3DInferencer, LidarDet3DInferencer,
MonoDet3DInferencer)
from .inferencers import (Base3DInferencer, LidarDet3DInferencer,
LidarSeg3DInferencer, MonoDet3DInferencer)

__all__ = [
'inference_detector', 'init_model', 'inference_mono_3d_detector',
'convert_SyncBN', 'inference_multi_modality_detector',
'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer'
'inference_segmentor', 'Base3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer', 'LidarSeg3DInferencer'
]
8 changes: 4 additions & 4 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,16 +76,16 @@ def init_model(config: Union[str, Path, Config],
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
model.dataset_meta = {'classes': classes}

if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else:
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': config.class_names}
model.dataset_meta = {'classes': config.class_names}

if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']

model.cfg = config # save the config in the model for convenience
if device != 'cpu':
Expand Down
6 changes: 4 additions & 2 deletions mmdet3d/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_det3d_inferencer import BaseDet3DInferencer
from .base_3d_inferencer import Base3DInferencer
from .lidar_det3d_inferencer import LidarDet3DInferencer
from .lidar_seg3d_inferencer import LidarSeg3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer

__all__ = [
'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer'
'Base3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer',
'LidarSeg3DInferencer'
]
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]


class BaseDet3DInferencer(BaseInferencer):
"""Base 3D object detection inferencer.
class Base3DInferencer(BaseInferencer):
"""Base 3D model inferencer.

Args:
model (str, optional): Path to the config file or the model name
Expand All @@ -39,7 +39,7 @@ class BaseDet3DInferencer(BaseInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
Expand All @@ -58,7 +58,7 @@ def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
self.palette = palette
init_default_scope(scope)
Expand Down Expand Up @@ -97,16 +97,16 @@ def _init_model(
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
model.dataset_meta = {'classes': classes}

if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else:
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': cfg.class_names}
model.dataset_meta = {'classes': cfg.class_names}

if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']

model.cfg = cfg # save the config in the model for convenience
model.to(device)
Expand All @@ -130,8 +130,8 @@ def _inputs_to_list(

Args:
inputs (Union[dict, list]): Inputs for the inferencer.
modality_key (Union[str, List[str]], optional): The key of the
modality. Defaults to 'points'.
modality_key (Union[str, List[str]]): The key of the modality.
Defaults to 'points'.

Returns:
list: List of input for the :meth:`preprocess`.
Expand Down Expand Up @@ -187,6 +187,7 @@ def __call__(self,
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer.

Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
Expand All @@ -205,14 +206,15 @@ def __call__(self,
If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.

Returns:
dict: Inference and visualization results.
"""
Expand Down Expand Up @@ -240,29 +242,36 @@ def postprocess(
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.

This method should be responsible for the following tasks:

1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.

Args:
preds (List[Dict]): Predictions of the model.
visualization (Optional[np.ndarray]): Visualized predictions.
visualization (np.ndarray, optional): Visualized predictions.
Defaults to None.
return_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used.
Defaults to False.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.

Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``.

- ``visualization`` (Any): Returned by :meth:`visualize`.
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
result_dict = {}
results = preds
Expand All @@ -286,11 +295,18 @@ def pred2dict(self, data_sample: InstanceData) -> Dict:
It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.
"""
pred_instances = data_sample.pred_instances_3d.numpy()
result = {
'bboxes_3d': pred_instances.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances.labels_3d.tolist(),
'scores_3d': pred_instances.scores_3d.tolist()
}
result = {}
if 'pred_instances_3d' in data_sample:
pred_instances_3d = data_sample.pred_instances_3d.numpy()
result = {
'bboxes_3d': pred_instances_3d.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances_3d.labels_3d.tolist(),
'scores_3d': pred_instances_3d.scores_3d.tolist()
}

if 'pred_pts_seg' in data_sample:
pred_pts_seg = data_sample.pred_pts_seg.numpy()
result['pts_semantic_mask'] = \
pred_pts_seg.pts_semantic_mask.tolist()

return result
21 changes: 13 additions & 8 deletions mmdet3d/apis/inferencers/lidar_det3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer
from .base_3d_inferencer import Base3DInferencer

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
Expand All @@ -22,7 +22,7 @@

@INFERENCERS.register_module(name='det3d-lidar')
@INFERENCERS.register_module()
class LidarDet3DInferencer(BaseDet3DInferencer):
class LidarDet3DInferencer(Base3DInferencer):
"""The inferencer of LiDAR-based detection.

Args:
Expand All @@ -38,8 +38,9 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of registry.
palette (str, optional): The palette of visualization.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""

preprocess_kwargs: set = set()
Expand All @@ -56,14 +57,17 @@ def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_visualized_frames = 0
self.palette = palette
super().__init__(
model=model, weights=weights, device=device, scope=scope)
super(LidarDet3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)

def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.
Expand Down Expand Up @@ -129,6 +133,7 @@ def visualize(self,
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.

Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
Expand Down
Loading