diff --git a/configs/_base_/models/mvx-fpn_second.py b/configs/_base_/models/mvx-fpn_second.py new file mode 100644 index 0000000000..e3eb3b6aa6 --- /dev/null +++ b/configs/_base_/models/mvx-fpn_second.py @@ -0,0 +1,126 @@ +voxel_size = [0.05, 0.05, 0.1] +point_cloud_range = [0, -40, -3, 70.4, 40, 1] + +model = dict( + type='DynamicMVXFasterRCNN', + img_backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='caffe'), + img_neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=5), + pts_voxel_layer=dict( + max_num_points=-1, + point_cloud_range=point_cloud_range, + voxel_size=voxel_size, + max_voxels=(-1, -1)), + pts_voxel_encoder=dict( + type='DynamicVFE', + in_channels=4, + feat_channels=[64, 64], + with_distance=False, + voxel_size=voxel_size, + with_cluster_center=True, + with_voxel_center=True, + point_cloud_range=point_cloud_range, + fusion_layer=dict( + type='PointFusion', + img_channels=256, + pts_channels=64, + mid_channels=128, + out_channels=128, + img_levels=[0, 1, 2, 3, 4], + align_corners=False, + activate_out=True, + fuse_out=False)), + pts_middle_encoder=dict( + type='SparseEncoder', + in_channels=128, + sparse_shape=[41, 1600, 1408], + order=('conv', 'norm', 'act')), + pts_backbone=dict( + type='SECOND', + in_channels=256, + layer_nums=[5, 5], + layer_strides=[1, 2], + out_channels=[128, 256]), + pts_neck=dict( + type='SECONDFPN', + in_channels=[128, 256], + upsample_strides=[1, 2], + out_channels=[256, 256]), + pts_bbox_head=dict( + type='Anchor3DHead', + num_classes=3, + in_channels=512, + feat_channels=512, + use_direction_classifier=True, + anchor_generator=dict( + type='Anchor3DRangeGenerator', + ranges=[ + [0, -40.0, -0.6, 70.4, 40.0, -0.6], + [0, -40.0, -0.6, 70.4, 40.0, -0.6], + [0, -40.0, -1.78, 70.4, 40.0, -1.78], + ], + sizes=[[0.6, 0.8, 1.73], [0.6, 1.76, 1.73], [1.6, 3.9, 1.56]], + rotations=[0, 1.57], + reshape_out=False), + assigner_per_size=True, + diff_rad_by_sin=True, + assign_per_class=True, + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), + loss_dir=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), + # model training and testing settings + train_cfg=dict( + pts=dict( + assigner=[ + dict( # for Pedestrian + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.35, + neg_iou_thr=0.2, + min_pos_iou=0.2, + ignore_iof_thr=-1), + dict( # for Cyclist + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.35, + neg_iou_thr=0.2, + min_pos_iou=0.2, + ignore_iof_thr=-1), + dict( # for Car + type='MaxIoUAssigner', + iou_calculator=dict(type='BboxOverlapsNearest3D'), + pos_iou_thr=0.6, + neg_iou_thr=0.45, + min_pos_iou=0.45, + ignore_iof_thr=-1), + ], + allowed_border=0, + pos_weight=-1, + debug=False)), + test_cfg=dict( + pts=dict( + use_rotate_nms=True, + nms_across_levels=False, + nms_thr=0.01, + score_thr=0.1, + min_bbox_size=0, + nms_pre=100, + max_num=50))) \ No newline at end of file diff --git a/configs/moca/dv_mvx-fpn_second_secfpn_hybrid_moca_2x8_80e_kitti-3d-3class.py b/configs/moca/dv_mvx-fpn_second_secfpn_hybrid_moca_2x8_80e_kitti-3d-3class.py new file mode 100644 index 0000000000..b987f19d35 --- /dev/null +++ b/configs/moca/dv_mvx-fpn_second_secfpn_hybrid_moca_2x8_80e_kitti-3d-3class.py @@ -0,0 +1,165 @@ +_base_ = '../_base_/models/mvx-fpn_second.py' +# model settings +voxel_size = [0.05, 0.05, 0.1] +point_cloud_range = [0, -40, -3, 70.4, 40, 1] + +# dataset settings +dataset_type = 'KittiDataset' +data_root = 'data/kitti/' +class_names = ['Pedestrian', 'Cyclist', 'Car'] +img_norm_cfg = dict( + mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) + +db_sampler = dict( + type='MMDataBaseSampler', + data_root=data_root, + info_path=data_root + 'kitti_mm_dbinfos_train.pkl', + rate=1.0, + blending_type=['box', 'gaussian', 'poisson'], + depth_consistent=True, + check_2D_collision=True, + collision_thr=[0, 0.3, 0.5, 0.7], + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=10, Cyclist=10)), + classes=class_names, + sample_groups=dict(Car=12, Pedestrian=6, Cyclist=6)) + +input_modality = dict(use_lidar=True, use_camera=True) +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='LoadImageFromFile'), + dict( + type='LoadAnnotations3D', + with_bbox_3d=True, + with_label_3d=True, + with_bbox=True, + with_label=True), + dict(type='ObjectSample', db_sampler=db_sampler, sample_2d=True), + dict( + type='Resize', + img_scale=[(640, 192), (2560, 768)], + multiscale_mode='range', + keep_ratio=True), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05], + translation_std=[0.2, 0.2, 0.2]), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointShuffle'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict( + type='Collect3D', + keys=['points', 'img', 'gt_bboxes_3d', 'gt_labels_3d']), +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1280, 384), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict(type='Resize', multiscale_mode='value', keep_ratio=True), + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size_divisor=32), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points', 'img']) + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'kitti_infos_train.pkl', + split='training', + pts_prefix='velodyne_reduced', + pipeline=train_pipeline, + modality=input_modality, + classes=class_names, + test_mode=False)), + val=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'kitti_infos_val.pkl', + split='training', + pts_prefix='velodyne_reduced', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True), + test=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + 'kitti_infos_val.pkl', + split='training', + pts_prefix='velodyne_reduced', + pipeline=test_pipeline, + modality=input_modality, + classes=class_names, + test_mode=True)) +# Training settings +optimizer = dict( + constructor='HybridOptimizerConstructor', + pts=dict( + type='AdamW', + lr=0.003, + betas=(0.95, 0.99), + weight_decay=0.01, + step_interval=1), + img=dict( + type='SGD', + lr=0.005, + momentum=0.9, + weight_decay=0.0001, + step_interval=1)) +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +lr_config = dict( + policy='CosineAnnealing', + warmup='linear', + warmup_iters=1000, + warmup_ratio=1.0 / 10, + min_lr_ratio=1e-5) +momentum_config = None +checkpoint_config = dict(interval=1) +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +evaluation = dict(interval=1) +# runtime settings +total_epochs = 40 +dist_params = dict(backend='nccl') +log_level = 'INFO' +work_dir = None +# You may need to download the model first is the network is unstable +load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa +resume_from = None +workflow = [('train', 1)] diff --git a/configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py b/configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py index 4ea320cd3f..23c9473761 100644 --- a/configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py +++ b/configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py @@ -1,132 +1,9 @@ # model settings +_base_ = '../_base_/models/mvx-fpn_second.py' + voxel_size = [0.05, 0.05, 0.1] point_cloud_range = [0, -40, -3, 70.4, 40, 1] -model = dict( - type='DynamicMVXFasterRCNN', - img_backbone=dict( - type='ResNet', - depth=50, - num_stages=4, - out_indices=(0, 1, 2, 3), - frozen_stages=1, - norm_cfg=dict(type='BN', requires_grad=False), - norm_eval=True, - style='caffe'), - img_neck=dict( - type='FPN', - in_channels=[256, 512, 1024, 2048], - out_channels=256, - num_outs=5), - pts_voxel_layer=dict( - max_num_points=-1, - point_cloud_range=point_cloud_range, - voxel_size=voxel_size, - max_voxels=(-1, -1), - ), - pts_voxel_encoder=dict( - type='DynamicVFE', - in_channels=4, - feat_channels=[64, 64], - with_distance=False, - voxel_size=voxel_size, - with_cluster_center=True, - with_voxel_center=True, - point_cloud_range=point_cloud_range, - fusion_layer=dict( - type='PointFusion', - img_channels=256, - pts_channels=64, - mid_channels=128, - out_channels=128, - img_levels=[0, 1, 2, 3, 4], - align_corners=False, - activate_out=True, - fuse_out=False)), - pts_middle_encoder=dict( - type='SparseEncoder', - in_channels=128, - sparse_shape=[41, 1600, 1408], - order=('conv', 'norm', 'act')), - pts_backbone=dict( - type='SECOND', - in_channels=256, - layer_nums=[5, 5], - layer_strides=[1, 2], - out_channels=[128, 256]), - pts_neck=dict( - type='SECONDFPN', - in_channels=[128, 256], - upsample_strides=[1, 2], - out_channels=[256, 256]), - pts_bbox_head=dict( - type='Anchor3DHead', - num_classes=3, - in_channels=512, - feat_channels=512, - use_direction_classifier=True, - anchor_generator=dict( - type='Anchor3DRangeGenerator', - ranges=[ - [0, -40.0, -0.6, 70.4, 40.0, -0.6], - [0, -40.0, -0.6, 70.4, 40.0, -0.6], - [0, -40.0, -1.78, 70.4, 40.0, -1.78], - ], - sizes=[[0.6, 0.8, 1.73], [0.6, 1.76, 1.73], [1.6, 3.9, 1.56]], - rotations=[0, 1.57], - reshape_out=False), - assigner_per_size=True, - diff_rad_by_sin=True, - assign_per_class=True, - bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), - loss_cls=dict( - type='FocalLoss', - use_sigmoid=True, - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), - loss_dir=dict( - type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)), - # model training and testing settings - train_cfg=dict( - pts=dict( - assigner=[ - dict( # for Pedestrian - type='MaxIoUAssigner', - iou_calculator=dict(type='BboxOverlapsNearest3D'), - pos_iou_thr=0.35, - neg_iou_thr=0.2, - min_pos_iou=0.2, - ignore_iof_thr=-1), - dict( # for Cyclist - type='MaxIoUAssigner', - iou_calculator=dict(type='BboxOverlapsNearest3D'), - pos_iou_thr=0.35, - neg_iou_thr=0.2, - min_pos_iou=0.2, - ignore_iof_thr=-1), - dict( # for Car - type='MaxIoUAssigner', - iou_calculator=dict(type='BboxOverlapsNearest3D'), - pos_iou_thr=0.6, - neg_iou_thr=0.45, - min_pos_iou=0.45, - ignore_iof_thr=-1), - ], - allowed_border=0, - pos_weight=-1, - debug=False)), - test_cfg=dict( - pts=dict( - use_rotate_nms=True, - nms_across_levels=False, - nms_thr=0.01, - score_thr=0.1, - min_bbox_size=0, - nms_pre=100, - max_num=50))) - # dataset settings dataset_type = 'KittiDataset' data_root = 'data/kitti/' diff --git a/mmdet3d/core/__init__.py b/mmdet3d/core/__init__.py index 9d77b76987..b8a439f4f3 100644 --- a/mmdet3d/core/__init__.py +++ b/mmdet3d/core/__init__.py @@ -6,3 +6,4 @@ from .utils import * # noqa: F401, F403 from .visualizer import * # noqa: F401, F403 from .voxel import * # noqa: F401, F403 +from .optimizer import * # noqa: F401, F403 diff --git a/mmdet3d/core/optimizer/__init__.py b/mmdet3d/core/optimizer/__init__.py new file mode 100644 index 0000000000..bf2711f333 --- /dev/null +++ b/mmdet3d/core/optimizer/__init__.py @@ -0,0 +1,4 @@ +from .hybrid_constructor import HybridOptimizerConstructor +from .hybrid_optimizer import HybridOptimizer + +__all__ = ['HybridOptimizerConstructor', 'HybridOptimizer'] diff --git a/mmdet3d/core/optimizer/hybrid_constructor.py b/mmdet3d/core/optimizer/hybrid_constructor.py new file mode 100644 index 0000000000..6f2cb620b9 --- /dev/null +++ b/mmdet3d/core/optimizer/hybrid_constructor.py @@ -0,0 +1,116 @@ +from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS +from mmcv.utils import build_from_cfg + +from mmdet3d.utils import get_root_logger +from .hybrid_optimizer import HybridOptimizer + + +@OPTIMIZER_BUILDERS.register_module() +class HybridOptimizerConstructor(object): + """Special constructor for hybrid optimizers. + This constructor constructs hybrid optimizer for multi-modality + detectors. It builds separate optimizers for separate branchs for + different modalities. More details can be found in the ECCV submission + (to be release). + Attributes: + model (:obj:`nn.Module`): The model with parameters to be optimized. + optimizer_cfg (dict): The config dict of the optimizer. The keys of + the dict are used to search for the corresponding keys in the + model, and the value if a dict that really defines the optimizer. + See example below for the usage. + paramwise_cfg (dict): The dict for paramwise options. This is not + supported in the current version. But it should be supported in + the future release. + Example: + >>> import torch + >>> import torch.nn as nn + >>> model = nn.ModuleDict({ + >>> 'pts': nn.modules.Conv1d(1, 1, 1, bias=False), + >>> 'img': nn.modules.Conv1d(1, 1, 1, bias=False) + >>> }) + >>> optimizer_cfg = dict( + >>> pts=dict(type='AdamW', lr=0.001, + >>> weight_decay=0.01, step_interval=1), + >>> img=dict(type='SGD', lr=0.02, momentum=0.9, + >>> weight_decay=0.0001, step_interval=2)) + >>> optim_builder = HybridOptimizerConstructor(optimizer_cfg) + >>> optimizer = optim_builder(model) + >>> print(optimizer) + HybridOptimizer ( + Update interval: 1 + AdamW ( + Parameter Group 0 + amsgrad: False + betas: (0.9, 0.999) + eps: 1e-08 + lr: 0.001 + weight_decay: 0.01 + ), + Update interval: 2 + SGD ( + Parameter Group 0 + dampening: 0 + lr: 0.02 + momentum: 0.9 + nesterov: False + weight_decay: 0.0001 + ), + ) + """ + + def __init__(self, optimizer_cfg, paramwise_cfg=None): + if not isinstance(optimizer_cfg, dict): + raise TypeError('optimizer_cfg should be a dict', + 'but got {}'.format(type(optimizer_cfg))) + # assert paramwise_cfg is None, \ + # 'Parameter wise config is not supported in Hybrid Optimizer' + self.paramwise_cfg = {} if paramwise_cfg is None else paramwise_cfg + self.optimizer_cfg = optimizer_cfg + self.base_lr = {x: optimizer_cfg[x].get('lr', None) for x in optimizer_cfg} + + def __call__(self, model): + if hasattr(model, 'module'): + model = model.module + + # get param-wise options + custom_keys = self.paramwise_cfg.get('custom_keys', {}) + # first sort with alphabet order and then sort with reversed len of str + sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) + + optimizer_cfg = self.optimizer_cfg.copy() + logger = get_root_logger() + keys_prefix = [key_prefix for key_prefix in optimizer_cfg.keys()] + keys_params = {key: [] for key in keys_prefix} + keys_params_name = {key: [] for key in keys_prefix} + keys_optimizer = [] + for name, param in model.named_parameters(): + param_group = {'params': [param]} + find_flag = False + for key in keys_prefix: + if key in name: + # if the parameter match one of the custom keys, ignore other rules + for custom_key in sorted_keys: + if custom_key in name: + lr_mult = custom_keys[custom_key].get('lr_mult', 1.) + param_group['lr'] = self.base_lr[key] * lr_mult + logger.info(f'learning rate of {name} is decreased by {lr_mult}') + break + + keys_params[key].append(param_group) + keys_params_name[key].append(name) + find_flag = True + break + assert find_flag, 'key {} is not matched to any optimizer'.format( + name) + + step_intervals = [] + for key, single_cfg in optimizer_cfg.items(): + step_intervals.append(single_cfg.pop('step_interval', 1)) + single_cfg['params'] = keys_params[key] + single_optim = build_from_cfg(single_cfg, OPTIMIZERS) + keys_optimizer.append(single_optim) + logger.info('{} optimizes key:\n {}\n'.format( + single_cfg['type'], keys_params_name[key])) + + hybrid_optimizer = HybridOptimizer(keys_optimizer, step_intervals) + return hybrid_optimizer diff --git a/mmdet3d/core/optimizer/hybrid_optimizer.py b/mmdet3d/core/optimizer/hybrid_optimizer.py new file mode 100644 index 0000000000..13a5f80e34 --- /dev/null +++ b/mmdet3d/core/optimizer/hybrid_optimizer.py @@ -0,0 +1,99 @@ +from mmcv.runner.optimizer import OPTIMIZERS +from torch.optim import Optimizer + + +@OPTIMIZERS.register_module() +class HybridOptimizer(Optimizer): + """Hybrid Optimizer that contains multiple optimizers This optimizer + applies the hybrid optimzation for multi-modality models.""" + + def __init__(self, optimizers, step_intervals=None): + self.optimizers = optimizers + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups += optimizer.param_groups + if not isinstance(step_intervals, list): + step_intervals = [1] * len(self.optimizers) + self.step_intervals = step_intervals + self.num_step_updated = 0 + + def __getstate__(self): + return { + 'num_step_updated': + self.num_step_updated, + 'defaults': [optimizer.defaults for optimizer in self.optimizers], + 'state': [optimizer.state for optimizer in self.optimizers], + 'param_groups': + [optimizer.param_groups for optimizer in self.optimizers], + } + + def __setstate__(self, state): + self.__dict__.update(state) + + def __repr__(self): + format_string = self.__class__.__name__ + ' (\n' + for optimizer, interval in zip(self.optimizers, self.step_intervals): + format_string += 'Update interval: {}\n'.format(interval) + format_string += optimizer.__repr__().replace('\n', '\n ') + ',\n' + format_string += ')' + return format_string + + def state_dict(self): + state_dicts = [optimizer.state_dict() for optimizer in self.optimizers] + return { + 'num_step_updated': + self.num_step_updated, + 'state': [state_dict['state'] for state_dict in state_dicts], + 'param_groups': + [state_dict['param_groups'] for state_dict in state_dicts], + } + + def load_state_dict(self, state_dict): + r"""Loads the optimizer state. + Arguments: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + assert len(state_dict['state']) == len(self.optimizers) + assert len(state_dict['param_groups']) == len(self.optimizers) + for i, (single_state, single_param_groups) in enumerate( + zip(state_dict['state'], state_dict['param_groups'])): + single_state_dict = dict( + state=single_state, param_groups=single_param_groups) + self.optimizers[i].load_state_dict(single_state_dict) + + self.param_groups = [] + for optimizer in self.optimizers: + self.param_groups += optimizer.param_groups + self.num_step_updated = state_dict['num_step_updated'] + + def zero_grad(self): + r"""Clears the gradients of all optimized :class:`torch.Tensor` s.""" + for optimizer in self.optimizers: + optimizer.zero_grad() + + def step(self, closure=None): + r"""Performs a single optimization step (parameter update). + + Arguments: + closure (callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + Returns: + Tensor or None: calculated loss if `closure` is not None. + If `closure` is None, None will be returned + """ + loss = None + if closure is not None: + loss = closure() + + self.num_step_updated += 1 + for step_interval, optimizer in zip(self.step_intervals, + self.optimizers): + if self.num_step_updated % step_interval == 0: + optimizer.step() + + return loss + + def add_param_group(self, param_group): + raise NotImplementedError diff --git a/mmdet3d/datasets/pipelines/__init__.py b/mmdet3d/datasets/pipelines/__init__.py index 724e99d8aa..69720d09fb 100644 --- a/mmdet3d/datasets/pipelines/__init__.py +++ b/mmdet3d/datasets/pipelines/__init__.py @@ -1,6 +1,6 @@ from mmdet.datasets.pipelines import Compose from .dbsampler import DataBaseSampler -from .formating import Collect3D, DefaultFormatBundle, DefaultFormatBundle3D +from .formating import Collect3D, DefaultFormatBundle3D from .loading import (LoadAnnotations3D, LoadMultiViewImageFromFiles, LoadPointsFromFile, LoadPointsFromMultiSweeps, NormalizePointsColor, PointSegClassMapping) @@ -14,7 +14,7 @@ 'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D', 'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile', - 'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler', + 'DefaultFormatBundle3D', 'DataBaseSampler', 'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample', 'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter', 'VoxelBasedPointSampler' diff --git a/mmdet3d/datasets/pipelines/dbsampler.py b/mmdet3d/datasets/pipelines/dbsampler.py index b5f57f5594..686771e17c 100644 --- a/mmdet3d/datasets/pipelines/dbsampler.py +++ b/mmdet3d/datasets/pipelines/dbsampler.py @@ -1,4 +1,5 @@ import copy +import cv2 import mmcv import numpy as np import os @@ -314,3 +315,269 @@ def sample_class_v2(self, name, num, gt_bboxes): else: valid_samples.append(sampled[i - num_gt]) return valid_samples + + +@OBJECTSAMPLERS.register_module() +class MMDataBaseSampler(DataBaseSampler): + + def __init__(self, + info_path, + data_root, + rate, + prepare, + sample_groups, + classes=None, + check_2D_collision=False, + collision_thr=0, + collision_in_classes=False, + depth_consistent=False, + blending_type=None, + img_loader=dict(type='LoadImageFromFile'), + mask_loader=dict( + type='LoadImageFromFile', color_type='grayscale'), + points_loader=dict( + type='LoadPointsFromFile', + load_dim=4, + coord_type='LIDAR', + use_dim=[0, 1, 2, 3])): + super(MMDataBaseSampler, self).__init__( + info_path=info_path, + data_root=data_root, + rate=rate, + prepare=prepare, + sample_groups=sample_groups, + classes=classes, + points_loader=points_loader) + + self.blending_type = blending_type + self.depth_consistent = depth_consistent + self.check_2D_collision = check_2D_collision + self.collision_thr = collision_thr + self.collision_in_classes = collision_in_classes + self.img_loader = mmcv.build_from_cfg(img_loader, PIPELINES) + self.mask_loader = mmcv.build_from_cfg(mask_loader, PIPELINES) + + def sample_all(self, gt_bboxes_3d, gt_names, gt_bboxes_2d=None, img=None): + sampled_num_dict = {} + sample_num_per_class = [] + for class_name, max_sample_num in zip(self.sample_classes, + self.sample_max_nums): + sampled_num = int(max_sample_num - + np.sum([n == class_name for n in gt_names])) + sampled_num = np.round(self.rate * sampled_num).astype(np.int64) + sampled_num_dict[class_name] = sampled_num + sample_num_per_class.append(sampled_num) + + sampled = [] + sampled_gt_bboxes_3d = [] + sampled_gt_bboxes_2d = [] + avoid_coll_boxes_3d = gt_bboxes_3d + avoid_coll_boxes_2d = gt_bboxes_2d + + for class_name, sampled_num in zip(self.sample_classes, + sample_num_per_class): + if sampled_num > 0: + sampled_cls = self.sample_class_v2(class_name, sampled_num, + avoid_coll_boxes_3d, + avoid_coll_boxes_2d) + + sampled += sampled_cls + if len(sampled_cls) > 0: + if len(sampled_cls) == 1: + sampled_gt_box_3d = sampled_cls[0]['box3d_lidar'][ + np.newaxis, ...] + sampled_gt_box_2d = sampled_cls[0]['box2d_camera'][ + np.newaxis, ...] + else: + sampled_gt_box_3d = np.stack( + [s['box3d_lidar'] for s in sampled_cls], axis=0) + sampled_gt_box_2d = np.stack( + [s['box2d_camera'] for s in sampled_cls], axis=0) + + sampled_gt_bboxes_3d += [sampled_gt_box_3d] + sampled_gt_bboxes_2d += [sampled_gt_box_2d] + if self.collision_in_classes: + # TODO: check whether check collision check among + # classes is necessary + avoid_coll_boxes_3d = np.concatenate( + [avoid_coll_boxes_3d, sampled_gt_box_3d], axis=0) + avoid_coll_boxes_2d = np.concatenate( + [avoid_coll_boxes_2d, sampled_gt_box_2d], axis=0) + + ret = None + if len(sampled) > 0: + sampled_gt_bboxes_3d = np.concatenate(sampled_gt_bboxes_3d, axis=0) + sampled_gt_bboxes_2d = np.concatenate(sampled_gt_bboxes_2d, axis=0) + + s_points_list = [] + count = 0 + + if self.depth_consistent: + # change the paster order based on distance + center = sampled_gt_bboxes_3d[:, 0:3] + paste_order = np.argsort( + -np.power(np.sum(np.power(center, 2), axis=-1), 1 / 2), + axis=-1) + + for idx in range(len(sampled)): + if self.depth_consistent: + inds = np.where(paste_order == idx)[0][0] + info = sampled[inds] + else: + info = sampled[idx] + + pcd_file_path = os.path.join( + self.data_root, + info['path']) if self.data_root else info['path'] + img_file_path = pcd_file_path + '.png' + mask_file_path = pcd_file_path + '.mask.png' + + results = dict(pts_filename=pcd_file_path) + s_points = self.points_loader(results)['points'] + # perform points sampling before pasting + + patch_results = dict( + img_prefix=None, img_info=dict(filename=img_file_path)) + mask_results = dict( + img_prefix=None, img_info=dict(filename=mask_file_path)) + s_patch = self.img_loader(patch_results)['img'] + s_mask = self.mask_loader(mask_results)['img'] + + # the points of each sample already minus the object center + # so this time it needs to add the offset back + s_points.translate(info['box3d_lidar'][:3]) + img = self.paste_obj( + img, + s_patch, + s_mask, + bbox_2d=info['box2d_camera'].astype(np.int32)) + + count += 1 + s_points_list.append(s_points) + + gt_labels = np.array([self.cat2label[s['name']] for s in sampled]) + ret = dict( + img=img, + gt_labels=gt_labels, + gt_labels_3d=copy.deepcopy(gt_labels), + gt_bboxes_3d=sampled_gt_bboxes_3d, + gt_bboxes_2d=sampled_gt_bboxes_2d, + points=s_points_list[0].cat(s_points_list), + group_ids=np.arange(gt_bboxes_3d.shape[0], + gt_bboxes_3d.shape[0] + len(sampled))) + + return ret + + def paste_obj(self, img, obj_img, obj_mask, bbox_2d): + # paste the image patch back + x1, y1, x2, y2 = bbox_2d + # the bbox might exceed the img size because the img is different + img_h, img_w = img.shape[:2] + w = np.maximum(min(x2, img_w - 1) - x1 + 1, 1) + h = np.maximum(min(y2, img_h - 1) - y1 + 1, 1) + obj_mask = obj_mask[:h, :w] + obj_img = obj_img[:h, :w] + + # choose a blend option + if not self.blending_type: + blending_op = 'none' + + else: + blending_choice = np.random.randint(len(self.blending_type)) + blending_op = self.blending_type[blending_choice] + + if blending_op.find('poisson') != -1: + # options: cv2.NORMAL_CLONE=1, or cv2.MONOCHROME_TRANSFER=3 + # cv2.MIXED_CLONE mixed the texture, thus is not used. + if blending_op == 'poisson': + mode = np.random.choice([1, 3], 1)[0] + elif blending_op == 'poisson_normal': + mode = cv2.NORMAL_CLONE + elif blending_op == 'poisson_transfer': + mode = cv2.MONOCHROME_TRANSFER + else: + raise NotImplementedError + center = (int(x1 + w / 2), int(y1 + h / 2)) + img = cv2.seamlessClone(obj_img, img, obj_mask * 255, center, mode) + else: + if blending_op == 'gaussian': + obj_mask = cv2.GaussianBlur( + obj_mask.astype(np.float32), (5, 5), 2) + elif blending_op == 'box': + obj_mask = cv2.blur(obj_mask.astype(np.float32), (3, 3)) + paste_mask = 1 - obj_mask + img[y1:y1 + h, + x1:x1 + w] = (img[y1:y1 + h, x1:x1 + w].astype(np.float32) * + paste_mask[..., None]).astype(np.uint8) + img[y1:y1 + h, x1:x1 + w] += (obj_img.astype(np.float32) * + obj_mask[..., None]).astype(np.uint8) + + return img + + def sample_class_v2(self, name, num, gt_bboxes_3d, gt_bboxes_2d): + sampled = self.sampler_dict[name].sample(num) + sampled = copy.deepcopy(sampled) + num_gt = gt_bboxes_3d.shape[0] + num_sampled = len(sampled) + # avoid collision in BEV first + gt_bboxes_bv = box_np_ops.center_to_corner_box2d( + gt_bboxes_3d[:, 0:2], gt_bboxes_3d[:, 3:5], gt_bboxes_3d[:, 6]) + sp_boxes = np.stack([i['box3d_lidar'] for i in sampled], axis=0) + sp_boxes_bv = box_np_ops.center_to_corner_box2d( + sp_boxes[:, 0:2], sp_boxes[:, 3:5], sp_boxes[:, 6]) + total_bv = np.concatenate([gt_bboxes_bv, sp_boxes_bv], axis=0) + coll_mat = data_augment_utils.box_collision_test(total_bv, total_bv) + + # Then avoid collision in 2D space + if self.check_2D_collision: + sp_boxes_2d = np.stack([i['box2d_camera'] for i in sampled], + axis=0) + total_bbox_2d = np.concatenate([gt_bboxes_2d, sp_boxes_2d], + axis=0) # Nx4 + # random select a collision threshold + if isinstance(self.collision_thr, float): + collision_thr = self.collision_thr + elif isinstance(self.collision_thr, list): + collision_thr = np.random.choice(self.collision_thr) + elif isinstance(self.collision_thr, dict): + mode = self.collision_thr.get('mode', 'value') + if mode == 'value': + collision_thr = np.random.choice( + self.collision_thr['thr_range']) + elif mode == 'range': + collision_thr = np.random.uniform( + self.collision_thr['thr_range'][0], + self.collision_thr['thr_range'][1]) + + if collision_thr == 0: + # use similar collision test as BEV did + # Nx4 (x1, y1, x2, y2) -> corners: Nx4x2 + # ((x1, y1), (x2, y1), (x1, y2), (x2, y2)) + x1y1 = total_bbox_2d[:, :2] + x2y2 = total_bbox_2d[:, 2:] + x1y2 = np.stack([total_bbox_2d[:, 0], total_bbox_2d[:, 3]], + axis=-1) + x2y1 = np.stack([total_bbox_2d[:, 2], total_bbox_2d[:, 1]], + axis=-1) + total_2d = np.stack([x1y1, x2y1, x1y2, x2y2], axis=1) + coll_mat_2d = data_augment_utils.box_collision_test( + total_2d, total_2d) + else: + # use iof rather than iou to protect the foreground + overlaps = box_np_ops.iou_jit(total_bbox_2d, total_bbox_2d, + 'iof') + coll_mat_2d = overlaps > collision_thr + coll_mat = coll_mat + coll_mat_2d + + diag = np.arange(total_bv.shape[0]) + coll_mat[diag, diag] = False + + valid_samples = [] + for i in range(num_gt, num_gt + num_sampled): + if coll_mat[i].any(): + coll_mat[i] = False + coll_mat[:, i] = False + else: + valid_samples.append(sampled[i - num_gt]) + + return valid_samples diff --git a/mmdet3d/datasets/pipelines/formating.py b/mmdet3d/datasets/pipelines/formating.py index f5b2def535..3444c6bfdb 100644 --- a/mmdet3d/datasets/pipelines/formating.py +++ b/mmdet3d/datasets/pipelines/formating.py @@ -6,78 +6,6 @@ from mmdet.datasets.builder import PIPELINES from mmdet.datasets.pipelines import to_tensor -PIPELINES._module_dict.pop('DefaultFormatBundle') - - -@PIPELINES.register_module() -class DefaultFormatBundle(object): - """Default formatting bundle. - - It simplifies the pipeline of formatting common fields, including "img", - "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg". - These fields are formatted as follows. - - - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True) - - proposals: (1)to tensor, (2)to DataContainer - - gt_bboxes: (1)to tensor, (2)to DataContainer - - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer - - gt_labels: (1)to tensor, (2)to DataContainer - - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True) - - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \ - (3)to DataContainer (stack=True) - """ - - def __init__(self, ): - return - - def __call__(self, results): - """Call function to transform and format common fields in results. - - Args: - results (dict): Result dict contains the data to convert. - - Returns: - dict: The result dict contains the data that is formatted with - default bundle. - """ - if 'img' in results: - if isinstance(results['img'], list): - # process multiple imgs in single frame - imgs = [img.transpose(2, 0, 1) for img in results['img']] - imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) - results['img'] = DC(to_tensor(imgs), stack=True) - else: - img = np.ascontiguousarray(results['img'].transpose(2, 0, 1)) - results['img'] = DC(to_tensor(img), stack=True) - for key in [ - 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', - 'gt_labels_3d', 'pts_instance_mask', 'pts_semantic_mask' - ]: - if key not in results: - continue - if isinstance(results[key], list): - results[key] = DC([to_tensor(res) for res in results[key]]) - else: - results[key] = DC(to_tensor(results[key])) - if 'gt_bboxes_3d' in results: - if isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): - results['gt_bboxes_3d'] = DC( - results['gt_bboxes_3d'], cpu_only=True) - else: - results['gt_bboxes_3d'] = DC( - to_tensor(results['gt_bboxes_3d'])) - - if 'gt_masks' in results: - results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) - if 'gt_semantic_seg' in results: - results['gt_semantic_seg'] = DC( - to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) - - return results - - def __repr__(self): - return self.__class__.__name__ - @PIPELINES.register_module() class Collect3D(object): @@ -172,7 +100,7 @@ def __repr__(self): @PIPELINES.register_module() -class DefaultFormatBundle3D(DefaultFormatBundle): +class DefaultFormatBundle3D(object): """Default formatting bundle. It simplifies the pipeline of formatting common fields for voxels, @@ -203,6 +131,21 @@ def __call__(self, results): dict: The result dict contains the data that is formatted with default bundle. """ + # 2D format bundle + img_fields = results.get('img_fields', []) + # img_fields must contain key 'img' if its length is 1 + if len(img_fields) == 1 and 'img' in img_fields: + img = results['img'] + if len(img.shape) < 3: + img = np.expand_dims(img, -1) + img = np.ascontiguousarray(img.transpose(2, 0, 1)) + results['img'] = DC(to_tensor(img), stack=True) + elif len(img_fields) > 1: + # process multiple imgs in single frame + imgs = [results[key].transpose(2, 0, 1) for key in img_fields] + imgs = np.ascontiguousarray(np.stack(imgs, axis=0)) + results['img'] = DC(to_tensor(imgs), stack=True) + # Format 3D data if 'points' in results: assert isinstance(results['points'], BasePoints) @@ -213,44 +156,29 @@ def __call__(self, results): continue results[key] = DC(to_tensor(results[key]), stack=False) - if self.with_gt: - # Clean GT bboxes in the final - if 'gt_bboxes_3d_mask' in results: - gt_bboxes_3d_mask = results['gt_bboxes_3d_mask'] - results['gt_bboxes_3d'] = results['gt_bboxes_3d'][ - gt_bboxes_3d_mask] - if 'gt_names_3d' in results: - results['gt_names_3d'] = results['gt_names_3d'][ - gt_bboxes_3d_mask] - if 'gt_bboxes_mask' in results: - gt_bboxes_mask = results['gt_bboxes_mask'] - if 'gt_bboxes' in results: - results['gt_bboxes'] = results['gt_bboxes'][gt_bboxes_mask] - results['gt_names'] = results['gt_names'][gt_bboxes_mask] - if self.with_label: - if 'gt_names' in results and len(results['gt_names']) == 0: - results['gt_labels'] = np.array([], dtype=np.int64) - elif 'gt_names' in results and isinstance( - results['gt_names'][0], list): - # gt_labels might be a list of list in multi-view setting - results['gt_labels'] = [ - np.array([self.class_names.index(n) for n in res], - dtype=np.int64) for res in results['gt_names'] - ] - elif 'gt_names' in results: - results['gt_labels'] = np.array([ - self.class_names.index(n) for n in results['gt_names'] - ], - dtype=np.int64) - # we still assume one pipeline for one frame LiDAR - # thus, the 3D name is list[string] - if 'gt_names_3d' in results: - results['gt_labels_3d'] = np.array([ - self.class_names.index(n) - for n in results['gt_names_3d'] - ], - dtype=np.int64) - results = super(DefaultFormatBundle3D, self).__call__(results) + for key in [ + 'proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels', + 'gt_labels_3d', 'pts_instance_mask', 'pts_semantic_mask' + ]: + if key not in results: + continue + if isinstance(results[key], list): + results[key] = DC([to_tensor(res) for res in results[key]]) + else: + results[key] = DC(to_tensor(results[key])) + if 'gt_bboxes_3d' in results: + if isinstance(results['gt_bboxes_3d'], BaseInstance3DBoxes): + results['gt_bboxes_3d'] = DC( + results['gt_bboxes_3d'], cpu_only=True) + else: + results['gt_bboxes_3d'] = DC( + to_tensor(results['gt_bboxes_3d'])) + + if 'gt_masks' in results: + results['gt_masks'] = DC(results['gt_masks'], cpu_only=True) + if 'gt_semantic_seg' in results: + results['gt_semantic_seg'] = DC( + to_tensor(results['gt_semantic_seg'][None, ...]), stack=True) return results def __repr__(self): diff --git a/mmdet3d/datasets/pipelines/loading.py b/mmdet3d/datasets/pipelines/loading.py index 4e9bfd2a53..3670bdb65d 100644 --- a/mmdet3d/datasets/pipelines/loading.py +++ b/mmdet3d/datasets/pipelines/loading.py @@ -8,7 +8,7 @@ @PIPELINES.register_module() class LoadMultiViewImageFromFiles(object): - """Load multi channel images from a list of separate channel files. + """Load multi view images from a dict that contains image information. Expects results['img_filename'] to be a list of filenames. @@ -18,45 +18,35 @@ class LoadMultiViewImageFromFiles(object): color_type (str): Color type of the file. Defaults to 'unchanged'. """ - def __init__(self, to_float32=False, color_type='unchanged'): + def __init__(self, + to_float32=False, + color_type='unchanged', + file_client_args=dict(backend='disk')): self.to_float32 = to_float32 self.color_type = color_type + self.file_client_args = file_client_args.copy() + self.file_client = None - def __call__(self, results): - """Call function to load multi-view image from files. - - Args: - results (dict): Result dict containing multi-view image filenames. - - Returns: - dict: The result dict containing the multi-view image data. \ - Added keys and values are described below. - - - filename (str): Multi-view image filenames. - - img (np.ndarray): Multi-view image arrays. - - img_shape (tuple[int]): Shape of multi-view image arrays. - - ori_shape (tuple[int]): Shape of original image arrays. - - pad_shape (tuple[int]): Shape of padded image arrays. - - scale_factor (float): Scale factor. - - img_norm_cfg (dict): Normalization configuration of images. - """ - filename = results['img_filename'] - img = np.stack( - [mmcv.imread(name, self.color_type) for name in filename], axis=-1) + def _load_img(self, filename): + img_bytes = self.file_client.get(filename) + img = mmcv.imfrombytes(img_bytes, flag=self.color_type) if self.to_float32: img = img.astype(np.float32) - results['filename'] = filename - results['img'] = img - results['img_shape'] = img.shape - results['ori_shape'] = img.shape - # Set initial values for default meta_keys - results['pad_shape'] = img.shape - results['scale_factor'] = 1.0 - num_channels = 1 if len(img.shape) < 3 else img.shape[2] - results['img_norm_cfg'] = dict( - mean=np.zeros(num_channels, dtype=np.float32), - std=np.ones(num_channels, dtype=np.float32), - to_rgb=False) + return img + + def __call__(self, results): + if self.file_client is None: + self.file_client = mmcv.FileClient(**self.file_client_args) + if 'img_fields' not in results: + results['img_fields'] = [] + for key in results['img_info']['filename']: + img = self._load_img(results['img_info']['filename'][key]) + results['img_fields'].append(key) + results[key] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + # Set initial values for default meta_keys + results['pad_shape'] = img.shape return results def __repr__(self): diff --git a/mmdet3d/datasets/pipelines/test_time_aug.py b/mmdet3d/datasets/pipelines/test_time_aug.py index f0de62b284..4ce3e58b7b 100644 --- a/mmdet3d/datasets/pipelines/test_time_aug.py +++ b/mmdet3d/datasets/pipelines/test_time_aug.py @@ -15,7 +15,8 @@ class MultiScaleFlipAug3D(object): img_scale (tuple | list[tuple]: Images scales for resizing. pts_scale_ratio (float | list[float]): Points scale ratios for resizing. - flip (bool): Whether apply flip augmentation. Defaults to False. + flip (bool): Whether apply flip augmentation to images. + Defaults to False. flip_direction (str | list[str]): Flip augmentation directions for images, options are "horizontal" and "vertical". If flip_direction is list, multiple flip augmentations will @@ -77,11 +78,11 @@ def __call__(self, results): # modified from `flip_aug = [False, True] if self.flip else [False]` # to reduce unnecessary scenes when using double flip augmentation # during test time - flip_aug = [True] if self.flip else [False] + flip_aug = [False, True] if self.flip else [False] pcd_horizontal_flip_aug = [False, True] \ - if self.flip and self.pcd_horizontal_flip else [False] + if self.pcd_horizontal_flip else [False] pcd_vertical_flip_aug = [False, True] \ - if self.flip and self.pcd_vertical_flip else [False] + if self.pcd_vertical_flip else [False] for scale in self.img_scale: for pts_scale_ratio in self.pts_scale_ratio: for flip in flip_aug: diff --git a/mmdet3d/models/detectors/centerpoint.py b/mmdet3d/models/detectors/centerpoint.py index 7705ce1a94..e5cc6b129d 100644 --- a/mmdet3d/models/detectors/centerpoint.py +++ b/mmdet3d/models/detectors/centerpoint.py @@ -31,20 +31,6 @@ def __init__(self, pts_bbox_head, img_roi_head, img_rpn_head, train_cfg, test_cfg, pretrained) - def extract_pts_feat(self, pts, img_feats, img_metas): - """Extract features of points.""" - if not self.with_pts_bbox: - return None - voxels, num_points, coors = self.voxelize(pts) - - voxel_features = self.pts_voxel_encoder(voxels, num_points, coors) - batch_size = coors[-1, 0] + 1 - x = self.pts_middle_encoder(voxel_features, coors, batch_size) - x = self.pts_backbone(x) - if self.with_pts_neck: - x = self.pts_neck(x) - return x - def forward_pts_train(self, pts_feats, gt_bboxes_3d, diff --git a/mmdet3d/models/detectors/mvx_two_stage.py b/mmdet3d/models/detectors/mvx_two_stage.py index 8297f2e746..57b600cb50 100644 --- a/mmdet3d/models/detectors/mvx_two_stage.py +++ b/mmdet3d/models/detectors/mvx_two_stage.py @@ -500,3 +500,32 @@ def show_results(self, data, result, out_dir): pred_bboxes = pred_bboxes.tensor.cpu().numpy() show_result(points, None, pred_bboxes, out_dir, file_name) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + # override the _load_from_state_dict function + # convert the backbone weights pre-trained in Mask R-CNN + # use list(state_dict.keys()) to avoid + # RuntimeError: OrderedDict mutated during iteration + for key_name in list(state_dict.keys()): + key_changed = True + if key_name.startswith('backbone.'): + new_key_name = f'img_backbone{key_name[8:]}' + elif key_name.startswith('neck.'): + new_key_name = f'img_neck{key_name[4:]}' + elif key_name.startswith('rpn_head.'): + new_key_name = f'img_rpn_head{key_name[8:]}' + elif key_name.startswith('roi_head.'): + new_key_name = f'img_roi_head{key_name[8:]}' + else: + key_changed = False + + if key_changed: + logger = get_root_logger() + print_log( + f'{key_name} renamed to be {new_key_name}', logger=logger) + state_dict[new_key_name] = state_dict.pop(key_name) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, + error_msgs) \ No newline at end of file diff --git a/mmdet3d/models/fusion_layers/point_fusion.py b/mmdet3d/models/fusion_layers/point_fusion.py index 388a225f8a..eef7599259 100644 --- a/mmdet3d/models/fusion_layers/point_fusion.py +++ b/mmdet3d/models/fusion_layers/point_fusion.py @@ -28,9 +28,9 @@ def point_sample( img_features (torch.Tensor): 1 x C x H x W image features. points (torch.Tensor): Nx3 point cloud in LiDAR coordinates. lidar2img_rt (torch.Tensor): 4x4 transformation matrix. - img_scale_factor (torch.Tensor): Scale factor with shape of \ + img_scale_factor (torch.Tensor): Scale factor with shape of (w_scale, h_scale). - img_crop_offset (torch.Tensor): Crop offset used to crop \ + img_crop_offset (torch.Tensor): Crop offset used to crop image during data augmentation with shape of (w_offset, h_offset). img_flip (bool): Whether the image is flipped. img_pad_shape (tuple[int]): int tuple indicates the h & w after @@ -47,7 +47,6 @@ def point_sample( Returns: torch.Tensor: NxC image features sampled by point coordinates. """ - # apply transformation based on info in img_meta points = apply_3d_transformation(points, 'LIDAR', img_meta, reverse=True) @@ -180,16 +179,16 @@ def __init__(self, self.lateral_convs.append(l_conv) self.img_transform = nn.Sequential( nn.Linear(mid_channels * len(img_channels), out_channels), - nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), + nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01) ) else: self.img_transform = nn.Sequential( nn.Linear(sum(img_channels), out_channels), - nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), + nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01) ) self.pts_transform = nn.Sequential( nn.Linear(pts_channels, out_channels), - nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01), + nn.BatchNorm1d(out_channels, eps=1e-3, momentum=0.01) ) if self.fuse_out: @@ -285,6 +284,7 @@ def sample_single(self, img_feats, pts, img_meta): img_scale_factor = ( pts.new_tensor(img_meta['scale_factor'][:2]) if 'scale_factor' in img_meta.keys() else 1) + img_flip = img_meta['flip'] if 'flip' in img_meta.keys() else False img_crop_offset = ( pts.new_tensor(img_meta['img_crop_offset']) diff --git a/mmdet3d/models/voxel_encoders/voxel_encoder.py b/mmdet3d/models/voxel_encoders/voxel_encoder.py index 9dfd59ede6..d61aadebed 100644 --- a/mmdet3d/models/voxel_encoders/voxel_encoder.py +++ b/mmdet3d/models/voxel_encoders/voxel_encoder.py @@ -25,7 +25,7 @@ def __init__(self, num_features=4): self.fp16_enabled = False @force_fp32(out_fp16=True) - def forward(self, features, num_points, coors): + def forward(self, features, num_points, coors, *args, **kwargs): """Forward function. Args: @@ -65,7 +65,7 @@ def __init__(self, @torch.no_grad() @force_fp32(out_fp16=True) - def forward(self, features, coors): + def forward(self, features, coors, *args, **kwargs): """Forward function. Args: diff --git a/mmdet3d/ops/norm.py b/mmdet3d/ops/norm.py index e9db8fb579..7d7eb08dd3 100644 --- a/mmdet3d/ops/norm.py +++ b/mmdet3d/ops/norm.py @@ -58,8 +58,17 @@ def forward(self, input): return super().forward(input) assert input.shape[0] > 0, 'SyncBN does not support empty inputs' C = input.shape[1] - mean = torch.mean(input, dim=[0, 2]) - meansqr = torch.mean(input * input, dim=[0, 2]) + if input.dim() == 3: + expected_dim = [0, 2] + expected_shape = [1, -1, 1] + elif input.dim() == 2: + expected_dim = [0] + expected_shape = [1, -1] + else: + raise ValueError( + f'expected 2D or 3D input (got {input.dim()}D input)') + mean = torch.mean(input, dim=expected_dim) + meansqr = torch.mean(input * input, dim=expected_dim) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) @@ -73,8 +82,8 @@ def forward(self, input): invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale - scale = scale.reshape(1, -1, 1) - bias = bias.reshape(1, -1, 1) + scale = scale.reshape(*expected_shape) + bias = bias.reshape(*expected_shape) return input * scale + bias