Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add modules before mg_head in centerpoint #46

Merged
merged 15 commits into from
Aug 18, 2020
44 changes: 39 additions & 5 deletions mmdet3d/models/middle_encoders/sparse_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from torch import nn as nn

from mmdet3d.ops import make_sparse_convmodule
from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule
from mmdet3d.ops import spconv as spconv
from ..registry import MIDDLE_ENCODERS

Expand All @@ -18,6 +18,7 @@ class SparseEncoder(nn.Module):
encoder_channels (tuple[tuple[int]]):
Convolutional channels of each encode block.
encoder_paddings (tuple[tuple[int]]): Paddings of each encode block.
block_type (str): Type of the block to use.
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self,
Expand All @@ -30,8 +31,10 @@ def __init__(self,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1))):
1)),
block_type='submblock'):
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
super().__init__()
assert block_type in ['submblock', 'basicblock']
self.sparse_shape = sparse_shape
self.in_channels = in_channels
self.order = order
Expand Down Expand Up @@ -66,7 +69,10 @@ def __init__(self,
conv_type='SubMConv3d')

encoder_out_channels = self.make_encoder_layers(
make_sparse_convmodule, norm_cfg, self.base_channels)
make_sparse_convmodule,
norm_cfg,
self.base_channels,
block_type=block_type)

self.conv_out = make_sparse_convmodule(
encoder_out_channels,
Expand Down Expand Up @@ -111,17 +117,25 @@ def forward(self, voxel_features, coors, batch_size):

return spatial_features

def make_encoder_layers(self, make_block, norm_cfg, in_channels):
def make_encoder_layers(self,
make_block,
norm_cfg,
in_channels,
block_type='submblock',
conv_cfg=dict(type='SubMConv3d')):
"""make encoder layers using sparse convs.

Args:
make_block (method): A bounded function to build blocks.
norm_cfg (dict[str]): Config of normalization layer.
in_channels (int): The number of encoder input channels.
block_type (str): Type of the block to use.
conv_cfg (dict): Config of conv layer.

Returns:
int: The number of encoder output channels.
"""
assert block_type in ['submblock', 'basicblock']
self.encoder_layers = spconv.SparseSequential()

for i, blocks in enumerate(self.encoder_channels):
Expand All @@ -130,7 +144,7 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels):
padding = tuple(self.encoder_paddings[i])[j]
# each stage started with a spconv layer
# except the first stage
if i != 0 and j == 0:
if i != 0 and j == 0 and block_type == 'submblock':
blocks_list.append(
make_block(
in_channels,
Expand All @@ -141,6 +155,26 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels):
padding=padding,
indice_key=f'spconv{i + 1}',
conv_type='SparseConv3d'))
elif block_type == 'basicblock':
if j == len(blocks) - 1 and i != len(
self.encoder_channels) - 1:
blocks_list.append(
make_block(
in_channels,
out_channels,
3,
norm_cfg=norm_cfg,
stride=2,
padding=padding,
indice_key=f'spconv{i + 1}',
conv_type='SparseConv3d'))
else:
blocks_list.append(
SparseBasicBlock(
out_channels,
out_channels,
norm_cfg=norm_cfg,
conv_cfg=conv_cfg))
else:
blocks_list.append(
make_block(
Expand Down
45 changes: 31 additions & 14 deletions mmdet3d/models/necks/second_fpn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
from mmcv.cnn import (build_norm_layer, build_upsample_layer, constant_init,
is_norm, kaiming_init)
from mmcv.cnn import (build_conv_layer, build_norm_layer, build_upsample_layer,
constant_init, is_norm, kaiming_init)
from torch import nn as nn

from mmdet.models import NECKS
Expand All @@ -11,19 +12,24 @@ class SECONDFPN(nn.Module):
"""FPN used in SECOND/PointPillars/PartA2/MVXNet.
Args:
in_channels (list[int]): Input channels of multi-scale feature maps
out_channels (list[int]): Output channels of feature maps
upsample_strides (list[int]): Strides used to upsample the feature maps
norm_cfg (dict): Config dict of normalization layers
upsample_cfg (dict): Config dict of upsample layers
in_channels (list[int]): Input channels of multi-scale feature maps.
out_channels (list[int]): Output channels of feature maps.
upsample_strides (list[int]): Strides used to upsample the
feature maps.
norm_cfg (dict): Config dict of normalization layers.
upsample_cfg (dict): Config dict of upsample layers.
conv_cfg (dict): Config dict of conv layers.
use_conv_for_no_stride (bool): Whether to use conv when stride is 1.
"""

def __init__(self,
in_channels=[128, 128, 256],
out_channels=[256, 256, 256],
upsample_strides=[1, 2, 4],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False)):
upsample_cfg=dict(type='deconv', bias=False),
conv_cfg=dict(type='Conv2d', bias=False),
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
use_conv_for_no_stride=False):
# if for GroupNorm,
# cfg is dict(type='GN', num_groups=num_groups, eps=1e-3, affine=True)
super(SECONDFPN, self).__init__()
Expand All @@ -33,12 +39,23 @@ def __init__(self,

deblocks = []
for i, out_channel in enumerate(out_channels):
upsample_layer = build_upsample_layer(
upsample_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=upsample_strides[i],
stride=upsample_strides[i])
stride = upsample_strides[i]
if stride > 1 or (stride == 1 and not use_conv_for_no_stride):
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
upsample_layer = build_upsample_layer(
upsample_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=upsample_strides[i],
stride=upsample_strides[i])
else:
stride = np.round(1 / stride).astype(np.int64)
upsample_layer = build_conv_layer(
conv_cfg,
in_channels=in_channels[i],
out_channels=out_channel,
kernel_size=stride,
stride=stride)

deblock = nn.Sequential(upsample_layer,
build_norm_layer(norm_cfg, out_channel)[1],
nn.ReLU(inplace=True))
Expand Down
32 changes: 23 additions & 9 deletions mmdet3d/models/voxel_encoders/pillar_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class PillarFeatureNet(nn.Module):
Defaults to dict(type='BN1d', eps=1e-3, momentum=0.01).
mode (str, optional): The mode to gather point features. Options are
'max' or 'avg'. Defaults to 'max'.
legacy (bool): Whether to use the new behavior or
the original behavior.
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self,
Expand All @@ -42,9 +44,11 @@ def __init__(self,
voxel_size=(0.2, 0.2, 4),
point_cloud_range=(0, -40, -3, 70.4, 40, 1),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
mode='max'):
mode='max',
legacy=False):
super(PillarFeatureNet, self).__init__()
assert len(feat_channels) > 0
self.egacy = legacy
if with_cluster_center:
in_channels += 3
if with_voxel_center:
Expand Down Expand Up @@ -89,7 +93,7 @@ def forward(self, features, num_points, coors):
features (torch.Tensor): Point features or raw points in shape
(N, M, C).
num_points (torch.Tensor): Number of points in each pillar.
coors (torch.Tensor): Coordinates of each voxel
coors (torch.Tensor): Coordinates of each voxel.

Returns:
torch.Tensor: Features of pillars.
Expand All @@ -104,14 +108,24 @@ def forward(self, features, num_points, coors):
features_ls.append(f_cluster)

# Find distance of x, y, and z from pillar center
dtype = features.dtype
if self._with_voxel_center:
f_center = features[:, :, :2]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
if self.legacy:
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
f_center = torch.zeros_like(features[:, :, :2])
f_center[:, :, 0] = features[:, :, 0] - (
coors[:, 3].to(dtype).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = features[:, :, 1] - (
coors[:, 2].to(dtype).unsqueeze(1) * self.vy +
self.y_offset)
else:
f_center = features[:, :, :2]
f_center[:, :, 0] = f_center[:, :, 0] - (
coors[:, 3].type_as(features).unsqueeze(1) * self.vx +
self.x_offset)
f_center[:, :, 1] = f_center[:, :, 1] - (
coors[:, 2].type_as(features).unsqueeze(1) * self.vy +
self.y_offset)
features_ls.append(f_center)

if self._with_distance:
Expand Down
8 changes: 6 additions & 2 deletions mmdet3d/models/voxel_encoders/voxel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ class HardSimpleVFE(nn.Module):
"""Simple voxel feature encoder used in SECOND.
It simply averages the values of points in a voxel.
Args:
num_features (int): Number of features to use. Default: 4.
"""
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self):
def __init__(self, num_features=4):
super(HardSimpleVFE, self).__init__()
self.num_features = num_features

def forward(self, features, num_points, coors):
"""Forward function.
Expand All @@ -32,7 +36,7 @@ def forward(self, features, num_points, coors):
Returns:
torch.Tensor: Mean of points inside each voxel in shape (N, 3(4))
"""
points_mean = features[:, :, :4].sum(
points_mean = features[:, :, :self.num_features].sum(
dim=1, keepdim=False) / num_points.type_as(features).view(-1, 1)
return points_mean.contiguous()

Expand Down
26 changes: 26 additions & 0 deletions tests/test_middle_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import pytest
import torch

from mmdet3d.models.builder import build_middle_encoder


def test_sparse_encoder():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')
sparse_encoder_cfg = dict(
type='SparseEncoder',
in_channels=5,
sparse_shape=[40, 1024, 1024],
order=('conv', 'norm', 'act'),
encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128,
128)),
encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1, 1), (1, 1,
1)),
option='basicblock')

sparse_encoder = build_middle_encoder(sparse_encoder_cfg).cuda()
voxel_features = torch.rand([207842, 5]).cuda()
coors = torch.randint(0, 4, [207842, 4]).cuda()

ret = sparse_encoder(voxel_features, coors, 4)
assert ret.shape == torch.Size([4, 256, 128, 128])
32 changes: 32 additions & 0 deletions tests/test_necks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch

from mmdet3d.models.builder import build_backbone, build_neck


def test_centerpoint_rpn():
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
second_cfg = dict(
type='SECOND',
in_channels=64,
out_channels=[64, 128, 256],
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
conv_cfg=dict(type='Conv2d', bias=False))

second = build_backbone(second_cfg)

centerpoint_fpn_cfg = dict(
type='SECONDFPN',
in_channels=[64, 128, 256],
out_channels=[128, 128, 128],
upsample_strides=[0.5, 1, 2],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=True)

second_fpn = build_neck(centerpoint_fpn_cfg)

input = torch.rand([4, 64, 512, 512])
sec_output = second(input)
output = second_fpn(sec_output)
assert output[0].shape == torch.Size([4, 384, 128, 128])
43 changes: 43 additions & 0 deletions tests/test_voxel_encoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import numpy as np
import torch

from mmdet3d.models.builder import build_voxel_encoder


def _set_seed():
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)


def test_pillar_feature_net():
_set_seed()
pillar_feature_net_cfg = dict(
type='PillarFeatureNet',
in_channels=5,
feat_channels=[64],
with_distance=False,
voxel_size=(0.2, 0.2, 8),
point_cloud_range=(-51.2, -51.2, -5.0, 51.2, 51.2, 3.0),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved
)

pillar_feature_net = build_voxel_encoder(pillar_feature_net_cfg)

features = torch.rand([97297, 20, 5])
num_voxels = torch.randint(1, 100, [97297])
coors = torch.randint(0, 100, [97297, 4])

features = pillar_feature_net(features, num_voxels, coors)
assert features.shape == torch.Size([97297, 64])


def test_hard_simple_VFE():
hard_simple_VFE_cfg = dict(type='HardSimpleVFE', num_features=5)
hard_simple_VFE = build_voxel_encoder(hard_simple_VFE_cfg)
features = torch.rand([240000, 10, 5])
num_voxels = torch.randint(1, 10, [240000])

outputs = hard_simple_VFE(features, num_voxels, None)
assert outputs.shape == torch.Size([240000, 5])
yinchimaoliang marked this conversation as resolved.
Show resolved Hide resolved