diff --git a/mmdet3d/models/middle_encoders/sparse_encoder.py b/mmdet3d/models/middle_encoders/sparse_encoder.py index 7f372b60ff..00462aad01 100644 --- a/mmdet3d/models/middle_encoders/sparse_encoder.py +++ b/mmdet3d/models/middle_encoders/sparse_encoder.py @@ -1,6 +1,6 @@ from torch import nn as nn -from mmdet3d.ops import make_sparse_convmodule +from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule from mmdet3d.ops import spconv as spconv from ..registry import MIDDLE_ENCODERS @@ -12,12 +12,19 @@ class SparseEncoder(nn.Module): Args: in_channels (int): The number of input channels. sparse_shape (list[int]): The sparse shape of input tensor. - norm_cfg (dict): Config of normalization layer. + order (list[str]): Order of conv module. Defaults to ('conv', + 'norm', 'act'). + norm_cfg (dict): Config of normalization layer. Defaults to + dict(type='BN1d', eps=1e-3, momentum=0.01). base_channels (int): Out channels for conv_input layer. + Defaults to 16. output_channels (int): Out channels for conv_out layer. + Defaults to 128. encoder_channels (tuple[tuple[int]]): Convolutional channels of each encode block. encoder_paddings (tuple[tuple[int]]): Paddings of each encode block. + Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)). + block_type (str): Type of the block to use. Defaults to 'conv_module'. """ def __init__(self, @@ -30,8 +37,10 @@ def __init__(self, encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)), encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1, - 1))): + 1)), + block_type='conv_module'): super().__init__() + assert block_type in ['conv_module', 'basicblock'] self.sparse_shape = sparse_shape self.in_channels = in_channels self.order = order @@ -66,7 +75,10 @@ def __init__(self, conv_type='SubMConv3d') encoder_out_channels = self.make_encoder_layers( - make_sparse_convmodule, norm_cfg, self.base_channels) + make_sparse_convmodule, + norm_cfg, + self.base_channels, + block_type=block_type) self.conv_out = make_sparse_convmodule( encoder_out_channels, @@ -111,17 +123,27 @@ def forward(self, voxel_features, coors, batch_size): return spatial_features - def make_encoder_layers(self, make_block, norm_cfg, in_channels): + def make_encoder_layers(self, + make_block, + norm_cfg, + in_channels, + block_type='conv_module', + conv_cfg=dict(type='SubMConv3d')): """make encoder layers using sparse convs. Args: make_block (method): A bounded function to build blocks. norm_cfg (dict[str]): Config of normalization layer. in_channels (int): The number of encoder input channels. + block_type (str): Type of the block to use. Defaults to + 'conv_module'. + conv_cfg (dict): Config of conv layer. Defaults to + dict(type='SubMConv3d'). Returns: int: The number of encoder output channels. """ + assert block_type in ['conv_module', 'basicblock'] self.encoder_layers = spconv.SparseSequential() for i, blocks in enumerate(self.encoder_channels): @@ -130,7 +152,7 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels): padding = tuple(self.encoder_paddings[i])[j] # each stage started with a spconv layer # except the first stage - if i != 0 and j == 0: + if i != 0 and j == 0 and block_type == 'conv_module': blocks_list.append( make_block( in_channels, @@ -141,6 +163,26 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels): padding=padding, indice_key=f'spconv{i + 1}', conv_type='SparseConv3d')) + elif block_type == 'basicblock': + if j == len(blocks) - 1 and i != len( + self.encoder_channels) - 1: + blocks_list.append( + make_block( + in_channels, + out_channels, + 3, + norm_cfg=norm_cfg, + stride=2, + padding=padding, + indice_key=f'spconv{i + 1}', + conv_type='SparseConv3d')) + else: + blocks_list.append( + SparseBasicBlock( + out_channels, + out_channels, + norm_cfg=norm_cfg, + conv_cfg=conv_cfg)) else: blocks_list.append( make_block( diff --git a/mmdet3d/models/necks/second_fpn.py b/mmdet3d/models/necks/second_fpn.py index 2b2a404840..ec9d40ee5c 100644 --- a/mmdet3d/models/necks/second_fpn.py +++ b/mmdet3d/models/necks/second_fpn.py @@ -1,6 +1,7 @@ +import numpy as np import torch -from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init, - is_norm, kaiming_init) +from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer, + constant_init, is_norm, kaiming_init) from torch import nn as nn from mmdet.models import NECKS @@ -11,11 +12,14 @@ class SECONDFPN(nn.Module): """FPN used in SECOND/PointPillars/PartA2/MVXNet. Args: - in_channels (list[int]): Input channels of multi-scale feature maps - out_channels (list[int]): Output channels of feature maps - upsample_strides (list[int]): Strides used to upsample the feature maps - norm_cfg (dict): Config dict of normalization layers - upsample_cfg (dict): Config dict of upsample layers + in_channels (list[int]): Input channels of multi-scale feature maps. + out_channels (list[int]): Output channels of feature maps. + upsample_strides (list[int]): Strides used to upsample the + feature maps. + norm_cfg (dict): Config dict of normalization layers. + upsample_cfg (dict): Config dict of upsample layers. + conv_cfg (dict): Config dict of conv layers. + use_conv_for_no_stride (bool): Whether to use conv when stride is 1. """ def __init__(self, @@ -23,7 +27,9 @@ def __init__(self, out_channels=[256, 256, 256], upsample_strides=[1, 2, 4], norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), - upsample_cfg=dict(type='deconv', bias=False)): + upsample_cfg=dict(type='deconv', bias=False), + conv_cfg=dict(type='Conv2d', bias=False), + use_conv_for_no_stride=False): # if for GroupNorm, # cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True) super(SECONDFPN, self).__init__() @@ -33,12 +39,23 @@ def __init__(self, deblocks = [] for i, out_channel in enumerate(out_channels): - upsample_layer = build_upsample_layer( - upsample_cfg, - in_channels=in_channels[i], - out_channels=out_channel, - kernel_size=upsample_strides[i], - stride=upsample_strides[i]) + stride = upsample_strides[i] + if stride > 1 or (stride == 1 and not use_conv_for_no_stride): + upsample_layer = build_upsample_layer( + upsample_cfg, + in_channels=in_channels[i], + out_channels=out_channel, + kernel_size=upsample_strides[i], + stride=upsample_strides[i]) + else: + stride = np.round(1 / stride).astype(np.int64) + upsample_layer = build_conv_layer( + conv_cfg, + in_channels=in_channels[i], + out_channels=out_channel, + kernel_size=stride, + stride=stride) + deblock = nn.Sequential(upsample_layer, build_norm_layer(norm_cfg, out_channel)[1], nn.ReLU(inplace=True)) diff --git a/mmdet3d/models/voxel_encoders/pillar_encoder.py b/mmdet3d/models/voxel_encoders/pillar_encoder.py index 58a28cd58a..3411260971 100644 --- a/mmdet3d/models/voxel_encoders/pillar_encoder.py +++ b/mmdet3d/models/voxel_encoders/pillar_encoder.py @@ -31,6 +31,8 @@ class PillarFeatureNet(nn.Module): Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01). mode (str, optional): The mode to gather point features. Options are 'max' or 'avg'. Defaults to 'max'. + legacy (bool): Whether to use the new behavior or + the original behavior. Defaults to True. """ def __init__(self, @@ -42,9 +44,11 @@ def __init__(self, voxel_size=(0.2, 0.2, 4), point_cloud_range=(0, -40, -3, 70.4, 40, 1), norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - mode='max'): + mode='max', + legacy=True): super(PillarFeatureNet, self).__init__() assert len(feat_channels) > 0 + self.legacy = legacy if with_cluster_center: in_channels += 3 if with_voxel_center: @@ -89,7 +93,7 @@ def forward(self, features, num_points, coors): features (torch.Tensor): Point features or raw points in shape (N, M, C). num_points (torch.Tensor): Number of points in each pillar. - coors (torch.Tensor): Coordinates of each voxel + coors (torch.Tensor): Coordinates of each voxel. Returns: torch.Tensor: Features of pillars. @@ -104,14 +108,24 @@ def forward(self, features, num_points, coors): features_ls.append(f_cluster) # Find distance of x, y, and z from pillar center + dtype = features.dtype if self._with_voxel_center: - f_center = features[:, :, :2] - f_center[:, :, 0] = f_center[:, :, 0] - ( - coors[:, 3].type_as(features).unsqueeze(1) * self.vx + - self.x_offset) - f_center[:, :, 1] = f_center[:, :, 1] - ( - coors[:, 2].type_as(features).unsqueeze(1) * self.vy + - self.y_offset) + if not self.legacy: + f_center = torch.zeros_like(features[:, :, :2]) + f_center[:, :, 0] = features[:, :, 0] - ( + coors[:, 3].to(dtype).unsqueeze(1) * self.vx + + self.x_offset) + f_center[:, :, 1] = features[:, :, 1] - ( + coors[:, 2].to(dtype).unsqueeze(1) * self.vy + + self.y_offset) + else: + f_center = features[:, :, :2] + f_center[:, :, 0] = f_center[:, :, 0] - ( + coors[:, 3].type_as(features).unsqueeze(1) * self.vx + + self.x_offset) + f_center[:, :, 1] = f_center[:, :, 1] - ( + coors[:, 2].type_as(features).unsqueeze(1) * self.vy + + self.y_offset) features_ls.append(f_center) if self._with_distance: diff --git a/mmdet3d/models/voxel_encoders/voxel_encoder.py b/mmdet3d/models/voxel_encoders/voxel_encoder.py index 67a4ab8162..b647eb608f 100644 --- a/mmdet3d/models/voxel_encoders/voxel_encoder.py +++ b/mmdet3d/models/voxel_encoders/voxel_encoder.py @@ -13,10 +13,14 @@ class HardSimpleVFE(nn.Module): """Simple voxel feature encoder used in SECOND. It simply averages the values of points in a voxel. + + Args: + num_features (int): Number of features to use. Default: 4. """ - def __init__(self): + def __init__(self, num_features=4): super(HardSimpleVFE, self).__init__() + self.num_features = num_features def forward(self, features, num_points, coors): """Forward function. @@ -32,7 +36,7 @@ def forward(self, features, num_points, coors): Returns: torch.Tensor: Mean of points inside each voxel in shape (N, 3(4)) """ - points_mean = features[:, :, :4].sum( + points_mean = features[:, :, :self.num_features].sum( dim=1, keepdim=False) / num_points.type_as(features).view(-1, 1) return points_mean.contiguous() diff --git a/tests/test_middle_encoders.py b/tests/test_middle_encoders.py new file mode 100644 index 0000000000..e3ab306eb4 --- /dev/null +++ b/tests/test_middle_encoders.py @@ -0,0 +1,26 @@ +import pytest +import torch + +from mmdet3d.models.builder import build_middle_encoder + + +def test_sparse_encoder(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + sparse_encoder_cfg = dict( + type='SparseEncoder', + in_channels=5, + sparse_shape=[40, 1024, 1024], + order=('conv', 'norm', 'act'), + encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128, + 128)), + encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, + 1)), + option='basicblock') + + sparse_encoder = build_middle_encoder(sparse_encoder_cfg).cuda() + voxel_features = torch.rand([207842, 5]).cuda() + coors = torch.randint(0, 4, [207842, 4]).cuda() + + ret = sparse_encoder(voxel_features, coors, 4) + assert ret.shape == torch.Size([4, 256, 128, 128]) diff --git a/tests/test_necks.py b/tests/test_necks.py new file mode 100644 index 0000000000..7e924cadaa --- /dev/null +++ b/tests/test_necks.py @@ -0,0 +1,45 @@ +import torch + +from mmdet3d.models.builder import build_backbone, build_neck + + +def test_centerpoint_fpn(): + + second_cfg = dict( + type='SECOND', + in_channels=64, + out_channels=[64, 128, 256], + layer_nums=[3, 5, 5], + layer_strides=[2, 2, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + conv_cfg=dict(type='Conv2d', bias=False)) + + second = build_backbone(second_cfg) + + # centerpoint usage of fpn + centerpoint_fpn_cfg = dict( + type='SECONDFPN', + in_channels=[64, 128, 256], + out_channels=[128, 128, 128], + upsample_strides=[0.5, 1, 2], + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + upsample_cfg=dict(type='deconv', bias=False), + use_conv_for_no_stride=True) + + # original usage of fpn + fpn_cfg = dict( + type='SECONDFPN', + in_channels=[64, 128, 256], + upsample_strides=[1, 2, 4], + out_channels=[128, 128, 128]) + + second_fpn = build_neck(fpn_cfg) + + centerpoint_second_fpn = build_neck(centerpoint_fpn_cfg) + + input = torch.rand([4, 64, 512, 512]) + sec_output = second(input) + centerpoint_output = centerpoint_second_fpn(sec_output) + second_output = second_fpn(sec_output) + assert centerpoint_output[0].shape == torch.Size([4, 384, 128, 128]) + assert second_output[0].shape == torch.Size([4, 384, 256, 256]) diff --git a/tests/test_voxel_encoders.py b/tests/test_voxel_encoders.py new file mode 100644 index 0000000000..f7503a8c34 --- /dev/null +++ b/tests/test_voxel_encoders.py @@ -0,0 +1,33 @@ +import torch + +from mmdet3d.models.builder import build_voxel_encoder + + +def test_pillar_feature_net(): + pillar_feature_net_cfg = dict( + type='PillarFeatureNet', + in_channels=5, + feat_channels=[64], + with_distance=False, + voxel_size=(0.2, 0.2, 8), + point_cloud_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0), + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01)) + + pillar_feature_net = build_voxel_encoder(pillar_feature_net_cfg) + + features = torch.rand([97297, 20, 5]) + num_voxels = torch.randint(1, 100, [97297]) + coors = torch.randint(0, 100, [97297, 4]) + + features = pillar_feature_net(features, num_voxels, coors) + assert features.shape == torch.Size([97297, 64]) + + +def test_hard_simple_VFE(): + hard_simple_VFE_cfg = dict(type='HardSimpleVFE', num_features=5) + hard_simple_VFE = build_voxel_encoder(hard_simple_VFE_cfg) + features = torch.rand([240000, 10, 5]) + num_voxels = torch.randint(1, 10, [240000]) + + outputs = hard_simple_VFE(features, num_voxels, None) + assert outputs.shape == torch.Size([240000, 5])