Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Spvcnn backbone #2320

Merged
merged 93 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
b275217
add cylindrical voxelization & voxel feature encoder
sunjiahao1999 Jan 18, 2023
8c08f98
add cylindrical voxelization & voxel feature encoder
sunjiahao1999 Jan 18, 2023
ea4116b
add voxel-wise label & voxelization UT
sunjiahao1999 Feb 3, 2023
532d0f9
fix vfe
sunjiahao1999 Feb 3, 2023
ca3d097
fix vfe UT
sunjiahao1999 Feb 3, 2023
85416ce
rename voxel encoder & add more test case
sunjiahao1999 Feb 7, 2023
460a002
Merge branch 'dev-1.x' into cylinder_voxel
sunjiahao1999 Feb 7, 2023
a693b8b
fix type hint
sunjiahao1999 Feb 7, 2023
56b554d
temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for da…
sunjiahao1999 Feb 7, 2023
953085d
_forward
sunjiahao1999 Feb 14, 2023
aee48b9
del checkpoints
sunjiahao1999 Feb 14, 2023
d5e835e
add if tp
sunjiahao1999 Feb 14, 2023
d4c96b2
add predict
sunjiahao1999 Feb 16, 2023
20a6eb4
fix vfe init bug & fix UT
sunjiahao1999 Feb 16, 2023
f0b7599
add grid_size & move voxelization code
sunjiahao1999 Feb 16, 2023
f8a310d
fix import bug
sunjiahao1999 Feb 16, 2023
42be1af
keep radian to follow origin
sunjiahao1999 Feb 16, 2023
9b5ac39
add doc string
sunjiahao1999 Feb 17, 2023
7715601
fix type hint
sunjiahao1999 Feb 17, 2023
546ddac
Merge branch 'cylinder_voxel' into minkunet
sunjiahao1999 Feb 17, 2023
def76bd
add minkunet voxelization and loss function
sunjiahao1999 Feb 17, 2023
c54ef5d
fix data
sunjiahao1999 Feb 17, 2023
6539d04
Merge branch 'dev-1.x' into minkunet
sunjiahao1999 Feb 17, 2023
8edd1c5
Merge branch 'dev-1.x' into minkunet 2
sunjiahao1999 Feb 17, 2023
cc0ea24
init train
sunjiahao1999 Feb 19, 2023
4865bc0
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
9ceb10c
fix sparsetensor typehint
sunjiahao1999 Feb 19, 2023
af36c48
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
84d4e99
rename dir
sunjiahao1999 Feb 19, 2023
84beaf2
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
3d1967f
fix data config
sunjiahao1999 Feb 19, 2023
abbb30b
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
7608b07
fix data config
sunjiahao1999 Feb 19, 2023
6f0116f
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
64546c9
fix batch_size & replace dynamic_scatter
sunjiahao1999 Feb 21, 2023
f5bdf1e
fix confilcts
sunjiahao1999 Feb 21, 2023
1896b30
fix conflicts 2
sunjiahao1999 Feb 21, 2023
d787a39
fix conflicts on s_70
sunjiahao1999 Feb 21, 2023
8db57c2
Alignment of the original implementation
sunjiahao1999 Feb 27, 2023
0e53760
rename config
sunjiahao1999 Feb 27, 2023
bd73a54
add worker_init_fn_hook
sunjiahao1999 Feb 27, 2023
2e002fd
remove test_config & worker hook
sunjiahao1999 Feb 27, 2023
df91ff7
Merge branch 'dev-1.x' into minkunet
sunjiahao1999 Feb 27, 2023
883586a
add UT
sunjiahao1999 Feb 27, 2023
8a629f5
fix polarmix UT
sunjiahao1999 Feb 27, 2023
56b9e2b
init spcvnn backbone
sunjiahao1999 Feb 28, 2023
10bdefe
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Feb 28, 2023
55bffc5
add seed for cr0p5
sunjiahao1999 Mar 1, 2023
3317d67
Merge branch 'minkunet_train' into minkunet
sunjiahao1999 Mar 1, 2023
963d0e3
spvcnn_init
sunjiahao1999 Mar 1, 2023
e66e5a7
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 1, 2023
06ad676
format
sunjiahao1999 Mar 1, 2023
5223d96
rename SemanticKittiDataset
sunjiahao1999 Mar 2, 2023
372ecba
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 6, 2023
44d1e6f
merge from dev-1.x
sunjiahao1999 Mar 6, 2023
9f6fc1c
add platte & fix visual bug
sunjiahao1999 Mar 6, 2023
7ccd1f9
add platte & fix data info bug
sunjiahao1999 Mar 7, 2023
8c91004
Merge branch 'minkunet' into spvcnn
sunjiahao1999 Mar 7, 2023
8d8864e
fix ut
sunjiahao1999 Mar 7, 2023
93461ee
fix ut
sunjiahao1999 Mar 7, 2023
40fcdda
fix semantic_kitti ut
sunjiahao1999 Mar 7, 2023
9794292
merge format semantic_kitti
sunjiahao1999 Mar 7, 2023
cf9e518
train init
sunjiahao1999 Mar 7, 2023
221798f
fix docstring
sunjiahao1999 Mar 8, 2023
3652f2d
fix config name
sunjiahao1999 Mar 8, 2023
e78b860
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 8, 2023
b7cca29
merge from dev-1.x
sunjiahao1999 Mar 8, 2023
5506766
merge format_semantickitti & rename segmentor ut
sunjiahao1999 Mar 8, 2023
5d787e0
rename layer
sunjiahao1999 Mar 8, 2023
c45b268
fix doc string
sunjiahao1999 Mar 8, 2023
08d9df2
fix review
sunjiahao1999 Mar 15, 2023
ce325f5
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 15, 2023
4827aca
fix conflicts
sunjiahao1999 Mar 15, 2023
57ab8bd
remove filter data
sunjiahao1999 Mar 15, 2023
c7dc194
merge dev-1.x
sunjiahao1999 Mar 16, 2023
9a3689f
Merge branch 'minkunet' into spvcnn
sunjiahao1999 Mar 16, 2023
92d9a02
rename config
sunjiahao1999 Mar 16, 2023
6392474
rename backbone
sunjiahao1999 Mar 16, 2023
fe56a64
rename backbone 2
sunjiahao1999 Mar 16, 2023
c1dc4fe
refactor voxel2point
sunjiahao1999 Mar 16, 2023
6658913
fix coors typo
sunjiahao1999 Mar 16, 2023
9d9b13f
fix ut
sunjiahao1999 Mar 16, 2023
6092aae
fix ut
sunjiahao1999 Mar 16, 2023
dbb2e68
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 20, 2023
7d9b05f
merge dev-1.x
sunjiahao1999 Mar 20, 2023
f7a44a1
pred in segmentor
sunjiahao1999 Mar 20, 2023
8ac5ed0
merge from minkunet
sunjiahao1999 Mar 21, 2023
c2b57a7
fix get voxel seg
sunjiahao1999 Mar 22, 2023
afb9ecb
resolve comments
sunjiahao1999 Mar 22, 2023
98da87e
Merge branch 'minkunet' into spvcnn
sunjiahao1999 Mar 22, 2023
c983cb1
rename p2v and v2p
sunjiahao1999 Mar 22, 2023
3c9e898
rename points and voxels
sunjiahao1999 Mar 23, 2023
445f58a
merge minkunet
sunjiahao1999 Mar 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions configs/_base_/models/spvcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
model = dict(
type='MinkUNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='minkunet',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
backbone=dict(
type='SPVCNNBackbone',
in_channels=4,
base_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
drop_ratio=0.3),
decode_head=dict(
type='MinkUNetHead',
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
10 changes: 10 additions & 0 deletions configs/spvcnn/spvcnn_w16_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']

model = dict(
backbone=dict(
base_channels=16,
encoder_channels=[16, 32, 64, 128],
decoder_channels=[128, 64, 48, 48]),
decode_head=dict(channels=48))

randomness = dict(seed=1588147245)
8 changes: 8 additions & 0 deletions configs/spvcnn/spvcnn_w20_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./spvcnn_w32_8xb2-15e_semantickitti.py']

model = dict(
backbone=dict(
base_channels=20,
encoder_channels=[20, 40, 81, 163],
decoder_channels=[163, 81, 61, 61]),
decode_head=dict(channels=61))
54 changes: 54 additions & 0 deletions configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/spvcnn.py',
'../_base_/default_runtime.py'
]

train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping'),
dict(
type='GlobalRotScaleTrans',
rot_range=[0., 6.28318531],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))

lr = 0.24
optim_wrapper = dict(
type='AmpOptimWrapper',
loss_scale='dynamic',
optimizer=dict(
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=15,
by_epoch=True,
eta_min=1e-5,
convert_to_iter_based=True)
]

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
env_cfg = dict(cudnn_benchmark=True)
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from .pointnet2_sa_msg import PointNet2SAMSG
from .pointnet2_sa_ssg import PointNet2SASSG
from .second import SECOND
from .spvcnn_backone import SPVCNNBackbone

__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
'MinkUNetBackbone'
'MinkUNetBackbone', 'SPVCNNBackbone'
]
237 changes: 237 additions & 0 deletions mmdet3d/models/backbones/spvcnn_backone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence

import torch
from mmengine.registry import MODELS
from torch import Tensor, nn

from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig
from .minkunet_backbone import MinkUNetBackbone

if IS_TORCHSPARSE_AVAILABLE:
import torchsparse
import torchsparse.nn.functional as F
from torchsparse.nn.utils import get_kernel_offsets
from torchsparse.tensor import PointTensor, SparseTensor
else:
PointTensor = SparseTensor = None


@MODELS.register_module()
class SPVCNNBackbone(MinkUNetBackbone):
"""SPVCNN backbone with torchsparse backend.

More details can be found in `paper <https://arxiv.org/abs/2007.16100>`_ .

Args:
in_channels (int): Number of input voxel feature channels.
Defaults to 4.
base_channels (int): The input channels for first encoder layer.
Defaults to 32.
encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256].
decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
drop_ratio (float): Dropout ratio of voxel features. Defaults to 0.3.
init_cfg (dict or :obj:`ConfigDict` or list[dict or :obj:`ConfigDict`]
, optional): Initialization config dict. Defaults to None.
"""

def __init__(self,
in_channels: int = 4,
base_channels: int = 32,
encoder_channels: Sequence[int] = [32, 64, 128, 256],
decoder_channels: Sequence[int] = [256, 128, 96, 96],
num_stages: int = 4,
drop_ratio: float = 0.3,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(
in_channels=in_channels,
base_channels=base_channels,
encoder_channels=encoder_channels,
decoder_channels=decoder_channels,
num_stages=num_stages,
init_cfg=init_cfg)

self.point_transforms = nn.ModuleList([
nn.Sequential(
nn.Linear(base_channels, encoder_channels[-1]),
nn.BatchNorm1d(encoder_channels[-1]), nn.ReLU(True)),
nn.Sequential(
nn.Linear(encoder_channels[-1], decoder_channels[2]),
nn.BatchNorm1d(decoder_channels[2]), nn.ReLU(True)),
nn.Sequential(
nn.Linear(decoder_channels[2], decoder_channels[4]),
nn.BatchNorm1d(decoder_channels[4]), nn.ReLU(True))
])
self.dropout = nn.Dropout(drop_ratio, True)

def forward(self, voxel_features: Tensor, coors: Tensor) -> PointTensor:
"""Forward function.

Args:
voxel_features (Tensor): Voxel features in shape (N, C).
coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).

Returns:
PointTensor: Backbone features.
"""
voxels = SparseTensor(voxel_features, coors)
points = PointTensor(voxels.F, voxels.C.float())
voxels = self.initial_voxelize(points)

voxels = self.conv_input(voxels)
points = self.voxel_to_point(voxels, points)
voxels = self.point_to_voxel(voxels, points)
laterals = [voxels]
for encoder in self.encoder:
voxels = encoder(voxels)
laterals.append(voxels)
laterals = laterals[:-1][::-1]

points = self.voxel_to_point(voxels, points, self.point_transforms[0])
voxels = self.point_to_voxel(voxels, points)
voxels.F = self.dropout(voxels.F)

decoder_outs = []
for i, decoder in enumerate(self.decoder):
voxels = decoder[0](voxels)
voxels = torchsparse.cat((voxels, laterals[i]))
voxels = decoder[1](voxels)
decoder_outs.append(voxels)
if i == 1:
points = self.voxel_to_point(voxels, points,
self.point_transforms[1])
voxels = self.point_to_voxel(voxels, points)
voxels.F = self.dropout(voxels.F)

points = self.voxel_to_point(voxels, points, self.point_transforms[2])
return points

def initial_voxelize(self, points: PointTensor) -> SparseTensor:
"""Voxelization again based on input PointTensor.

Args:
points (PointTensor): Input points after voxelization.

Returns:
SparseTensor: New voxels.
"""
pc_hash = F.sphash(torch.floor(points.C).int())
sparse_hash = torch.unique(pc_hash)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), len(sparse_hash))

inserted_coords = F.spvoxelize(
torch.floor(points.C), idx_query, counts)
inserted_coords = torch.round(inserted_coords).int()
inserted_feat = F.spvoxelize(points.F, idx_query, counts)

new_tensor = SparseTensor(inserted_feat, inserted_coords, 1)
new_tensor.cmaps.setdefault(new_tensor.stride, new_tensor.coords)
points.additional_features['idx_query'][1] = idx_query
points.additional_features['counts'][1] = counts
return new_tensor

def voxel_to_point(self,
voxels: SparseTensor,
points: PointTensor,
point_transform: Optional[nn.Module] = None,
nearest: bool = False) -> PointTensor:
"""Feed voxel features to points.

Args:
voxels (SparseTensor): Input voxels.
points (PointTensor): Input points.
point_transform (nn.Module, optional): Point transform module
for input point features. Defaults to None.
nearest (bool): Whether to use nearest neighbor interpolation.
Defaults to False.

Returns:
PointTensor: Points with new features.
"""
if points.idx_query is None or points.weights is None or \
points.idx_query.get(voxels.s) is None or \
points.weights.get(voxels.s) is None:
offsets = get_kernel_offsets(
2, voxels.s, 1, device=points.F.device)
old_hash = F.sphash(
torch.cat([
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
voxels.s[0], points.C[:, -1].int().view(-1, 1)
], 1), offsets)
pc_hash = F.sphash(voxels.C.to(points.F.device))
idx_query = F.sphashquery(old_hash, pc_hash)
weights = F.calc_ti_weights(
points.C, idx_query,
scale=voxels.s[0]).transpose(0, 1).contiguous()
idx_query = idx_query.transpose(0, 1).contiguous()
if nearest:
weights[:, 1:] = 0.
idx_query[:, 1:] = -1
new_features = F.spdevoxelize(voxels.F, idx_query, weights)
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features
new_tensor.idx_query[voxels.s] = idx_query
new_tensor.weights[voxels.s] = weights
points.idx_query[voxels.s] = idx_query
points.weights[voxels.s] = weights
else:
new_features = F.spdevoxelize(voxels.F,
points.idx_query.get(voxels.s),
points.weights.get(voxels.s))
new_tensor = PointTensor(
new_features,
points.C,
idx_query=points.idx_query,
weights=points.weights)
new_tensor.additional_features = points.additional_features

if point_transform is not None:
new_tensor.F = new_tensor.F + point_transform(points.F)

return new_tensor

def point_to_voxel(self, voxels: SparseTensor,
points: PointTensor) -> SparseTensor:
"""Feed point features to voxels.

Args:
voxels (SparseTensor): Input voxels.
points (PointTensor): Input points.

Returns:
SparseTensor: Voxels with new features.
"""
if points.additional_features is None or \
points.additional_features.get('idx_query') is None or \
points.additional_features['idx_query'].get(voxels.s) is None:
pc_hash = F.sphash(
torch.cat([
torch.floor(points.C[:, :3] / voxels.s[0]).int() *
voxels.s[0], points.C[:, -1].int().view(-1, 1)
], 1))
sparse_hash = F.sphash(voxels.C)
idx_query = F.sphashquery(pc_hash, sparse_hash)
counts = F.spcount(idx_query.int(), voxels.C.shape[0])
points.additional_features['idx_query'][voxels.s] = idx_query
points.additional_features['counts'][voxels.s] = counts
else:
idx_query = points.additional_features['idx_query'][voxels.s]
counts = points.additional_features['counts'][voxels.s]

inserted_features = F.spvoxelize(points.F, idx_query, counts)
new_tensor = SparseTensor(inserted_features, voxels.C, voxels.s)
new_tensor.cmaps = voxels.cmaps
new_tensor.kmaps = voxels.kmaps

return new_tensor
34 changes: 34 additions & 0 deletions tests/test_models/test_backbones/test_spvcnn_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch
import torch.nn.functional as F

from mmdet3d.registry import MODELS


def test_spvcnn_backbone():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')

try:
import torchsparse # noqa: F401
except ImportError:
pytest.skip('test requires Torchsparse installation')

coordinates, features = [], []
for i in range(2):
c = torch.randint(0, 10, (100, 3)).int()
c = F.pad(c, (0, 1), mode='constant', value=i)
coordinates.append(c)
f = torch.rand(100, 4)
features.append(f)
features = torch.cat(features, dim=0).cuda()
coordinates = torch.cat(coordinates, dim=0).cuda()

cfg = dict(type='SPVCNNBackbone')
self = MODELS.build(cfg).cuda()
self.init_weights()

y = self(features, coordinates)
assert y.F.shape == torch.Size([200, 96])
assert y.C.shape == torch.Size([200, 4])