diff --git a/configs/_base_/models/minkunet.py b/configs/_base_/models/minkunet.py new file mode 100644 index 0000000000..0a691d876d --- /dev/null +++ b/configs/_base_/models/minkunet.py @@ -0,0 +1,29 @@ +model = dict( + type='MinkUNet', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + voxel=True, + voxel_type='minkunet', + voxel_layer=dict( + max_num_points=-1, + point_cloud_range=[-100, -100, -20, 100, 100, 20], + voxel_size=[0.05, 0.05, 0.05], + max_voxels=(-1, -1)), + ), + backbone=dict( + type='MinkUNetBackbone', + in_channels=4, + base_channels=32, + encoder_channels=[32, 64, 128, 256], + decoder_channels=[256, 128, 96, 96], + num_stages=4, + init_cfg=None), + decode_head=dict( + type='MinkUNetHead', + channels=96, + num_classes=19, + dropout_ratio=0, + loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True), + ignore_index=19), + train_cfg=dict(), + test_cfg=dict()) diff --git a/configs/minkunet/minkunet_w16_8xb2-15e_semantickitti.py b/configs/minkunet/minkunet_w16_8xb2-15e_semantickitti.py new file mode 100644 index 0000000000..ac450bf03a --- /dev/null +++ b/configs/minkunet/minkunet_w16_8xb2-15e_semantickitti.py @@ -0,0 +1,13 @@ +_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py'] + +model = dict( + backbone=dict( + base_channels=16, + encoder_channels=[16, 32, 64, 128], + decoder_channels=[128, 64, 48, 48]), + decode_head=dict(channels=48)) + +# NOTE: Due to TorchSparse backend, the model performance is relatively +# dependent on random seeds, and if random seeds are not specified the +# model performance will be different (± 1.5 mIoU). +randomness = dict(seed=1588147245) diff --git a/configs/minkunet/minkunet_w20_8xb2-15e_semantickitti.py b/configs/minkunet/minkunet_w20_8xb2-15e_semantickitti.py new file mode 100644 index 0000000000..34c501f52a --- /dev/null +++ b/configs/minkunet/minkunet_w20_8xb2-15e_semantickitti.py @@ -0,0 +1,8 @@ +_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py'] + +model = dict( + backbone=dict( + base_channels=20, + encoder_channels=[20, 40, 81, 163], + decoder_channels=[163, 81, 61, 61]), + decode_head=dict(channels=61)) diff --git a/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py b/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py new file mode 100644 index 0000000000..80f5283ce8 --- /dev/null +++ b/configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py @@ -0,0 +1,54 @@ +_base_ = [ + '../_base_/datasets/semantickitti.py', '../_base_/models/minkunet.py', + '../_base_/default_runtime.py' +] + +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_seg_3d=True, + seg_3d_dtype='np.int32', + seg_offset=2**16, + dataset_type='semantickitti'), + dict(type='PointSegClassMapping'), + dict( + type='GlobalRotScaleTrans', + rot_range=[0., 6.28318531], + scale_ratio_range=[0.95, 1.05], + translation_std=[0, 0, 0], + ), + dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask']) +] + +train_dataloader = dict( + sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline))) + +lr = 0.24 +optim_wrapper = dict( + type='AmpOptimWrapper', + loss_scale='dynamic', + optimizer=dict( + type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True)) + +param_scheduler = [ + dict( + type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125), + dict( + type='CosineAnnealingLR', + begin=0, + T_max=15, + by_epoch=True, + eta_min=1e-5, + convert_to_iter_based=True) +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1)) +randomness = dict(seed=0, deterministic=False, diff_rank_seed=True) +env_cfg = dict(cudnn_benchmark=True) diff --git a/mmdet3d/models/backbones/__init__.py b/mmdet3d/models/backbones/__init__.py index bd7dc04ad4..02bd6787a4 100644 --- a/mmdet3d/models/backbones/__init__.py +++ b/mmdet3d/models/backbones/__init__.py @@ -5,6 +5,7 @@ from .dgcnn import DGCNNBackbone from .dla import DLANet from .mink_resnet import MinkResNet +from .minkunet_backbone import MinkUNetBackbone from .multi_backbone import MultiBackbone from .nostem_regnet import NoStemRegNet from .pointnet2_sa_msg import PointNet2SAMSG @@ -14,5 +15,6 @@ __all__ = [ 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', - 'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv' + 'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv', + 'MinkUNetBackbone' ] diff --git a/mmdet3d/models/backbones/minkunet_backbone.py b/mmdet3d/models/backbones/minkunet_backbone.py new file mode 100644 index 0000000000..22a725c50f --- /dev/null +++ b/mmdet3d/models/backbones/minkunet_backbone.py @@ -0,0 +1,121 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +from mmengine.model import BaseModule +from mmengine.registry import MODELS +from torch import Tensor, nn + +from mmdet3d.models.layers import (TorchSparseConvModule, + TorchSparseResidualBlock) +from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE +from mmdet3d.utils import OptMultiConfig + +if IS_TORCHSPARSE_AVAILABLE: + import torchsparse + from torchsparse.tensor import SparseTensor +else: + SparseTensor = None + + +@MODELS.register_module() +class MinkUNetBackbone(BaseModule): + r"""MinkUNet backbone with TorchSparse backend. + + Refer to `implementation code `_. + + Args: + in_channels (int): Number of input voxel feature channels. + Defaults to 4. + base_channels (int): The input channels for first encoder layer. + Defaults to 32. + encoder_channels (List[int]): Convolutional channels of each encode + layer. Defaults to [32, 64, 128, 256]. + decoder_channels (List[int]): Convolutional channels of each decode + layer. Defaults to [256, 128, 96, 96]. + num_stages (int): Number of stages in encoder and decoder. + Defaults to 4. + init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`] + , optional): Initialization config dict. + """ + + def __init__(self, + in_channels: int = 4, + base_channels: int = 32, + encoder_channels: List[int] = [32, 64, 128, 256], + decoder_channels: List[int] = [256, 128, 96, 96], + num_stages: int = 4, + init_cfg: OptMultiConfig = None) -> None: + super().__init__(init_cfg) + assert num_stages == len(encoder_channels) == len(decoder_channels) + self.num_stages = num_stages + self.conv_input = nn.Sequential( + TorchSparseConvModule(in_channels, base_channels, kernel_size=3), + TorchSparseConvModule(base_channels, base_channels, kernel_size=3)) + self.encoder = nn.ModuleList() + self.decoder = nn.ModuleList() + + encoder_channels.insert(0, base_channels) + decoder_channels.insert(0, encoder_channels[-1]) + for i in range(num_stages): + self.encoder.append( + nn.Sequential( + TorchSparseConvModule( + encoder_channels[i], + encoder_channels[i], + kernel_size=2, + stride=2), + TorchSparseResidualBlock( + encoder_channels[i], + encoder_channels[i + 1], + kernel_size=3), + TorchSparseResidualBlock( + encoder_channels[i + 1], + encoder_channels[i + 1], + kernel_size=3))) + + self.decoder.append( + nn.ModuleList([ + TorchSparseConvModule( + decoder_channels[i], + decoder_channels[i + 1], + kernel_size=2, + stride=2, + transposed=True), + nn.Sequential( + TorchSparseResidualBlock( + decoder_channels[i + 1] + encoder_channels[-2 - i], + decoder_channels[i + 1], + kernel_size=3), + TorchSparseResidualBlock( + decoder_channels[i + 1], + decoder_channels[i + 1], + kernel_size=3)) + ])) + + def forward(self, voxel_features: Tensor, coors: Tensor) -> SparseTensor: + """Forward function. + + Args: + voxel_features (Tensor): Voxel features in shape (N, C). + coors (Tensor): Coordinates in shape (N, 4), + the columns in the order of (x_idx, y_idx, z_idx, batch_idx). + + Returns: + SparseTensor: Backbone features. + """ + x = torchsparse.SparseTensor(voxel_features, coors) + x = self.conv_input(x) + laterals = [x] + for encoder_layer in self.encoder: + x = encoder_layer(x) + laterals.append(x) + laterals = laterals[:-1][::-1] + + decoder_outs = [] + for i, decoder_layer in enumerate(self.decoder): + x = decoder_layer[0](x) + x = torchsparse.cat((x, laterals[i])) + x = decoder_layer[1](x) + decoder_outs.append(x) + + return decoder_outs[-1] diff --git a/mmdet3d/models/data_preprocessors/data_preprocessor.py b/mmdet3d/models/data_preprocessors/data_preprocessor.py index 4e2a506a2f..85286c9b62 100644 --- a/mmdet3d/models/data_preprocessors/data_preprocessor.py +++ b/mmdet3d/models/data_preprocessors/data_preprocessor.py @@ -415,6 +415,33 @@ def voxelize(self, points: List[torch.Tensor], coors.append(res_coors) voxels = torch.cat(voxels, dim=0) coors = torch.cat(coors, dim=0) + elif self.voxel_type == 'minkunet': + voxels, coors = [], [] + voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size) + for i, (res, data_sample) in enumerate(zip(points, data_samples)): + res_coors = torch.round(res[:, :3] / voxel_size).int() + res_coors -= res_coors.min(0)[0] + + res_coors_numpy = res_coors.cpu().numpy() + inds, voxel2point_map = self.sparse_quantize( + res_coors_numpy, return_index=True, return_inverse=True) + voxel2point_map = torch.from_numpy(voxel2point_map).cuda() + if self.training: + if len(inds) > 80000: + inds = np.random.choice(inds, 80000, replace=False) + inds = torch.from_numpy(inds).cuda() + data_sample.gt_pts_seg.voxel_semantic_mask \ + = data_sample.gt_pts_seg.pts_semantic_mask[inds] + res_voxel_coors = res_coors[inds] + res_voxels = res[inds] + res_voxel_coors = F.pad( + res_voxel_coors, (0, 1), mode='constant', value=i) + data_sample.voxel2point_map = voxel2point_map.long() + voxels.append(res_voxels) + coors.append(res_voxel_coors) + voxels = torch.cat(voxels, dim=0) + coors = torch.cat(coors, dim=0) + else: raise ValueError(f'Invalid voxelization type {self.voxel_type}') @@ -445,3 +472,53 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList): _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor, res_coors, 'mean', True) data_sample.gt_pts_seg.point2voxel_map = point2voxel_map + + def ravel_hash(self, x: np.ndarray) -> np.ndarray: + """Get voxel coordinates hash for np.unique(). + + Args: + x (np.ndarray): The voxel coordinates of points, Nx3. + + Returns: + np.ndarray: Voxels coordinates hash. + """ + assert x.ndim == 2, x.shape + + x = x - np.min(x, axis=0) + x = x.astype(np.uint64, copy=False) + xmax = np.max(x, axis=0).astype(np.uint64) + 1 + + h = np.zeros(x.shape[0], dtype=np.uint64) + for k in range(x.shape[1] - 1): + h += x[:, k] + h *= xmax[k + 1] + h += x[:, -1] + return h + + def sparse_quantize(self, + coords: np.ndarray, + return_index: bool = False, + return_inverse: bool = False) -> List[np.ndarray]: + """Sparse Quantization for voxel coordinates used in Minkunet. + + Args: + coords (np.ndarray): The voxel coordinates of points, Nx3. + return_index (bool): Whether to return the indices of the + unique coords, shape (M,). + return_inverse (bool): Whether to return the indices of the + original coords shape (N,). + + Returns: + List[np.ndarray] or None: Return index and inverse map if + return_index and return_inverse is True. + """ + _, indices, inverse_indices = np.unique( + self.ravel_hash(coords), return_index=True, return_inverse=True) + coords = coords[indices] + + outputs = [] + if return_index: + outputs += [indices] + if return_inverse: + outputs += [inverse_indices] + return outputs diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index 2a1f07a338..f7560e5a64 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cylinder3d_head import Cylinder3DHead from .dgcnn_head import DGCNNHead +from .minkunet_head import MinkUNetHead from .paconv_head import PAConvHead from .pointnet2_head import PointNet2Head -__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead'] +__all__ = [ + 'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead', + 'MinkUNetHead' +] diff --git a/mmdet3d/models/decode_heads/minkunet_head.py b/mmdet3d/models/decode_heads/minkunet_head.py new file mode 100644 index 0000000000..97d8fdf59f --- /dev/null +++ b/mmdet3d/models/decode_heads/minkunet_head.py @@ -0,0 +1,80 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List + +import torch +from torch import Tensor +from torch import nn as nn + +from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList +from .decode_head import Base3DDecodeHead + +if IS_TORCHSPARSE_AVAILABLE: + from torchsparse import SparseTensor +else: + SparseTensor = None + + +@MODELS.register_module() +class MinkUNetHead(Base3DDecodeHead): + r"""MinkUNet decoder head with TorchSparse backend. + + Refer to `implementation code `_. + + Args: + channels (int): The input channel of conv_seg. + num_classes (int): Number of classes. + """ + + def __init__(self, channels: int, num_classes: int, **kwargs) -> None: + super().__init__(channels, num_classes, **kwargs) + + def build_conv_seg(self, channels: int, num_classes: int, + kernel_size: int) -> nn.Module: + """Build Convolutional Segmentation Layers.""" + return nn.Linear(channels, num_classes) + + def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: + """Concat voxel-wise Groud Truth.""" + gt_semantic_segs = [ + data_sample.gt_pts_seg.voxel_semantic_mask + for data_sample in batch_data_samples + ] + return torch.cat(gt_semantic_segs) + + def predict(self, inputs: SparseTensor, + batch_data_samples: SampleList) -> List[Tensor]: + """Forward function for testing. + + Args: + inputs (SparseTensor): Features from backone. + batch_data_samples (List[:obj:`Det3DDataSample`]): The seg + data samples. + + Returns: + List[Tensor]: The segmentation prediction mask of each batch. + """ + seg_logits = self.forward(inputs) + + batch_idx = inputs.C[:, -1] + seg_logit_list = [] + for i, data_sample in enumerate(batch_data_samples): + seg_logit = seg_logits[batch_idx == i] + seg_logit = seg_logit[data_sample.voxel2point_map] + seg_logit_list.append(seg_logit) + + return seg_logit_list + + def forward(self, x: SparseTensor) -> Tensor: + """Forward function. + + Args: + x (SparseTensor): Features from backbone. + + Returns: + Tensor: Segmentation map of shape [N, C]. + Note that output contains all points from each batch. + """ + output = self.cls_seg(x.F) + return output diff --git a/mmdet3d/models/segmentors/__init__.py b/mmdet3d/models/segmentors/__init__.py index bc458d311f..5e43985d82 100644 --- a/mmdet3d/models/segmentors/__init__.py +++ b/mmdet3d/models/segmentors/__init__.py @@ -2,5 +2,6 @@ from .base import Base3DSegmentor from .cylinder3d import Cylinder3D from .encoder_decoder import EncoderDecoder3D +from .minkunet import MinkUNet -__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D'] +__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet'] diff --git a/mmdet3d/models/segmentors/minkunet.py b/mmdet3d/models/segmentors/minkunet.py new file mode 100644 index 0000000000..fcf8c22421 --- /dev/null +++ b/mmdet3d/models/segmentors/minkunet.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from torch import Tensor + +from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import OptSampleList, SampleList +from .encoder_decoder import EncoderDecoder3D + +if IS_TORCHSPARSE_AVAILABLE: + from torchsparse import SparseTensor +else: + SparseTensor = None + + +@MODELS.register_module() +class MinkUNet(EncoderDecoder3D): + r"""MinkUNet is the implementation of `4D Spatio-Temporal ConvNets. + `_ with TorchSparse backend. + + Refer to `implementation code `_. + + Args: + kwargs (dict): Arguments are the same as those in + :class:`EncoderDecoder3D`. + """ + + def __init__(self, **kwargs) -> None: + if not IS_TORCHSPARSE_AVAILABLE: + raise ImportError( + 'Please follow `get_started.md` to install Torchsparse.`') + super().__init__(**kwargs) + + def loss(self, inputs: dict, data_samples: SampleList): + """Calculate losses from a batch of inputs and data samples. + + Args: + batch_inputs_dict (dict): Input sample dict which + includes 'points' and 'voxels' keys. + + - points (List[Tensor]): Point cloud of each sample. + - voxels (dict): Voxel feature and coords after voxelization. + batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. + + Returns: + Dict[str, Tensor]: A dictionary of loss components. + """ + x = self.extract_feat(inputs) + losses = self.decode_head.loss(x, data_samples, self.train_cfg) + return losses + + def predict(self, inputs: dict, data_samples: SampleList) -> SampleList: + """Simple test with single scene. + + Args: + batch_inputs_dict (dict): Input sample dict which + includes 'points' and 'voxels' keys. + + - points (List[Tensor]): Point cloud of each sample. + - voxels (dict): Voxel feature and coords after voxelization. + batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. + + Returns: + List[:obj:`Det3DDataSample`]: Segmentation results of the input + points. Each Det3DDataSample usually contains: + + - ``pred_pts_seg`` (PixelData): Prediction of 3D semantic + segmentation. + """ + x = self.extract_feat(inputs) + seg_logits = self.decode_head.predict(x, data_samples) + seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits] + + return self.postprocess_result(seg_preds, data_samples) + + def _forward(self, + batch_inputs_dict: dict, + batch_data_samples: OptSampleList = None) -> Tensor: + """Network forward process. + + Args: + batch_inputs_dict (dict): Input sample dict which + includes 'points' and 'voxels' keys. + + - points (List[Tensor]): Point cloud of each sample. + - voxels (dict): Voxel feature and coords after voxelization. + batch_data_samples (List[:obj:`Det3DDataSample`]): The seg data + samples. It usually includes information such as `metainfo` and + `gt_pts_seg`. Defaults to None. + + Returns: + Tensor: Forward output of model without any post-processes. + """ + x = self.extract_feat(batch_inputs_dict) + return self.decode_head.forward(x) + + def extract_feat(self, batch_inputs_dict: dict) -> SparseTensor: + """Extract features from voxels. + + Args: + batch_inputs_dict (dict): Input sample dict which + includes 'points' and 'voxels' keys. + + - points (List[Tensor]): Point cloud of each sample. + - voxels (dict): Voxel feature and coords after voxelization. + + Returns: + SparseTensor: voxels with features. + """ + voxel_dict = batch_inputs_dict['voxels'] + x = self.backbone(voxel_dict['voxels'], voxel_dict['coors']) + if self.with_neck: + x = self.neck(x) + return x diff --git a/mmdet3d/testing/model_utils.py b/mmdet3d/testing/model_utils.py index 505c0c0e06..da449398d6 100644 --- a/mmdet3d/testing/model_utils.py +++ b/mmdet3d/testing/model_utils.py @@ -84,6 +84,7 @@ def create_detector_inputs(seed=0, gt_bboxes_dim=7, with_pts_semantic_mask=False, with_pts_instance_mask=False, + with_eval_ann_info=False, bboxes_3d_type='lidar'): setup_seed(seed) assert bboxes_3d_type in ('lidar', 'depth', 'cam') @@ -145,5 +146,9 @@ def create_detector_inputs(seed=0, if with_pts_semantic_mask: pts_semantic_mask = torch.randint(0, num_classes, [num_points]) data_sample.gt_pts_seg['pts_semantic_mask'] = pts_semantic_mask + if with_eval_ann_info: + data_sample.eval_ann_info = dict() + else: + data_sample.eval_ann_info = None return dict(inputs=inputs_dict, data_samples=[data_sample]) diff --git a/tests/test_models/test_backbones/test_minkunet_backbone.py b/tests/test_models/test_backbones/test_minkunet_backbone.py new file mode 100644 index 0000000000..086c269a9f --- /dev/null +++ b/tests/test_models/test_backbones/test_minkunet_backbone.py @@ -0,0 +1,34 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import pytest +import torch +import torch.nn.functional as F + +from mmdet3d.registry import MODELS + + +def test_minkunet_backbone(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + try: + import torchsparse # noqa: F401 + except ImportError: + pytest.skip('test requires Torchsparse installation') + + coordinates, features = [], [] + for i in range(2): + c = torch.randint(0, 10, (100, 3)).int() + c = F.pad(c, (0, 1), mode='constant', value=i) + coordinates.append(c) + f = torch.rand(100, 4) + features.append(f) + features = torch.cat(features, dim=0).cuda() + coordinates = torch.cat(coordinates, dim=0).cuda() + + cfg = dict(type='MinkUNetBackbone') + self = MODELS.build(cfg).cuda() + self.init_weights() + + y = self(features, coordinates) + assert y.F.shape == torch.Size([200, 96]) + assert y.C.shape == torch.Size([200, 4]) diff --git a/tests/test_models/test_decode_heads/test_minkunet_head.py b/tests/test_models/test_decode_heads/test_minkunet_head.py new file mode 100644 index 0000000000..c684565ded --- /dev/null +++ b/tests/test_models/test_decode_heads/test_minkunet_head.py @@ -0,0 +1,54 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase + +import pytest +import torch +import torch.nn.functional as F + +from mmdet3d.models.decode_heads import MinkUNetHead +from mmdet3d.structures import Det3DDataSample, PointData + + +class TestMinkUNetHead(TestCase): + + def test_minkunet_head_loss(self): + """Tests PAConv head loss.""" + + try: + import torchsparse + except ImportError: + pytest.skip('test requires Torchsparse installation') + if torch.cuda.is_available(): + minkunet_head = MinkUNetHead(channels=4, num_classes=19) + + minkunet_head.cuda() + coordinates, features = [], [] + for i in range(2): + c = torch.randint(0, 10, (100, 3)).int() + c = F.pad(c, (0, 1), mode='constant', value=i) + coordinates.append(c) + f = torch.rand(100, 4) + features.append(f) + features = torch.cat(features, dim=0).cuda() + coordinates = torch.cat(coordinates, dim=0).cuda() + x = torchsparse.SparseTensor(feats=features, coords=coordinates) + + # Test forward + seg_logits = minkunet_head.forward(x) + + self.assertEqual(seg_logits.shape, torch.Size([200, 19])) + + # When truth is non-empty then losses + # should be nonzero for random inputs + voxel_semantic_mask = torch.randint(0, 19, (100, )).long().cuda() + gt_pts_seg = PointData(voxel_semantic_mask=voxel_semantic_mask) + + datasample = Det3DDataSample() + datasample.gt_pts_seg = gt_pts_seg + + gt_losses = minkunet_head.loss(x, [datasample, datasample], {}) + + gt_sem_seg_loss = gt_losses['loss_sem_seg'].item() + + self.assertGreater(gt_sem_seg_loss, 0, + 'semantic seg loss should be positive') diff --git a/tests/test_models/test_segmentor/test_minkunet.py b/tests/test_models/test_segmentor/test_minkunet.py new file mode 100644 index 0000000000..16312c293e --- /dev/null +++ b/tests/test_models/test_segmentor/test_minkunet.py @@ -0,0 +1,46 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import unittest + +import pytest +import torch +from mmengine import DefaultScope + +from mmdet3d.registry import MODELS +from mmdet3d.testing import (create_detector_inputs, get_detector_cfg, + setup_seed) + + +class TestMinkUNet(unittest.TestCase): + + def test_minkunet(self): + try: + import torchsparse # noqa + except ImportError: + pytest.skip('test requires Torchsparse installation') + + import mmdet3d.models + + assert hasattr(mmdet3d.models, 'MinkUNet') + DefaultScope.get_instance('test_minkunet', scope_name='mmdet3d') + setup_seed(0) + model_cfg = get_detector_cfg('_base_/models/minkunet.py') + model = MODELS.build(model_cfg) + num_gt_instance = 3 + packed_inputs = create_detector_inputs( + num_gt_instance=num_gt_instance, + num_classes=19, + with_pts_semantic_mask=True) + + if torch.cuda.is_available(): + model = model.cuda() + # test simple_test + with torch.no_grad(): + data = model.data_preprocessor(packed_inputs, True) + torch.cuda.empty_cache() + results = model.forward(**data, mode='predict') + self.assertEqual(len(results), 1) + self.assertIn('pts_semantic_mask', results[0].pred_pts_seg) + + losses = model.forward(**data, mode='loss') + + self.assertGreater(losses['loss_sem_seg'], 0)