Skip to content

Commit

Permalink
[Feature] Support entire PAConv and PAConvCUDA models (#783)
Browse files Browse the repository at this point in the history
* add PAConv decode head

* add config files

* add paconv's correlation loss

* support reg loss in Segmentor class

* minor fix

* add augmentation to configs

* fix ed7 in cfg

* fix bug in corr loss

* enable syncbn in paconv

* rename to loss_regularization

* rename loss_reg to loss_regularize

* use SyncBN

* change weight kernels to kernel weights

* rename corr_loss to reg_loss

* minor fix

* configs fix IndoorPatchPointSample

* fix grouped points minus center error

* update transform_3d & add configs

* merge master

* fix enlarge_size bug

* refine config

* remove cfg files

* minor fix

* add comments on PAConv's ScoreNet

* refine comments

* update compatibility doc

* remove useless lines in transforms_3d

* rename with_loss_regularization to with_regularization_loss

* revert palette change

* remove xavier init from PAConv's ScoreNet
  • Loading branch information
Wuziyi616 authored Jul 28, 2021
1 parent a8f4752 commit 3870001
Show file tree
Hide file tree
Showing 21 changed files with 664 additions and 81 deletions.
7 changes: 7 additions & 0 deletions configs/_base_/models/paconv_cuda_ssg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_base_ = './paconv_ssg.py'

model = dict(
backbone=dict(
sa_cfg=dict(
type='PAConvCUDASAModule',
scorenet_cfg=dict(mlp_channels=[8, 16, 16]))))
49 changes: 49 additions & 0 deletions configs/_base_/models/paconv_ssg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# model settings
model = dict(
type='EncoderDecoder3D',
backbone=dict(
type='PointNet2SASSG',
in_channels=9, # [xyz, rgb, normalized_xyz]
num_points=(1024, 256, 64, 16),
radius=(None, None, None, None), # use kNN instead of ball query
num_samples=(32, 32, 32, 32),
sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256,
512)),
fp_channels=(),
norm_cfg=dict(type='BN2d', momentum=0.1),
sa_cfg=dict(
type='PAConvSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=False,
paconv_num_kernels=[16, 16, 16],
paconv_kernel_input='w_neighbor',
scorenet_input='w_neighbor_dist',
scorenet_cfg=dict(
mlp_channels=[16, 16, 16],
score_norm='softmax',
temp_factor=1.0,
last_bn=False))),
decode_head=dict(
type='PAConvHead',
# PAConv model's decoder takes skip connections from beckbone
# different from PointNet++, it also concats input features in the last
# level of decoder, leading to `128 + 6` as the channel number
fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128),
(128 + 6, 128, 128, 128)),
channels=128,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU'),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # should be modified with dataset
loss_weight=1.0)),
# correlation loss to regularize PAConv's kernel weights
loss_regularization=dict(
type='PAConvRegularizationLoss', reduction='sum', loss_weight=10.0),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide'))
8 changes: 7 additions & 1 deletion docs/compatibility.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet

## MMDetection3D 0.16.0

### Returned values of `QueryAndGroup` operation

We modified the returned `grouped_xyz` value of operation `QueryAndGroup` to support PAConv segmentor. Originally, the `grouped_xyz` is centered by subtracting the grouping centers, which represents the relative positions of grouped points. Now, we didn't perform such subtraction and the returned `grouped_xyz` stands for the absolute coordinates of these points.

Note that, the other returned variables of `QueryAndGroup` such as `new_features`, `unique_cnt` and `grouped_idx` are not affected.

### NuScenes coco-style data pre-processing

We remove the rotation and dimension hack in the monocular 3D detection on nuScenes. Specifically, we transform the rotation and dimension of boxes defined by nuScenes devkit to the coordinate system of our `CameraInstance3DBoxes` in the pre-processing and transform them back in the post-processing. In this way, we can remove the corresponding [hack](https://github.com/open-mmlab/mmdetection3d/pull/744/files#diff-5bee5062bd84e6fa25a2fdd71353f6f283dfdc4a66a0316c3b1ca26078c978b6L165) used in the visualization tools. The modification also guarantees the correctness of all the operations based on our `CameraInstance3DBoxes` (such as NMS and flip augmentation) when training monocular 3D detectors.
Expand All @@ -15,7 +21,7 @@ The modification only influences nuScenes coco-style json files. Please re-run t
We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection and segmentation methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes:

- Add [script](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/extract_posed_images.py) for extracting RGB data.
- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating.
- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating.
- Add instructions in the documents on preparing image data.

Please refer to the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) for more details.
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/apis/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def single_gpu_test(model,
models_3d = (Base3DDetector, Base3DSegmentor,
SingleStageMono3DDetector)
if isinstance(model.module, models_3d):
model.module.show_results(data, result, out_dir)
model.module.show_results(data, result, out_dir=out_dir)
# Visualize the results of MMDetection model
# 'show_result' is MMdetection visualization API
else:
Expand Down
49 changes: 31 additions & 18 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,17 +928,20 @@ class IndoorPatchPointSample(object):
Defaults to None.
ignore_index (int, optional): Label index that won't be used for the
segmentation task. This is set in PointSegClassMapping as neg_cls.
If not None, will be used as a patch selection criterion.
Defaults to None.
use_normalized_coord (bool, optional): Whether to use normalized xyz as
additional features. Defaults to False.
num_try (int, optional): Number of times to try if the patch selected
is invalid. Defaults to 10.
enlarge_size (float | None, optional): Enlarge the sampled patch to
[-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as
an augmentation. If None, set it as 0.01. Defaults to 0.2.
an augmentation. If None, set it as 0. Defaults to 0.2.
min_unique_num (int | None, optional): Minimum number of unique points
the sampled patch should contain. If None, use PointNet++'s method
to judge uniqueness. Defaults to None.
eps (float, optional): A value added to patch boundary to guarantee
points coverage. Defaults to 1e-2.
Note:
This transform should only be used in the training process of point
Expand All @@ -955,14 +958,16 @@ def __init__(self,
use_normalized_coord=False,
num_try=10,
enlarge_size=0.2,
min_unique_num=None):
min_unique_num=None,
eps=1e-2):
self.num_points = num_points
self.block_size = block_size
self.ignore_index = ignore_index
self.use_normalized_coord = use_normalized_coord
self.num_try = num_try
self.enlarge_size = enlarge_size if enlarge_size is not None else 0.01
self.enlarge_size = enlarge_size if enlarge_size is not None else 0.0
self.min_unique_num = min_unique_num
self.eps = eps

if sample_rate is not None:
warnings.warn(
Expand Down Expand Up @@ -1010,7 +1015,7 @@ def _input_generation(self, coords, patch_center, coord_max, attributes,

return points

def _patch_points_sampling(self, points, sem_mask, replace=None):
def _patch_points_sampling(self, points, sem_mask):
"""Patch points sampling.
First sample a valid patch.
Expand All @@ -1019,8 +1024,6 @@ def _patch_points_sampling(self, points, sem_mask, replace=None):
Args:
points (:obj:`BasePoints`): 3D Points.
sem_mask (np.ndarray): semantic segmentation mask for input points.
replace (bool): Whether the sample is with or without replacement.
Defaults to None.
Returns:
tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`:
Expand All @@ -1040,7 +1043,8 @@ def _patch_points_sampling(self, points, sem_mask, replace=None):
# random sample a point as patch center
cur_center = coords[np.random.choice(coords.shape[0])]

# boundary of a patch
# boundary of a patch, which would be enlarged by
# `self.enlarge_size` as an augmentation
cur_max = cur_center + np.array(
[self.block_size / 2.0, self.block_size / 2.0, 0.0])
cur_min = cur_center - np.array(
Expand All @@ -1057,14 +1061,14 @@ def _patch_points_sampling(self, points, sem_mask, replace=None):

cur_coords = coords[cur_choice, :]
cur_sem_mask = sem_mask[cur_choice]

# two criterion for patch sampling, adopted from PointNet++
# points within selected patch shoule be scattered separately
point_idxs = np.where(cur_choice)[0]
mask = np.sum(
(cur_coords >= (cur_min - 0.01)) * (cur_coords <=
(cur_max + 0.01)),
(cur_coords >= (cur_min - self.eps)) * (cur_coords <=
(cur_max + self.eps)),
axis=1) == 3

# two criteria for patch sampling, adopted from PointNet++
# 1. selected patch should contain enough unique points
if self.min_unique_num is None:
# use PointNet++'s method as default
# [31, 31, 62] are just some big values used to transform
Expand All @@ -1077,9 +1081,10 @@ def _patch_points_sampling(self, points, sem_mask, replace=None):
vidx[:, 2])
flag1 = len(vidx) / 31.0 / 31.0 / 62.0 >= 0.02
else:
# if `min_unique_num` is provided, directly compare with it
flag1 = mask.sum() >= self.min_unique_num

# selected patch should contain enough annotated points
# 2. selected patch should contain enough annotated points
if self.ignore_index is None:
flag2 = True
else:
Expand All @@ -1089,11 +1094,19 @@ def _patch_points_sampling(self, points, sem_mask, replace=None):
if flag1 and flag2:
break

# random sample idx
if replace is None:
replace = (cur_sem_mask.shape[0] < self.num_points)
choices = np.random.choice(
np.where(cur_choice)[0], self.num_points, replace=replace)
# sample idx to `self.num_points`
if point_idxs.size >= self.num_points:
# no duplicate in sub-sampling
choices = np.random.choice(
point_idxs, self.num_points, replace=False)
else:
# do not use random choice here to avoid some points not counted
dup = np.random.choice(point_idxs.size,
self.num_points - point_idxs.size)
idx_dup = np.concatenate(
[np.arange(point_idxs.size),
np.array(dup)], 0)
choices = point_idxs[idx_dup]

# construct model input
points = self._input_generation(coords[choices], cur_center, coord_max,
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .paconv_head import PAConvHead
from .pointnet2_head import PointNet2Head

__all__ = ['PointNet2Head']
__all__ = ['PointNet2Head', 'PAConvHead']
62 changes: 62 additions & 0 deletions mmdet3d/models/decode_heads/paconv_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from mmcv.cnn.bricks import ConvModule

from mmdet.models import HEADS
from .pointnet2_head import PointNet2Head


@HEADS.register_module()
class PAConvHead(PointNet2Head):
r"""PAConv decoder head.
Decoder head used in `PAConv <https://arxiv.org/abs/2103.14635>`_.
Refer to the `official code <https://github.com/CVMI-Lab/PAConv>`_.
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
fp_norm_cfg (dict|None): Config of norm layers used in FP modules.
Default: dict(type='BN2d').
"""

def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256),
(320, 256, 128), (128 + 6, 128, 128, 128)),
fp_norm_cfg=dict(type='BN2d'),
**kwargs):
super(PAConvHead, self).__init__(fp_channels, fp_norm_cfg, **kwargs)

# https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53
# PointNet++'s decoder conv has bias while PAConv's doesn't have
# so we need to rebuild it here
self.pre_seg_conv = ConvModule(
fp_channels[-1][-1],
self.channels,
kernel_size=1,
bias=False,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)

def forward(self, feat_dict):
"""Forward pass.
Args:
feat_dict (dict): Feature dict from backbone.
Returns:
torch.Tensor: Segmentation map of shape [B, num_classes, N].
"""
sa_xyz, sa_features = self._extract_input(feat_dict)

# PointNet++ doesn't use the first level of `sa_features` as input
# while PAConv inputs it through skip-connection
fp_feature = sa_features[-1]

for i in range(self.num_fp):
# consume the points in a bottom-up manner
fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)],
sa_features[-(i + 2)], fp_feature)

output = self.pre_seg_conv(fp_feature)
output = self.cls_seg(output)

return output
6 changes: 5 additions & 1 deletion mmdet3d/models/decode_heads/pointnet2_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@ class PointNet2Head(Base3DDecodeHead):
Args:
fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules.
fp_norm_cfg (dict|None): Config of norm layers used in FP modules.
Default: dict(type='BN2d').
"""

def __init__(self,
fp_channels=((768, 256, 256), (384, 256, 256),
(320, 256, 128), (128, 128, 128, 128)),
fp_norm_cfg=dict(type='BN2d'),
**kwargs):
super(PointNet2Head, self).__init__(**kwargs)

self.num_fp = len(fp_channels)
self.FP_modules = nn.ModuleList()
for cur_fp_mlps in fp_channels:
self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps))
self.FP_modules.append(
PointFPModule(mlp_channels=cur_fp_mlps, norm_cfg=fp_norm_cfg))

# https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40
self.pre_seg_conv = ConvModule(
Expand Down
4 changes: 3 additions & 1 deletion mmdet3d/models/losses/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy
from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss
from .chamfer_distance import ChamferDistance, chamfer_distance
from .paconv_regularization_loss import PAConvRegularizationLoss

__all__ = [
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance',
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss'
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss',
'PAConvRegularizationLoss'
]
Loading

0 comments on commit 3870001

Please sign in to comment.