Skip to content

Commit

Permalink
[Refactor] Refactor voxelization for faster speed (open-mmlab#2062)
Browse files Browse the repository at this point in the history
* refactor voxelization for faster speed

* fix doc typo
  • Loading branch information
Xiangxu-0103 authored and ZwwWayne committed Dec 3, 2022
1 parent c53516a commit 1b09bb6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 53 deletions.
97 changes: 50 additions & 47 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import math
from numbers import Number
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Dict, List, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -28,24 +28,25 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
- 1) For image data:
- Pad images in inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
``pad_size_divisor``.
- Stack images in inputs to batch_imgs.
- Convert images in inputs from bgr to rgb if the shape of input is
(3, H, W).
(3, H, W).
- Normalize images in inputs with defined std and mean.
- Do batch augmentations during training.
- 2) For point cloud data:
- if no voxelization, directly return list of point cloud data.
- if voxelization is applied, voxelize point cloud according to
- If no voxelization, directly return list of point cloud data.
- If voxelization is applied, voxelize point cloud according to
``voxel_type`` and obtain ``voxels``.
Args:
voxel (bool): Whether to apply voxelziation to point cloud.
voxel (bool): Whether to apply voxelization to point cloud.
Defaults to False.
voxel_type (str): Voxelization type. Two voxelization types are
provided: 'hard' and 'dynamic', respectively for hard
voxelization and dynamic voxelization. Defaults to 'hard'.
voxel_layer (:obj:`ConfigDict`, optional): Voxelization layer
voxel_layer (dict or :obj:`ConfigDict`, optional): Voxelization layer
config. Defaults to None.
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
Expand All @@ -54,11 +55,21 @@ class Det3DDataPreprocessor(DetDataPreprocessor):
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
bgr_to_rgb (bool): whether to convert image from BGR to RGB.
pad_mask (bool): Whether to pad instance masks. Defaults to False.
mask_pad_value (int): The padded pixel value for instance masks.
Defaults to 0.
pad_seg (bool): Whether to pad semantic segmentation maps.
Defaults to False.
seg_pad_value (int): The padded pixel value for semantic
segmentation maps. Defaults to 255.
bgr_to_rgb (bool): Whether to convert image from BGR to RGB.
Defaults to False.
rgb_to_bgr (bool): whether to convert image from RGB to RGB.
rgb_to_bgr (bool): Whether to convert image from RGB to BGR.
Defaults to False.
batch_augments (list[dict], optional): Batch-level augmentations
boxtype2tensor (bool): Whether to keep the ``BaseBoxes`` type of
bboxes data or not. Defaults to True.
batch_augments (List[dict], optional): Batch-level augmentations.
Defaults to None.
"""

def __init__(self,
Expand All @@ -76,8 +87,8 @@ def __init__(self,
bgr_to_rgb: bool = False,
rgb_to_bgr: bool = False,
boxtype2tensor: bool = True,
batch_augments: Optional[List[dict]] = None):
super().__init__(
batch_augments: Optional[List[dict]] = None) -> None:
super(Det3DDataPreprocessor).__init__(
mean=mean,
std=std,
pad_size_divisor=pad_size_divisor,
Expand All @@ -94,24 +105,21 @@ def __init__(self,
if voxel:
self.voxel_layer = Voxelization(**voxel_layer)

def forward(
self,
data: Union[dict, List[dict]],
training: bool = False
) -> Tuple[Union[dict, List[dict]], Optional[list]]:
"""Perform normalization、padding and bgr2rgb conversion based on
def forward(self,
data: Union[dict, List[dict]],
training: bool = False) -> Union[dict, List[dict]]:
"""Perform normalization, padding and bgr2rgb conversion based on
``BaseDataPreprocessor``.
Args:
data (dict | List[dict]): data from dataloader.
data (dict or List[dict]): Data from dataloader.
The dict contains the whole batch data, when it is
a list[dict], the list indicate test time augmentation.
training (bool): Whether to enable training time augmentation.
Defaults to False.
Returns:
Dict | List[Dict]: Data in the same format as the model input.
dict or List[dict]: Data in the same format as the model input.
"""
if isinstance(data, list):
num_augs = len(data)
Expand All @@ -126,7 +134,7 @@ def forward(
return self.simple_process(data, training)

def simple_process(self, data: dict, training: bool = False) -> dict:
"""Perform normalizationpadding and bgr2rgb conversion for img data
"""Perform normalization, padding and bgr2rgb conversion for img data
based on ``BaseDataPreprocessor``, and voxelize point cloud if `voxel`
is set to be True.
Expand Down Expand Up @@ -188,7 +196,7 @@ def simple_process(self, data: dict, training: bool = False) -> dict:

return {'inputs': batch_inputs, 'data_samples': data_samples}

def preprocess_img(self, _batch_img):
def preprocess_img(self, _batch_img: torch.Tensor) -> torch.Tensor:
# channel transform
if self._channel_conversion:
_batch_img = _batch_img[[2, 1, 0], ...]
Expand All @@ -206,7 +214,7 @@ def preprocess_img(self, _batch_img):
return _batch_img

def collate_data(self, data: dict) -> dict:
"""Copying data to the target device and Performs normalization
"""Copying data to the target device and Performs normalization,
padding and bgr2rgb conversion and stack based on
``BaseDataPreprocessor``.
Expand Down Expand Up @@ -273,7 +281,7 @@ def collate_data(self, data: dict) -> dict:
raise TypeError(
'Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got'
f'{type(data)} {data}')
f'{type(data)}: {data}')

data['inputs']['imgs'] = batch_imgs

Expand All @@ -284,14 +292,14 @@ def collate_data(self, data: dict) -> dict:
def _get_pad_shape(self, data: dict) -> List[tuple]:
"""Get the pad_shape of each image based on data and
pad_size_divisor."""
# rewrite `_get_pad_shape` for obaining image inputs.
# rewrite `_get_pad_shape` for obtaining image inputs.
_batch_inputs = data['inputs']['img']
# Process data with `pseudo_collate`.
if is_list_of(_batch_inputs, torch.Tensor):
batch_pad_shape = []
for ori_input in _batch_inputs:
if ori_input.dim() == 4:
# mean multiivew input, select ont of the
# mean multiview input, select one of the
# image to calculate the pad shape
ori_input = ori_input[0]
pad_h = int(
Expand All @@ -316,24 +324,24 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
else:
raise TypeError('Output of `cast_data` should be a list of dict '
'or a tuple with inputs and data_samples, but got'
'or a tuple with inputs and data_samples, but got '
f'{type(data)}: {data}')
return batch_pad_shape

@torch.no_grad()
def voxelize(self, points: List[torch.Tensor]) -> Dict:
def voxelize(self, points: List[torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Apply voxelization to point cloud.
Args:
points (List[Tensor]): Point cloud in one data batch.
Returns:
dict[str, Tensor]: Voxelization information.
Dict[str, Tensor]: Voxelization information.
- voxels (Tensor): Features of voxels, shape is MXNxC for hard
voxelization, NXC for dynamic voxelization.
- coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
where 1 represents the batch index.
- voxels (Tensor): Features of voxels, shape is MxNxC for hard
voxelization, NxC for dynamic voxelization.
- coors (Tensor): Coordinates of voxels, shape is Nx(1+NDim),
where 1 represents the batch index.
- num_points (Tensor, optional): Number of points in each voxel.
- voxel_centers (Tensor, optional): Centers of voxels.
"""
Expand All @@ -342,43 +350,38 @@ def voxelize(self, points: List[torch.Tensor]) -> Dict:

if self.voxel_type == 'hard':
voxels, coors, num_points, voxel_centers = [], [], [], []
for res in points:
for i, res in enumerate(points):
res_voxels, res_coors, res_num_points = self.voxel_layer(res)
res_voxel_centers = (
res_coors[:, [2, 1, 0]] + 0.5) * res_voxels.new_tensor(
self.voxel_layer.voxel_size) + res_voxels.new_tensor(
self.voxel_layer.point_cloud_range[0:3])
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
voxels.append(res_voxels)
coors.append(res_coors)
num_points.append(res_num_points)
voxel_centers.append(res_voxel_centers)

voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
num_points = torch.cat(num_points, dim=0)
voxel_centers = torch.cat(voxel_centers, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)

voxel_dict['num_points'] = num_points
voxel_dict['voxel_centers'] = voxel_centers
elif self.voxel_type == 'dynamic':
coors = []
# dynamic voxelization only provide a coors mapping
for res in points:
for i, res in enumerate(points):
res_coors = self.voxel_layer(res)
res_coors = F.pad(res_coors, (1, 0), mode='constant', value=i)
coors.append(res_coors)
voxels = torch.cat(points, dim=0)
coors_batch = []
for i, coor in enumerate(coors):
coor_pad = F.pad(coor, (1, 0), mode='constant', value=i)
coors_batch.append(coor_pad)
coors_batch = torch.cat(coors_batch, dim=0)
coors = torch.cat(coors, dim=0)
else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')

voxel_dict['voxels'] = voxels
voxel_dict['coors'] = coors_batch
voxel_dict['coors'] = coors

return voxel_dict
12 changes: 6 additions & 6 deletions mmdet3d/models/data_preprocessors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def multiview_img_stack_batch(
"""
Compared to the stack_batch in mmengine.model.utils,
multiview_img_stack_batch further handle the multiview images.
see diff of padded_sizes[:, :-2] = 0 vs padded_sizees[:, 0] = 0 in line 47
see diff of padded_sizes[:, :-2] = 0 vs padded_sizes[:, 0] = 0 in line 47
Stack multiple tensors to form a batch and pad the tensor to the max
shape use the right bottom padding mode in these images. If
``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
Expand All @@ -23,20 +23,20 @@ def multiview_img_stack_batch(
pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
to ensure the shape of each dim is divisible by
``pad_size_divisor``. This depends on the model, and many
models need to be divisible by 32. Defaults to 1
pad_value (int, float): The padding value. Defaults to 0.
models need to be divisible by 32. Defaults to 1.
pad_value (int or float): The padding value. Defaults to 0.
Returns:
Tensor: The n dim tensor.
"""
assert isinstance(
tensor_list,
list), (f'Expected input type to be list, but got {type(tensor_list)}')
list), f'Expected input type to be list, but got {type(tensor_list)}'
assert tensor_list, '`tensor_list` could not be an empty list'
assert len({
tensor.ndim
for tensor in tensor_list
}) == 1, (f'Expected the dimensions of all tensors must be the same, '
}) == 1, ('Expected the dimensions of all tensors must be the same, '
f'but got {[tensor.ndim for tensor in tensor_list]}')

dim = tensor_list[0].dim()
Expand All @@ -46,7 +46,7 @@ def multiview_img_stack_batch(
max_sizes = torch.ceil(
torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor
padded_sizes = max_sizes - all_sizes
# The first dim normally means channel, which should not be padded.
# The first dim normally means channel, which should not be padded.
padded_sizes[:, :-2] = 0
if padded_sizes.sum() == 0:
return torch.stack(tensor_list)
Expand Down

0 comments on commit 1b09bb6

Please sign in to comment.