Skip to content

Commit

Permalink
[Feature] Support ImVoxelNet on KITTI (#627)
Browse files Browse the repository at this point in the history
* first try of kitti

* fix python3

* imvoxelnet is ready for open-mmlab/mmdetection3d

* apply pre-commit

* update to ConvModule and AlignedAnchor3DGenerator

* add unit tests

* fix torch.Tensor in docstrings

* revert anchor ranges
  • Loading branch information
filaPro authored Jun 16, 2021
1 parent b814e7f commit c1f6bba
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 2 deletions.
160 changes: 160 additions & 0 deletions configs/imvoxelnet/imvoxelnet_kitti-3d-car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
model = dict(
type='ImVoxelNet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=64,
num_outs=4),
neck_3d=dict(type='OutdoorImVoxelNeck', in_channels=64, out_channels=256),
bbox_head=dict(
type='Anchor3DHead',
num_classes=1,
in_channels=256,
feat_channels=256,
use_direction_classifier=True,
anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-0.16, -39.68, -1.78, 68.96, 39.68, -1.78]],
sizes=[[1.6, 3.9, 1.56]],
rotations=[0, 1.57],
reshape_out=True),
diff_rad_by_sin=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
n_voxels=[216, 248, 12],
anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-0.16, -39.68, -3.08, 68.96, 39.68, 0.76]],
rotations=[.0]),
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1),
allowed_border=0,
pos_weight=-1,
debug=False),
test_cfg=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_thr=0.01,
score_thr=0.1,
min_bbox_size=0,
nms_pre=100,
max_num=50))

dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Car']
input_modality = dict(use_lidar=False, use_camera=True)
point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(type='LoadImageFromFile'),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='Resize',
img_scale=[(1173, 352), (1387, 416)],
keep_ratio=True,
multiscale_mode='range'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', img_scale=(1280, 384), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img'])
]

data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=3,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_train.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=train_pipeline,
modality=input_modality,
classes=class_names,
test_mode=False)),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True))

optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.0001,
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
lr_config = dict(policy='step', step=[8, 11])
total_epochs = 12

checkpoint_config = dict(interval=1, max_keep_ckpts=1)
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
evaluation = dict(interval=1)
dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
3 changes: 2 additions & 1 deletion mmdet3d/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .fcos_mono3d import FCOSMono3D
from .h3dnet import H3DNet
from .imvotenet import ImVoteNet
from .imvoxelnet import ImVoxelNet
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
Expand All @@ -16,5 +17,5 @@
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'FCOSMono3D'
'FCOSMono3D', 'ImVoxelNet'
]
149 changes: 149 additions & 0 deletions mmdet3d/models/detectors/imvoxelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch

from mmdet3d.core import bbox3d2result, build_anchor_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector


@DETECTORS.register_module()
class ImVoxelNet(BaseDetector):
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_."""

def __init__(self,
backbone,
neck,
neck_3d,
bbox_head,
n_voxels,
anchor_generator,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__()
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.n_voxels = n_voxels
self.anchor_generator = build_anchor_generator(anchor_generator)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)

def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super().init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.neck.init_weights()
self.neck_3d.init_weights()
self.bbox_head.init_weights()

def extract_feat(self, img, img_metas):
"""Extract 3d features from the backbone -> fpn -> 3d projection.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
"""
x = self.backbone(img)
x = self.neck(x)[0]
points = self.anchor_generator.grid_anchors(
[self.n_voxels[::-1]], device=img.device)[0][:, :3]
volumes = []
for feature, img_meta in zip(x, img_metas):
img_scale_factor = (
points.new_tensor(img_meta['scale_factor'][:2])
if 'scale_factor' in img_meta.keys() else 1)
img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
img_crop_offset = (
points.new_tensor(img_meta['img_crop_offset'])
if 'img_crop_offset' in img_meta.keys() else 0)
volume = point_sample(
img_meta,
img_features=feature[None, ...],
points=points,
lidar2img_rt=points.new_tensor(img_meta['lidar2img']),
img_scale_factor=img_scale_factor,
img_crop_offset=img_crop_offset,
img_flip=img_flip,
img_pad_shape=img.shape[-2:],
img_shape=img_meta['img_shape'][:2],
aligned=False)
volumes.append(
volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
x = torch.stack(volumes)
x = self.neck_3d(x)
return x

def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d,
**kwargs):
"""Forward of training.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.
Returns:
dict[str, torch.Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas)
return losses

def forward_test(self, img, img_metas, **kwargs):
"""Forward of testing.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
# not supporting aug_test for now
return self.simple_test(img, img_metas)

def simple_test(self, img, img_metas):
"""Test without augmentations.
Args:
img (torch.Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*x, img_metas)
bbox_results = [
bbox3d2result(det_bboxes, det_scores, det_labels)
for det_bboxes, det_scores, det_labels in bbox_list
]
return bbox_results

def aug_test(self, imgs, img_metas, **kwargs):
"""Test with augmentations.
Args:
imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
Returns:
list[dict]: Predicted 3d boxes.
"""
raise NotImplementedError
3 changes: 2 additions & 1 deletion mmdet3d/models/necks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mmdet.models.necks.fpn import FPN
from .imvoxel_neck import OutdoorImVoxelNeck
from .second_fpn import SECONDFPN

__all__ = ['FPN', 'SECONDFPN']
__all__ = ['FPN', 'SECONDFPN', 'OutdoorImVoxelNeck']
Loading

0 comments on commit c1f6bba

Please sign in to comment.