Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add MinkUNet segmentor #2294

Merged
merged 74 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 73 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
b275217
add cylindrical voxelization & voxel feature encoder
sunjiahao1999 Jan 18, 2023
8c08f98
add cylindrical voxelization & voxel feature encoder
sunjiahao1999 Jan 18, 2023
ea4116b
add voxel-wise label & voxelization UT
sunjiahao1999 Feb 3, 2023
532d0f9
fix vfe
sunjiahao1999 Feb 3, 2023
ca3d097
fix vfe UT
sunjiahao1999 Feb 3, 2023
85416ce
rename voxel encoder & add more test case
sunjiahao1999 Feb 7, 2023
460a002
Merge branch 'dev-1.x' into cylinder_voxel
sunjiahao1999 Feb 7, 2023
a693b8b
fix type hint
sunjiahao1999 Feb 7, 2023
56b554d
temporarily refactoring mmcv's voxelize and dynamic in mmdet3d for da…
sunjiahao1999 Feb 7, 2023
953085d
_forward
sunjiahao1999 Feb 14, 2023
aee48b9
del checkpoints
sunjiahao1999 Feb 14, 2023
d5e835e
add if tp
sunjiahao1999 Feb 14, 2023
d4c96b2
add predict
sunjiahao1999 Feb 16, 2023
20a6eb4
fix vfe init bug & fix UT
sunjiahao1999 Feb 16, 2023
f0b7599
add grid_size & move voxelization code
sunjiahao1999 Feb 16, 2023
f8a310d
fix import bug
sunjiahao1999 Feb 16, 2023
42be1af
keep radian to follow origin
sunjiahao1999 Feb 16, 2023
9b5ac39
add doc string
sunjiahao1999 Feb 17, 2023
7715601
fix type hint
sunjiahao1999 Feb 17, 2023
546ddac
Merge branch 'cylinder_voxel' into minkunet
sunjiahao1999 Feb 17, 2023
def76bd
add minkunet voxelization and loss function
sunjiahao1999 Feb 17, 2023
c54ef5d
fix data
sunjiahao1999 Feb 17, 2023
6539d04
Merge branch 'dev-1.x' into minkunet
sunjiahao1999 Feb 17, 2023
8edd1c5
Merge branch 'dev-1.x' into minkunet 2
sunjiahao1999 Feb 17, 2023
cc0ea24
init train
sunjiahao1999 Feb 19, 2023
4865bc0
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
9ceb10c
fix sparsetensor typehint
sunjiahao1999 Feb 19, 2023
af36c48
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
84d4e99
rename dir
sunjiahao1999 Feb 19, 2023
84beaf2
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
3d1967f
fix data config
sunjiahao1999 Feb 19, 2023
abbb30b
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
7608b07
fix data config
sunjiahao1999 Feb 19, 2023
6f0116f
Merge branch 'minkunet_train' of https://github.com/sunjiahao1999/mmd…
sunjiahao1999 Feb 19, 2023
64546c9
fix batch_size & replace dynamic_scatter
sunjiahao1999 Feb 21, 2023
f5bdf1e
fix confilcts
sunjiahao1999 Feb 21, 2023
1896b30
fix conflicts 2
sunjiahao1999 Feb 21, 2023
d787a39
fix conflicts on s_70
sunjiahao1999 Feb 21, 2023
8db57c2
Alignment of the original implementation
sunjiahao1999 Feb 27, 2023
0e53760
rename config
sunjiahao1999 Feb 27, 2023
bd73a54
add worker_init_fn_hook
sunjiahao1999 Feb 27, 2023
2e002fd
remove test_config & worker hook
sunjiahao1999 Feb 27, 2023
df91ff7
Merge branch 'dev-1.x' into minkunet
sunjiahao1999 Feb 27, 2023
883586a
add UT
sunjiahao1999 Feb 27, 2023
8a629f5
fix polarmix UT
sunjiahao1999 Feb 27, 2023
10bdefe
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Feb 28, 2023
55bffc5
add seed for cr0p5
sunjiahao1999 Mar 1, 2023
3317d67
Merge branch 'minkunet_train' into minkunet
sunjiahao1999 Mar 1, 2023
e66e5a7
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 1, 2023
06ad676
format
sunjiahao1999 Mar 1, 2023
5223d96
rename SemanticKittiDataset
sunjiahao1999 Mar 2, 2023
372ecba
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 6, 2023
44d1e6f
merge from dev-1.x
sunjiahao1999 Mar 6, 2023
9f6fc1c
add platte & fix visual bug
sunjiahao1999 Mar 6, 2023
7ccd1f9
add platte & fix data info bug
sunjiahao1999 Mar 7, 2023
93461ee
fix ut
sunjiahao1999 Mar 7, 2023
40fcdda
fix semantic_kitti ut
sunjiahao1999 Mar 7, 2023
221798f
fix docstring
sunjiahao1999 Mar 8, 2023
3652f2d
fix config name
sunjiahao1999 Mar 8, 2023
e78b860
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 8, 2023
b7cca29
merge from dev-1.x
sunjiahao1999 Mar 8, 2023
5506766
merge format_semantickitti & rename segmentor ut
sunjiahao1999 Mar 8, 2023
5d787e0
rename layer
sunjiahao1999 Mar 8, 2023
c45b268
fix doc string
sunjiahao1999 Mar 8, 2023
08d9df2
fix review
sunjiahao1999 Mar 15, 2023
4827aca
fix conflicts
sunjiahao1999 Mar 15, 2023
57ab8bd
remove filter data
sunjiahao1999 Mar 15, 2023
6658913
fix coors typo
sunjiahao1999 Mar 16, 2023
9d9b13f
fix ut
sunjiahao1999 Mar 16, 2023
dbb2e68
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into de…
sunjiahao1999 Mar 20, 2023
7d9b05f
merge dev-1.x
sunjiahao1999 Mar 20, 2023
f7a44a1
pred in segmentor
sunjiahao1999 Mar 20, 2023
c2b57a7
fix get voxel seg
sunjiahao1999 Mar 22, 2023
afb9ecb
resolve comments
sunjiahao1999 Mar 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions configs/_base_/models/minkunet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
model = dict(
type='MinkUNet',
data_preprocessor=dict(
type='Det3DDataPreprocessor',
voxel=True,
voxel_type='minkunet',
voxel_layer=dict(
max_num_points=-1,
point_cloud_range=[-100, -100, -20, 100, 100, 20],
voxel_size=[0.05, 0.05, 0.05],
max_voxels=(-1, -1)),
),
backbone=dict(
type='MinkUNetBackbone',
in_channels=4,
base_channels=32,
encoder_channels=[32, 64, 128, 256],
decoder_channels=[256, 128, 96, 96],
num_stages=4,
init_cfg=None),
decode_head=dict(
type='MinkUNetHead',
channels=96,
num_classes=19,
dropout_ratio=0,
loss_decode=dict(type='mmdet.CrossEntropyLoss', avg_non_ignore=True),
ignore_index=19),
train_cfg=dict(),
test_cfg=dict())
13 changes: 13 additions & 0 deletions configs/minkunet/minkunet_w16_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved

model = dict(
backbone=dict(
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
base_channels=16,
encoder_channels=[16, 32, 64, 128],
decoder_channels=[128, 64, 48, 48]),
decode_head=dict(channels=48))

# NOTE: Due to TorchSparse backend, the model performance is relatively
# dependent on random seeds, and if random seeds are not specified the
# model performance will be different (± 1.5 mIoU).
randomness = dict(seed=1588147245)
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
8 changes: 8 additions & 0 deletions configs/minkunet/minkunet_w20_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
_base_ = ['./minkunet_w32_8xb2-15e_semantickitti.py']

model = dict(
backbone=dict(
base_channels=20,
encoder_channels=[20, 40, 81, 163],
decoder_channels=[163, 81, 61, 61]),
decode_head=dict(channels=61))
54 changes: 54 additions & 0 deletions configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
_base_ = [
'../_base_/datasets/semantickitti.py', '../_base_/models/minkunet.py',
'../_base_/default_runtime.py'
]

train_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_seg_3d=True,
seg_3d_dtype='np.int32',
seg_offset=2**16,
dataset_type='semantickitti'),
dict(type='PointSegClassMapping'),
dict(
type='GlobalRotScaleTrans',
rot_range=[0., 6.28318531],
scale_ratio_range=[0.95, 1.05],
translation_std=[0, 0, 0],
),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved

lr = 0.24
optim_wrapper = dict(
type='AmpOptimWrapper',
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
loss_scale='dynamic',
optimizer=dict(
type='SGD', lr=lr, weight_decay=0.0001, momentum=0.9, nesterov=True))

param_scheduler = [
dict(
type='LinearLR', start_factor=0.008, by_epoch=False, begin=0, end=125),
dict(
type='CosineAnnealingLR',
begin=0,
T_max=15,
by_epoch=True,
eta_min=1e-5,
convert_to_iter_based=True)
]

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=15, val_interval=1)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')

default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1))
randomness = dict(seed=0, deterministic=False, diff_rank_seed=True)
env_cfg = dict(cudnn_benchmark=True)
4 changes: 3 additions & 1 deletion mmdet3d/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .mink_resnet import MinkResNet
from .minkunet_backbone import MinkUNetBackbone
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
Expand All @@ -14,5 +15,6 @@
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv'
'MultiBackbone', 'DLANet', 'MinkResNet', 'Asymm3DSpconv',
'MinkUNetBackbone'
]
121 changes: 121 additions & 0 deletions mmdet3d/models/backbones/minkunet_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

from mmengine.model import BaseModule
from mmengine.registry import MODELS
from torch import Tensor, nn

from mmdet3d.models.layers import (TorchSparseConvModule,
TorchSparseResidualBlock)
from mmdet3d.models.layers.torchsparse import IS_TORCHSPARSE_AVAILABLE
from mmdet3d.utils import OptMultiConfig

if IS_TORCHSPARSE_AVAILABLE:
import torchsparse
from torchsparse.tensor import SparseTensor
else:
SparseTensor = None


@MODELS.register_module()
class MinkUNetBackbone(BaseModule):
r"""MinkUNet backbone with TorchSparse backend.

Refer to `implementation code <https://github.com/mit-han-lab/spvnas>`_.

Args:
in_channels (int): Number of input voxel feature channels.
Defaults to 4.
base_channels (int): The input channels for first encoder layer.
Defaults to 32.
encoder_channels (List[int]): Convolutional channels of each encode
layer. Defaults to [32, 64, 128, 256].
decoder_channels (List[int]): Convolutional channels of each decode
layer. Defaults to [256, 128, 96, 96].
num_stages (int): Number of stages in encoder and decoder.
Defaults to 4.
init_cfg (dict or :obj:`ConfigDict` or List[dict or :obj:`ConfigDict`]
, optional): Initialization config dict.
"""

def __init__(self,
in_channels: int = 4,
base_channels: int = 32,
encoder_channels: List[int] = [32, 64, 128, 256],
decoder_channels: List[int] = [256, 128, 96, 96],
num_stages: int = 4,
init_cfg: OptMultiConfig = None) -> None:
super().__init__(init_cfg)
assert num_stages == len(encoder_channels) == len(decoder_channels)
self.num_stages = num_stages
self.conv_input = nn.Sequential(
TorchSparseConvModule(in_channels, base_channels, kernel_size=3),
TorchSparseConvModule(base_channels, base_channels, kernel_size=3))
self.encoder = nn.ModuleList()
self.decoder = nn.ModuleList()

encoder_channels.insert(0, base_channels)
decoder_channels.insert(0, encoder_channels[-1])
for i in range(num_stages):
self.encoder.append(
nn.Sequential(
TorchSparseConvModule(
encoder_channels[i],
encoder_channels[i],
kernel_size=2,
stride=2),
TorchSparseResidualBlock(
encoder_channels[i],
encoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
encoder_channels[i + 1],
encoder_channels[i + 1],
kernel_size=3)))

self.decoder.append(
nn.ModuleList([
TorchSparseConvModule(
decoder_channels[i],
decoder_channels[i + 1],
kernel_size=2,
stride=2,
transposed=True),
nn.Sequential(
TorchSparseResidualBlock(
decoder_channels[i + 1] + encoder_channels[-2 - i],
decoder_channels[i + 1],
kernel_size=3),
TorchSparseResidualBlock(
decoder_channels[i + 1],
decoder_channels[i + 1],
kernel_size=3))
]))

def forward(self, voxel_features: Tensor, coors: Tensor) -> SparseTensor:
"""Forward function.

Args:
voxel_features (Tensor): Voxel features in shape (N, C).
coors (Tensor): Coordinates in shape (N, 4),
the columns in the order of (x_idx, y_idx, z_idx, batch_idx).

Returns:
SparseTensor: Backbone features.
"""
x = torchsparse.SparseTensor(voxel_features, coors)
x = self.conv_input(x)
laterals = [x]
for encoder_layer in self.encoder:
x = encoder_layer(x)
laterals.append(x)
laterals = laterals[:-1][::-1]

decoder_outs = []
for i, decoder_layer in enumerate(self.decoder):
x = decoder_layer[0](x)
x = torchsparse.cat((x, laterals[i]))
x = decoder_layer[1](x)
decoder_outs.append(x)

return decoder_outs[-1]
77 changes: 77 additions & 0 deletions mmdet3d/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,33 @@ def voxelize(self, points: List[torch.Tensor],
coors.append(res_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)
elif self.voxel_type == 'minkunet':
voxels, coors = [], []
voxel_size = points[0].new_tensor(self.voxel_layer.voxel_size)
for i, (res, data_sample) in enumerate(zip(points, data_samples)):
res_coors = torch.round(res[:, :3] / voxel_size).int()
res_coors -= res_coors.min(0)[0]

res_coors_numpy = res_coors.cpu().numpy()
inds, voxel2point_map = self.sparse_quantize(
res_coors_numpy, return_index=True, return_inverse=True)
voxel2point_map = torch.from_numpy(voxel2point_map).cuda()
if self.training:
if len(inds) > 80000:
inds = np.random.choice(inds, 80000, replace=False)
inds = torch.from_numpy(inds).cuda()
data_sample.gt_pts_seg.voxel_semantic_mask \
= data_sample.gt_pts_seg.pts_semantic_mask[inds]
res_voxel_coors = res_coors[inds]
res_voxels = res[inds]
res_voxel_coors = F.pad(
res_voxel_coors, (0, 1), mode='constant', value=i)
data_sample.voxel2point_map = voxel2point_map.long()
voxels.append(res_voxels)
coors.append(res_voxel_coors)
voxels = torch.cat(voxels, dim=0)
coors = torch.cat(coors, dim=0)

else:
raise ValueError(f'Invalid voxelization type {self.voxel_type}')

Expand Down Expand Up @@ -445,3 +472,53 @@ def get_voxel_seg(self, res_coors: torch.Tensor, data_sample: SampleList):
_, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor,
res_coors, 'mean', True)
data_sample.gt_pts_seg.point2voxel_map = point2voxel_map

def ravel_hash(self, x: np.ndarray) -> np.ndarray:
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
"""Get voxel coordinates hash for np.unique().

Args:
x (np.ndarray): The voxel coordinates of points, Nx3.

Returns:
np.ndarray: Voxels coordinates hash.
"""
assert x.ndim == 2, x.shape

x = x - np.min(x, axis=0)
x = x.astype(np.uint64, copy=False)
xmax = np.max(x, axis=0).astype(np.uint64) + 1

h = np.zeros(x.shape[0], dtype=np.uint64)
for k in range(x.shape[1] - 1):
h += x[:, k]
h *= xmax[k + 1]
h += x[:, -1]
return h

def sparse_quantize(self,
coords: np.ndarray,
return_index: bool = False,
return_inverse: bool = False) -> List[np.ndarray]:
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
"""Sparse Quantization for voxel coordinates used in Minkunet.

Args:
coords (np.ndarray): The voxel coordinates of points, Nx3.
sunjiahao1999 marked this conversation as resolved.
Show resolved Hide resolved
return_index (bool): Whether to return the indices of the
unique coords, shape (M,).
return_inverse (bool): Whether to return the indices of the
original coords shape (N,).

Returns:
List[np.ndarray] or None: Return index and inverse map if
return_index and return_inverse is True.
"""
_, indices, inverse_indices = np.unique(
self.ravel_hash(coords), return_index=True, return_inverse=True)
coords = coords[indices]

outputs = []
if return_index:
outputs += [indices]
if return_inverse:
outputs += [inverse_indices]
return outputs
6 changes: 5 additions & 1 deletion mmdet3d/models/decode_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cylinder3d_head import Cylinder3DHead
from .dgcnn_head import DGCNNHead
from .minkunet_head import MinkUNetHead
from .paconv_head import PAConvHead
from .pointnet2_head import PointNet2Head

__all__ = ['PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead']
__all__ = [
'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead',
'MinkUNetHead'
]
Loading