From 3870001ad33d09a3a691375c4bd2de7f56b29808 Mon Sep 17 00:00:00 2001 From: Ziyi Wu Date: Wed, 28 Jul 2021 16:36:01 +0800 Subject: [PATCH] [Feature] Support entire PAConv and PAConvCUDA models (#783) * add PAConv decode head * add config files * add paconv's correlation loss * support reg loss in Segmentor class * minor fix * add augmentation to configs * fix ed7 in cfg * fix bug in corr loss * enable syncbn in paconv * rename to loss_regularization * rename loss_reg to loss_regularize * use SyncBN * change weight kernels to kernel weights * rename corr_loss to reg_loss * minor fix * configs fix IndoorPatchPointSample * fix grouped points minus center error * update transform_3d & add configs * merge master * fix enlarge_size bug * refine config * remove cfg files * minor fix * add comments on PAConv's ScoreNet * refine comments * update compatibility doc * remove useless lines in transforms_3d * rename with_loss_regularization to with_regularization_loss * revert palette change * remove xavier init from PAConv's ScoreNet --- configs/_base_/models/paconv_cuda_ssg.py | 7 + configs/_base_/models/paconv_ssg.py | 49 ++++++ docs/compatibility.md | 8 +- mmdet3d/apis/test.py | 2 +- mmdet3d/datasets/pipelines/transforms_3d.py | 49 +++--- mmdet3d/models/decode_heads/__init__.py | 3 +- mmdet3d/models/decode_heads/paconv_head.py | 62 ++++++++ mmdet3d/models/decode_heads/pointnet2_head.py | 6 +- mmdet3d/models/losses/__init__.py | 4 +- .../losses/paconv_regularization_loss.py | 107 +++++++++++++ mmdet3d/models/segmentors/base.py | 17 ++- mmdet3d/models/segmentors/encoder_decoder.py | 33 +++- mmdet3d/ops/group_points/group_points.py | 9 +- mmdet3d/ops/paconv/paconv.py | 61 ++++---- .../ops/pointnet_modules/paconv_sa_module.py | 8 +- tests/test_metrics/test_losses.py | 39 +++++ .../test_paconv_modules.py | 40 ++--- .../test_common_modules/test_paconv_ops.py | 10 +- .../test_heads/test_paconv_decode_head.py | 82 ++++++++++ tests/test_models/test_segmentors.py | 144 ++++++++++++++++++ tools/test.py | 5 + 21 files changed, 664 insertions(+), 81 deletions(-) create mode 100644 configs/_base_/models/paconv_cuda_ssg.py create mode 100644 configs/_base_/models/paconv_ssg.py create mode 100644 mmdet3d/models/decode_heads/paconv_head.py create mode 100644 mmdet3d/models/losses/paconv_regularization_loss.py create mode 100644 tests/test_models/test_heads/test_paconv_decode_head.py diff --git a/configs/_base_/models/paconv_cuda_ssg.py b/configs/_base_/models/paconv_cuda_ssg.py new file mode 100644 index 0000000000..f513bd4a2f --- /dev/null +++ b/configs/_base_/models/paconv_cuda_ssg.py @@ -0,0 +1,7 @@ +_base_ = './paconv_ssg.py' + +model = dict( + backbone=dict( + sa_cfg=dict( + type='PAConvCUDASAModule', + scorenet_cfg=dict(mlp_channels=[8, 16, 16])))) diff --git a/configs/_base_/models/paconv_ssg.py b/configs/_base_/models/paconv_ssg.py new file mode 100644 index 0000000000..1d4f1ed393 --- /dev/null +++ b/configs/_base_/models/paconv_ssg.py @@ -0,0 +1,49 @@ +# model settings +model = dict( + type='EncoderDecoder3D', + backbone=dict( + type='PointNet2SASSG', + in_channels=9, # [xyz, rgb, normalized_xyz] + num_points=(1024, 256, 64, 16), + radius=(None, None, None, None), # use kNN instead of ball query + num_samples=(32, 32, 32, 32), + sa_channels=((32, 32, 64), (64, 64, 128), (128, 128, 256), (256, 256, + 512)), + fp_channels=(), + norm_cfg=dict(type='BN2d', momentum=0.1), + sa_cfg=dict( + type='PAConvSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=False, + paconv_num_kernels=[16, 16, 16], + paconv_kernel_input='w_neighbor', + scorenet_input='w_neighbor_dist', + scorenet_cfg=dict( + mlp_channels=[16, 16, 16], + score_norm='softmax', + temp_factor=1.0, + last_bn=False))), + decode_head=dict( + type='PAConvHead', + # PAConv model's decoder takes skip connections from beckbone + # different from PointNet++, it also concats input features in the last + # level of decoder, leading to `128 + 6` as the channel number + fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128), + (128 + 6, 128, 128, 128)), + channels=128, + dropout_ratio=0.5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU'), + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, # should be modified with dataset + loss_weight=1.0)), + # correlation loss to regularize PAConv's kernel weights + loss_regularization=dict( + type='PAConvRegularizationLoss', reduction='sum', loss_weight=10.0), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide')) diff --git a/docs/compatibility.md b/docs/compatibility.md index 904cebbcf5..b091fbef58 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -4,6 +4,12 @@ This document provides detailed descriptions of the BC-breaking changes in MMDet ## MMDetection3D 0.16.0 +### Returned values of `QueryAndGroup` operation + +We modified the returned `grouped_xyz` value of operation `QueryAndGroup` to support PAConv segmentor. Originally, the `grouped_xyz` is centered by subtracting the grouping centers, which represents the relative positions of grouped points. Now, we didn't perform such subtraction and the returned `grouped_xyz` stands for the absolute coordinates of these points. + +Note that, the other returned variables of `QueryAndGroup` such as `new_features`, `unique_cnt` and `grouped_idx` are not affected. + ### NuScenes coco-style data pre-processing We remove the rotation and dimension hack in the monocular 3D detection on nuScenes. Specifically, we transform the rotation and dimension of boxes defined by nuScenes devkit to the coordinate system of our `CameraInstance3DBoxes` in the pre-processing and transform them back in the post-processing. In this way, we can remove the corresponding [hack](https://github.com/open-mmlab/mmdetection3d/pull/744/files#diff-5bee5062bd84e6fa25a2fdd71353f6f283dfdc4a66a0316c3b1ca26078c978b6L165) used in the visualization tools. The modification also guarantees the correctness of all the operations based on our `CameraInstance3DBoxes` (such as NMS and flip augmentation) when training monocular 3D detectors. @@ -15,7 +21,7 @@ The modification only influences nuScenes coco-style json files. Please re-run t We adopt a new pre-processing procedure for the ScanNet dataset in order to support ImVoxelNet, which is a multi-view method requiring image data. In previous versions of MMDetection3D, ScanNet dataset was only used for point cloud based 3D detection and segmentation methods. We plan adding ImVoxelNet to our model zoo, thus updating ScanNet correspondingly by adding image-related pre-processing steps. Specifically, we made these changes: - Add [script](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/extract_posed_images.py) for extracting RGB data. -- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating. +- Update [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/data_converter/scannet_data_utils.py) for annotation creating. - Add instructions in the documents on preparing image data. Please refer to the ScanNet [README.md](https://github.com/open-mmlab/mmdetection3d/blob/master/data/scannet/README.md/) for more details. diff --git a/mmdet3d/apis/test.py b/mmdet3d/apis/test.py index e7f32c6a91..f60c56a119 100644 --- a/mmdet3d/apis/test.py +++ b/mmdet3d/apis/test.py @@ -43,7 +43,7 @@ def single_gpu_test(model, models_3d = (Base3DDetector, Base3DSegmentor, SingleStageMono3DDetector) if isinstance(model.module, models_3d): - model.module.show_results(data, result, out_dir) + model.module.show_results(data, result, out_dir=out_dir) # Visualize the results of MMDetection model # 'show_result' is MMdetection visualization API else: diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index 6788df68c4..b3670c364f 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -928,6 +928,7 @@ class IndoorPatchPointSample(object): Defaults to None. ignore_index (int, optional): Label index that won't be used for the segmentation task. This is set in PointSegClassMapping as neg_cls. + If not None, will be used as a patch selection criterion. Defaults to None. use_normalized_coord (bool, optional): Whether to use normalized xyz as additional features. Defaults to False. @@ -935,10 +936,12 @@ class IndoorPatchPointSample(object): is invalid. Defaults to 10. enlarge_size (float | None, optional): Enlarge the sampled patch to [-block_size / 2 - enlarge_size, block_size / 2 + enlarge_size] as - an augmentation. If None, set it as 0.01. Defaults to 0.2. + an augmentation. If None, set it as 0. Defaults to 0.2. min_unique_num (int | None, optional): Minimum number of unique points the sampled patch should contain. If None, use PointNet++'s method to judge uniqueness. Defaults to None. + eps (float, optional): A value added to patch boundary to guarantee + points coverage. Defaults to 1e-2. Note: This transform should only be used in the training process of point @@ -955,14 +958,16 @@ def __init__(self, use_normalized_coord=False, num_try=10, enlarge_size=0.2, - min_unique_num=None): + min_unique_num=None, + eps=1e-2): self.num_points = num_points self.block_size = block_size self.ignore_index = ignore_index self.use_normalized_coord = use_normalized_coord self.num_try = num_try - self.enlarge_size = enlarge_size if enlarge_size is not None else 0.01 + self.enlarge_size = enlarge_size if enlarge_size is not None else 0.0 self.min_unique_num = min_unique_num + self.eps = eps if sample_rate is not None: warnings.warn( @@ -1010,7 +1015,7 @@ def _input_generation(self, coords, patch_center, coord_max, attributes, return points - def _patch_points_sampling(self, points, sem_mask, replace=None): + def _patch_points_sampling(self, points, sem_mask): """Patch points sampling. First sample a valid patch. @@ -1019,8 +1024,6 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): Args: points (:obj:`BasePoints`): 3D Points. sem_mask (np.ndarray): semantic segmentation mask for input points. - replace (bool): Whether the sample is with or without replacement. - Defaults to None. Returns: tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: @@ -1040,7 +1043,8 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): # random sample a point as patch center cur_center = coords[np.random.choice(coords.shape[0])] - # boundary of a patch + # boundary of a patch, which would be enlarged by + # `self.enlarge_size` as an augmentation cur_max = cur_center + np.array( [self.block_size / 2.0, self.block_size / 2.0, 0.0]) cur_min = cur_center - np.array( @@ -1057,14 +1061,14 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): cur_coords = coords[cur_choice, :] cur_sem_mask = sem_mask[cur_choice] - - # two criterion for patch sampling, adopted from PointNet++ - # points within selected patch shoule be scattered separately + point_idxs = np.where(cur_choice)[0] mask = np.sum( - (cur_coords >= (cur_min - 0.01)) * (cur_coords <= - (cur_max + 0.01)), + (cur_coords >= (cur_min - self.eps)) * (cur_coords <= + (cur_max + self.eps)), axis=1) == 3 + # two criteria for patch sampling, adopted from PointNet++ + # 1. selected patch should contain enough unique points if self.min_unique_num is None: # use PointNet++'s method as default # [31, 31, 62] are just some big values used to transform @@ -1077,9 +1081,10 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): vidx[:, 2]) flag1 = len(vidx) / 31.0 / 31.0 / 62.0 >= 0.02 else: + # if `min_unique_num` is provided, directly compare with it flag1 = mask.sum() >= self.min_unique_num - # selected patch should contain enough annotated points + # 2. selected patch should contain enough annotated points if self.ignore_index is None: flag2 = True else: @@ -1089,11 +1094,19 @@ def _patch_points_sampling(self, points, sem_mask, replace=None): if flag1 and flag2: break - # random sample idx - if replace is None: - replace = (cur_sem_mask.shape[0] < self.num_points) - choices = np.random.choice( - np.where(cur_choice)[0], self.num_points, replace=replace) + # sample idx to `self.num_points` + if point_idxs.size >= self.num_points: + # no duplicate in sub-sampling + choices = np.random.choice( + point_idxs, self.num_points, replace=False) + else: + # do not use random choice here to avoid some points not counted + dup = np.random.choice(point_idxs.size, + self.num_points - point_idxs.size) + idx_dup = np.concatenate( + [np.arange(point_idxs.size), + np.array(dup)], 0) + choices = point_idxs[idx_dup] # construct model input points = self._input_generation(coords[choices], cur_center, coord_max, diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index 38ecaf094b..0ecf553392 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -1,3 +1,4 @@ +from .paconv_head import PAConvHead from .pointnet2_head import PointNet2Head -__all__ = ['PointNet2Head'] +__all__ = ['PointNet2Head', 'PAConvHead'] diff --git a/mmdet3d/models/decode_heads/paconv_head.py b/mmdet3d/models/decode_heads/paconv_head.py new file mode 100644 index 0000000000..cf24ba510a --- /dev/null +++ b/mmdet3d/models/decode_heads/paconv_head.py @@ -0,0 +1,62 @@ +from mmcv.cnn.bricks import ConvModule + +from mmdet.models import HEADS +from .pointnet2_head import PointNet2Head + + +@HEADS.register_module() +class PAConvHead(PointNet2Head): + r"""PAConv decoder head. + + Decoder head used in `PAConv `_. + Refer to the `official code `_. + + Args: + fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules. + fp_norm_cfg (dict|None): Config of norm layers used in FP modules. + Default: dict(type='BN2d'). + """ + + def __init__(self, + fp_channels=((768, 256, 256), (384, 256, 256), + (320, 256, 128), (128 + 6, 128, 128, 128)), + fp_norm_cfg=dict(type='BN2d'), + **kwargs): + super(PAConvHead, self).__init__(fp_channels, fp_norm_cfg, **kwargs) + + # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/pointnet2_paconv_seg.py#L53 + # PointNet++'s decoder conv has bias while PAConv's doesn't have + # so we need to rebuild it here + self.pre_seg_conv = ConvModule( + fp_channels[-1][-1], + self.channels, + kernel_size=1, + bias=False, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def forward(self, feat_dict): + """Forward pass. + + Args: + feat_dict (dict): Feature dict from backbone. + + Returns: + torch.Tensor: Segmentation map of shape [B, num_classes, N]. + """ + sa_xyz, sa_features = self._extract_input(feat_dict) + + # PointNet++ doesn't use the first level of `sa_features` as input + # while PAConv inputs it through skip-connection + fp_feature = sa_features[-1] + + for i in range(self.num_fp): + # consume the points in a bottom-up manner + fp_feature = self.FP_modules[i](sa_xyz[-(i + 2)], sa_xyz[-(i + 1)], + sa_features[-(i + 2)], fp_feature) + + output = self.pre_seg_conv(fp_feature) + output = self.cls_seg(output) + + return output diff --git a/mmdet3d/models/decode_heads/pointnet2_head.py b/mmdet3d/models/decode_heads/pointnet2_head.py index 8624cf6712..271b3d2c83 100644 --- a/mmdet3d/models/decode_heads/pointnet2_head.py +++ b/mmdet3d/models/decode_heads/pointnet2_head.py @@ -15,18 +15,22 @@ class PointNet2Head(Base3DDecodeHead): Args: fp_channels (tuple[tuple[int]]): Tuple of mlp channels in FP modules. + fp_norm_cfg (dict|None): Config of norm layers used in FP modules. + Default: dict(type='BN2d'). """ def __init__(self, fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128), (128, 128, 128, 128)), + fp_norm_cfg=dict(type='BN2d'), **kwargs): super(PointNet2Head, self).__init__(**kwargs) self.num_fp = len(fp_channels) self.FP_modules = nn.ModuleList() for cur_fp_mlps in fp_channels: - self.FP_modules.append(PointFPModule(mlp_channels=cur_fp_mlps)) + self.FP_modules.append( + PointFPModule(mlp_channels=cur_fp_mlps, norm_cfg=fp_norm_cfg)) # https://github.com/charlesq34/pointnet2/blob/master/models/pointnet2_sem_seg.py#L40 self.pre_seg_conv = ConvModule( diff --git a/mmdet3d/models/losses/__init__.py b/mmdet3d/models/losses/__init__.py index eda0a0e34b..d14a0ca7df 100644 --- a/mmdet3d/models/losses/__init__.py +++ b/mmdet3d/models/losses/__init__.py @@ -1,8 +1,10 @@ from mmdet.models.losses import FocalLoss, SmoothL1Loss, binary_cross_entropy from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss from .chamfer_distance import ChamferDistance, chamfer_distance +from .paconv_regularization_loss import PAConvRegularizationLoss __all__ = [ 'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance', - 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss' + 'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss', + 'PAConvRegularizationLoss' ] diff --git a/mmdet3d/models/losses/paconv_regularization_loss.py b/mmdet3d/models/losses/paconv_regularization_loss.py new file mode 100644 index 0000000000..f2e3f2650f --- /dev/null +++ b/mmdet3d/models/losses/paconv_regularization_loss.py @@ -0,0 +1,107 @@ +import torch +from torch import nn as nn + +from mmdet3d.ops import PAConv, PAConvCUDA +from mmdet.models.builder import LOSSES +from mmdet.models.losses.utils import weight_reduce_loss + + +def weight_correlation(conv): + """Calculate correlations between kernel weights in Conv's weight bank as + regularization loss. The cosine similarity is used as metrics. + + Args: + conv (nn.Module): A Conv modules to be regularized. + Currently we only support `PAConv` and `PAConvCUDA`. + + Returns: + torch.Tensor: Correlations between each kernel weights in weight bank. + """ + assert isinstance(conv, (PAConv, PAConvCUDA)), \ + f'unsupported module type {type(conv)}' + kernels = conv.weight_bank # [C_in, num_kernels * C_out] + in_channels = conv.in_channels + out_channels = conv.out_channels + num_kernels = conv.num_kernels + + # [num_kernels, Cin * Cout] + flatten_kernels = kernels.view(in_channels, num_kernels, out_channels).\ + permute(1, 0, 2).reshape(num_kernels, -1) + # [num_kernels, num_kernels] + inner_product = torch.matmul(flatten_kernels, flatten_kernels.T) + # [num_kernels, 1] + kernel_norms = torch.sum(flatten_kernels**2, dim=-1, keepdim=True)**0.5 + # [num_kernels, num_kernels] + kernel_norms = torch.matmul(kernel_norms, kernel_norms.T) + cosine_sims = inner_product / kernel_norms + # take upper triangular part excluding diagonal since we only compute + # correlation between different kernels once + # the square is to ensure positive loss, refer to: + # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/tool/train.py#L208 + corr = torch.sum(torch.triu(cosine_sims, diagonal=1)**2) + + return corr + + +def paconv_regularization_loss(modules, reduction): + """Computes correlation loss of PAConv weight kernels as regularization. + + Args: + modules (List[nn.Module] | :obj:`generator`): + A list or a python generator of torch.nn.Modules. + reduction (str): Method to reduce losses among PAConv modules. + The valid reduction method are none, sum or mean. + + Returns: + torch.Tensor: Correlation loss of kernel weights. + """ + corr_loss = [] + for module in modules: + if isinstance(module, (PAConv, PAConvCUDA)): + corr_loss.append(weight_correlation(module)) + corr_loss = torch.stack(corr_loss) + + # perform reduction + corr_loss = weight_reduce_loss(corr_loss, reduction=reduction) + + return corr_loss + + +@LOSSES.register_module() +class PAConvRegularizationLoss(nn.Module): + """Calculate correlation loss of kernel weights in PAConv's weight bank. + + This is used as a regularization term in PAConv model training. + + Args: + reduction (str): Method to reduce losses. The reduction is performed + among all PAConv modules instead of prediction tensors. + The valid reduction method are none, sum or mean. + loss_weight (float, optional): Weight of loss. Defaults to 1.0. + """ + + def __init__(self, reduction='mean', loss_weight=1.0): + super(PAConvRegularizationLoss, self).__init__() + assert reduction in ['none', 'sum', 'mean'] + self.reduction = reduction + self.loss_weight = loss_weight + + def forward(self, modules, reduction_override=None, **kwargs): + """Forward function of loss calculation. + + Args: + modules (List[nn.Module] | :obj:`generator`): + A list or a python generator of torch.nn.Modules. + reduction_override (str, optional): Method to reduce losses. + The valid reduction method are 'none', 'sum' or 'mean'. + Defaults to None. + + Returns: + torch.Tensor: Correlation loss of kernel weights. + """ + assert reduction_override in (None, 'none', 'mean', 'sum') + reduction = ( + reduction_override if reduction_override else self.reduction) + + return self.loss_weight * paconv_regularization_loss( + modules, reduction=reduction) diff --git a/mmdet3d/models/segmentors/base.py b/mmdet3d/models/segmentors/base.py index c372790edc..d4fc8917ff 100644 --- a/mmdet3d/models/segmentors/base.py +++ b/mmdet3d/models/segmentors/base.py @@ -16,6 +16,12 @@ class Base3DSegmentor(BaseSegmentor): data_dict and use a 3D seg specific visualization function. """ + @property + def with_regularization_loss(self): + """bool: whether the segmentor has regularization loss for weight""" + return hasattr(self, 'loss_regularization') and \ + self.loss_regularization is not None + def forward_test(self, points, img_metas, **kwargs): """Calls either simple_test or aug_test depending on the length of outer list of points. If len(points) == 1, call simple_test. Otherwise @@ -108,5 +114,12 @@ def show_results(self, pred_sem_mask = result[batch_id]['semantic_mask'].cpu().numpy() - show_seg_result(points, None, pred_sem_mask, out_dir, file_name, - palette, ignore_index) + show_seg_result( + points, + None, + pred_sem_mask, + out_dir, + file_name, + palette, + ignore_index, + show=True) diff --git a/mmdet3d/models/segmentors/encoder_decoder.py b/mmdet3d/models/segmentors/encoder_decoder.py index 4841b4ba17..353c461861 100644 --- a/mmdet3d/models/segmentors/encoder_decoder.py +++ b/mmdet3d/models/segmentors/encoder_decoder.py @@ -5,7 +5,7 @@ from mmseg.core import add_prefix from mmseg.models import SEGMENTORS -from ..builder import build_backbone, build_head, build_neck +from ..builder import build_backbone, build_head, build_loss, build_neck from .base import Base3DSegmentor @@ -23,6 +23,7 @@ def __init__(self, decode_head, neck=None, auxiliary_head=None, + loss_regularization=None, train_cfg=None, test_cfg=None, pretrained=None, @@ -33,6 +34,7 @@ def __init__(self, self.neck = build_neck(neck) self._init_decode_head(decode_head) self._init_auxiliary_head(auxiliary_head) + self._init_loss_regularization(loss_regularization) self.train_cfg = train_cfg self.test_cfg = test_cfg @@ -54,6 +56,16 @@ def _init_auxiliary_head(self, auxiliary_head): else: self.auxiliary_head = build_head(auxiliary_head) + def _init_loss_regularization(self, loss_regularization): + """Initialize ``loss_regularization``""" + if loss_regularization is not None: + if isinstance(loss_regularization, list): + self.loss_regularization = nn.ModuleList() + for loss_cfg in loss_regularization: + self.loss_regularization.append(build_loss(loss_cfg)) + else: + self.loss_regularization = build_loss(loss_regularization) + def extract_feat(self, points): """Extract features from points.""" x = self.backbone(points) @@ -110,6 +122,21 @@ def _auxiliary_head_forward_train(self, x, img_metas, pts_semantic_mask): return losses + def _loss_regularization_forward_train(self): + """Calculate regularization loss for model weight in training.""" + losses = dict() + if isinstance(self.loss_regularization, nn.ModuleList): + for idx, regularize_loss in enumerate(self.loss_regularization): + loss_regularize = dict( + loss_regularize=regularize_loss(self.modules())) + losses.update(add_prefix(loss_regularize, f'regularize_{idx}')) + else: + loss_regularize = dict( + loss_regularize=self.loss_regularization(self.modules())) + losses.update(add_prefix(loss_regularize, 'regularize')) + + return losses + def forward_dummy(self, points): """Dummy forward function.""" seg_logit = self.encode_decode(points, None) @@ -145,6 +172,10 @@ def forward_train(self, points, img_metas, pts_semantic_mask): x, img_metas, pts_semantic_mask_cat) losses.update(loss_aux) + if self.with_regularization_loss: + loss_regularize = self._loss_regularization_forward_train() + losses.update(loss_regularize) + return losses @staticmethod diff --git a/mmdet3d/ops/group_points/group_points.py b/mmdet3d/ops/group_points/group_points.py index 52787c8276..88122a88d5 100644 --- a/mmdet3d/ops/group_points/group_points.py +++ b/mmdet3d/ops/group_points/group_points.py @@ -98,22 +98,23 @@ def forward(self, points_xyz, center_xyz, features=None): xyz_trans = points_xyz.transpose(1, 2).contiguous() # (B, 3, npoint, sample_num) grouped_xyz = grouping_operation(xyz_trans, idx) - grouped_xyz -= center_xyz.transpose(1, 2).unsqueeze(-1) + grouped_xyz_diff = grouped_xyz - \ + center_xyz.transpose(1, 2).unsqueeze(-1) # relative offsets if self.normalize_xyz: - grouped_xyz /= self.max_radius + grouped_xyz_diff /= self.max_radius if features is not None: grouped_features = grouping_operation(features, idx) if self.use_xyz: # (B, C + 3, npoint, sample_num) - new_features = torch.cat([grouped_xyz, grouped_features], + new_features = torch.cat([grouped_xyz_diff, grouped_features], dim=1) else: new_features = grouped_features else: assert (self.use_xyz ), 'Cannot have not features and not use xyz as a feature!' - new_features = grouped_xyz + new_features = grouped_xyz_diff ret = [new_features] if self.return_grouped_xyz: diff --git a/mmdet3d/ops/paconv/paconv.py b/mmdet3d/ops/paconv/paconv.py index e23da71063..3e401a49b8 100644 --- a/mmdet3d/ops/paconv/paconv.py +++ b/mmdet3d/ops/paconv/paconv.py @@ -1,7 +1,7 @@ import copy import torch from mmcv.cnn import (ConvModule, build_activation_layer, build_norm_layer, - constant_init, xavier_init) + constant_init) from torch import nn as nn from torch.nn import functional as F @@ -10,7 +10,7 @@ class ScoreNet(nn.Module): - """ScoreNet that outputs coefficient scores to assemble weight kernels in + r"""ScoreNet that outputs coefficient scores to assemble kernel weights in the weight bank according to the relative position of point pairs. Args: @@ -26,6 +26,13 @@ class ScoreNet(nn.Module): bias (bool | str, optional): If specified as `auto`, it will be decided by the norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise False. Defaults to 'auto'. + + Note: + The official code applies xavier_init to all Conv layers in ScoreNet, + see `PAConv `_. However in our experiments, we + did not find much difference in applying such xavier initialization + or not. So we neglect this initialization in our implementation. """ def __init__(self, @@ -70,13 +77,6 @@ def __init__(self, act_cfg=None, bias=bias)) - def init_weights(self): - """Initialize weights of shared MLP layers.""" - # refer to https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/model/pointnet2/paconv.py#L105 # noqa - for m in self.mlps.modules(): - if isinstance(m, nn.Conv2d): - xavier_init(m) - def forward(self, xyz_features): """Forward. @@ -106,14 +106,14 @@ def forward(self, xyz_features): class PAConv(nn.Module): """Non-CUDA version of PAConv. - PAConv stores a trainable weight bank containing several weight kernels. + PAConv stores a trainable weight bank containing several kernel weights. Given input points and features, it computes coefficient scores to assemble those kernels to form conv kernels, and then runs convolution on the input. Args: in_channels (int): Input channels of point features. out_channels (int): Output channels of point features. - num_kernels (int): Number of weight kernels in the weight bank. + num_kernels (int): Number of kernel weights in the weight bank. norm_cfg (dict, optional): Type of normalization method. Defaults to dict(type='BN2d', momentum=0.1). act_cfg (dict, optional): Type of activation method. @@ -124,7 +124,7 @@ class PAConv(nn.Module): weight_bank_init (str, optional): Init method of weight bank kernels. Can be 'kaiming' or 'xavier'. Defaults to 'kaiming'. kernel_input (str, optional): Input features to be multiplied with - weight kernels. Can be 'identity' or 'w_neighbor'. + kernel weights. Can be 'identity' or 'w_neighbor'. Defaults to 'w_neighbor'. scorenet_cfg (dict, optional): Config of the ScoreNet module, which may contain the following keys and values: @@ -147,7 +147,7 @@ def __init__(self, weight_bank_init='kaiming', kernel_input='w_neighbor', scorenet_cfg=dict( - mlp_channels=[8, 16, 16], + mlp_channels=[16, 16, 16], score_norm='softmax', temp_factor=1.0, last_bn=False)): @@ -156,14 +156,15 @@ def __init__(self, # determine weight kernel size according to used features if kernel_input == 'identity': # only use grouped_features - self.kernel_mul = 1 + kernel_mul = 1 elif kernel_input == 'w_neighbor': # concat of (grouped_features - center_features, grouped_features) - self.kernel_mul = 2 + kernel_mul = 2 else: raise NotImplementedError( f'unsupported kernel_input {kernel_input}') self.kernel_input = kernel_input + in_channels = kernel_mul * in_channels # determine mlp channels in ScoreNet according to used xyz features if scorenet_input == 'identity': @@ -180,7 +181,7 @@ def __init__(self, f'unsupported scorenet_input {scorenet_input}') self.scorenet_input = scorenet_input - # construct weight kernels in weight bank + # construct kernel weights in weight bank # self.weight_bank is of shape [C, num_kernels * out_c] # where C can be in_c or (2 * in_c) if weight_bank_init == 'kaiming': @@ -191,17 +192,17 @@ def __init__(self, raise NotImplementedError( f'unsupported weight bank init method {weight_bank_init}') - self.m = num_kernels + self.num_kernels = num_kernels # the parameter `m` in the paper weight_bank = weight_init( - torch.empty(self.m, in_channels * self.kernel_mul, out_channels)) + torch.empty(self.num_kernels, in_channels, out_channels)) weight_bank = weight_bank.permute(1, 0, 2).reshape( - in_channels * self.kernel_mul, self.m * out_channels).contiguous() + in_channels, self.num_kernels * out_channels).contiguous() self.weight_bank = nn.Parameter(weight_bank, requires_grad=True) # construct ScoreNet scorenet_cfg_ = copy.deepcopy(scorenet_cfg) scorenet_cfg_['mlp_channels'].insert(0, self.scorenet_in_channels) - scorenet_cfg_['mlp_channels'].append(self.m) + scorenet_cfg_['mlp_channels'].append(self.num_kernels) self.scorenet = ScoreNet(**scorenet_cfg_) self.bn = build_norm_layer(norm_cfg, out_channels)[1] if \ @@ -209,13 +210,16 @@ def __init__(self, self.activate = build_activation_layer(act_cfg) if \ act_cfg is not None else None + # set some basic attributes of Conv layers + self.in_channels = in_channels + self.out_channels = out_channels + self.init_weights() def init_weights(self): - """Initialize weights of shared MLP layers.""" - self.scorenet.init_weights() + """Initialize weights of shared MLP layers and BN layers.""" if self.bn is not None: - constant_init(self.bn, val=1) + constant_init(self.bn, val=1, bias=0) def _prepare_scorenet_input(self, points_xyz): """Prepare input point pairs features for self.ScoreNet. @@ -273,14 +277,15 @@ def forward(self, inputs): # prepare features for between each point and its grouping center xyz_features = self._prepare_scorenet_input(points_xyz) - # scores to assemble weight kernels + # scores to assemble kernel weights scores = self.scorenet(xyz_features) # [B, npoint, K, m] # first compute out features over all kernels # features is [B, C, npoint, K], weight_bank is [C, m * out_c] new_features = torch.matmul( - features.permute(0, 2, 3, 1), self.weight_bank).\ - view(B, npoint, K, self.m, -1) # [B, npoint, K, m, out_c] + features.permute(0, 2, 3, 1), + self.weight_bank).view(B, npoint, K, self.num_kernels, + -1) # [B, npoint, K, m, out_c] # then aggregate using scores new_features = assign_score(scores, new_features) @@ -363,13 +368,13 @@ def forward(self, inputs): # prepare features for between each point and its grouping center xyz_features = self._prepare_scorenet_input(points_xyz) - # scores to assemble weight kernels + # scores to assemble kernel weights scores = self.scorenet(xyz_features) # [B, npoint, K, m] # pre-compute features for points and centers separately # features is [B, in_c, N], weight_bank is [C, m * out_dim] point_feat, center_feat = assign_kernel_withoutk( - features, self.weight_bank, self.m) + features, self.weight_bank, self.num_kernels) # aggregate features using custom cuda op new_features = assign_score_cuda( diff --git a/mmdet3d/ops/pointnet_modules/paconv_sa_module.py b/mmdet3d/ops/pointnet_modules/paconv_sa_module.py index 73dedce1e8..4d5ac218f0 100644 --- a/mmdet3d/ops/pointnet_modules/paconv_sa_module.py +++ b/mmdet3d/ops/pointnet_modules/paconv_sa_module.py @@ -15,10 +15,10 @@ class PAConvSAModuleMSG(BasePointSAModule): See the `paper `_ for more details. Args: - paconv_num_kernels (list[list[int]]): Number of weight kernels in the + paconv_num_kernels (list[list[int]]): Number of kernel weights in the weight banks of each layer's PAConv. paconv_kernel_input (str, optional): Input features to be multiplied - with weight kernels. Can be 'identity' or 'w_neighbor'. + with kernel weights. Can be 'identity' or 'w_neighbor'. Defaults to 'w_neighbor'. scorenet_input (str, optional): Type of the input to ScoreNet. Defaults to 'w_neighbor_dist'. Can be the following values: @@ -77,7 +77,7 @@ def __init__(self, assert len(paconv_num_kernels) == len(mlp_channels) for i in range(len(mlp_channels)): assert len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1, \ - 'PAConv number of weight kernels wrong' + 'PAConv number of kernel weights wrong' # in PAConv, bias only exists in ScoreNet scorenet_cfg['bias'] = bias @@ -197,7 +197,7 @@ def __init__(self, assert len(paconv_num_kernels) == len(mlp_channels) for i in range(len(mlp_channels)): assert len(paconv_num_kernels[i]) == len(mlp_channels[i]) - 1, \ - 'PAConv number of weight kernels wrong' + 'PAConv number of kernel weights wrong' # in PAConv, bias only exists in ScoreNet scorenet_cfg['bias'] = bias diff --git a/tests/test_metrics/test_losses.py b/tests/test_metrics/test_losses.py index c1b543dfd4..ef24128c9d 100644 --- a/tests/test_metrics/test_losses.py +++ b/tests/test_metrics/test_losses.py @@ -1,5 +1,6 @@ import pytest import torch +from torch import nn as nn def test_chamfer_disrance(): @@ -69,3 +70,41 @@ def test_chamfer_disrance(): or torch.equal(indices1, indices1.new_tensor(expected_inds2))) assert (indices2 == indices2.new_tensor([[0, 0, 0, 0, 0], [0, 3, 6, 0, 0]])).all() + + +def test_paconv_regularization_loss(): + from mmdet3d.models.losses import PAConvRegularizationLoss + from mmdet3d.ops import PAConv, PAConvCUDA + from mmdet.apis import set_random_seed + + class ToyModel(nn.Module): + + def __init__(self): + super(ToyModel, self).__init__() + + self.paconvs = nn.ModuleList() + self.paconvs.append(PAConv(8, 16, 8)) + self.paconvs.append(PAConv(8, 16, 8, kernel_input='identity')) + self.paconvs.append(PAConvCUDA(8, 16, 8)) + + self.conv1 = nn.Conv1d(3, 8, 1) + + set_random_seed(0, True) + model = ToyModel() + + # reduction shoule be in ['none', 'mean', 'sum'] + with pytest.raises(AssertionError): + paconv_corr_loss = PAConvRegularizationLoss(reduction='l2') + + paconv_corr_loss = PAConvRegularizationLoss(reduction='mean') + mean_corr_loss = paconv_corr_loss(model.modules()) + assert mean_corr_loss >= 0 + assert mean_corr_loss.requires_grad + + sum_corr_loss = paconv_corr_loss(model.modules(), reduction_override='sum') + assert torch.allclose(sum_corr_loss, mean_corr_loss * 3) + + none_corr_loss = paconv_corr_loss( + model.modules(), reduction_override='none') + assert none_corr_loss.shape[0] == 3 + assert torch.allclose(none_corr_loss.mean(), mean_corr_loss) diff --git a/tests/test_models/test_common_modules/test_paconv_modules.py b/tests/test_models/test_common_modules/test_paconv_modules.py index b5730f03c2..cb69bac8b8 100644 --- a/tests/test_models/test_common_modules/test_paconv_modules.py +++ b/tests/test_models/test_common_modules/test_paconv_modules.py @@ -37,10 +37,10 @@ def test_paconv_sa_module_msg(): pool_mod='max', paconv_kernel_input='w_neighbor').cuda() - assert self.mlps[0].layer0.weight_bank.shape[0] == 12 * 2 - assert self.mlps[0].layer0.weight_bank.shape[1] == 16 * 4 - assert self.mlps[1].layer0.weight_bank.shape[0] == 12 * 2 - assert self.mlps[1].layer0.weight_bank.shape[1] == 32 * 8 + assert self.mlps[0].layer0.in_channels == 12 * 2 + assert self.mlps[0].layer0.out_channels == 16 + assert self.mlps[1].layer0.in_channels == 12 * 2 + assert self.mlps[1].layer0.out_channels == 32 assert self.mlps[0].layer0.bn.num_features == 16 assert self.mlps[1].layer0.bn.num_features == 32 @@ -80,10 +80,12 @@ def test_paconv_sa_module_msg(): pool_mod='max', paconv_kernel_input='identity').cuda() - assert self.mlps[0].layer0.weight_bank.shape[0] == 12 * 1 - assert self.mlps[0].layer0.weight_bank.shape[1] == 16 * 4 - assert self.mlps[1].layer0.weight_bank.shape[0] == 12 * 1 - assert self.mlps[1].layer0.weight_bank.shape[1] == 32 * 8 + assert self.mlps[0].layer0.in_channels == 12 * 1 + assert self.mlps[0].layer0.out_channels == 16 + assert self.mlps[0].layer0.num_kernels == 4 + assert self.mlps[1].layer0.in_channels == 12 * 1 + assert self.mlps[1].layer0.out_channels == 32 + assert self.mlps[1].layer0.num_kernels == 8 xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32) @@ -116,8 +118,9 @@ def test_paconv_sa_module(): paconv_kernel_input='w_neighbor') self = build_sa_module(sa_cfg).cuda() - assert self.mlps[0].layer0.weight_bank.shape[0] == 15 * 2 - assert self.mlps[0].layer0.weight_bank.shape[1] == 32 * 8 + assert self.mlps[0].layer0.in_channels == 15 * 2 + assert self.mlps[0].layer0.out_channels == 32 + assert self.mlps[0].layer0.num_kernels == 8 xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32) @@ -145,7 +148,7 @@ def test_paconv_sa_module(): pool_mod='max', paconv_kernel_input='identity') self = build_sa_module(sa_cfg).cuda() - assert self.mlps[0].layer0.weight_bank.shape[0] == 15 * 1 + assert self.mlps[0].layer0.in_channels == 15 * 1 xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32) @@ -191,11 +194,13 @@ def test_paconv_cuda_sa_module_msg(): pool_mod='max', paconv_kernel_input='w_neighbor').cuda() - assert self.mlps[0][0].weight_bank.shape[0] == 12 * 2 - assert self.mlps[0][0].weight_bank.shape[1] == 16 * 4 - assert self.mlps[1][0].weight_bank.shape[0] == 12 * 2 - assert self.mlps[1][0].weight_bank.shape[1] == 32 * 8 + assert self.mlps[0][0].in_channels == 12 * 2 + assert self.mlps[0][0].out_channels == 16 + assert self.mlps[0][0].num_kernels == 4 assert self.mlps[0][0].bn.num_features == 16 + assert self.mlps[1][0].in_channels == 12 * 2 + assert self.mlps[1][0].out_channels == 32 + assert self.mlps[1][0].num_kernels == 8 assert self.mlps[1][0].bn.num_features == 32 assert self.mlps[0][0].scorenet.mlps.layer0.conv.in_channels == 7 @@ -253,8 +258,9 @@ def test_paconv_cuda_sa_module(): paconv_kernel_input='w_neighbor') self = build_sa_module(sa_cfg).cuda() - assert self.mlps[0][0].weight_bank.shape[0] == 15 * 2 - assert self.mlps[0][0].weight_bank.shape[1] == 32 * 8 + assert self.mlps[0][0].in_channels == 15 * 2 + assert self.mlps[0][0].out_channels == 32 + assert self.mlps[0][0].num_kernels == 8 xyz = np.fromfile('tests/data/sunrgbd/points/000001.bin', np.float32) diff --git a/tests/test_models/test_common_modules/test_paconv_ops.py b/tests/test_models/test_common_modules/test_paconv_ops.py index 88346ecb57..6add7359db 100644 --- a/tests/test_models/test_common_modules/test_paconv_ops.py +++ b/tests/test_models/test_common_modules/test_paconv_ops.py @@ -193,10 +193,13 @@ def test_paconv(): out_channels = 12 npoint = 4 K = 3 + num_kernels = 4 points_xyz = torch.randn(B, 3, npoint, K) features = torch.randn(B, in_channels, npoint, K) - paconv = PAConv(in_channels, out_channels, 4) + paconv = PAConv(in_channels, out_channels, num_kernels) + assert paconv.weight_bank.shape == torch.Size( + [in_channels * 2, out_channels * num_kernels]) with torch.no_grad(): new_features, _ = paconv((features, points_xyz)) @@ -213,11 +216,14 @@ def test_paconv_cuda(): N = 32 npoint = 4 K = 3 + num_kernels = 4 points_xyz = torch.randn(B, 3, npoint, K).float().cuda() features = torch.randn(B, in_channels, N).float().cuda() points_idx = torch.randint(0, N, (B, npoint, K)).long().cuda() - paconv = PAConvCUDA(in_channels, out_channels, 4).cuda() + paconv = PAConvCUDA(in_channels, out_channels, num_kernels).cuda() + assert paconv.weight_bank.shape == torch.Size( + [in_channels * 2, out_channels * num_kernels]) with torch.no_grad(): new_features, _, _ = paconv((features, points_xyz, points_idx)) diff --git a/tests/test_models/test_heads/test_paconv_decode_head.py b/tests/test_models/test_heads/test_paconv_decode_head.py new file mode 100644 index 0000000000..4b152100e2 --- /dev/null +++ b/tests/test_models/test_heads/test_paconv_decode_head.py @@ -0,0 +1,82 @@ +import numpy as np +import pytest +import torch +from mmcv.cnn.bricks import ConvModule + +from mmdet3d.models.builder import build_head + + +def test_paconv_decode_head_loss(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + paconv_decode_head_cfg = dict( + type='PAConvHead', + fp_channels=((768, 256, 256), (384, 256, 256), (320, 256, 128), + (128 + 6, 128, 128, 128)), + channels=128, + num_classes=20, + dropout_ratio=0.5, + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU'), + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + loss_weight=1.0), + ignore_index=20) + + self = build_head(paconv_decode_head_cfg) + self.cuda() + assert isinstance(self.conv_seg, torch.nn.Conv1d) + assert self.conv_seg.in_channels == 128 + assert self.conv_seg.out_channels == 20 + assert self.conv_seg.kernel_size == (1, ) + assert isinstance(self.pre_seg_conv, ConvModule) + assert isinstance(self.pre_seg_conv.conv, torch.nn.Conv1d) + assert self.pre_seg_conv.conv.in_channels == 128 + assert self.pre_seg_conv.conv.out_channels == 128 + assert self.pre_seg_conv.conv.kernel_size == (1, ) + assert isinstance(self.pre_seg_conv.bn, torch.nn.BatchNorm1d) + assert self.pre_seg_conv.bn.num_features == 128 + assert isinstance(self.pre_seg_conv.activate, torch.nn.ReLU) + + # test forward + sa_xyz = [ + torch.rand(2, 4096, 3).float().cuda(), + torch.rand(2, 1024, 3).float().cuda(), + torch.rand(2, 256, 3).float().cuda(), + torch.rand(2, 64, 3).float().cuda(), + torch.rand(2, 16, 3).float().cuda(), + ] + sa_features = [ + torch.rand(2, 6, 4096).float().cuda(), + torch.rand(2, 64, 1024).float().cuda(), + torch.rand(2, 128, 256).float().cuda(), + torch.rand(2, 256, 64).float().cuda(), + torch.rand(2, 512, 16).float().cuda(), + ] + input_dict = dict(sa_xyz=sa_xyz, sa_features=sa_features) + seg_logits = self(input_dict) + assert seg_logits.shape == torch.Size([2, 20, 4096]) + + # test loss + pts_semantic_mask = torch.randint(0, 20, (2, 4096)).long().cuda() + losses = self.losses(seg_logits, pts_semantic_mask) + assert losses['loss_sem_seg'].item() > 0 + + # test loss with ignore_index + ignore_index_mask = torch.ones_like(pts_semantic_mask) * 20 + losses = self.losses(seg_logits, ignore_index_mask) + assert losses['loss_sem_seg'].item() == 0 + + # test loss with class_weight + paconv_decode_head_cfg['loss_decode'] = dict( + type='CrossEntropyLoss', + use_sigmoid=False, + class_weight=np.random.rand(20), + loss_weight=1.0) + self = build_head(paconv_decode_head_cfg) + self.cuda() + losses = self.losses(seg_logits, pts_semantic_mask) + assert losses['loss_sem_seg'].item() > 0 diff --git a/tests/test_models/test_segmentors.py b/tests/test_models/test_segmentors.py index 87c5f55434..e10af68382 100644 --- a/tests/test_models/test_segmentors.py +++ b/tests/test_models/test_segmentors.py @@ -159,3 +159,147 @@ def test_pointnet2_msg(): results = self.aug_test(scene_points, img_metas) assert results[0]['semantic_mask'].shape == torch.Size([500]) assert results[1]['semantic_mask'].shape == torch.Size([200]) + + +def test_paconv_ssg(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + set_random_seed(0, True) + paconv_ssg_cfg = _get_segmentor_cfg( + 'paconv/paconv_ssg_8x2_step_100e_s3dis_seg-3d-13class.py') + # for GPU memory consideration + paconv_ssg_cfg.backbone.num_points = (256, 64, 16, 4) + paconv_ssg_cfg.test_cfg.num_points = 32 + self = build_segmentor(paconv_ssg_cfg).cuda() + points = [torch.rand(1024, 9).float().cuda() for _ in range(2)] + img_metas = [dict(), dict()] + gt_masks = [torch.randint(0, 13, (1024, )).long().cuda() for _ in range(2)] + + # test forward_train + losses = self.forward_train(points, img_metas, gt_masks) + assert losses['decode.loss_sem_seg'].item() >= 0 + assert losses['regularize.loss_regularize'].item() >= 0 + + # test forward function + set_random_seed(0, True) + data_dict = dict( + points=points, img_metas=img_metas, pts_semantic_mask=gt_masks) + forward_losses = self.forward(return_loss=True, **data_dict) + assert np.allclose(losses['decode.loss_sem_seg'].item(), + forward_losses['decode.loss_sem_seg'].item()) + assert np.allclose(losses['regularize.loss_regularize'].item(), + forward_losses['regularize.loss_regularize'].item()) + + # test loss with ignore_index + ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)] + losses = self.forward_train(points, img_metas, ignore_masks) + assert losses['decode.loss_sem_seg'].item() == 0 + + # test simple_test + self.eval() + with torch.no_grad(): + scene_points = [ + torch.randn(200, 6).float().cuda() * 3.0, + torch.randn(100, 6).float().cuda() * 2.5 + ] + results = self.simple_test(scene_points, img_metas) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + # test forward function calling simple_test + with torch.no_grad(): + data_dict = dict(points=[scene_points], img_metas=[img_metas]) + results = self.forward(return_loss=False, **data_dict) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + # test aug_test + with torch.no_grad(): + scene_points = [ + torch.randn(2, 200, 6).float().cuda() * 3.0, + torch.randn(2, 100, 6).float().cuda() * 2.5 + ] + img_metas = [[dict(), dict()], [dict(), dict()]] + results = self.aug_test(scene_points, img_metas) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + # test forward function calling aug_test + with torch.no_grad(): + data_dict = dict(points=scene_points, img_metas=img_metas) + results = self.forward(return_loss=False, **data_dict) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + +def test_paconv_cuda_ssg(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + set_random_seed(0, True) + paconv_cuda_ssg_cfg = _get_segmentor_cfg( + 'paconv/paconv_ssg_8x2_step_100e_s3dis_seg-3d-13class.py') + # for GPU memory consideration + paconv_cuda_ssg_cfg.backbone.num_points = (256, 64, 16, 4) + paconv_cuda_ssg_cfg.test_cfg.num_points = 32 + self = build_segmentor(paconv_cuda_ssg_cfg).cuda() + points = [torch.rand(1024, 9).float().cuda() for _ in range(2)] + img_metas = [dict(), dict()] + gt_masks = [torch.randint(0, 13, (1024, )).long().cuda() for _ in range(2)] + + # test forward_train + losses = self.forward_train(points, img_metas, gt_masks) + assert losses['decode.loss_sem_seg'].item() >= 0 + assert losses['regularize.loss_regularize'].item() >= 0 + + # test forward function + set_random_seed(0, True) + data_dict = dict( + points=points, img_metas=img_metas, pts_semantic_mask=gt_masks) + forward_losses = self.forward(return_loss=True, **data_dict) + assert np.allclose(losses['decode.loss_sem_seg'].item(), + forward_losses['decode.loss_sem_seg'].item()) + assert np.allclose(losses['regularize.loss_regularize'].item(), + forward_losses['regularize.loss_regularize'].item()) + + # test loss with ignore_index + ignore_masks = [torch.ones_like(gt_masks[0]) * 13 for _ in range(2)] + losses = self.forward_train(points, img_metas, ignore_masks) + assert losses['decode.loss_sem_seg'].item() == 0 + + # test simple_test + self.eval() + with torch.no_grad(): + scene_points = [ + torch.randn(200, 6).float().cuda() * 3.0, + torch.randn(100, 6).float().cuda() * 2.5 + ] + results = self.simple_test(scene_points, img_metas) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + # test forward function calling simple_test + with torch.no_grad(): + data_dict = dict(points=[scene_points], img_metas=[img_metas]) + results = self.forward(return_loss=False, **data_dict) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + # test aug_test + with torch.no_grad(): + scene_points = [ + torch.randn(2, 200, 6).float().cuda() * 3.0, + torch.randn(2, 100, 6).float().cuda() * 2.5 + ] + img_metas = [[dict(), dict()], [dict(), dict()]] + results = self.aug_test(scene_points, img_metas) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) + + # test forward function calling aug_test + with torch.no_grad(): + data_dict = dict(points=scene_points, img_metas=img_metas) + results = self.forward(return_loss=False, **data_dict) + assert results[0]['semantic_mask'].shape == torch.Size([200]) + assert results[1]['semantic_mask'].shape == torch.Size([100]) diff --git a/tools/test.py b/tools/test.py index 0e6d65638f..bf65c265bf 100644 --- a/tools/test.py +++ b/tools/test.py @@ -178,6 +178,11 @@ def main(): model.CLASSES = checkpoint['meta']['CLASSES'] else: model.CLASSES = dataset.CLASSES + # palette for visualization in segmentation tasks + if 'PALETTE' in checkpoint.get('meta', {}): + model.PALETTE = checkpoint['meta']['PALETTE'] + else: + model.PALETTE = dataset.PALETTE if not distributed: model = MMDataParallel(model, device_ids=[0])