Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support ImVoxelNet on KITTI #627

Merged
merged 10 commits into from
Jun 16, 2021
160 changes: 160 additions & 0 deletions configs/imvoxelnet/imvoxelnet_kitti-3d-car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
model = dict(
type='ImVoxelNet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
style='pytorch'),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=64,
num_outs=4),
neck_3d=dict(type='OutdoorImVoxelNeck', in_channels=64, out_channels=256),
bbox_head=dict(
type='Anchor3DHead',
num_classes=1,
in_channels=256,
feat_channels=256,
use_direction_classifier=True,
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[[0, -39.68, -1.78, 69.12 - .32, 39.68 - .32, -1.78]],
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved
sizes=[[1.6, 3.9, 1.56]],
rotations=[0, 1.57],
reshape_out=True),
diff_rad_by_sin=True,
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
n_voxels=[216, 248, 12],
anchor_generator=dict(
type='Anchor3DRangeGenerator',
ranges=[[0, -39.68, -2.92, 69.12 - .32, 39.68 - .32, 0.92 - .32]],
rotations=[.0]),
train_cfg=dict(
assigner=dict(
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.45,
min_pos_iou=0.45,
ignore_iof_thr=-1),
allowed_border=0,
pos_weight=-1,
debug=False),
test_cfg=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_thr=0.01,
score_thr=0.1,
min_bbox_size=0,
nms_pre=100,
max_num=50))

dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Car']
input_modality = dict(use_lidar=False, use_camera=True)
point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1]
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(type='LoadImageFromFile'),
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5),
dict(
type='Resize',
img_scale=[(1173, 352), (1387, 416)],
keep_ratio=True,
multiscale_mode='range'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', img_scale=(1280, 384), keep_ratio=True),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=32),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['img'])
]

data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
train=dict(
type='RepeatDataset',
times=3,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_train.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=train_pipeline,
modality=input_modality,
classes=class_names,
test_mode=False)),
val=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True),
test=dict(
type=dataset_type,
data_root=data_root,
ann_file=data_root + 'kitti_infos_val.pkl',
split='training',
pts_prefix='velodyne_reduced',
pipeline=test_pipeline,
modality=input_modality,
classes=class_names,
test_mode=True))

optimizer = dict(
type='AdamW',
lr=0.0001,
weight_decay=0.0001,
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2))
lr_config = dict(policy='step', step=[8, 11])
total_epochs = 12

checkpoint_config = dict(interval=1, max_keep_ckpts=1)
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')])
evaluation = dict(interval=1)
dist_params = dict(backend='nccl')
find_unused_parameters = True # only 1 of 4 FPN outputs is used
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
3 changes: 2 additions & 1 deletion mmdet3d/models/detectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .fcos_mono3d import FCOSMono3D
from .h3dnet import H3DNet
from .imvotenet import ImVoteNet
from .imvoxelnet import ImVoxelNet
from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN
from .mvx_two_stage import MVXTwoStageDetector
from .parta2 import PartA2
Expand All @@ -16,5 +17,5 @@
'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector',
'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet',
'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector',
'FCOSMono3D'
'FCOSMono3D', 'ImVoxelNet'
]
149 changes: 149 additions & 0 deletions mmdet3d/models/detectors/imvoxelnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torch

from mmdet3d.core import bbox3d2result, build_anchor_generator
from mmdet3d.models.fusion_layers.point_fusion import point_sample
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck
from mmdet.models.detectors import BaseDetector


@DETECTORS.register_module()
class ImVoxelNet(BaseDetector):
"""ImVoxelNet <https://arxiv.org/abs/2106.01178>."""
filaPro marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
backbone,
neck,
neck_3d,
bbox_head,
n_voxels,
anchor_generator,
train_cfg=None,
test_cfg=None,
pretrained=None):
super().__init__()
self.backbone = build_backbone(backbone)
self.neck = build_neck(neck)
self.neck_3d = build_neck(neck_3d)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.n_voxels = n_voxels
self.anchor_generator = build_anchor_generator(anchor_generator)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)

def init_weights(self, pretrained=None):
"""Initialize the weights in detector.

Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super().init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
self.neck.init_weights()
self.neck_3d.init_weights()
self.bbox_head.init_weights()

def extract_feat(self, img, img_metas):
"""Extract 3d features from the backbone -> fpn -> 3d projection.

Args:
img (Tensor): Input images of shape (N, C_in, H, W).
filaPro marked this conversation as resolved.
Show resolved Hide resolved
img_metas (list): Image metas.

Returns:
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z)
"""
x = self.backbone(img)
x = self.neck(x)[0]
points = self.anchor_generator.grid_anchors(
[self.n_voxels[::-1]], device=img.device)[0][:, :3]
volumes = []
for feature, img_meta in zip(x, img_metas):
img_scale_factor = (
points.new_tensor(img_meta['scale_factor'][:2])
if 'scale_factor' in img_meta.keys() else 1)
img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False
img_crop_offset = (
points.new_tensor(img_meta['img_crop_offset'])
if 'img_crop_offset' in img_meta.keys() else 0)
volume = point_sample(
img_meta,
img_features=feature[None, ...],
points=points,
lidar2img_rt=points.new_tensor(img_meta['lidar2img']),
img_scale_factor=img_scale_factor,
img_crop_offset=img_crop_offset,
img_flip=img_flip,
img_pad_shape=img.shape[-2:],
img_shape=img_meta['img_shape'][:2],
aligned=False)
volumes.append(
volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0))
x = torch.stack(volumes)
x = self.neck_3d(x)
return x

def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d,
**kwargs):
"""Forward of training.

Args:
img (Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch.
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch.

Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas)
return losses

def forward_test(self, img, img_metas, **kwargs):
"""Forward of testing.

Args:
img (Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.

Returns:
list[dict]: Predicted 3d boxes.
"""
# not supporting aug_test for now
return self.simple_test(img, img_metas)

def simple_test(self, img, img_metas):
"""Test without augmentations.

Args:
img (Tensor): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.

Returns:
list[dict]: Predicted 3d boxes.
"""
x = self.extract_feat(img, img_metas)
x = self.bbox_head(x)
bbox_list = self.bbox_head.get_bboxes(*x, img_metas)
bbox_results = [
bbox3d2result(det_bboxes, det_scores, det_labels)
for det_bboxes, det_scores, det_labels in bbox_list
]
return bbox_results

def aug_test(self, imgs, img_metas, **kwargs):
"""Test with augmentations.

Args:
imgs (list[Tensor]): Input images of shape (N, C_in, H, W).
img_metas (list): Image metas.

Returns:
list[dict]: Predicted 3d boxes.
"""
raise NotImplementedError
3 changes: 2 additions & 1 deletion mmdet3d/models/necks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from mmdet.models.necks.fpn import FPN
from .imvoxel_neck import OutdoorImVoxelNeck
from .second_fpn import SECONDFPN

__all__ = ['FPN', 'SECONDFPN']
__all__ = ['FPN', 'SECONDFPN', 'OutdoorImVoxelNeck']
Loading