Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add inferencer for lidar-based segmentation #2304

Merged
merged 12 commits into from
Mar 20, 2023
6 changes: 0 additions & 6 deletions configs/_base_/datasets/s3dis-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
Expand Down
6 changes: 0 additions & 6 deletions configs/_base_/datasets/scannet-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,6 @@
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
Expand Down
6 changes: 0 additions & 6 deletions configs/_base_/datasets/semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,6 @@
load_dim=4,
use_dim=4,
file_client_args=file_client_args),
dict(
type='LoadAnnotations3D',
with_seg_3d=True,
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]
# construct a pipeline for data and gt loading in show function
Expand Down
13 changes: 7 additions & 6 deletions configs/pointnet2/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Collections:
Models:
- Name: pointnet2_ssg_2xb16-cosine-200e_scannet-seg-xyz-only
In Collection: PointNet++
Config: configs/pointnet/pointnet2_ssg_2xb16-cosine-200e_scannet-seg-xyz-only.py
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-200e_scannet-seg-xyz-only.py
Metadata:
Training Data: ScanNet
Training Memory (GB): 1.9
Expand All @@ -30,7 +30,7 @@ Models:

- Name: pointnet2_ssg_2xb16-cosine-200e_scannet-seg
In Collection: PointNet++
Config: configs/pointnet/pointnet2_ssg_2xb16-cosine-200e_scannet-seg.py
Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-200e_scannet-seg.py
Metadata:
Training Data: ScanNet
Training Memory (GB): 1.9
Expand All @@ -43,7 +43,7 @@ Models:

- Name: pointnet2_msg_2xb16-cosine-250e_scannet-seg-xyz-only
In Collection: PointNet++
Config: configs/pointnet/pointnet2_msg_2xb16-cosine-250e_scannet-seg-xyz-only.py
Config: configs/pointnet2/pointnet2_msg_2xb16-cosine-250e_scannet-seg-xyz-only.py
Metadata:
Training Data: ScanNet
Training Memory (GB): 2.4
Expand All @@ -56,7 +56,7 @@ Models:

- Name: pointnet2_msg_2xb16-cosine-250e_scannet-seg
In Collection: PointNet++
Config: configs/pointnet/pointnet2_msg_2xb16-cosine-250e_scannet-seg.py
Config: configs/pointnet2/pointnet2_msg_2xb16-cosine-250e_scannet-seg.py
Metadata:
Training Data: ScanNet
Training Memory (GB): 2.4
Expand All @@ -68,8 +68,9 @@ Models:
Weights: https://download.openmmlab.com/mmdetection3d/v0.1.0_models/pointnet2/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class/pointnet2_msg_16x2_cosine_250e_scannet_seg-3d-20class_20210514_144009-24477ab1.pth

- Name: pointnet2_ssg_2xb16-cosine-50e_s3dis-seg
Alias: pointnet2-ssg_s3dis-seg
In Collection: PointNet++
Config: configs/pointnet/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py
Config: configs/pointnet2/pointnet2_ssg_2xb16-cosine-50e_s3dis-seg.py
Metadata:
Training Data: S3DIS
Training Memory (GB): 3.6
Expand All @@ -82,7 +83,7 @@ Models:

- Name: pointnet2_msg_2xb16-cosine-80e_s3dis-seg
In Collection: PointNet++
Config: configs/pointnet/pointnet2_msg_2xb16-cosine-80e_s3dis-seg.py
Config: configs/pointnet2/pointnet2_msg_2xb16-cosine-80e_s3dis-seg.py
Metadata:
Training Data: S3DIS
Training Memory (GB): 3.6
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@
use_color=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@
use_color=False,
load_dim=6,
use_dim=[0, 1, 2]),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
inference_multi_modality_detector, inference_segmentor,
init_model)
from .inferencers import (BaseDet3DInferencer, LidarDet3DInferencer,
MonoDet3DInferencer)
LidarSeg3DInferencer, MonoDet3DInferencer)

__all__ = [
'inference_detector', 'init_model', 'inference_mono_3d_detector',
'convert_SyncBN', 'inference_multi_modality_detector',
'inference_segmentor', 'BaseDet3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer'
'LidarDet3DInferencer', 'LidarSeg3DInferencer'
]
4 changes: 3 additions & 1 deletion mmdet3d/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_det3d_inferencer import BaseDet3DInferencer
from .lidar_det3d_inferencer import LidarDet3DInferencer
from .lidar_seg3d_inferencer import LidarSeg3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer

__all__ = [
'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer'
'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer',
'LidarSeg3DInferencer'
]
39 changes: 24 additions & 15 deletions mmdet3d/apis/inferencers/base_det3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BaseDet3DInferencer(BaseInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of the model. Defaults to mmdet3d.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""
Expand All @@ -57,7 +57,7 @@ def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
self.palette = palette
register_all_modules()
Expand Down Expand Up @@ -96,16 +96,16 @@ def _init_model(
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
model.dataset_meta = {'CLASSES': classes}
model.dataset_meta = {'classes': classes}

if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']
else:
# < mmdet3d 1.x
model.dataset_meta = {'CLASSES': cfg.class_names}
model.dataset_meta = {'classes': cfg.class_names}

if 'PALETTE' in checkpoint.get('meta', {}): # 3D Segmentor
model.dataset_meta['PALETTE'] = checkpoint['meta']['PALETTE']
model.dataset_meta['palette'] = checkpoint['meta']['PALETTE']

model.cfg = cfg # save the config in the model for convenience
model.to(device)
Expand All @@ -129,8 +129,8 @@ def _inputs_to_list(

Args:
inputs (Union[dict, list]): Inputs for the inferencer.
modality_key (Union[str, List[str]], optional): The key of the
modality. Defaults to 'points'.
modality_key (Union[str, List[str]]): The key of the modality.
Defaults to 'points'.

Returns:
list: List of input for the :meth:`preprocess`.
Expand Down Expand Up @@ -186,6 +186,7 @@ def __call__(self,
pred_out_file: str = '',
**kwargs) -> dict:
"""Call the inferencer.

Args:
inputs (InputsType): Inputs for the inferencer.
return_datasamples (bool): Whether to return results as
Expand All @@ -204,14 +205,15 @@ def __call__(self,
If left as empty, no file will be saved. Defaults to ''.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.
**kwargs: Other keyword arguments passed to :meth:`preprocess`,
:meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
Each key in kwargs should be in the corresponding set of
``preprocess_kwargs``, ``forward_kwargs``, ``visualize_kwargs``
and ``postprocess_kwargs``.

Returns:
dict: Inference and visualization results.
"""
Expand Down Expand Up @@ -239,29 +241,36 @@ def postprocess(
) -> Union[ResType, Tuple[ResType, np.ndarray]]:
"""Process the predictions and visualization results from ``forward``
and ``visualize``.

This method should be responsible for the following tasks:

1. Convert datasamples into a json-serializable dict if needed.
2. Pack the predictions and visualization results and return them.
3. Dump or log the predictions.

Args:
preds (List[Dict]): Predictions of the model.
visualization (Optional[np.ndarray]): Visualized predictions.
visualization (np.ndarray, optional): Visualized predictions.
Defaults to None.
return_datasample (bool): Whether to use Datasample to store
inference results. If False, dict will be used.
Defaults to False.
print_result (bool): Whether to print the inference result w/o
visualization to the console. Defaults to False.
pred_out_file: File to save the inference results w/o
pred_out_file (str): File to save the inference results w/o
visualization. If left as empty, no file will be saved.
Defaults to ''.

Returns:
dict: Inference and visualization results with key ``predictions``
and ``visualization``.

- ``visualization`` (Any): Returned by :meth:`visualize`.
- ``predictions`` (dict or DataSample): Returned by
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
:meth:`forward` and processed in :meth:`postprocess`.
If ``return_datasample=False``, it usually should be a
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
result_dict = {}
results = preds
Expand Down
20 changes: 12 additions & 8 deletions mmdet3d/apis/inferencers/lidar_det3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mmengine.structures import InstanceData

from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType, register_all_modules
from mmdet3d.utils import ConfigType
from .base_det3d_inferencer import BaseDet3DInferencer

InstanceList = List[InstanceData]
Expand Down Expand Up @@ -38,8 +38,9 @@ class LidarDet3DInferencer(BaseDet3DInferencer):
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str, optional): The scope of registry.
palette (str, optional): The palette of visualization.
scope (str): The scope of the model. Defaults to 'mmdet3d'.
palette (str): Color palette used for visualization. The order of
priority is palette -> config -> checkpoint. Defaults to 'none'.
"""

preprocess_kwargs: set = set()
Expand All @@ -56,15 +57,17 @@ def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: Optional[str] = 'mmdet3d',
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_visualized_frames = 0
self.palette = palette
register_all_modules()
super().__init__(
model=model, weights=weights, device=device, scope=scope)
super(LidarDet3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)

def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.
Expand Down Expand Up @@ -130,6 +133,7 @@ def visualize(self,
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.

Returns:
List[np.ndarray] or None: Returns visualization results only if
applicable.
Expand Down
Loading