Skip to content

Commit

Permalink
Add modules before mg_head in centerpoint (#46)
Browse files Browse the repository at this point in the history
* Add centerpoint_rpn and scn, change pillar encoder and voxel_encoder

* Move test_voxel_encoders.

* Change names, add docstring.

* Reconstruct centerpoint_rpn.

* Add centerpoint_rpn.

* Change SECONDFPN, delete centerpoint_fpn

* Remove SparseBasicBlock.

* Change SpMiddleResNetFHD to SparseEncoder.

* Finish SparseEncoder unittest.

* Change test_hard_simple_VFE.

* Change option, add legacy.

* Change docstring, change legacy.

* Fix legacy bug.

* Change unittest, change docstring.

* Change docstring.
  • Loading branch information
yinchimaoliang authored Aug 18, 2020
1 parent 5f7b31c commit 27d0001
Show file tree
Hide file tree
Showing 7 changed files with 212 additions and 31 deletions.
54 changes: 48 additions & 6 deletions mmdet3d/models/middle_encoders/sparse_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import nn as nn

from mmdet3d.ops import make_sparse_convmodule
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..registry import MIDDLE_ENCODERS

Expand All @@ -12,12 +12,19 @@ class SparseEncoder(nn.Module):
Args:
in_channels (int): The number of input channels.
sparse_shape (list[int]): The sparse shape of input tensor.
norm_cfg (dict): Config of normalization layer.
order (list[str]): Order of conv module. Defaults to ('conv',
'norm', 'act').
norm_cfg (dict): Config of normalization layer. Defaults to
dict(type='BN1d', eps=1e-3, momentum=0.01).
base_channels (int): Out channels for conv_input layer.
Defaults to 16.
output_channels (int): Out channels for conv_out layer.
Defaults to 128.
encoder_channels (tuple[tuple[int]]):
Convolutional channels of each encode block.
encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
Defaults to ((16, ), (32, 32, 32), (64, 64, 64), (64, 64, 64)).
block_type (str): Type of the block to use. Defaults to 'conv_module'.
"""

def __init__(self,
Expand All @@ -30,8 +37,10 @@ def __init__(self,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1))):
1)),
block_type='conv_module'):
super().__init__()
assert block_type in ['conv_module', 'basicblock']
self.sparse_shape = sparse_shape
self.in_channels = in_channels
self.order = order
Expand Down Expand Up @@ -66,7 +75,10 @@ def __init__(self,
conv_type='SubMConv3d')

encoder_out_channels = self.make_encoder_layers(
make_sparse_convmodule, norm_cfg, self.base_channels)
make_sparse_convmodule,
norm_cfg,
self.base_channels,
block_type=block_type)

self.conv_out = make_sparse_convmodule(
encoder_out_channels,
Expand Down Expand Up @@ -111,17 +123,27 @@ def forward(self, voxel_features, coors, batch_size):

return spatial_features

def make_encoder_layers(self, make_block, norm_cfg, in_channels):
def make_encoder_layers(self,
make_block,
norm_cfg,
in_channels,
block_type='conv_module',
conv_cfg=dict(type='SubMConv3d')):
"""make encoder layers using sparse convs.
Args:
make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): The number of encoder input channels.
block_type (str): Type of the block to use. Defaults to
'conv_module'.
conv_cfg (dict): Config of conv layer. Defaults to
dict(type='SubMConv3d').
Returns:
int: The number of encoder output channels.
"""
assert block_type in ['conv_module', 'basicblock']
self.encoder_layers = spconv.SparseSequential()

for i, blocks in enumerate(self.encoder_channels):
Expand All @@ -130,7 +152,7 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels):
padding = tuple(self.encoder_paddings[i])[j]
# each stage started with a spconv layer
# except the first stage
if i != 0 and j == 0:
if i != 0 and j == 0 and block_type == 'conv_module':
blocks_list.append(
make_block(
in_channels,
Expand All @@ -141,6 +163,26 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels):
padding=padding,
indice_key=f'spconv{i + 1}',
conv_type='SparseConv3d'))
elif block_type == 'basicblock':
if j == len(blocks) - 1 and i != len(
self.encoder_channels) - 1:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key=f'spconv{i + 1}',
conv_type='SparseConv3d'))
else:
blocks_list.append(
SparseBasicBlock(
out_channels,
out_channels,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg))
else:
blocks_list.append(
make_block(
Expand Down
45 changes: 31 additions & 14 deletions mmdet3d/models/necks/second_fpn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
is_norm, kaiming_init)
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, is_norm, kaiming_init)
from torch import nn as nn

from mmdet.models import NECKS
Expand All @@ -11,19 +12,24 @@ class SECONDFPN(nn.Module):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet.
Args:
in_channels (list[int]): Input channels of multi-scale feature maps
out_channels (list[int]): Output channels of feature maps
upsample_strides (list[int]): Strides used to upsample the feature maps
norm_cfg (dict): Config dict of normalization layers
upsample_cfg (dict): Config dict of upsample layers
in_channels (list[int]): Input channels of multi-scale feature maps.
out_channels (list[int]): Output channels of feature maps.
upsample_strides (list[int]): Strides used to upsample the
feature maps.
norm_cfg (dict): Config dict of normalization layers.
upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers.
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
"""

def __init__(self,
in_channels=[128, 128, 256],
out_channels=[256, 256, 256],
upsample_strides=[1, 2, 4],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False)):
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False),
use_conv_for_no_stride=False):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__()
Expand All @@ -33,12 +39,23 @@ def __init__(self,

deblocks = []
for i, out_channel in enumerate(out_channels):
upsample_layer = build_upsample_layer(
upsample_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=upsample_strides[i],
stride=upsample_strides[i])
stride = upsample_strides[i]
if stride > 1 or (stride == 1 and not use_conv_for_no_stride):
upsample_layer = build_upsample_layer(
upsample_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=upsample_strides[i],
stride=upsample_strides[i])
else:
stride = np.round(1 / stride).astype(np.int64)
upsample_layer = build_conv_layer(
conv_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=stride,
stride=stride)

deblock = nn.Sequential(upsample_layer,
build_norm_layer(norm_cfg, out_channel)[1],
nn.ReLU(inplace=True))
Expand Down
32 changes: 23 additions & 9 deletions mmdet3d/models/voxel_encoders/pillar_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class PillarFeatureNet(nn.Module):
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
legacy (bool): Whether to use the new behavior or
the original behavior. Defaults to True.
"""

def __init__(self,
Expand All @@ -42,9 +44,11 @@ def __init__(self,
voxel_size=(0.2, 0.2, 4),
point_cloud_range=(0, -40, -3, 70.4, 40, 1),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
mode='max'):
mode='max',
legacy=True):
super(PillarFeatureNet, self).__init__()
assert len(feat_channels) > 0
self.legacy = legacy
if with_cluster_center:
in_channels += 3
if with_voxel_center:
Expand Down Expand Up @@ -89,7 +93,7 @@ def forward(self, features, num_points, coors):
features (torch.Tensor): Point features or raw points in shape
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel
coors (torch.Tensor): Coordinates of each voxel.
Returns:
torch.Tensor: Features of pillars.
Expand All @@ -104,14 +108,24 @@ def forward(self, features, num_points, coors):
features_ls.append(f_cluster)

# Find distance of x, y, and z from pillar center
dtype = features.dtype
if self._with_voxel_center:
f_center = features[:, :, :2]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
if not self.legacy:
f_center = torch.zeros_like(features[:, :, :2])
f_center[:, :, 0] = features[:, :, 0] - (
coors[:, 3].to(dtype).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = features[:, :, 1] - (
coors[:, 2].to(dtype).unsqueeze(1) * self.vy +
self.y_offset)
else:
f_center = features[:, :, :2]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
features_ls.append(f_center)

if self._with_distance:
Expand Down
8 changes: 6 additions & 2 deletions mmdet3d/models/voxel_encoders/voxel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ class HardSimpleVFE(nn.Module):
"""Simple voxel feature encoder used in SECOND.
It simply averages the values of points in a voxel.
Args:
num_features (int): Number of features to use. Default: 4.
"""

def __init__(self):
def __init__(self, num_features=4):
super(HardSimpleVFE, self).__init__()
self.num_features = num_features

def forward(self, features, num_points, coors):
"""Forward function.
Expand All @@ -32,7 +36,7 @@ def forward(self, features, num_points, coors):
Returns:
torch.Tensor: Mean of points inside each voxel in shape (N, 3(4))
"""
points_mean = features[:, :, :4].sum(
points_mean = features[:, :, :self.num_features].sum(
dim=1, keepdim=False) / num_points.type_as(features).view(-1, 1)
return points_mean.contiguous()

Expand Down
26 changes: 26 additions & 0 deletions tests/test_middle_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch

from mmdet3d.models.builder import build_middle_encoder


def test_sparse_encoder():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
sparse_encoder_cfg = dict(
type='SparseEncoder',
in_channels=5,
sparse_shape=[40, 1024, 1024],
order=('conv', 'norm', 'act'),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1,
1)),
option='basicblock')

sparse_encoder = build_middle_encoder(sparse_encoder_cfg).cuda()
voxel_features = torch.rand([207842, 5]).cuda()
coors = torch.randint(0, 4, [207842, 4]).cuda()

ret = sparse_encoder(voxel_features, coors, 4)
assert ret.shape == torch.Size([4, 256, 128, 128])
45 changes: 45 additions & 0 deletions tests/test_necks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch

from mmdet3d.models.builder import build_backbone, build_neck


def test_centerpoint_fpn():

second_cfg = dict(
type='SECOND',
in_channels=64,
out_channels=[64, 128, 256],
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False))

second = build_backbone(second_cfg)

# centerpoint usage of fpn
centerpoint_fpn_cfg = dict(
type='SECONDFPN',
in_channels=[64, 128, 256],
out_channels=[128, 128, 128],
upsample_strides=[0.5, 1, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=True)

# original usage of fpn
fpn_cfg = dict(
type='SECONDFPN',
in_channels=[64, 128, 256],
upsample_strides=[1, 2, 4],
out_channels=[128, 128, 128])

second_fpn = build_neck(fpn_cfg)

centerpoint_second_fpn = build_neck(centerpoint_fpn_cfg)

input = torch.rand([4, 64, 512, 512])
sec_output = second(input)
centerpoint_output = centerpoint_second_fpn(sec_output)
second_output = second_fpn(sec_output)
assert centerpoint_output[0].shape == torch.Size([4, 384, 128, 128])
assert second_output[0].shape == torch.Size([4, 384, 256, 256])
33 changes: 33 additions & 0 deletions tests/test_voxel_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch

from mmdet3d.models.builder import build_voxel_encoder


def test_pillar_feature_net():
pillar_feature_net_cfg = dict(
type='PillarFeatureNet',
in_channels=5,
feat_channels=[64],
with_distance=False,
voxel_size=(0.2, 0.2, 8),
point_cloud_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01))

pillar_feature_net = build_voxel_encoder(pillar_feature_net_cfg)

features = torch.rand([97297, 20, 5])
num_voxels = torch.randint(1, 100, [97297])
coors = torch.randint(0, 100, [97297, 4])

features = pillar_feature_net(features, num_voxels, coors)
assert features.shape == torch.Size([97297, 64])


def test_hard_simple_VFE():
hard_simple_VFE_cfg = dict(type='HardSimpleVFE', num_features=5)
hard_simple_VFE = build_voxel_encoder(hard_simple_VFE_cfg)
features = torch.rand([240000, 10, 5])
num_voxels = torch.randint(1, 10, [240000])

outputs = hard_simple_VFE(features, num_voxels, None)
assert outputs.shape == torch.Size([240000, 5])

0 comments on commit 27d0001

Please sign in to comment.