-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support ImVoxelNet on KITTI (#627)
* first try of kitti * fix python3 * imvoxelnet is ready for open-mmlab/mmdetection3d * apply pre-commit * update to ConvModule and AlignedAnchor3DGenerator * add unit tests * fix torch.Tensor in docstrings * revert anchor ranges
- Loading branch information
Showing
7 changed files
with
473 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
model = dict( | ||
type='ImVoxelNet', | ||
pretrained='torchvision://resnet50', | ||
backbone=dict( | ||
type='ResNet', | ||
depth=50, | ||
num_stages=4, | ||
out_indices=(0, 1, 2, 3), | ||
frozen_stages=1, | ||
norm_cfg=dict(type='BN', requires_grad=False), | ||
norm_eval=True, | ||
style='pytorch'), | ||
neck=dict( | ||
type='FPN', | ||
in_channels=[256, 512, 1024, 2048], | ||
out_channels=64, | ||
num_outs=4), | ||
neck_3d=dict(type='OutdoorImVoxelNeck', in_channels=64, out_channels=256), | ||
bbox_head=dict( | ||
type='Anchor3DHead', | ||
num_classes=1, | ||
in_channels=256, | ||
feat_channels=256, | ||
use_direction_classifier=True, | ||
anchor_generator=dict( | ||
type='AlignedAnchor3DRangeGenerator', | ||
ranges=[[-0.16, -39.68, -1.78, 68.96, 39.68, -1.78]], | ||
sizes=[[1.6, 3.9, 1.56]], | ||
rotations=[0, 1.57], | ||
reshape_out=True), | ||
diff_rad_by_sin=True, | ||
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), | ||
loss_cls=dict( | ||
type='FocalLoss', | ||
use_sigmoid=True, | ||
gamma=2.0, | ||
alpha=0.25, | ||
loss_weight=1.0), | ||
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), | ||
loss_dir=dict( | ||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), | ||
n_voxels=[216, 248, 12], | ||
anchor_generator=dict( | ||
type='AlignedAnchor3DRangeGenerator', | ||
ranges=[[-0.16, -39.68, -3.08, 68.96, 39.68, 0.76]], | ||
rotations=[.0]), | ||
train_cfg=dict( | ||
assigner=dict( | ||
type='MaxIoUAssigner', | ||
iou_calculator=dict(type='BboxOverlapsNearest3D'), | ||
pos_iou_thr=0.6, | ||
neg_iou_thr=0.45, | ||
min_pos_iou=0.45, | ||
ignore_iof_thr=-1), | ||
allowed_border=0, | ||
pos_weight=-1, | ||
debug=False), | ||
test_cfg=dict( | ||
use_rotate_nms=True, | ||
nms_across_levels=False, | ||
nms_thr=0.01, | ||
score_thr=0.1, | ||
min_bbox_size=0, | ||
nms_pre=100, | ||
max_num=50)) | ||
|
||
dataset_type = 'KittiDataset' | ||
data_root = 'data/kitti/' | ||
class_names = ['Car'] | ||
input_modality = dict(use_lidar=False, use_camera=True) | ||
point_cloud_range = [0, -39.68, -3, 69.12, 39.68, 1] | ||
img_norm_cfg = dict( | ||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadAnnotations3D'), | ||
dict(type='LoadImageFromFile'), | ||
dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), | ||
dict( | ||
type='Resize', | ||
img_scale=[(1173, 352), (1387, 416)], | ||
keep_ratio=True, | ||
multiscale_mode='range'), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), | ||
dict(type='DefaultFormatBundle3D', class_names=class_names), | ||
dict(type='Collect3D', keys=['img', 'gt_bboxes_3d', 'gt_labels_3d']) | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='Resize', img_scale=(1280, 384), keep_ratio=True), | ||
dict(type='Normalize', **img_norm_cfg), | ||
dict(type='Pad', size_divisor=32), | ||
dict( | ||
type='DefaultFormatBundle3D', | ||
class_names=class_names, | ||
with_label=False), | ||
dict(type='Collect3D', keys=['img']) | ||
] | ||
|
||
data = dict( | ||
samples_per_gpu=4, | ||
workers_per_gpu=4, | ||
train=dict( | ||
type='RepeatDataset', | ||
times=3, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file=data_root + 'kitti_infos_train.pkl', | ||
split='training', | ||
pts_prefix='velodyne_reduced', | ||
pipeline=train_pipeline, | ||
modality=input_modality, | ||
classes=class_names, | ||
test_mode=False)), | ||
val=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file=data_root + 'kitti_infos_val.pkl', | ||
split='training', | ||
pts_prefix='velodyne_reduced', | ||
pipeline=test_pipeline, | ||
modality=input_modality, | ||
classes=class_names, | ||
test_mode=True), | ||
test=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file=data_root + 'kitti_infos_val.pkl', | ||
split='training', | ||
pts_prefix='velodyne_reduced', | ||
pipeline=test_pipeline, | ||
modality=input_modality, | ||
classes=class_names, | ||
test_mode=True)) | ||
|
||
optimizer = dict( | ||
type='AdamW', | ||
lr=0.0001, | ||
weight_decay=0.0001, | ||
paramwise_cfg=dict( | ||
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)})) | ||
optimizer_config = dict(grad_clip=dict(max_norm=35., norm_type=2)) | ||
lr_config = dict(policy='step', step=[8, 11]) | ||
total_epochs = 12 | ||
|
||
checkpoint_config = dict(interval=1, max_keep_ckpts=1) | ||
log_config = dict( | ||
interval=50, | ||
hooks=[dict(type='TextLoggerHook'), | ||
dict(type='TensorboardLoggerHook')]) | ||
evaluation = dict(interval=1) | ||
dist_params = dict(backend='nccl') | ||
find_unused_parameters = True # only 1 of 4 FPN outputs is used | ||
log_level = 'INFO' | ||
load_from = None | ||
resume_from = None | ||
workflow = [('train', 1)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import torch | ||
|
||
from mmdet3d.core import bbox3d2result, build_anchor_generator | ||
from mmdet3d.models.fusion_layers.point_fusion import point_sample | ||
from mmdet.models import DETECTORS, build_backbone, build_head, build_neck | ||
from mmdet.models.detectors import BaseDetector | ||
|
||
|
||
@DETECTORS.register_module() | ||
class ImVoxelNet(BaseDetector): | ||
r"""`ImVoxelNet <https://arxiv.org/abs/2106.01178>`_.""" | ||
|
||
def __init__(self, | ||
backbone, | ||
neck, | ||
neck_3d, | ||
bbox_head, | ||
n_voxels, | ||
anchor_generator, | ||
train_cfg=None, | ||
test_cfg=None, | ||
pretrained=None): | ||
super().__init__() | ||
self.backbone = build_backbone(backbone) | ||
self.neck = build_neck(neck) | ||
self.neck_3d = build_neck(neck_3d) | ||
bbox_head.update(train_cfg=train_cfg) | ||
bbox_head.update(test_cfg=test_cfg) | ||
self.bbox_head = build_head(bbox_head) | ||
self.n_voxels = n_voxels | ||
self.anchor_generator = build_anchor_generator(anchor_generator) | ||
self.train_cfg = train_cfg | ||
self.test_cfg = test_cfg | ||
self.init_weights(pretrained=pretrained) | ||
|
||
def init_weights(self, pretrained=None): | ||
"""Initialize the weights in detector. | ||
Args: | ||
pretrained (str, optional): Path to pre-trained weights. | ||
Defaults to None. | ||
""" | ||
super().init_weights(pretrained) | ||
self.backbone.init_weights(pretrained=pretrained) | ||
self.neck.init_weights() | ||
self.neck_3d.init_weights() | ||
self.bbox_head.init_weights() | ||
|
||
def extract_feat(self, img, img_metas): | ||
"""Extract 3d features from the backbone -> fpn -> 3d projection. | ||
Args: | ||
img (torch.Tensor): Input images of shape (N, C_in, H, W). | ||
img_metas (list): Image metas. | ||
Returns: | ||
torch.Tensor: of shape (N, C_out, N_x, N_y, N_z) | ||
""" | ||
x = self.backbone(img) | ||
x = self.neck(x)[0] | ||
points = self.anchor_generator.grid_anchors( | ||
[self.n_voxels[::-1]], device=img.device)[0][:, :3] | ||
volumes = [] | ||
for feature, img_meta in zip(x, img_metas): | ||
img_scale_factor = ( | ||
points.new_tensor(img_meta['scale_factor'][:2]) | ||
if 'scale_factor' in img_meta.keys() else 1) | ||
img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False | ||
img_crop_offset = ( | ||
points.new_tensor(img_meta['img_crop_offset']) | ||
if 'img_crop_offset' in img_meta.keys() else 0) | ||
volume = point_sample( | ||
img_meta, | ||
img_features=feature[None, ...], | ||
points=points, | ||
lidar2img_rt=points.new_tensor(img_meta['lidar2img']), | ||
img_scale_factor=img_scale_factor, | ||
img_crop_offset=img_crop_offset, | ||
img_flip=img_flip, | ||
img_pad_shape=img.shape[-2:], | ||
img_shape=img_meta['img_shape'][:2], | ||
aligned=False) | ||
volumes.append( | ||
volume.reshape(self.n_voxels[::-1] + [-1]).permute(3, 2, 1, 0)) | ||
x = torch.stack(volumes) | ||
x = self.neck_3d(x) | ||
return x | ||
|
||
def forward_train(self, img, img_metas, gt_bboxes_3d, gt_labels_3d, | ||
**kwargs): | ||
"""Forward of training. | ||
Args: | ||
img (torch.Tensor): Input images of shape (N, C_in, H, W). | ||
img_metas (list): Image metas. | ||
gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. | ||
gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. | ||
Returns: | ||
dict[str, torch.Tensor]: A dictionary of loss components. | ||
""" | ||
x = self.extract_feat(img, img_metas) | ||
x = self.bbox_head(x) | ||
losses = self.bbox_head.loss(*x, gt_bboxes_3d, gt_labels_3d, img_metas) | ||
return losses | ||
|
||
def forward_test(self, img, img_metas, **kwargs): | ||
"""Forward of testing. | ||
Args: | ||
img (torch.Tensor): Input images of shape (N, C_in, H, W). | ||
img_metas (list): Image metas. | ||
Returns: | ||
list[dict]: Predicted 3d boxes. | ||
""" | ||
# not supporting aug_test for now | ||
return self.simple_test(img, img_metas) | ||
|
||
def simple_test(self, img, img_metas): | ||
"""Test without augmentations. | ||
Args: | ||
img (torch.Tensor): Input images of shape (N, C_in, H, W). | ||
img_metas (list): Image metas. | ||
Returns: | ||
list[dict]: Predicted 3d boxes. | ||
""" | ||
x = self.extract_feat(img, img_metas) | ||
x = self.bbox_head(x) | ||
bbox_list = self.bbox_head.get_bboxes(*x, img_metas) | ||
bbox_results = [ | ||
bbox3d2result(det_bboxes, det_scores, det_labels) | ||
for det_bboxes, det_scores, det_labels in bbox_list | ||
] | ||
return bbox_results | ||
|
||
def aug_test(self, imgs, img_metas, **kwargs): | ||
"""Test with augmentations. | ||
Args: | ||
imgs (list[torch.Tensor]): Input images of shape (N, C_in, H, W). | ||
img_metas (list): Image metas. | ||
Returns: | ||
list[dict]: Predicted 3d boxes. | ||
""" | ||
raise NotImplementedError |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from mmdet.models.necks.fpn import FPN | ||
from .imvoxel_neck import OutdoorImVoxelNeck | ||
from .second_fpn import SECONDFPN | ||
|
||
__all__ = ['FPN', 'SECONDFPN'] | ||
__all__ = ['FPN', 'SECONDFPN', 'OutdoorImVoxelNeck'] |
Oops, something went wrong.