diff --git a/README.md b/README.md
index 52f6293..05c919c 100644
--- a/README.md
+++ b/README.md
@@ -1,16 +1,5 @@
# CenterFormer
-Implementation for CenterFormer: Center-based Transformer for 3D Object Detection (ECCV 2022)
-
-Code is coming soon!
-
-
-
-## Abstract
-Query-based transformer has shown great potential in constructing long-range attention in many image-domain tasks, but has rarely been considered in LiDAR-based 3D object detection due to the overwhelming size of the point cloud data. In this paper, we propose **CenterFormer**, a center-based transformer network for 3D object detection. CenterFormer first uses a center heatmap to select center candidates on top of a standard voxel-based point cloud encoder. It then uses the feature of the center candidate as the query embedding in the transformer. To further aggregate features from multiple frames, we design an approach to fuse features through cross-attention. Lastly, regression heads are added to predict the bounding box on the output center feature representation. Our design reduces the convergence difficulty and computational complexity of the transformer structure. The results show significant improvements over the strong baseline of anchor-free object detection networks. CenterFormer achieves state-of-the-art performance for a single model on the Waymo Open Dataset, with 73.7% mAPH on the validation set and 75.6% mAPH on the test set, significantly outperforming all previously published CNN and transformer-based methods.
-
-[arXiv](https://arxiv.org/abs/2209.05588)
-
-## Citation
+Official implementation for [**CenterFormer: Center-based Transformer for 3D Object Detection**](https://arxiv.org/abs/2209.05588) (ECCV 2022 Oral)
```
@InProceedings{Zhou_centerformer,
title = {CenterFormer: Center-based Transformer for 3D Object Detection},
@@ -19,3 +8,51 @@ booktitle = {ECCV},
year = {2022}
}
```
+
+## Highlights
+- **Center Transformer** We introduce a center-based transformer network for 3D object detection.
+
+- **Fast and Easy to Train** We use the center feature as the initial query embedding to facilitate learning of the transformer. We propose a multi-scale cross-attention layer to efficiently aggregate neighboring features without significantly increasing the computational complexity.
+
+- **Temporal information**: We propose using the cross-attention transformer to fuse object features from past frames.
+
+
+
+## NEWS
+[2022-09-30] CenterFormer source code is released.
+
+## Abstract
+Query-based transformer has shown great potential in constructing long-range attention in many image-domain tasks, but has rarely been considered in LiDAR-based 3D object detection due to the overwhelming size of the point cloud data. In this paper, we propose **CenterFormer**, a center-based transformer network for 3D object detection. CenterFormer first uses a center heatmap to select center candidates on top of a standard voxel-based point cloud encoder. It then uses the feature of the center candidate as the query embedding in the transformer. To further aggregate features from multiple frames, we design an approach to fuse features through cross-attention. Lastly, regression heads are added to predict the bounding box on the output center feature representation. Our design reduces the convergence difficulty and computational complexity of the transformer structure. The results show significant improvements over the strong baseline of anchor-free object detection networks. CenterFormer achieves state-of-the-art performance for a single model on the Waymo Open Dataset, with 73.7% mAPH on the validation set and 75.6% mAPH on the test set, significantly outperforming all previously published CNN and transformer-based methods.
+
+## Result
+
+#### 3D detection on Waymo test set
+
+| | #Frame | Veh_L2 | Ped_L2 | Cyc_L2 | Mean |
+|---------|---------|--------|--------|---------|---------|
+| CenterFormer| 8 | 77.7 | 76.6 | 72.4 | 75.6 |
+| CenterFormer| 16 | 78.3 | 77.4 | 73.2 | 76.3 |
+
+#### 3D detection on Waymo val set
+
+| | #Frame | Veh_L2 | Ped_L2 | Cyc_L2 | Mean |
+|---------|---------|--------|--------|---------|---------|
+| [CenterFormer](voxelnet/waymo_centerformer.py)| 1 | 69.4 | 67.7 | 70.2 | 69.1 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_deformable.py)| 1 | 69.7 | 68.3 | 68.8 | 69.0 |
+| [CenterFormer](voxelnet/waymo_centerformer_multiframe_2frames.py)| 2 | 71.7 | 73.0 | 72.7 | 72.5 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_2frames.py)| 2 | 71.6 | 73.4 | 73.3 | 72.8 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_4frames.py)| 4 | 72.9 | 74.2 | 72.6 | 73.2 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_8frames.py)| 8 | 73.8 | 75.0 | 72.3 | 73.7 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_16frames.py)| 16 | 74.6 | 75.6 | 72.7 | 74.3 |
+
+The training and evaluation configs of the above models are provided in [Configs](configs/waymo/README.md).
+
+## Installation
+Please refer to [INSTALL](docs/INSTALL.md) to set up libraries needed for distributed training and sparse convolution.
+
+## Training and Evaluation
+Please refer to [WAYMO](docs/WAYMO.md) to prepare the data, training and evaluation.
+
+
+## Acknowlegement
+This project is developed based on the [CenterPoint](https://github.com/tianweiy/CenterPoint) codebase. We use the deformable cross-attention implementation from [Deformable-DETR](https://github.com/fundamentalvision/Deformable-DETR).
diff --git a/configs/waymo/README.md b/configs/waymo/README.md
new file mode 100644
index 0000000..2ce0ee4
--- /dev/null
+++ b/configs/waymo/README.md
@@ -0,0 +1,22 @@
+# Configs
+
+### Common settings and notes
+
+- The experiments are run with PyTorch 1.9 and CUDA 11.1.
+- The training is conducted on 8 A100 GPUs.
+- Training on GPU with less memory would likely cause GPU out-of-memory. In this case, you can try configs with smaller batch size or frames.
+
+
+### Waymo Validation Results
+
+We provide the training and validation configs for the model in our paper. Let us know if you have trouble reproducing the results.
+
+| | #Frame | Veh_L2 | Ped_L2 | Cyc_L2 | Mean |
+|---------|---------|--------|--------|---------|---------|
+| [CenterFormer](voxelnet/waymo_centerformer.py)| 1 | 69.4 | 67.7 | 70.2 | 69.1 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_deformable.py)| 1 | 69.7 | 68.3 | 68.8 | 69.0 |
+| [CenterFormer](voxelnet/waymo_centerformer_multiframe_2frames.py)| 2 | 71.7 | 73.0 | 72.7 | 72.5 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_2frames.py)| 2 | 71.6 | 73.4 | 73.3 | 72.8 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_4frames.py)| 4 | 72.9 | 74.2 | 72.6 | 73.2 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_8frames.py)| 8 | 73.8 | 75.0 | 72.3 | 73.7 |
+| [CenterFormer deformable](voxelnet/waymo_centerformer_multiframe_deformable_16frames.py)| 16 | 74.6 | 75.6 | 72.7 | 74.3 |
\ No newline at end of file
diff --git a/configs/waymo/voxelnet/waymo_centerformer.py b/configs/waymo/voxelnet/waymo_centerformer.py
new file mode 100644
index 0000000..4d7f5f5
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer.py
@@ -0,0 +1,232 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=5, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ obj_num= 500,
+ assign_label_window_size=window_size,
+ transformer_config=dict(
+ depth = 3,
+ heads = 4,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ cross_attention_kernel_size = [3,3,3]
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ assign_label_window_size=window_size,
+ corner_loss=True,
+ iou_loss=True,
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2), 'iou':(1,2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num= 1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+nsweeps = 1
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_1sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[150000, 200000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset=dataset_type),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset=dataset_type),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_01sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_01sweeps_filter_zero_gt.pkl"
+test_anno = 'data/Waymo/infos_test_01sweeps_filter_zero_gt.pkl'
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/waymo/voxelnet/waymo_centerformer_deformable.py b/configs/waymo/voxelnet/waymo_centerformer_deformable.py
new file mode 100644
index 0000000..cf083e7
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer_deformable.py
@@ -0,0 +1,232 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=5, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer_deformable",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ assign_label_window_size=window_size,
+ obj_num=500,
+ transformer_config=dict(
+ depth = 2,
+ heads = 6,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ n_points = 15,
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ corner_loss=True,
+ iou_loss=True,
+ assign_label_window_size=window_size,
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2),'iou':(1,2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num=1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+nsweeps = 1
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_1sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[150000, 200000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset=dataset_type),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset=dataset_type),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_01sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_01sweeps_filter_zero_gt.pkl"
+test_anno = None
+
+data = dict(
+ samples_per_gpu=4,
+ workers_per_gpu=4,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/waymo/voxelnet/waymo_centerformer_multiframe_2frames.py b/configs/waymo/voxelnet/waymo_centerformer_multiframe_2frames.py
new file mode 100644
index 0000000..1471f42
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer_multiframe_2frames.py
@@ -0,0 +1,233 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=5, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer_multiframe",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ obj_num= 500,
+ assign_label_window_size=window_size,
+ frame=2,
+ transformer_config=dict(
+ depth = 3,
+ heads = 4,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ cross_attention_kernel_size = [3,3,3]
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ assign_label_window_size=window_size,
+ corner_loss=True,
+ iou_loss = True,
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2), 'iou':(1, 2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num= 1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+nsweeps = 2
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_1sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[150000, 200000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame"),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame"),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_02sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_02sweeps_filter_zero_gt.pkl"
+test_anno = 'data/Waymo/infos_test_02sweeps_filter_zero_gt.pkl'
+
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_16frames.py b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_16frames.py
new file mode 100644
index 0000000..00e33b6
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_16frames.py
@@ -0,0 +1,238 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=6, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer_deformable_mtf",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ frame=4,
+ obj_num=500,
+ assign_label_window_size=window_size,
+ transformer_config=dict(
+ depth = 2,
+ heads = 6,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ n_points = 15,
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ assign_label_window_size=window_size,
+ corner_loss=True,
+ iou_loss = True,
+ iou_factor=[1,1,1],
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2),'iou':(1, 2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num=1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+
+
+nsweeps = 16
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_4sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+ combine_frame=True,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+ combine_frame=True,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[250000, 600000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame",combine=4),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame",combine=4),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_16sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_16sweeps_filter_zero_gt.pkl"
+test_anno = 'data/Waymo/infos_test_16sweeps_filter_zero_gt.pkl'
+
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_2frames.py b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_2frames.py
new file mode 100644
index 0000000..c98d342
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_2frames.py
@@ -0,0 +1,235 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=5, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer_deformable_mtf",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ frame=2,
+ obj_num= 500,
+ assign_label_window_size=window_size,
+ transformer_config=dict(
+ depth = 2,
+ heads = 6,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ n_points = 15,
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ assign_label_window_size=window_size,
+ corner_loss=True,
+ iou_loss = True,
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2), 'iou':(1, 2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num=1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+
+
+nsweeps = 2
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_1sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[150000, 200000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame"),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame"),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_02sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_02sweeps_filter_zero_gt.pkl"
+test_anno = 'data/Waymo/infos_test_02sweeps_filter_zero_gt.pkl'
+
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_4frames.py b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_4frames.py
new file mode 100644
index 0000000..e3e774d
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_4frames.py
@@ -0,0 +1,238 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=6, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer_deformable_mtf",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ obj_num= 500,
+ score_threshold=0.1,
+ frame=2,
+ assign_label_window_size=window_size,
+ transformer_config=dict(
+ depth = 2,
+ heads = 6,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ n_points = 15,
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ assign_label_window_size=window_size,
+ corner_loss=True,
+ iou_loss = True,
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2), 'iou':(1, 2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ # score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num=1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+
+
+nsweeps = 4
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_2sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+ combine_frame=True,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+ combine_frame=True,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[180000, 400000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame", combine=2),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame", combine=2),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_04sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_04sweeps_filter_zero_gt.pkl"
+test_anno = 'data/Waymo/infos_test_04sweeps_filter_zero_gt.pkl'
+
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_8frames.py b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_8frames.py
new file mode 100644
index 0000000..ed0c3c0
--- /dev/null
+++ b/configs/waymo/voxelnet/waymo_centerformer_multiframe_deformable_8frames.py
@@ -0,0 +1,238 @@
+import itertools
+import logging
+
+from det3d.utils.config_tool import get_downsample_factor
+
+tasks = [
+ dict(num_class=3, class_names=['VEHICLE', 'PEDESTRIAN', 'CYCLIST']),
+]
+
+class_names = list(itertools.chain(*[t["class_names"] for t in tasks]))
+
+# training and testing settings
+target_assigner = dict(
+ tasks=tasks,
+)
+
+# use expanded gt label assigner
+window_size = 1
+
+# model settings
+model = dict(
+ type="VoxelNet_dynamic",
+ pretrained=None,
+ reader=dict(
+ type="DynamicVoxelEncoder",
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ ),
+ backbone=dict(
+ type="SpMiddleResNetFHD", num_input_features=6, ds_factor=8),
+ neck=dict(
+ type="RPN_transformer_deformable_mtf",
+ layer_nums=[5, 5, 1],
+ ds_num_filters=[256, 256, 128],
+ num_input_features=256,
+ use_gt_training=True,
+ corner = True,
+ frame=2,
+ obj_num=500,
+ assign_label_window_size=window_size,
+ transformer_config=dict(
+ depth = 2,
+ heads = 6,
+ dim_head = 64,
+ MLP_dim = 256,
+ DP_rate=0.3,
+ out_att = False,
+ n_points = 15,
+ ),
+ logger=logging.getLogger("RPN"),
+ ),
+ bbox_head=dict(
+ type="CenterHeadIoU_1d",
+ in_channels=256,
+ tasks=tasks,
+ dataset='waymo',
+ weight=2,
+ assign_label_window_size=window_size,
+ corner_loss=True,
+ iou_loss = True,
+ iou_factor=[1,1,1],
+ code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
+ common_heads={'reg': (2, 2), 'height': (1, 2), 'dim':(3, 2), 'rot':(2, 2),'iou':(1, 2)}, # (output_channel, num_conv)
+ ),
+)
+
+assigner = dict(
+ target_assigner=target_assigner,
+ out_size_factor=4,
+ dense_reg=1,
+ gaussian_overlap=0.1,
+ max_objs=500,
+ min_radius=2,
+ gt_kernel_size=window_size,
+ corner_prediction=True,
+ pc_range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+)
+
+
+train_cfg = dict(assigner=assigner)
+
+
+test_cfg = dict(
+ post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0],
+ nms=dict(
+ use_rotate_nms=False,
+ use_multi_class_nms=True,
+ nms_pre_max_size=[1600,1600,800],
+ nms_post_max_size=[200,200,100],
+ nms_iou_threshold=[0.8,0.55,0.55],
+ ),
+ score_threshold=0.1,
+ pc_range=[-75.2, -75.2],
+ out_size_factor=4,
+ voxel_size=[0.1, 0.1],
+ obj_num=1000,
+)
+
+
+# dataset settings
+dataset_type = "WaymoDataset"
+
+
+nsweeps = 8
+data_root = "data/Waymo"
+
+db_sampler = dict(
+ type="GT-AUG",
+ enable=False,
+ db_info_path="data/Waymo/dbinfos_train_4sweeps_withvelo.pkl",
+ sample_groups=[
+ dict(VEHICLE=15),
+ dict(PEDESTRIAN=10),
+ dict(CYCLIST=10),
+ ],
+ db_prep_steps=[
+ dict(
+ filter_by_min_num_points=dict(
+ VEHICLE=5,
+ PEDESTRIAN=5,
+ CYCLIST=5,
+ )
+ ),
+ dict(filter_by_difficulty=[-1],),
+ ],
+ global_random_rotation_range_per_object=[0, 0],
+ rate=1.0,
+)
+
+train_preprocessor = dict(
+ mode="train",
+ shuffle_points=True,
+ global_rot_noise=[-0.78539816, 0.78539816],
+ global_scale_noise=[0.95, 1.05],
+ global_translate_noise=0.5,
+ db_sampler=db_sampler,
+ class_names=class_names,
+ combine_frame=True,
+)
+val_preprocessor = dict(
+ mode="val",
+ shuffle_points=False,
+ combine_frame=True,
+)
+
+voxel_generator = dict(
+ range=[-75.2, -75.2, -2, 75.2, 75.2, 4],
+ voxel_size=[0.1, 0.1, 0.15],
+ max_points_in_voxel=5,
+ max_voxel_num=[250000, 600000],
+)
+
+train_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame",combine=4),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=train_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+test_pipeline = [
+ dict(type="LoadPointCloudFromFile", dataset="WaymoDataset_multi_frame",combine=4),
+ dict(type="LoadPointCloudAnnotations", with_bbox=True),
+ dict(type="Preprocess_multiframe", cfg=val_preprocessor),
+ dict(type="AssignLabel", cfg=train_cfg["assigner"]),
+ dict(type="Reformat"),
+]
+
+train_anno = "data/Waymo/infos_train_08sweeps_filter_zero_gt.pkl"
+val_anno = "data/Waymo/infos_val_08sweeps_filter_zero_gt.pkl"
+test_anno = 'data/Waymo/infos_test_08sweeps_filter_zero_gt.pkl'
+
+data = dict(
+ samples_per_gpu=2,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=train_anno,
+ ann_file=train_anno,
+ nsweeps=nsweeps,
+ # load_interval=5,
+ class_names=class_names,
+ pipeline=train_pipeline,
+ ),
+ val=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=val_anno,
+ test_mode=True,
+ ann_file=val_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+ test=dict(
+ type=dataset_type,
+ root_path=data_root,
+ info_path=test_anno,
+ ann_file=test_anno,
+ nsweeps=nsweeps,
+ class_names=class_names,
+ pipeline=test_pipeline,
+ ),
+)
+
+
+
+optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2))
+
+# optimizer
+optimizer = dict(
+ type="adam", amsgrad=0.0, wd=0.01, fixed_wd=True, moving_average=False,
+)
+lr_config = dict(
+ type="one_cycle", lr_max=0.003, moms=[0.95, 0.85], div_factor=10.0, pct_start=0.4,
+)
+
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=5,
+ hooks=[
+ dict(type="TextLoggerHook"),
+ # dict(type='TensorboardLoggerHook')
+ ],
+)
+# yapf:enable
+# runtime settings
+total_epochs = 20
+disable_dbsampler_after_epoch = 15
+device_ids = range(8)
+dist_params = dict(backend="nccl", init_method="env://")
+log_level = "INFO"
+work_dir = './work_dirs/{}/'.format(__file__[__file__.rfind('/') + 1:-3])
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/det3d/__init__.py b/det3d/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/det3d/builder.py b/det3d/builder.py
new file mode 100644
index 0000000..f1bf49f
--- /dev/null
+++ b/det3d/builder.py
@@ -0,0 +1,222 @@
+import logging
+import pickle
+from functools import partial
+
+import det3d.core.sampler.preprocess as prep
+import numpy as np
+import torch
+from det3d.core.input.voxel_generator import VoxelGenerator
+from det3d.core.sampler.preprocess import DataBasePreprocessor
+from det3d.core.sampler.sample_ops import DataBaseSamplerV2
+from det3d.solver import learning_schedules
+from det3d.solver import learning_schedules_fastai as lsf
+from det3d.solver import optim
+from det3d.solver.fastai_optim import FastAIMixedOptim, OptimWrapper
+from torch import nn
+
+
+def build_voxel_generator(voxel_config):
+
+ voxel_generator = VoxelGenerator(
+ voxel_size=voxel_config.VOXEL_SIZE,
+ point_cloud_range=voxel_config.RANGE,
+ max_num_points=voxel_config.MAX_POINTS_NUM_PER_VOXEL,
+ max_voxels=20000,
+ )
+
+ return voxel_generator
+
+def build_db_preprocess(db_prep_config, logger=None):
+ logger = logging.getLogger("build_db_preprocess")
+ cfg = db_prep_config
+ if "filter_by_difficulty" in cfg:
+ v = cfg["filter_by_difficulty"]
+ return prep.DBFilterByDifficulty(v, logger=logger)
+ elif "filter_by_min_num_points" in cfg:
+ v = cfg["filter_by_min_num_points"]
+ return prep.DBFilterByMinNumPoint(v, logger=logger)
+ else:
+ raise ValueError("unknown database prep type")
+
+
+def children(m: nn.Module):
+ "Get children of `m`."
+ return list(m.children())
+
+
+def num_children(m: nn.Module) -> int:
+ "Get number of children modules in `m`."
+ return len(children(m))
+
+
+def flatten_model(m: nn.Module):
+ return sum(map(flatten_model, m.children()), []) if num_children(m) else [m]
+
+
+def get_layer_groups(m: nn.Module):
+ return [nn.Sequential(*flatten_model(m))]
+
+
+def build_optimizer(optimizer_config, net, name=None, mixed=False, loss_scale=512.0):
+ """Create optimizer based on config.
+
+ Args:
+ optimizer_config: A Optimizer proto message.
+
+ Returns:
+ An optimizer and a list of variables for summary.
+
+ Raises:
+ ValueError: when using an unsupported input data type.
+ """
+ optimizer_type = optimizer_config.TYPE
+ config = optimizer_config.VALUE
+
+ if optimizer_type == "rms_prop_optimizer":
+ optimizer_func = partial(
+ torch.optim.RMSprop,
+ alpha=config.decay,
+ momentum=config.momentum_optimizer_value,
+ eps=config.epsilon,
+ )
+ elif optimizer_type == "momentum_optimizer":
+ optimizer_func = partial(
+ torch.optim.SGD,
+ momentum=config.momentum_optimizer_value,
+ eps=config.epsilon,
+ )
+ elif optimizer_type == "adam":
+ if optimizer_config.FIXED_WD:
+ optimizer_func = partial(
+ torch.optim.Adam, betas=(0.9, 0.99), amsgrad=config.amsgrad
+ )
+ else:
+ # regular adam
+ optimizer_func = partial(torch.optim.Adam, amsgrad=config.amsgrad)
+
+ optimizer = OptimWrapper.create(
+ optimizer_func,
+ 3e-3,
+ get_layer_groups(net),
+ wd=config.WD,
+ true_wd=optimizer_config.FIXED_WD,
+ bn_wd=True,
+ )
+
+ if optimizer is None:
+ raise ValueError("Optimizer %s not supported." % optimizer_type)
+
+ if optimizer_config.MOVING_AVERAGE:
+ raise ValueError("torch don't support moving average")
+
+ if name is None:
+ # assign a name to optimizer for checkpoint system
+ optimizer.name = optimizer_type
+ else:
+ optimizer.name = name
+
+ return optimizer
+
+
+def build_lr_scheduler(optimizer, optimizer_config, total_step):
+ """Create lr scheduler based on config. note that
+ lr_scheduler must accept a optimizer that has been restored.
+
+ Args:
+ optimizer_config: A Optimizer proto message.
+
+ Returns:
+ An optimizer and a list of variables for summary.
+
+ Raises:
+ ValueError: when using an unsupported input data type.
+ """
+ optimizer_type = optimizer_config.type
+ config = optimizer_config
+
+ if optimizer_type == "rms_prop_optimizer":
+ lr_scheduler = _create_learning_rate_scheduler(
+ config, optimizer, total_step=total_step
+ )
+ elif optimizer_type == "momentum_optimizer":
+ lr_scheduler = _create_learning_rate_scheduler(
+ config, optimizer, total_step=total_step
+ )
+ elif optimizer_type == "adam":
+ lr_scheduler = _create_learning_rate_scheduler(
+ config, optimizer, total_step=total_step
+ )
+
+ return lr_scheduler
+
+
+def _create_learning_rate_scheduler(optimizer, learning_rate_config, total_step):
+ """Create optimizer learning rate scheduler based on config.
+
+ Args:
+ learning_rate_config: A LearningRate proto message.
+
+ Returns:
+ A learning rate.
+
+ Raises:
+ ValueError: when using an unsupported input data type.
+ """
+ lr_scheduler = None
+ learning_rate_type = learning_rate_config.type
+ config = learning_rate_config
+
+ if learning_rate_type == "multi_phase":
+ lr_phases = []
+ mom_phases = []
+ for phase_cfg in config.phases:
+ lr_phases.append((phase_cfg.start, phase_cfg.lambda_func))
+ mom_phases.append((phase_cfg.start, phase_cfg.momentum_lambda_func))
+ lr_scheduler = lsf.LRSchedulerStep(optimizer, total_step, lr_phases, mom_phases)
+ elif learning_rate_type == "one_cycle":
+ lr_scheduler = lsf.OneCycle(
+ optimizer,
+ total_step,
+ config.lr_max,
+ config.moms,
+ config.div_factor,
+ config.pct_start,
+ )
+ elif learning_rate_type == "exponential_decay":
+ lr_scheduler = lsf.ExponentialDecay(
+ optimizer,
+ total_step,
+ config.initial_learning_rate,
+ config.decay_length,
+ config.decay_factor,
+ config.staircase,
+ )
+ elif learning_rate_type == "manual_stepping":
+ lr_scheduler = lsf.ManualStepping(
+ optimizer, total_step, config.boundaries, config.rates
+ )
+ elif lr_scheduler is None:
+ raise ValueError("Learning_rate %s not supported." % learning_rate_type)
+
+ return lr_scheduler
+
+
+def build_dbsampler(cfg, logger=None):
+ logger = logging.getLogger("build_dbsampler")
+ prepors = [build_db_preprocess(c, logger=logger) for c in cfg.db_prep_steps]
+ db_prepor = DataBasePreprocessor(prepors)
+ rate = cfg.rate
+ grot_range = cfg.global_random_rotation_range_per_object
+ groups = cfg.sample_groups
+ # groups = [dict(g.name_to_max_num) for g in groups]
+ info_path = cfg.db_info_path
+ with open(info_path, "rb") as f:
+ db_infos = pickle.load(f)
+ grot_range = list(grot_range)
+ if len(grot_range) == 0:
+ grot_range = None
+ sampler = DataBaseSamplerV2(
+ db_infos, groups, db_prepor, rate, grot_range, logger=logger
+ )
+
+ return sampler
diff --git a/det3d/core/__init__.py b/det3d/core/__init__.py
new file mode 100644
index 0000000..d05014b
--- /dev/null
+++ b/det3d/core/__init__.py
@@ -0,0 +1,4 @@
+from .utils import *
+from .bbox import *
+from .input import *
+from .sampler import *
diff --git a/det3d/core/bbox/__init__.py b/det3d/core/bbox/__init__.py
new file mode 100644
index 0000000..11c3613
--- /dev/null
+++ b/det3d/core/bbox/__init__.py
@@ -0,0 +1 @@
+from . import box_np_ops, box_torch_ops, geometry
diff --git a/det3d/core/bbox/box_np_ops.py b/det3d/core/bbox/box_np_ops.py
new file mode 100644
index 0000000..5d7e5c8
--- /dev/null
+++ b/det3d/core/bbox/box_np_ops.py
@@ -0,0 +1,803 @@
+from pathlib import Path
+
+import numba
+import numpy as np
+from det3d.core.bbox.geometry import (
+ points_count_convex_polygon_3d_jit,
+ points_in_convex_polygon_3d_jit,
+)
+try:
+ from spconv.utils import rbbox_intersection, rbbox_iou
+except:
+ print("Import spconv fail, no support for sparse convolution!")
+
+
+def points_count_rbbox(points, rbbox, z_axis=2, origin=(0.5, 0.5, 0.5)):
+ rbbox_corners = center_to_corner_box3d(
+ rbbox[:, :3], rbbox[:, 3:6], rbbox[:, -1], origin=origin, axis=z_axis
+ )
+ surfaces = corner_to_surfaces_3d(rbbox_corners)
+ return points_count_convex_polygon_3d_jit(points[:, :3], surfaces)
+
+
+def riou_cc(rbboxes, qrbboxes, standup_thresh=0.0):
+ # less than 50ms when used in second one thread. 10x slower than gpu
+ boxes_corners = center_to_corner_box2d(
+ rbboxes[:, :2], rbboxes[:, 2:4], rbboxes[:, 4]
+ )
+ boxes_standup = corner_to_standup_nd(boxes_corners)
+ qboxes_corners = center_to_corner_box2d(
+ qrbboxes[:, :2], qrbboxes[:, 2:4], qrbboxes[:, 4]
+ )
+ qboxes_standup = corner_to_standup_nd(qboxes_corners)
+ # if standup box not overlapped, rbbox not overlapped too.
+ standup_iou = iou_jit(boxes_standup, qboxes_standup, eps=0.0)
+ return rbbox_iou(boxes_corners, qboxes_corners, standup_iou, standup_thresh)
+
+
+def rinter_cc(rbboxes, qrbboxes, standup_thresh=0.0):
+ # less than 50ms when used in second one thread. 10x slower than gpu
+ boxes_corners = center_to_corner_box2d(
+ rbboxes[:, :2], rbboxes[:, 2:4], rbboxes[:, 4]
+ )
+ boxes_standup = corner_to_standup_nd(boxes_corners)
+ qboxes_corners = center_to_corner_box2d(
+ qrbboxes[:, :2], qrbboxes[:, 2:4], qrbboxes[:, 4]
+ )
+ qboxes_standup = corner_to_standup_nd(qboxes_corners)
+ # if standup box not overlapped, rbbox not overlapped too.
+ standup_iou = iou_jit(boxes_standup, qboxes_standup, eps=0.0)
+ return rbbox_intersection(
+ boxes_corners, qboxes_corners, standup_iou, standup_thresh
+ )
+
+
+def corners_nd(dims, origin=0.5):
+ """generate relative box corners based on length per dim and
+ origin point.
+
+ Args:
+ dims (float array, shape=[N, ndim]): array of length per dim
+ origin (list or array or float): origin point relate to smallest point.
+
+ Returns:
+ float array, shape=[N, 2 ** ndim, ndim]: returned corners.
+ point layout example: (2d) x0y0, x0y1, x1y0, x1y1;
+ (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
+ where x0 < x1, y0 < y1, z0 < z1
+ """
+ ndim = int(dims.shape[1])
+ corners_norm = np.stack(
+ np.unravel_index(np.arange(2 ** ndim), [2] * ndim), axis=1
+ ).astype(dims.dtype)
+ # now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
+ # (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
+ # so need to convert to a format which is convenient to do other computing.
+ # for 2d boxes, format is clockwise start with minimum point
+ # for 3d boxes, please draw lines by your hand.
+ if ndim == 2:
+ # generate clockwise box corners
+ corners_norm = corners_norm[[0, 1, 3, 2]]
+ elif ndim == 3:
+ corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
+ corners_norm = corners_norm - np.array(origin, dtype=dims.dtype)
+ corners = dims.reshape([-1, 1, ndim]) * corners_norm.reshape([1, 2 ** ndim, ndim])
+ return corners
+
+
+@numba.njit
+def corners_2d_jit(dims, origin=0.5):
+ ndim = 2
+ corners_norm = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=dims.dtype)
+ corners_norm = corners_norm - np.array(origin, dtype=dims.dtype)
+ corners = dims.reshape((-1, 1, ndim)) * corners_norm.reshape((1, 2 ** ndim, ndim))
+ return corners
+
+
+@numba.njit
+def corners_3d_jit(dims, origin=0.5):
+ ndim = 3
+ corners_norm = np.array(
+ [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1],
+ dtype=dims.dtype,
+ ).reshape((8, 3))
+ corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
+ corners_norm = corners_norm - np.array(origin, dtype=dims.dtype)
+ corners = dims.reshape((-1, 1, ndim)) * corners_norm.reshape((1, 2 ** ndim, ndim))
+ return corners
+
+
+@numba.njit
+def corner_to_standup_nd_jit(boxes_corner):
+ num_boxes = boxes_corner.shape[0]
+ ndim = boxes_corner.shape[-1]
+ result = np.zeros((num_boxes, ndim * 2), dtype=boxes_corner.dtype)
+ for i in range(num_boxes):
+ for j in range(ndim):
+ result[i, j] = np.min(boxes_corner[i, :, j])
+ for j in range(ndim):
+ result[i, j + ndim] = np.max(boxes_corner[i, :, j])
+ return result
+
+
+def corner_to_standup_nd(boxes_corner):
+ assert len(boxes_corner.shape) == 3
+ standup_boxes = []
+ standup_boxes.append(np.min(boxes_corner, axis=1))
+ standup_boxes.append(np.max(boxes_corner, axis=1))
+ return np.concatenate(standup_boxes, -1)
+
+
+def rbbox2d_to_near_bbox(rbboxes):
+ """convert rotated bbox to nearest 'standing' or 'lying' bbox.
+ Args:
+ rbboxes: [N, 5(x, y, xdim, ydim, rad)] rotated bboxes
+ Returns:
+ bboxes: [N, 4(xmin, ymin, xmax, ymax)] bboxes
+ """
+ rots = rbboxes[..., -1]
+ rots_0_pi_div_2 = np.abs(limit_period(rots, 0.5, np.pi))
+ cond = (rots_0_pi_div_2 > np.pi / 4)[..., np.newaxis]
+ bboxes_center = np.where(cond, rbboxes[:, [0, 1, 3, 2]], rbboxes[:, :4])
+ bboxes = center_to_minmax_2d(bboxes_center[:, :2], bboxes_center[:, 2:])
+ return bboxes
+
+
+def rotation_3d_in_axis(points, angles, axis=0):
+ # points: [N, point_size, 3]
+ rot_sin = np.sin(angles)
+ rot_cos = np.cos(angles)
+ ones = np.ones_like(rot_cos)
+ zeros = np.zeros_like(rot_cos)
+ if axis == 1:
+ rot_mat_T = np.stack(
+ [
+ [rot_cos, zeros, -rot_sin],
+ [zeros, ones, zeros],
+ [rot_sin, zeros, rot_cos],
+ ]
+ )
+ elif axis == 2 or axis == -1:
+ rot_mat_T = np.stack(
+ [
+ [rot_cos, -rot_sin, zeros],
+ [rot_sin, rot_cos, zeros],
+ [zeros, zeros, ones],
+ ]
+ )
+ elif axis == 0:
+ rot_mat_T = np.stack(
+ [
+ [zeros, rot_cos, -rot_sin],
+ [zeros, rot_sin, rot_cos],
+ [ones, zeros, zeros],
+ ]
+ )
+ else:
+ raise ValueError("axis should in range")
+
+ return np.einsum("aij,jka->aik", points, rot_mat_T)
+
+
+def rotation_points_single_angle(points, angle, axis=0):
+ # points: [N, 3]
+ rot_sin = np.sin(angle)
+ rot_cos = np.cos(angle)
+ if axis == 1:
+ rot_mat_T = np.array(
+ [[rot_cos, 0, -rot_sin], [0, 1, 0], [rot_sin, 0, rot_cos]],
+ dtype=points.dtype,
+ )
+ elif axis == 2 or axis == -1:
+ rot_mat_T = np.array(
+ [[rot_cos, -rot_sin, 0], [rot_sin, rot_cos, 0], [0, 0, 1]],
+ dtype=points.dtype,
+ )
+ elif axis == 0:
+ rot_mat_T = np.array(
+ [[1, 0, 0], [0, rot_cos, -rot_sin], [0, rot_sin, rot_cos]],
+ dtype=points.dtype,
+ )
+ else:
+ raise ValueError("axis should in range")
+
+ return points @ rot_mat_T
+
+
+def rotation_2d(points, angles):
+ """rotation 2d points based on origin point clockwise when angle positive.
+
+ Args:
+ points (float array, shape=[N, point_size, 2]): points to be rotated.
+ angles (float array, shape=[N]): rotation angle.
+
+ Returns:
+ float array: same shape as points
+ """
+ rot_sin = np.sin(angles)
+ rot_cos = np.cos(angles)
+ rot_mat_T = np.stack([[rot_cos, -rot_sin], [rot_sin, rot_cos]])
+ return np.einsum("aij,jka->aik", points, rot_mat_T)
+
+
+def rotation_box(box_corners, angle):
+ """rotation 2d points based on origin point clockwise when angle positive.
+
+ Args:
+ points (float array, shape=[N, point_size, 2]): points to be rotated.
+ angle (float): rotation angle.
+
+ Returns:
+ float array: same shape as points
+ """
+ rot_sin = np.sin(angle)
+ rot_cos = np.cos(angle)
+ rot_mat_T = np.array(
+ [[rot_cos, -rot_sin], [rot_sin, rot_cos]], dtype=box_corners.dtype
+ )
+ return box_corners @ rot_mat_T
+
+
+def center_to_corner_box3d(centers, dims, angles=None, origin=(0.5, 0.5, 0.5), axis=2):
+ """convert kitti locations, dimensions and angles to corners
+
+ Args:
+ centers (float array, shape=[N, 3]): locations in kitti label file.
+ dims (float array, shape=[N, 3]): dimensions in kitti label file.
+ angles (float array, shape=[N]): rotation_y in kitti label file.
+ origin (list or array or float): origin point relate to smallest point.
+ use [0.5, 1.0, 0.5] in camera and [0.5, 0.5, 0] in lidar.
+ axis (int): rotation axis. 1 for camera and 2 for lidar.
+ Returns:
+ [type]: [description]
+ """
+ # 'length' in kitti format is in x axis.
+ # yzx(hwl)(kitti label file)<->xyz(lhw)(camera)<->z(-x)(-y)(wlh)(lidar)
+ # center in kitti format is [0.5, 1.0, 0.5] in xyz.
+ corners = corners_nd(dims, origin=origin)
+ # corners: [N, 8, 3]
+ if angles is not None:
+ corners = rotation_3d_in_axis(corners, angles, axis=axis)
+ corners += centers.reshape([-1, 1, 3])
+ return corners
+
+
+def center_to_corner_box2d(centers, dims, angles=None, origin=0.5):
+ """convert kitti locations, dimensions and angles to corners.
+ format: center(xy), dims(xy), angles(clockwise when positive)
+
+ Args:
+ centers (float array, shape=[N, 2]): locations in kitti label file.
+ dims (float array, shape=[N, 2]): dimensions in kitti label file.
+ angles (float array, shape=[N]): rotation_y in kitti label file.
+
+ Returns:
+ [type]: [description]
+ """
+ # 'length' in kitti format is in x axis.
+ # xyz(hwl)(kitti label file)<->xyz(lhw)(camera)<->z(-x)(-y)(wlh)(lidar)
+ # center in kitti format is [0.5, 1.0, 0.5] in xyz.
+ corners = corners_nd(dims, origin=origin)
+ # corners: [N, 4, 2]
+ if angles is not None:
+ corners = rotation_2d(corners, angles)
+ corners += centers.reshape([-1, 1, 2])
+ return corners
+
+
+@numba.jit(nopython=True)
+def box2d_to_corner_jit(boxes):
+ num_box = boxes.shape[0]
+ corners_norm = np.zeros((4, 2), dtype=boxes.dtype)
+ corners_norm[1, 1] = 1.0
+ corners_norm[2] = 1.0
+ corners_norm[3, 0] = 1.0
+ corners_norm -= np.array([0.5, 0.5], dtype=boxes.dtype)
+ corners = boxes.reshape(num_box, 1, 5)[:, :, 2:4] * corners_norm.reshape(1, 4, 2)
+ rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
+ box_corners = np.zeros((num_box, 4, 2), dtype=boxes.dtype)
+ for i in range(num_box):
+ rot_sin = np.sin(boxes[i, -1])
+ rot_cos = np.cos(boxes[i, -1])
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 1] = -rot_sin
+ rot_mat_T[1, 0] = rot_sin
+ rot_mat_T[1, 1] = rot_cos
+ box_corners[i] = corners[i] @ rot_mat_T + boxes[i, :2]
+ return box_corners
+
+
+def rbbox3d_to_corners(rbboxes, origin=[0.5, 0.5, 0.5], axis=2):
+ return center_to_corner_box3d(
+ rbboxes[..., :3], rbboxes[..., 3:6], rbboxes[..., 6], origin, axis=axis
+ )
+
+
+def rbbox3d_to_bev_corners(rbboxes, origin=0.5):
+ return center_to_corner_box2d(
+ rbboxes[..., :2], rbboxes[..., 3:5], rbboxes[..., 6], origin
+ )
+
+
+def minmax_to_corner_2d(minmax_box):
+ ndim = minmax_box.shape[-1] // 2
+ center = minmax_box[..., :ndim]
+ dims = minmax_box[..., ndim:] - center
+ return center_to_corner_box2d(center, dims, origin=0.0)
+
+
+def minmax_to_corner_2d_v2(minmax_box):
+ # N, 4 -> N 4 2
+ return minmax_box[..., [0, 1, 0, 3, 2, 3, 2, 1]].reshape(-1, 4, 2)
+
+
+def minmax_to_corner_3d(minmax_box):
+ ndim = minmax_box.shape[-1] // 2
+ center = minmax_box[..., :ndim]
+ dims = minmax_box[..., ndim:] - center
+ return center_to_corner_box3d(center, dims, origin=0.0)
+
+
+def minmax_to_center_2d(minmax_box):
+ ndim = minmax_box.shape[-1] // 2
+ center_min = minmax_box[..., :ndim]
+ dims = minmax_box[..., ndim:] - center_min
+ center = center_min + 0.5 * dims
+ return np.concatenate([center, dims], axis=-1)
+
+
+def center_to_minmax_2d_0_5(centers, dims):
+ return np.concatenate([centers - dims / 2, centers + dims / 2], axis=-1)
+
+
+def center_to_minmax_2d(centers, dims, origin=0.5):
+ if origin == 0.5:
+ return center_to_minmax_2d_0_5(centers, dims)
+ corners = center_to_corner_box2d(centers, dims, origin=origin)
+ return corners[:, [0, 2]].reshape([-1, 4])
+
+
+def limit_period(val, offset=0.5, period=np.pi):
+ return val - np.floor(val / period + offset) * period
+
+
+def projection_matrix_to_CRT_kitti(proj):
+ # P = C @ [R|T]
+ # C is upper triangular matrix, so we need to inverse CR and use QR
+ # stable for all kitti camera projection matrix
+ CR = proj[0:3, 0:3]
+ CT = proj[0:3, 3]
+ RinvCinv = np.linalg.inv(CR)
+ Rinv, Cinv = np.linalg.qr(RinvCinv)
+ C = np.linalg.inv(Cinv)
+ R = np.linalg.inv(Rinv)
+ T = Cinv @ CT
+ return C, R, T
+
+
+def get_frustum(bbox_image, C, near_clip=0.001, far_clip=100):
+ fku = C[0, 0]
+ fkv = -C[1, 1]
+ u0v0 = C[0:2, 2]
+ z_points = np.array([near_clip] * 4 + [far_clip] * 4, dtype=C.dtype)[:, np.newaxis]
+ b = bbox_image
+ box_corners = np.array(
+ [[b[0], b[1]], [b[0], b[3]], [b[2], b[3]], [b[2], b[1]]], dtype=C.dtype
+ )
+ near_box_corners = (box_corners - u0v0) / np.array(
+ [fku / near_clip, -fkv / near_clip], dtype=C.dtype
+ )
+ far_box_corners = (box_corners - u0v0) / np.array(
+ [fku / far_clip, -fkv / far_clip], dtype=C.dtype
+ )
+ ret_xy = np.concatenate([near_box_corners, far_box_corners], axis=0) # [8, 2]
+ ret_xyz = np.concatenate([ret_xy, z_points], axis=1)
+ return ret_xyz
+
+
+def get_frustum_v2(bboxes, C, near_clip=0.001, far_clip=100):
+ fku = C[0, 0]
+ fkv = -C[1, 1]
+ u0v0 = C[0:2, 2]
+ num_box = bboxes.shape[0]
+ z_points = np.array([near_clip] * 4 + [far_clip] * 4, dtype=C.dtype)[
+ np.newaxis, :, np.newaxis
+ ]
+ z_points = np.tile(z_points, [num_box, 1, 1])
+ box_corners = minmax_to_corner_2d_v2(bboxes)
+ near_box_corners = (box_corners - u0v0) / np.array(
+ [fku / near_clip, -fkv / near_clip], dtype=C.dtype
+ )
+ far_box_corners = (box_corners - u0v0) / np.array(
+ [fku / far_clip, -fkv / far_clip], dtype=C.dtype
+ )
+ ret_xy = np.concatenate([near_box_corners, far_box_corners], axis=1) # [8, 2]
+ ret_xyz = np.concatenate([ret_xy, z_points], axis=-1)
+ return ret_xyz
+
+
+@numba.njit
+def _add_rgb_to_points_kernel(points_2d, image, points_rgb):
+ num_points = points_2d.shape[0]
+ image_h, image_w = image.shape[:2]
+ for i in range(num_points):
+ img_pos = np.floor(points_2d[i]).astype(np.int32)
+ if img_pos[0] >= 0 and img_pos[0] < image_w:
+ if img_pos[1] >= 0 and img_pos[1] < image_h:
+ points_rgb[i, :] = image[img_pos[1], img_pos[0], :]
+ # image[img_pos[1], img_pos[0]] = 0
+
+
+def add_rgb_to_points(points, image, rect, Trv2c, P2, mean_size=[5, 5]):
+ kernel = np.ones(mean_size, np.float32) / np.prod(mean_size)
+ # image = cv2.filter2D(image, -1, kernel)
+ points_cam = lidar_to_camera(points[:, :3], rect, Trv2c)
+ points_2d = project_to_image(points_cam, P2)
+ points_rgb = np.zeros([points_cam.shape[0], 3], dtype=points.dtype)
+ _add_rgb_to_points_kernel(points_2d, image, points_rgb)
+ return points_rgb
+
+
+def project_to_image(points_3d, proj_mat):
+ points_shape = list(points_3d.shape)
+ points_shape[-1] = 1
+ points_4 = np.concatenate([points_3d, np.ones(points_shape)], axis=-1)
+ point_2d = points_4 @ proj_mat.T
+ point_2d_res = point_2d[..., :2] / point_2d[..., 2:3]
+ return point_2d_res
+
+
+def camera_to_lidar(points, r_rect, velo2cam):
+ points_shape = list(points.shape[0:-1])
+ if points.shape[-1] == 3:
+ points = np.concatenate([points, np.ones(points_shape + [1])], axis=-1)
+ lidar_points = points @ np.linalg.inv((r_rect @ velo2cam).T)
+ return lidar_points[..., :3]
+
+
+def lidar_to_camera(points, r_rect, velo2cam):
+ points_shape = list(points.shape[:-1])
+ if points.shape[-1] == 3:
+ points = np.concatenate([points, np.ones(points_shape + [1])], axis=-1)
+ camera_points = points @ (r_rect @ velo2cam).T
+ return camera_points[..., :3]
+
+
+def box_camera_to_lidar(data, r_rect, velo2cam):
+ xyz = data[:, 0:3]
+ l, h, w = data[:, 3:4], data[:, 4:5], data[:, 5:6]
+ r = data[:, 6:7]
+ xyz_lidar = camera_to_lidar(xyz, r_rect, velo2cam)
+ return np.concatenate([xyz_lidar, w, l, h, r], axis=1)
+
+
+def box_lidar_to_camera(data, r_rect, velo2cam):
+ xyz_lidar = data[:, 0:3]
+ w, l, h = data[:, 3:4], data[:, 4:5], data[:, 5:6]
+ r = data[:, 6:7]
+ xyz = lidar_to_camera(xyz_lidar, r_rect, velo2cam)
+ return np.concatenate([xyz, l, h, w, r], axis=1)
+
+
+def remove_outside_points(points, rect, Trv2c, P2, image_shape):
+ # 5x faster than remove_outside_points_v1(2ms vs 10ms)
+ C, R, T = projection_matrix_to_CRT_kitti(P2)
+ image_bbox = [0, 0, image_shape[1], image_shape[0]]
+ frustum = get_frustum(image_bbox, C)
+ frustum -= T
+ frustum = np.linalg.inv(R) @ frustum.T
+ frustum = camera_to_lidar(frustum.T, rect, Trv2c)
+ frustum_surfaces = corner_to_surfaces_3d_jit(frustum[np.newaxis, ...])
+ indices = points_in_convex_polygon_3d_jit(points[:, :3], frustum_surfaces)
+ points = points[indices.reshape([-1])]
+ return points
+
+
+@numba.jit(nopython=True)
+def iou_jit(boxes, query_boxes, eps=1.0):
+ """calculate box iou. note that jit version runs 2x faster than cython in
+ my machine!
+ Parameters
+ ----------
+ boxes: (N, 4) ndarray of float
+ query_boxes: (K, 4) ndarray of float
+ Returns
+ -------
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
+ """
+ N = boxes.shape[0]
+ K = query_boxes.shape[0]
+ overlaps = np.zeros((N, K), dtype=boxes.dtype)
+ for k in range(K):
+ box_area = (query_boxes[k, 2] - query_boxes[k, 0] + eps) * (
+ query_boxes[k, 3] - query_boxes[k, 1] + eps
+ )
+ for n in range(N):
+ iw = (
+ min(boxes[n, 2], query_boxes[k, 2])
+ - max(boxes[n, 0], query_boxes[k, 0])
+ + eps
+ )
+ if iw > 0:
+ ih = (
+ min(boxes[n, 3], query_boxes[k, 3])
+ - max(boxes[n, 1], query_boxes[k, 1])
+ + eps
+ )
+ if ih > 0:
+ ua = (
+ (boxes[n, 2] - boxes[n, 0] + eps)
+ * (boxes[n, 3] - boxes[n, 1] + eps)
+ + box_area
+ - iw * ih
+ )
+ overlaps[n, k] = iw * ih / ua
+ return overlaps
+
+
+@numba.jit(nopython=True)
+def iou_3d_jit(boxes, query_boxes, add1=True):
+ """calculate box iou3d,
+ ----------
+ boxes: (N, 6) ndarray of float
+ query_boxes: (K, 6) ndarray of float
+ Returns
+ -------
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
+ """
+ N = boxes.shape[0]
+ K = query_boxes.shape[0]
+ overlaps = np.zeros((N, K), dtype=boxes.dtype)
+ if add1:
+ add1 = 1.0
+ else:
+ add1 = 0.0
+ for k in range(K):
+ box_area = (
+ (query_boxes[k, 3] - query_boxes[k, 0] + add1)
+ * (query_boxes[k, 4] - query_boxes[k, 1] + add1)
+ * (query_boxes[k, 5] - query_boxes[k, 2] + add1)
+ )
+ for n in range(N):
+ iw = (
+ min(boxes[n, 3], query_boxes[k, 3])
+ - max(boxes[n, 0], query_boxes[k, 0])
+ + add1
+ )
+ if iw > 0:
+ ih = (
+ min(boxes[n, 4], query_boxes[k, 4])
+ - max(boxes[n, 1], query_boxes[k, 1])
+ + add1
+ )
+ if ih > 0:
+ il = (
+ min(boxes[n, 5], query_boxes[k, 5])
+ - max(boxes[n, 2], query_boxes[k, 2])
+ + add1
+ )
+ if il > 0:
+ ua = float(
+ (boxes[n, 3] - boxes[n, 0] + add1)
+ * (boxes[n, 4] - boxes[n, 1] + add1)
+ * (boxes[n, 5] - boxes[n, 2] + add1)
+ + box_area
+ - iw * ih * il
+ )
+ overlaps[n, k] = iw * ih * il / ua
+ return overlaps
+
+
+@numba.jit(nopython=True)
+def iou_nd_jit(boxes, query_boxes, add1=True):
+ """calculate box iou nd, 2x slower than iou_jit.
+ ----------
+ boxes: (N, ndim * 2) ndarray of float
+ query_boxes: (K, ndim * 2) ndarray of float
+ Returns
+ -------
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
+ """
+ N = boxes.shape[0]
+ K = query_boxes.shape[0]
+ ndim = boxes.shape[1] // 2
+ overlaps = np.zeros((N, K), dtype=boxes.dtype)
+ side_lengths = np.zeros((ndim,), dtype=boxes.dtype)
+ if add1:
+ add1 = 1.0
+ else:
+ add1 = 0.0
+ invalid = False
+ for k in range(K):
+ qbox_area = query_boxes[k, ndim] - query_boxes[k, 0] + add1
+ for i in range(1, ndim):
+ qbox_area *= query_boxes[k, ndim + i] - query_boxes[k, i] + add1
+ for n in range(N):
+ invalid = False
+ for i in range(ndim):
+ side_length = (
+ min(boxes[n, i + ndim], query_boxes[k, i + ndim])
+ - max(boxes[n, i], query_boxes[k, i])
+ + add1
+ )
+ if side_length <= 0:
+ invalid = True
+ break
+ side_lengths[i] = side_length
+ if not invalid:
+ box_area = boxes[n, ndim] - boxes[n, 0] + add1
+ for i in range(1, ndim):
+ box_area *= boxes[n, ndim + i] - boxes[n, i] + add1
+ inter = side_lengths[0]
+ for i in range(1, ndim):
+ inter *= side_lengths[i]
+ # inter = np.prod(side_lengths)
+ ua = float(box_area + qbox_area - inter)
+ overlaps[n, k] = inter / ua
+
+ return overlaps
+
+
+def points_in_rbbox(points, rbbox, z_axis=2, origin=(0.5, 0.5, 0.5)):
+ rbbox_corners = center_to_corner_box3d(
+ rbbox[:, :3], rbbox[:, 3:6], rbbox[:, -1], origin=origin, axis=z_axis
+ )
+ surfaces = corner_to_surfaces_3d(rbbox_corners)
+ indices = points_in_convex_polygon_3d_jit(points[:, :3], surfaces)
+ return indices
+
+
+def corner_to_surfaces_3d(corners):
+ """convert 3d box corners from corner function above
+ to surfaces that normal vectors all direct to internal.
+
+ Args:
+ corners (float array, [N, 8, 3]): 3d box corners.
+ Returns:
+ surfaces (float array, [N, 6, 4, 3]):
+ """
+ # box_corners: [N, 8, 3], must from corner functions in this module
+ surfaces = np.array(
+ [
+ [corners[:, 0], corners[:, 1], corners[:, 2], corners[:, 3]],
+ [corners[:, 7], corners[:, 6], corners[:, 5], corners[:, 4]],
+ [corners[:, 0], corners[:, 3], corners[:, 7], corners[:, 4]],
+ [corners[:, 1], corners[:, 5], corners[:, 6], corners[:, 2]],
+ [corners[:, 0], corners[:, 4], corners[:, 5], corners[:, 1]],
+ [corners[:, 3], corners[:, 2], corners[:, 6], corners[:, 7]],
+ ]
+ ).transpose([2, 0, 1, 3])
+ return surfaces
+
+
+@numba.jit(nopython=True)
+def corner_to_surfaces_3d_jit(corners):
+ """convert 3d box corners from corner function above
+ to surfaces that normal vectors all direct to internal.
+
+ Args:
+ corners (float array, [N, 8, 3]): 3d box corners.
+ Returns:
+ surfaces (float array, [N, 6, 4, 3]):
+ """
+ # box_corners: [N, 8, 3], must from corner functions in this module
+ num_boxes = corners.shape[0]
+ surfaces = np.zeros((num_boxes, 6, 4, 3), dtype=corners.dtype)
+ corner_idxes = np.array(
+ [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 1, 5, 6, 2, 0, 4, 5, 1, 3, 2, 6, 7]
+ ).reshape(6, 4)
+ for i in range(num_boxes):
+ for j in range(6):
+ for k in range(4):
+ surfaces[i, j, k] = corners[i, corner_idxes[j, k]]
+ return surfaces
+
+
+def assign_label_to_voxel(gt_boxes, coors, voxel_size, coors_range):
+ """assign a 0/1 label to each voxel based on whether
+ the center of voxel is in gt_box. LIDAR.
+ """
+ voxel_size = np.array(voxel_size, dtype=gt_boxes.dtype)
+ coors_range = np.array(coors_range, dtype=gt_boxes.dtype)
+ shift = coors_range[:3]
+ voxel_origins = coors[:, ::-1] * voxel_size + shift
+ voxel_centers = voxel_origins + voxel_size * 0.5
+ gt_box_corners = center_to_corner_box3d(
+ gt_boxes[:, :3] - voxel_size * 0.5,
+ gt_boxes[:, 3:6] + voxel_size,
+ gt_boxes[:, 6],
+ origin=[0.5, 0.5, 0.5],
+ axis=2,
+ )
+ gt_surfaces = corner_to_surfaces_3d(gt_box_corners)
+ ret = points_in_convex_polygon_3d_jit(voxel_centers, gt_surfaces)
+ return np.any(ret, axis=1).astype(np.int64)
+
+
+def assign_label_to_voxel_v3(gt_boxes, coors, voxel_size, coors_range):
+ """assign a 0/1 label to each voxel based on whether
+ the center of voxel is in gt_box. LIDAR.
+ """
+ voxel_size = np.array(voxel_size, dtype=gt_boxes.dtype)
+ coors_range = np.array(coors_range, dtype=gt_boxes.dtype)
+ shift = coors_range[:3]
+ voxel_origins = coors[:, ::-1] * voxel_size + shift
+ voxel_maxes = voxel_origins + voxel_size
+ voxel_minmax = np.concatenate([voxel_origins, voxel_maxes], axis=-1)
+ voxel_corners = minmax_to_corner_3d(voxel_minmax)
+ gt_box_corners = center_to_corner_box3d(
+ gt_boxes[:, :3],
+ gt_boxes[:, 3:6],
+ gt_boxes[:, 6],
+ origin=[0.5, 0.5, 0.5],
+ axis=2,
+ )
+ gt_surfaces = corner_to_surfaces_3d(gt_box_corners)
+ voxel_corners_flat = voxel_corners.reshape([-1, 3])
+ ret = points_in_convex_polygon_3d_jit(voxel_corners_flat, gt_surfaces)
+ ret = ret.reshape([-1, 8, ret.shape[-1]])
+ return ret.any(-1).any(-1).astype(np.int64)
+
+
+def image_box_region_area(img_cumsum, bbox):
+ """check a 2d voxel is contained by a box. used to filter empty
+ anchors.
+ Summed-area table algorithm:
+ ==> W
+ ------------------
+ | | |
+ |------A---------B
+ | | |
+ | | |
+ |----- C---------D
+ Iabcd = ID-IB-IC+IA
+ Args:
+ img_cumsum: [M, H, W](yx) cumsumed image.
+ bbox: [N, 4](xyxy) bounding box,
+ """
+ N = bbox.shape[0]
+ M = img_cumsum.shape[0]
+ ret = np.zeros([N, M], dtype=img_cumsum.dtype)
+ ID = img_cumsum[:, bbox[:, 3], bbox[:, 2]]
+ IA = img_cumsum[:, bbox[:, 1], bbox[:, 0]]
+ IB = img_cumsum[:, bbox[:, 3], bbox[:, 0]]
+ IC = img_cumsum[:, bbox[:, 1], bbox[:, 2]]
+ ret = ID - IB - IC + IA
+ return ret
+
+
+def get_minimum_bounding_box_bv(points, voxel_size, bound, downsample=8, margin=1.6):
+ x_vsize = voxel_size[0]
+ y_vsize = voxel_size[1]
+ max_x = points[:, 0].max()
+ max_y = points[:, 1].max()
+ min_x = points[:, 0].min()
+ min_y = points[:, 1].min()
+ max_x = np.floor(max_x / (x_vsize * downsample) + 1) * (x_vsize * downsample)
+ max_y = np.floor(max_y / (y_vsize * downsample) + 1) * (y_vsize * downsample)
+ min_x = np.floor(min_x / (x_vsize * downsample)) * (x_vsize * downsample)
+ min_y = np.floor(min_y / (y_vsize * downsample)) * (y_vsize * downsample)
+ max_x = np.minimum(max_x + margin, bound[2])
+ max_y = np.minimum(max_y + margin, bound[3])
+ min_x = np.maximum(min_x - margin, bound[0])
+ min_y = np.maximum(min_y - margin, bound[1])
+ return np.array([min_x, min_y, max_x, max_y])
+
+
+def box3d_to_bbox(box3d, rect, Trv2c, P2):
+ box3d_to_cam = box_lidar_to_camera(box3d, rect, Trv2c)
+ box_corners = center_to_corner_box3d(
+ box3d[:, :3], box3d[:, 3:6], box3d[:, 6], [0.5, 1.0, 0.5], axis=1
+ )
+ box_corners_in_image = project_to_image(box_corners, P2)
+ # box_corners_in_image: [N, 8, 2]
+ minxy = np.min(box_corners_in_image, axis=1)
+ maxxy = np.max(box_corners_in_image, axis=1)
+ bbox = np.concatenate([minxy, maxxy], axis=1)
+ return bbox
+
+
+def change_box3d_center_(box3d, src, dst):
+ dst = np.array(dst, dtype=box3d.dtype)
+ src = np.array(src, dtype=box3d.dtype)
+ box3d[..., :3] += box3d[..., 3:6] * (dst - src)
diff --git a/det3d/core/bbox/box_torch_ops.py b/det3d/core/bbox/box_torch_ops.py
new file mode 100644
index 0000000..18a3b12
--- /dev/null
+++ b/det3d/core/bbox/box_torch_ops.py
@@ -0,0 +1,277 @@
+import math
+from functools import reduce
+
+import numpy as np
+import torch
+from torch import stack as tstack
+try:
+ from det3d.ops.iou3d_nms import iou3d_nms_cuda, iou3d_nms_utils
+except:
+ print("iou3d cuda not built. You don't need this if you use circle_nms. Otherwise, refer to the advanced installation part to build this cuda extension")
+
+def torch_to_np_dtype(ttype):
+ type_map = {
+ torch.float16: np.dtype(np.float16),
+ torch.float32: np.dtype(np.float32),
+ torch.float16: np.dtype(np.float64),
+ torch.int32: np.dtype(np.int32),
+ torch.int64: np.dtype(np.int64),
+ torch.uint8: np.dtype(np.uint8),
+ }
+ return type_map[ttype]
+
+
+def corners_nd(dims, origin=0.5):
+ """generate relative box corners based on length per dim and
+ origin point.
+
+ Args:
+ dims (float array, shape=[N, ndim]): array of length per dim
+ origin (list or array or float): origin point relate to smallest point.
+ dtype (output dtype, optional): Defaults to np.float32
+
+ Returns:
+ float array, shape=[N, 2 ** ndim, ndim]: returned corners.
+ point layout example: (2d) x0y0, x0y1, x1y0, x1y1;
+ (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
+ where x0 < x1, y0 < y1, z0 < z1
+ """
+ ndim = int(dims.shape[1])
+ dtype = torch_to_np_dtype(dims.dtype)
+ if isinstance(origin, float):
+ origin = [origin] * ndim
+ corners_norm = np.stack(
+ np.unravel_index(np.arange(2 ** ndim), [2] * ndim), axis=1
+ ).astype(dtype)
+ # now corners_norm has format: (2d) x0y0, x0y1, x1y0, x1y1
+ # (3d) x0y0z0, x0y0z1, x0y1z0, x0y1z1, x1y0z0, x1y0z1, x1y1z0, x1y1z1
+ # so need to convert to a format which is convenient to do other computing.
+ # for 2d boxes, format is clockwise start from minimum point
+ # for 3d boxes, please draw them by your hand.
+ if ndim == 2:
+ # generate clockwise box corners
+ corners_norm = corners_norm[[0, 1, 3, 2]]
+ elif ndim == 3:
+ corners_norm = corners_norm[[0, 1, 3, 2, 4, 5, 7, 6]]
+ corners_norm = corners_norm - np.array(origin, dtype=dtype)
+ corners_norm = torch.from_numpy(corners_norm).type_as(dims)
+ corners = dims.view(-1, 1, ndim) * corners_norm.view(1, 2 ** ndim, ndim)
+ return corners
+
+
+def corners_2d(dims, origin=0.5):
+ """generate relative 2d box corners based on length per dim and
+ origin point.
+
+ Args:
+ dims (float array, shape=[N, 2]): array of length per dim
+ origin (list or array or float): origin point relate to smallest point.
+ dtype (output dtype, optional): Defaults to np.float32
+
+ Returns:
+ float array, shape=[N, 4, 2]: returned corners.
+ point layout: x0y0, x0y1, x1y1, x1y0
+ """
+ return corners_nd(dims, origin)
+
+
+def corner_to_standup_nd(boxes_corner):
+ ndim = boxes_corner.shape[2]
+ standup_boxes = []
+ for i in range(ndim):
+ standup_boxes.append(torch.min(boxes_corner[:, :, i], dim=1)[0])
+ for i in range(ndim):
+ standup_boxes.append(torch.max(boxes_corner[:, :, i], dim=1)[0])
+ return torch.stack(standup_boxes, dim=1)
+
+
+def rotation_3d_in_axis(points, angles, axis=0):
+ # points: [N, point_size, 3]
+ # angles: [N]
+ rot_sin = torch.sin(angles)
+ rot_cos = torch.cos(angles)
+ ones = torch.ones_like(rot_cos)
+ zeros = torch.zeros_like(rot_cos)
+ if axis == 1:
+ rot_mat_T = tstack(
+ [
+ tstack([rot_cos, zeros, -rot_sin]),
+ tstack([zeros, ones, zeros]),
+ tstack([rot_sin, zeros, rot_cos]),
+ ]
+ )
+ elif axis == 2 or axis == -1:
+ rot_mat_T = tstack(
+ [
+ tstack([rot_cos, -rot_sin, zeros]),
+ tstack([rot_sin, rot_cos, zeros]),
+ tstack([zeros, zeros, ones]),
+ ]
+ )
+ elif axis == 0:
+ rot_mat_T = tstack(
+ [
+ tstack([zeros, rot_cos, -rot_sin]),
+ tstack([zeros, rot_sin, rot_cos]),
+ tstack([ones, zeros, zeros]),
+ ]
+ )
+ else:
+ raise ValueError("axis should in range")
+ # print(points.shape, rot_mat_T.shape)
+ return torch.einsum("aij,jka->aik", points, rot_mat_T)
+
+def rotate_points_along_z(points, angle):
+ """
+ Args:
+ points: (B, N, 3 + C)
+ angle: (B), angle along z-axis, angle increases x ==> y
+ Returns:
+ """
+ cosa = torch.cos(angle)
+ sina = torch.sin(angle)
+ zeros = angle.new_zeros(points.shape[0])
+ ones = angle.new_ones(points.shape[0])
+ rot_matrix = torch.stack((
+ cosa, -sina, zeros,
+ sina, cosa, zeros,
+ zeros, zeros, ones
+ ), dim=1).view(-1, 3, 3).float()
+ points_rot = torch.matmul(points[:, :, 0:3], rot_matrix)
+ points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1)
+ return points_rot
+
+
+def rotation_2d(points, angles):
+ """rotation 2d points based on origin point clockwise when angle positive.
+
+ Args:
+ points (float array, shape=[N, point_size, 2]): points to be rotated.
+ angles (float array, shape=[N]): rotation angle.
+
+ Returns:
+ float array: same shape as points
+ """
+ rot_sin = torch.sin(angles)
+ rot_cos = torch.cos(angles)
+ rot_mat_T = torch.stack([tstack([rot_cos, -rot_sin]), tstack([rot_sin, rot_cos])])
+ return torch.einsum("aij,jka->aik", (points, rot_mat_T))
+
+
+def center_to_corner_box3d(centers, dims, angles, origin=(0.5, 0.5, 0.5), axis=1):
+ """convert kitti locations, dimensions and angles to corners
+
+ Args:
+ centers (float array, shape=[N, 3]): locations in kitti label file.
+ dims (float array, shape=[N, 3]): dimensions in kitti label file.
+ angles (float array, shape=[N]): rotation_y in kitti label file.
+ origin (list or array or float): origin point relate to smallest point.
+ use [0.5, 1.0, 0.5] in camera and [0.5, 0.5, 0] in lidar.
+ axis (int): rotation axis. 1 for camera and 2 for lidar.
+ Returns:
+ [type]: [description]
+ """
+ # 'length' in kitti format is in x axis.
+ # yzx(hwl)(kitti label file)<->xyz(lhw)(camera)<->z(-x)(-y)(wlh)(lidar)
+ # center in kitti format is [0.5, 1.0, 0.5] in xyz.
+ corners = corners_nd(dims, origin=origin)
+ # corners: [N, 8, 3]
+ corners = rotation_3d_in_axis(corners, angles, axis=axis)
+ corners += centers.view(-1, 1, 3)
+ return corners
+
+
+def center_to_corner_box2d(centers, dims, angles=None, origin=0.5):
+ """convert kitti locations, dimensions and angles to corners
+
+ Args:
+ centers (float array, shape=[N, 2]): locations in kitti label file.
+ dims (float array, shape=[N, 2]): dimensions in kitti label file.
+ angles (float array, shape=[N]): rotation_y in kitti label file.
+
+ Returns:
+ [type]: [description]
+ """
+ # 'length' in kitti format is in x axis.
+ # xyz(hwl)(kitti label file)<->xyz(lhw)(camera)<->z(-x)(-y)(wlh)(lidar)
+ # center in kitti format is [0.5, 1.0, 0.5] in xyz.
+ corners = corners_nd(dims, origin=origin)
+ # corners: [N, 4, 2]
+ if angles is not None:
+ corners = rotation_2d(corners, angles)
+ corners += centers.view(-1, 1, 2)
+ return corners
+
+
+def project_to_image(points_3d, proj_mat):
+ points_num = list(points_3d.shape)[:-1]
+ points_shape = np.concatenate([points_num, [1]], axis=0).tolist()
+ points_4 = torch.cat(
+ [points_3d, torch.ones(*points_shape).type_as(points_3d)], dim=-1
+ )
+ # point_2d = points_4 @ tf.transpose(proj_mat, [1, 0])
+ point_2d = torch.matmul(points_4, proj_mat.t())
+ point_2d_res = point_2d[..., :2] / point_2d[..., 2:3]
+ return point_2d_res
+
+
+def camera_to_lidar(points, r_rect, velo2cam):
+ num_points = points.shape[0]
+ points = torch.cat([points, torch.ones(num_points, 1).type_as(points)], dim=-1)
+ lidar_points = points @ torch.inverse((r_rect @ velo2cam).t())
+ return lidar_points[..., :3]
+
+
+def lidar_to_camera(points, r_rect, velo2cam):
+ num_points = points.shape[0]
+ points = torch.cat([points, torch.ones(num_points, 1).type_as(points)], dim=-1)
+ camera_points = points @ (r_rect @ velo2cam).t()
+ return camera_points[..., :3]
+
+
+def box_camera_to_lidar(data, r_rect, velo2cam):
+ xyz = data[..., 0:3]
+ l, h, w = data[..., 3:4], data[..., 4:5], data[..., 5:6]
+ r = data[..., 6:7]
+ xyz_lidar = camera_to_lidar(xyz, r_rect, velo2cam)
+ return torch.cat([xyz_lidar, w, l, h, r], dim=-1)
+
+
+def box_lidar_to_camera(data, r_rect, velo2cam):
+ xyz_lidar = data[..., 0:3]
+ w, l, h = data[..., 3:4], data[..., 4:5], data[..., 5:6]
+ r = data[..., 6:7]
+ xyz = lidar_to_camera(xyz_lidar, r_rect, velo2cam)
+ return torch.cat([xyz, l, h, w, r], dim=-1)
+
+
+def rotate_nms_pcdet(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
+ """
+ :param boxes: (N, 5) [x, y, z, l, w, h, theta]
+ :param scores: (N)
+ :param thresh:
+ :return:
+ """
+ # transform back to pcdet's coordinate
+ boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]]
+ boxes[:, -1] = -boxes[:, -1] - np.pi /2
+
+ order = scores.sort(0, descending=True)[1]
+ if pre_maxsize is not None:
+ order = order[:pre_maxsize]
+
+ boxes = boxes[order].contiguous()
+
+ keep = torch.LongTensor(boxes.size(0))
+
+ if len(boxes) == 0:
+ num_out =0
+ else:
+ num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh)
+
+ selected = order[keep[:num_out].cuda()].contiguous()
+
+ if post_max_size is not None:
+ selected = selected[:post_max_size]
+
+ return selected
\ No newline at end of file
diff --git a/det3d/core/bbox/geometry.py b/det3d/core/bbox/geometry.py
new file mode 100644
index 0000000..a62ee7b
--- /dev/null
+++ b/det3d/core/bbox/geometry.py
@@ -0,0 +1,457 @@
+import numba
+import numpy as np
+
+
+@numba.njit
+def _points_count_convex_polygon_3d_jit(
+ points, polygon_surfaces, normal_vec, d, num_surfaces=None
+):
+ """count points in 3d convex polygons.
+ Args:
+ points: [num_points, 3] array.
+ polygon_surfaces: [num_polygon, max_num_surfaces,
+ max_num_points_of_surface, 3]
+ array. all surfaces' normal vector must direct to internal.
+ max_num_points_of_surface must at least 3.
+ num_surfaces: [num_polygon] array. indicate how many surfaces
+ a polygon contain
+ Returns:
+ [num_polygon] array.
+ """
+ max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
+ num_points = points.shape[0]
+ num_polygons = polygon_surfaces.shape[0]
+ ret = np.full((num_polygons,), num_points, dtype=np.int64)
+ sign = 0.0
+ for i in range(num_points):
+ for j in range(num_polygons):
+ for k in range(max_num_surfaces):
+ if k > num_surfaces[j]:
+ break
+ sign = (
+ points[i, 0] * normal_vec[j, k, 0]
+ + points[i, 1] * normal_vec[j, k, 1]
+ + points[i, 2] * normal_vec[j, k, 2]
+ + d[j, k]
+ )
+ if sign >= 0:
+ ret[j] -= 1
+ break
+ return ret
+
+
+def points_count_convex_polygon_3d_jit(points, polygon_surfaces, num_surfaces=None):
+ """check points is in 3d convex polygons.
+ Args:
+ points: [num_points, 3] array.
+ polygon_surfaces: [num_polygon, max_num_surfaces,
+ max_num_points_of_surface, 3]
+ array. all surfaces' normal vector must direct to internal.
+ max_num_points_of_surface must at least 3.
+ num_surfaces: [num_polygon] array. indicate how many surfaces
+ a polygon contain
+ Returns:
+ [num_polygon] array.
+ """
+ max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
+ num_points = points.shape[0]
+ num_polygons = polygon_surfaces.shape[0]
+ if num_surfaces is None:
+ num_surfaces = np.full((num_polygons,), 9999999, dtype=np.int64)
+ normal_vec, d = surface_equ_3d_jitv2(polygon_surfaces[:, :, :3, :])
+ # normal_vec: [num_polygon, max_num_surfaces, 3]
+ # d: [num_polygon, max_num_surfaces]
+ return _points_count_convex_polygon_3d_jit(
+ points, polygon_surfaces, normal_vec, d, num_surfaces
+ )
+
+
+@numba.njit
+def is_line_segment_intersection_jit(lines1, lines2):
+ """check if line segments1 and line segments2 have cross point
+
+ Args:
+ lines1 (float, [N, 2, 2]): [description]
+ lines2 (float, [M, 2, 2]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+
+ # Return true if line segments AB and CD intersect
+ N = lines1.shape[0]
+ M = lines2.shape[0]
+ ret = np.zeros((N, M), dtype=np.bool_)
+ for i in range(N):
+ for j in range(M):
+ A = lines1[i, 0]
+ B = lines1[i, 1]
+ C = lines2[j, 0]
+ D = lines2[j, 1]
+ acd = (D[1] - A[1]) * (C[0] - A[0]) > (C[1] - A[1]) * (D[0] - A[0])
+ bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0])
+ if acd != bcd:
+ abc = (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])
+ abd = (D[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (D[0] - A[0])
+ if abc != abd:
+ ret[i, j] = True
+ return ret
+
+
+@numba.njit
+def line_segment_intersection(line1, line2, intersection):
+ A = line1[0]
+ B = line1[1]
+ C = line2[0]
+ D = line2[1]
+ BA0 = B[0] - A[0]
+ BA1 = B[1] - A[1]
+ DA0 = D[0] - A[0]
+ CA0 = C[0] - A[0]
+ DA1 = D[1] - A[1]
+ CA1 = C[1] - A[1]
+ acd = DA1 * CA0 > CA1 * DA0
+ bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (D[0] - B[0])
+ if acd != bcd:
+ abc = CA1 * BA0 > BA1 * CA0
+ abd = DA1 * BA0 > BA1 * DA0
+ if abc != abd:
+ DC0 = D[0] - C[0]
+ DC1 = D[1] - C[1]
+ ABBA = A[0] * B[1] - B[0] * A[1]
+ CDDC = C[0] * D[1] - D[0] * C[1]
+ DH = BA1 * DC0 - BA0 * DC1
+ intersection[0] = (ABBA * DC0 - BA0 * CDDC) / DH
+ intersection[1] = (ABBA * DC1 - BA1 * CDDC) / DH
+ return True
+ return False
+
+
+def _ccw(A, B, C):
+ return (C[..., 1] - A[..., 1]) * (B[..., 0] - A[..., 0]) > (
+ B[..., 1] - A[..., 1]
+ ) * (C[..., 0] - A[..., 0])
+
+
+def is_line_segment_cross(lines1, lines2):
+ # 10x slower than jit version with 1000-1000 random lines input.
+ # lines1, [N, 2, 2]
+ # lines2, [M, 2, 2]
+ A = lines1[:, 0, :][:, np.newaxis, :]
+ B = lines1[:, 1, :][:, np.newaxis, :]
+ C = lines2[:, 0, :][np.newaxis, :, :]
+ D = lines2[:, 1, :][np.newaxis, :, :]
+ return np.logical_and(
+ _ccw(A, C, D) != _ccw(B, C, D), _ccw(A, B, C) != _ccw(A, B, D)
+ )
+
+
+@numba.jit(nopython=False)
+def surface_equ_3d_jit(polygon_surfaces):
+ # return [a, b, c], d in ax+by+cz+d=0
+ # polygon_surfaces: [num_polygon, num_surfaces, num_points_of_polygon, 3]
+ surface_v = polygon_surfaces[:, :, :2, :] - polygon_surfaces[:, :, 1:3, :]
+ # normal_vec: [..., 3]
+ normal_v = np.cross(surface_v[:, :, 0, :], surface_v[:, :, 1, :])
+ # print(normal_vec.shape, points[..., 0, :].shape)
+ # d = -np.inner(normal_vec, points[..., 0, :])
+ d = np.einsum("aij, aij->ai", normal_v, polygon_surfaces[:, :, 0, :])
+ return normal_vec, -d
+
+
+@numba.jit(nopython=False)
+def points_in_convex_polygon_3d_jit_v1(points, polygon_surfaces, num_surfaces=None):
+ """check points is in 3d convex polygons.
+ Args:
+ points: [num_points, 3] array.
+ polygon_surfaces: [num_polygon, max_num_surfaces,
+ max_num_points_of_surface, 3]
+ array. all surfaces' normal vector must direct to internal.
+ max_num_points_of_surface must at least 3.
+ num_surfaces: [num_polygon] array. indicate how many surfaces
+ a polygon contain
+ Returns:
+ [num_points, num_polygon] bool array.
+ """
+ max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
+ num_points = points.shape[0]
+ num_polygons = polygon_surfaces.shape[0]
+ if num_surfaces is None:
+ num_surfaces = np.full((num_polygons,), 9999999, dtype=np.int64)
+ normal_vec, d = surface_equ_3d_jit(polygon_surfaces[:, :, :3, :])
+ # normal_vec: [num_polygon, max_num_surfaces, 3]
+ # d: [num_polygon, max_num_surfaces]
+ ret = np.ones((num_points, num_polygons), dtype=np.bool_)
+ sign = 0.0
+ for i in range(num_points):
+ for j in range(num_polygons):
+ for k in range(max_num_surfaces):
+ if k > num_surfaces[j]:
+ break
+ sign = (
+ points[i, 0] * normal_vec[j, k, 0]
+ + points[i, 1] * normal_vec[j, k, 1]
+ + points[i, 2] * normal_vec[j, k, 2]
+ + d[j, k]
+ )
+ if sign >= 0:
+ ret[i, j] = False
+ break
+ return ret
+
+
+def surface_equ_3d(polygon_surfaces):
+ # return [a, b, c], d in ax+by+cz+d=0
+ # polygon_surfaces: [num_polygon, num_surfaces, num_points_of_polygon, 3]
+ surface_v = polygon_surfaces[:, :, :2, :] - polygon_surfaces[:, :, 1:3, :]
+ # normal_vec: [..., 3]
+ normal_v = np.cross(surface_v[:, :, 0, :], surface_v[:, :, 1, :])
+ # print(normal_vec.shape, points[..., 0, :].shape)
+ # d = -np.inner(normal_vec, points[..., 0, :])
+ d = np.einsum("aij, aij->ai", normal_v, polygon_surfaces[:, :, 0, :])
+ return normal_v, -d
+
+
+def points_in_convex_polygon_3d_jit(points, polygon_surfaces, num_surfaces=None):
+ """check points is in 3d convex polygons.
+ Args:
+ points: [num_points, 3] array.
+ polygon_surfaces: [num_polygon, max_num_surfaces,
+ max_num_points_of_surface, 3]
+ array. all surfaces' normal vector must direct to internal.
+ max_num_points_of_surface must at least 3.
+ num_surfaces: [num_polygon] array. indicate how many surfaces
+ a polygon contain
+ Returns:
+ [num_points, num_polygon] bool array.
+ """
+ max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
+ num_points = points.shape[0]
+ num_polygons = polygon_surfaces.shape[0]
+ if num_surfaces is None:
+ num_surfaces = np.full((num_polygons,), 9999999, dtype=np.int64)
+ normal_vec, d = surface_equ_3d_jitv2(polygon_surfaces[:, :, :3, :])
+ # normal_vec: [num_polygon, max_num_surfaces, 3]
+ # d: [num_polygon, max_num_surfaces]
+ return _points_in_convex_polygon_3d_jit(
+ points, polygon_surfaces, normal_vec, d, num_surfaces
+ )
+
+
+@numba.njit
+def _points_in_convex_polygon_3d_jit(
+ points, polygon_surfaces, normal_vec, d, num_surfaces=None
+):
+ """check points is in 3d convex polygons.
+ Args:
+ points: [num_points, 3] array.
+ polygon_surfaces: [num_polygon, max_num_surfaces,
+ max_num_points_of_surface, 3]
+ array. all surfaces' normal vector must direct to internal.
+ max_num_points_of_surface must at least 3.
+ num_surfaces: [num_polygon] array. indicate how many surfaces
+ a polygon contain
+ Returns:
+ [num_points, num_polygon] bool array.
+ """
+ max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
+ num_points = points.shape[0]
+ num_polygons = polygon_surfaces.shape[0]
+ ret = np.ones((num_points, num_polygons), dtype=np.bool_)
+ sign = 0.0
+ for i in range(num_points):
+ for j in range(num_polygons):
+ for k in range(max_num_surfaces):
+ if k > num_surfaces[j]:
+ break
+ sign = (
+ points[i, 0] * normal_vec[j, k, 0]
+ + points[i, 1] * normal_vec[j, k, 1]
+ + points[i, 2] * normal_vec[j, k, 2]
+ + d[j, k]
+ )
+ if sign >= 0:
+ ret[i, j] = False
+ break
+ return ret
+
+
+@numba.jit
+def points_in_convex_polygon_jit(points, polygon, clockwise=True):
+ """check points is in 2d convex polygons. True when point in polygon
+ Args:
+ points: [num_points, 2] array.
+ polygon: [num_polygon, num_points_of_polygon, 2] array.
+ clockwise: bool. indicate polygon is clockwise.
+ Returns:
+ [num_points, num_polygon] bool array.
+ """
+ # first convert polygon to directed lines
+ num_points_of_polygon = polygon.shape[1]
+ num_points = points.shape[0]
+ num_polygons = polygon.shape[0]
+ if clockwise:
+ vec1 = (
+ polygon
+ - polygon[
+ :,
+ [num_points_of_polygon - 1] + list(range(num_points_of_polygon - 1)),
+ :,
+ ]
+ )
+ else:
+ vec1 = (
+ polygon[
+ :,
+ [num_points_of_polygon - 1] + list(range(num_points_of_polygon - 1)),
+ :,
+ ]
+ - polygon
+ )
+ # vec1: [num_polygon, num_points_of_polygon, 2]
+ ret = np.zeros((num_points, num_polygons), dtype=np.bool_)
+ success = True
+ cross = 0.0
+ for i in range(num_points):
+ for j in range(num_polygons):
+ success = True
+ for k in range(num_points_of_polygon):
+ cross = vec1[j, k, 1] * (polygon[j, k, 0] - points[i, 0])
+ cross -= vec1[j, k, 0] * (polygon[j, k, 1] - points[i, 1])
+ if cross >= 0:
+ success = False
+ break
+ ret[i, j] = success
+ return ret
+
+
+def points_in_convex_polygon(points, polygon, clockwise=True):
+ """check points is in convex polygons. may run 2x faster when write in
+ cython(don't need to calculate all cross-product between edge and point)
+ Args:
+ points: [num_points, 2] array.
+ polygon: [num_polygon, num_points_of_polygon, 2] array.
+ clockwise: bool. indicate polygon is clockwise.
+ Returns:
+ [num_points, num_polygon] bool array.
+ """
+ # first convert polygon to directed lines
+ num_lines = polygon.shape[1]
+ polygon_next = polygon[:, [num_lines - 1] + list(range(num_lines - 1)), :]
+ if clockwise:
+ vec1 = (polygon - polygon_next)[np.newaxis, ...]
+ else:
+ vec1 = (polygon_next - polygon)[np.newaxis, ...]
+ vec2 = polygon[np.newaxis, ...] - points[:, np.newaxis, np.newaxis, :]
+ # [num_points, num_polygon, num_points_of_polygon, 2]
+ cross = np.cross(vec1, vec2)
+ return np.all(cross > 0, axis=2)
+
+
+@numba.njit
+def surface_equ_3d_jitv2(surfaces):
+ # polygon_surfaces: [num_polygon, num_surfaces, num_points_of_polygon, 3]
+ num_polygon = surfaces.shape[0]
+ max_num_surfaces = surfaces.shape[1]
+ normal_vec = np.zeros((num_polygon, max_num_surfaces, 3), dtype=surfaces.dtype)
+ d = np.zeros((num_polygon, max_num_surfaces), dtype=surfaces.dtype)
+ sv0 = surfaces[0, 0, 0] - surfaces[0, 0, 1]
+ sv1 = surfaces[0, 0, 0] - surfaces[0, 0, 1]
+ for i in range(num_polygon):
+ for j in range(max_num_surfaces):
+ sv0[0] = surfaces[i, j, 0, 0] - surfaces[i, j, 1, 0]
+ sv0[1] = surfaces[i, j, 0, 1] - surfaces[i, j, 1, 1]
+ sv0[2] = surfaces[i, j, 0, 2] - surfaces[i, j, 1, 2]
+ sv1[0] = surfaces[i, j, 1, 0] - surfaces[i, j, 2, 0]
+ sv1[1] = surfaces[i, j, 1, 1] - surfaces[i, j, 2, 1]
+ sv1[2] = surfaces[i, j, 1, 2] - surfaces[i, j, 2, 2]
+ normal_vec[i, j, 0] = sv0[1] * sv1[2] - sv0[2] * sv1[1]
+ normal_vec[i, j, 1] = sv0[2] * sv1[0] - sv0[0] * sv1[2]
+ normal_vec[i, j, 2] = sv0[0] * sv1[1] - sv0[1] * sv1[0]
+
+ d[i, j] = (
+ -surfaces[i, j, 0, 0] * normal_vec[i, j, 0]
+ - surfaces[i, j, 0, 1] * normal_vec[i, j, 1]
+ - surfaces[i, j, 0, 2] * normal_vec[i, j, 2]
+ )
+ return normal_vec, d
+
+
+@numba.njit
+def _points_in_convex_polygon_3d_jit_v2(points, surfaces):
+ max_num_surfaces, max_num_points_of_surface = polygon_surfaces.shape[1:3]
+ num_points = points.shape[0]
+ num_polygons = polygon_surfaces.shape[0]
+ ret = np.ones((num_points, num_polygons), dtype=np.bool_)
+ sign = 0.0
+ for i in range(num_points):
+ for j in range(num_polygons):
+ for k in range(max_num_surfaces):
+ if k > num_surfaces[j]:
+ break
+ sign = (
+ points[i, 0] * normal_vec[j, k, 0]
+ + points[i, 1] * normal_vec[j, k, 1]
+ + points[i, 2] * normal_vec[j, k, 2]
+ + d[j, k]
+ )
+ if sign >= 0:
+ ret[i, j] = False
+ break
+ return ret
+
+
+@numba.njit
+def points_in_convex_polygon_3d_jit_v2(points, surfaces, num_surfaces=None):
+ """check points is in 3d convex polygons.
+ Args:
+ points: [num_points, 3] array.
+ polygon_surfaces: [num_polygon, max_num_surfaces,
+ max_num_points_of_surface, 3]
+ array. all surfaces' normal vector must direct to internal.
+ max_num_points_of_surface must at least 3.
+ num_surfaces: [num_polygon] array. indicate how many surfaces
+ a polygon contain
+ Returns:
+ [num_points, num_polygon] bool array.
+ """
+ num_polygon = surfaces.shape[0]
+ max_num_surfaces = surfaces.shape[1]
+ num_points = points.shape[0]
+ normal_vec = np.zeros((num_polygon, max_num_surfaces, 3), dtype=surfaces.dtype)
+ d = np.zeros((num_polygon, max_num_surfaces), dtype=surfaces.dtype)
+ sv0 = surfaces[0, 0, 0] - surfaces[0, 0, 1]
+ sv1 = surfaces[0, 0, 0] - surfaces[0, 0, 1]
+ ret = np.ones((num_points, num_polygon), dtype=np.bool_)
+ for i in range(num_polygon):
+ for j in range(max_num_surfaces):
+ sv0[0] = surfaces[i, j, 0, 0] - surfaces[i, j, 1, 0]
+ sv0[1] = surfaces[i, j, 0, 1] - surfaces[i, j, 1, 1]
+ sv0[2] = surfaces[i, j, 0, 2] - surfaces[i, j, 1, 2]
+ sv1[0] = surfaces[i, j, 1, 0] - surfaces[i, j, 2, 0]
+ sv1[1] = surfaces[i, j, 1, 1] - surfaces[i, j, 2, 1]
+ sv1[2] = surfaces[i, j, 1, 2] - surfaces[i, j, 2, 2]
+ normal_vec[i, j, 0] = sv0[1] * sv1[2] - sv0[2] * sv1[1]
+ normal_vec[i, j, 1] = sv0[2] * sv1[0] - sv0[0] * sv1[2]
+ normal_vec[i, j, 2] = sv0[0] * sv1[1] - sv0[1] * sv1[0]
+
+ d[i, j] = (
+ -surfaces[i, j, 0, 0] * normal_vec[i, j, 0]
+ - surfaces[i, j, 0, 1] * normal_vec[i, j, 1]
+ - surfaces[i, j, 0, 2] * normal_vec[i, j, 2]
+ )
+
+ sign = 0.0
+ for i in range(num_points):
+ for j in range(num_polygon):
+ for k in range(max_num_surfaces):
+ sign = (
+ points[i, 0] * normal_vec[j, k, 0]
+ + points[i, 1] * normal_vec[j, k, 1]
+ + points[i, 2] * normal_vec[j, k, 2]
+ + d[j, k]
+ )
+ if sign >= 0:
+ ret[i, j] = False
+ break
+ return ret
diff --git a/det3d/core/input/__init__.py b/det3d/core/input/__init__.py
new file mode 100644
index 0000000..4ae403f
--- /dev/null
+++ b/det3d/core/input/__init__.py
@@ -0,0 +1 @@
+from . import voxel_generator
diff --git a/det3d/core/input/voxel_generator.py b/det3d/core/input/voxel_generator.py
new file mode 100644
index 0000000..4164469
--- /dev/null
+++ b/det3d/core/input/voxel_generator.py
@@ -0,0 +1,46 @@
+import numpy as np
+from det3d.ops.point_cloud.point_cloud_ops import points_to_voxel
+
+
+class VoxelGenerator:
+ def __init__(self, voxel_size, point_cloud_range, max_num_points, max_voxels=20000):
+ point_cloud_range = np.array(point_cloud_range, dtype=np.float32)
+ # [0, -40, -3, 70.4, 40, 1]
+ voxel_size = np.array(voxel_size, dtype=np.float32)
+ grid_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size
+ grid_size = np.round(grid_size).astype(np.int64)
+
+ self._voxel_size = voxel_size
+ self._point_cloud_range = point_cloud_range
+ self._max_num_points = max_num_points
+ self._max_voxels = max_voxels
+ self._grid_size = grid_size
+
+ def generate(self, points, max_voxels=-1):
+ if max_voxels == -1:
+ max_voxels=self._max_voxels
+
+ return points_to_voxel(
+ points,
+ self._voxel_size,
+ self._point_cloud_range,
+ self._max_num_points,
+ True,
+ max_voxels,
+ )
+
+ @property
+ def voxel_size(self):
+ return self._voxel_size
+
+ @property
+ def max_num_points_per_voxel(self):
+ return self._max_num_points
+
+ @property
+ def point_cloud_range(self):
+ return self._point_cloud_range
+
+ @property
+ def grid_size(self):
+ return self._grid_size
diff --git a/det3d/core/sampler/__init__.py b/det3d/core/sampler/__init__.py
new file mode 100644
index 0000000..7c3af1e
--- /dev/null
+++ b/det3d/core/sampler/__init__.py
@@ -0,0 +1,2 @@
+from . import preprocess
+from . import sample_ops
diff --git a/det3d/core/sampler/preprocess.py b/det3d/core/sampler/preprocess.py
new file mode 100644
index 0000000..f21cb09
--- /dev/null
+++ b/det3d/core/sampler/preprocess.py
@@ -0,0 +1,999 @@
+import abc
+import sys
+import time
+from collections import OrderedDict
+from functools import reduce
+
+import numba
+import numpy as np
+
+from det3d.core.bbox import box_np_ops
+from det3d.core.bbox.geometry import (
+ is_line_segment_intersection_jit,
+ points_in_convex_polygon_3d_jit,
+ points_in_convex_polygon_jit,
+)
+import copy
+
+
+class BatchSampler:
+ def __init__(
+ self, sampled_list, name=None, epoch=None, shuffle=True, drop_reminder=False
+ ):
+ self._sampled_list = sampled_list
+ self._indices = np.arange(len(sampled_list))
+ if shuffle:
+ np.random.shuffle(self._indices)
+ self._idx = 0
+ self._example_num = len(sampled_list)
+ self._name = name
+ self._shuffle = shuffle
+ self._epoch = epoch
+ self._epoch_counter = 0
+ self._drop_reminder = drop_reminder
+
+ def _sample(self, num):
+ if self._idx + num >= self._example_num:
+ ret = self._indices[self._idx :].copy()
+ self._reset()
+ else:
+ ret = self._indices[self._idx : self._idx + num]
+ self._idx += num
+ return ret
+
+ def _reset(self):
+ # if self._name is not None:
+ # print("reset", self._name)
+ if self._shuffle:
+ np.random.shuffle(self._indices)
+ self._idx = 0
+
+ def sample(self, num):
+ indices = self._sample(num)
+ return [self._sampled_list[i] for i in indices]
+ # return np.random.choice(self._sampled_list, num)
+
+
+class DataBasePreprocessing:
+ def __call__(self, db_infos):
+ return self._preprocess(db_infos)
+
+ @abc.abstractclassmethod
+ def _preprocess(self, db_infos):
+ pass
+
+
+class DBFilterByDifficulty(DataBasePreprocessing):
+ def __init__(self, removed_difficulties, logger=None):
+ self._removed_difficulties = removed_difficulties
+ logger.info(f"{removed_difficulties}")
+
+ def _preprocess(self, db_infos):
+ new_db_infos = {}
+ for key, dinfos in db_infos.items():
+ new_db_infos[key] = [
+ info
+ for info in dinfos
+ if info["difficulty"] not in self._removed_difficulties
+ ]
+ return new_db_infos
+
+
+class DBFilterByMinNumPoint(DataBasePreprocessing):
+ def __init__(self, min_gt_point_dict, logger=None):
+ self._min_gt_point_dict = min_gt_point_dict
+ logger.info(f"{min_gt_point_dict}")
+
+ def _preprocess(self, db_infos):
+ for name, min_num in self._min_gt_point_dict.items():
+ if min_num > 0:
+ filtered_infos = []
+ for info in db_infos[name]:
+ if info["num_points_in_gt"] >= min_num:
+ filtered_infos.append(info)
+ db_infos[name] = filtered_infos
+ return db_infos
+
+
+class DataBasePreprocessor:
+ def __init__(self, preprocessors):
+ self._preprocessors = preprocessors
+
+ def __call__(self, db_infos):
+ for prepor in self._preprocessors:
+ db_infos = prepor(db_infos)
+ return db_infos
+
+
+def filter_gt_box_outside_range(gt_boxes, limit_range):
+ """remove gtbox outside training range.
+ this function should be applied after other prep functions
+ Args:
+ gt_boxes ([type]): [description]
+ limit_range ([type]): [description]
+ """
+ gt_boxes_bv = box_np_ops.center_to_corner_box2d(
+ gt_boxes[:, [0, 1]], gt_boxes[:, [3, 3 + 1]], gt_boxes[:, -1]
+ )
+ bounding_box = box_np_ops.minmax_to_corner_2d(
+ np.asarray(limit_range)[np.newaxis, ...]
+ )
+ ret = points_in_convex_polygon_jit(gt_boxes_bv.reshape(-1, 2), bounding_box)
+ return np.any(ret.reshape(-1, 4), axis=1)
+
+
+def filter_gt_box_outside_range_by_center(gt_boxes, limit_range):
+ """remove gtbox outside training range.
+ this function should be applied after other prep functions
+ Args:
+ gt_boxes ([type]): [description]
+ limit_range ([type]): [description]
+ """
+ gt_box_centers = gt_boxes[:, :2]
+ bounding_box = box_np_ops.minmax_to_corner_2d(
+ np.asarray(limit_range)[np.newaxis, ...]
+ )
+ ret = points_in_convex_polygon_jit(gt_box_centers, bounding_box)
+ return ret.reshape(-1)
+
+
+def filter_gt_low_points(gt_boxes, points, num_gt_points, point_num_threshold=2):
+ points_mask = np.ones([points.shape[0]], np.bool)
+ gt_boxes_mask = np.ones([gt_boxes.shape[0]], np.bool)
+ for i, num in enumerate(num_gt_points):
+ if num <= point_num_threshold:
+ masks = box_np_ops.points_in_rbbox(points, gt_boxes[i : i + 1])
+ masks = masks.reshape([-1])
+ points_mask &= np.logical_not(masks)
+ gt_boxes_mask[i] = False
+ return gt_boxes[gt_boxes_mask], points[points_mask]
+
+
+def mask_points_in_corners(points, box_corners):
+ surfaces = box_np_ops.corner_to_surfaces_3d(box_corners)
+ mask = points_in_convex_polygon_3d_jit(points[:, :3], surfaces)
+ return mask
+
+
+@numba.njit
+def _rotation_matrix_3d_(rot_mat_T, angle, axis):
+ rot_sin = np.sin(angle)
+ rot_cos = np.cos(angle)
+ rot_mat_T[:] = np.eye(3)
+ if axis == 1:
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 2] = -rot_sin
+ rot_mat_T[2, 0] = rot_sin
+ rot_mat_T[2, 2] = rot_cos
+ elif axis == 2 or axis == -1:
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 1] = -rot_sin
+ rot_mat_T[1, 0] = rot_sin
+ rot_mat_T[1, 1] = rot_cos
+ elif axis == 0:
+ rot_mat_T[1, 1] = rot_cos
+ rot_mat_T[1, 2] = -rot_sin
+ rot_mat_T[2, 1] = rot_sin
+ rot_mat_T[2, 2] = rot_cos
+
+
+@numba.njit
+def _rotation_box2d_jit_(corners, angle, rot_mat_T):
+ rot_sin = np.sin(angle)
+ rot_cos = np.cos(angle)
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 1] = -rot_sin
+ rot_mat_T[1, 0] = rot_sin
+ rot_mat_T[1, 1] = rot_cos
+ corners[:] = corners @ rot_mat_T
+
+
+@numba.jit(nopython=True)
+def _box_single_to_corner_jit(boxes):
+ num_box = boxes.shape[0]
+ corners_norm = np.zeros((4, 2), dtype=boxes.dtype)
+ corners_norm[1, 1] = 1.0
+ corners_norm[2] = 1.0
+ corners_norm[3, 0] = 1.0
+ corners_norm -= np.array([0.5, 0.5], dtype=boxes.dtype)
+ corners = boxes.reshape(num_box, 1, 5)[:, :, 2:4] * corners_norm.reshape(1, 4, 2)
+ rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
+ box_corners = np.zeros((num_box, 4, 2), dtype=boxes.dtype)
+ for i in range(num_box):
+ rot_sin = np.sin(boxes[i, -1])
+ rot_cos = np.cos(boxes[i, -1])
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 1] = -rot_sin
+ rot_mat_T[1, 0] = rot_sin
+ rot_mat_T[1, 1] = rot_cos
+ box_corners[i] = corners[i] @ rot_mat_T + boxes[i, :2]
+ return box_corners
+
+
+@numba.njit
+def noise_per_box(boxes, valid_mask, loc_noises, rot_noises):
+ # boxes: [N, 5]
+ # valid_mask: [N]
+ # loc_noises: [N, M, 3]
+ # rot_noises: [N, M]
+ num_boxes = boxes.shape[0]
+ num_tests = loc_noises.shape[1]
+ box_corners = box_np_ops.box2d_to_corner_jit(boxes)
+ current_corners = np.zeros((4, 2), dtype=boxes.dtype)
+ rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
+ success_mask = -np.ones((num_boxes,), dtype=np.int64)
+ # print(valid_mask)
+ for i in range(num_boxes):
+ if valid_mask[i]:
+ for j in range(num_tests):
+ current_corners[:] = box_corners[i]
+ current_corners -= boxes[i, :2]
+ _rotation_box2d_jit_(current_corners, rot_noises[i, j], rot_mat_T)
+ current_corners += boxes[i, :2] + loc_noises[i, j, :2]
+ coll_mat = box_collision_test(
+ current_corners.reshape(1, 4, 2), box_corners
+ )
+ coll_mat[0, i] = False
+ # print(coll_mat)
+ if not coll_mat.any():
+ success_mask[i] = j
+ box_corners[i] = current_corners
+ break
+ return success_mask
+
+
+@numba.njit
+def noise_per_box_group(boxes, valid_mask, loc_noises, rot_noises, group_nums):
+ # WARNING: this function need boxes to be sorted by group id.
+ # boxes: [N, 5]
+ # valid_mask: [N]
+ # loc_noises: [N, M, 3]
+ # rot_noises: [N, M]
+ num_groups = group_nums.shape[0]
+ num_boxes = boxes.shape[0]
+ num_tests = loc_noises.shape[1]
+ box_corners = box_np_ops.box2d_to_corner_jit(boxes)
+ max_group_num = group_nums.max()
+ current_corners = np.zeros((max_group_num, 4, 2), dtype=boxes.dtype)
+ rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
+ success_mask = -np.ones((num_boxes,), dtype=np.int64)
+ # print(valid_mask)
+ idx = 0
+ for num in group_nums:
+ if valid_mask[idx]:
+ for j in range(num_tests):
+ for i in range(num):
+ current_corners[i] = box_corners[i + idx]
+ current_corners[i] -= boxes[i + idx, :2]
+ _rotation_box2d_jit_(
+ current_corners[i], rot_noises[idx + i, j], rot_mat_T
+ )
+ current_corners[i] += (
+ boxes[i + idx, :2] + loc_noises[i + idx, j, :2]
+ )
+ coll_mat = box_collision_test(
+ current_corners[:num].reshape(num, 4, 2), box_corners
+ )
+ for i in range(num): # remove self-coll
+ coll_mat[i, idx : idx + num] = False
+ if not coll_mat.any():
+ for i in range(num):
+ success_mask[i + idx] = j
+ box_corners[i + idx] = current_corners[i]
+ break
+ idx += num
+ return success_mask
+
+
+@numba.njit
+def noise_per_box_group_v2_(
+ boxes, valid_mask, loc_noises, rot_noises, group_nums, global_rot_noises
+):
+ # WARNING: this function need boxes to be sorted by group id.
+ # boxes: [N, 5]
+ # valid_mask: [N]
+ # loc_noises: [N, M, 3]
+ # rot_noises: [N, M]
+ num_boxes = boxes.shape[0]
+ num_tests = loc_noises.shape[1]
+ box_corners = box_np_ops.box2d_to_corner_jit(boxes)
+ max_group_num = group_nums.max()
+ current_box = np.zeros((1, 5), dtype=boxes.dtype)
+ current_corners = np.zeros((max_group_num, 4, 2), dtype=boxes.dtype)
+ dst_pos = np.zeros((max_group_num, 2), dtype=boxes.dtype)
+
+ current_grot = np.zeros((max_group_num,), dtype=boxes.dtype)
+ dst_grot = np.zeros((max_group_num,), dtype=boxes.dtype)
+
+ rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
+ success_mask = -np.ones((num_boxes,), dtype=np.int64)
+ corners_norm = np.zeros((4, 2), dtype=boxes.dtype)
+ corners_norm[1, 1] = 1.0
+ corners_norm[2] = 1.0
+ corners_norm[3, 0] = 1.0
+ corners_norm -= np.array([0.5, 0.5], dtype=boxes.dtype)
+ corners_norm = corners_norm.reshape(4, 2)
+
+ # print(valid_mask)
+ idx = 0
+ for num in group_nums:
+ if valid_mask[idx]:
+ for j in range(num_tests):
+ for i in range(num):
+ current_box[0, :] = boxes[i + idx]
+ current_radius = np.sqrt(
+ current_box[0, 0] ** 2 + current_box[0, 1] ** 2
+ )
+ current_grot[i] = np.arctan2(current_box[0, 0], current_box[0, 1])
+ dst_grot[i] = current_grot[i] + global_rot_noises[idx + i, j]
+ dst_pos[i, 0] = current_radius * np.sin(dst_grot[i])
+ dst_pos[i, 1] = current_radius * np.cos(dst_grot[i])
+ current_box[0, :2] = dst_pos[i]
+ current_box[0, -1] += dst_grot[i] - current_grot[i]
+
+ rot_sin = np.sin(current_box[0, -1])
+ rot_cos = np.cos(current_box[0, -1])
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 1] = -rot_sin
+ rot_mat_T[1, 0] = rot_sin
+ rot_mat_T[1, 1] = rot_cos
+ current_corners[i] = (
+ current_box[0, 2:4] * corners_norm @ rot_mat_T
+ + current_box[0, :2]
+ )
+ current_corners[i] -= current_box[0, :2]
+
+ _rotation_box2d_jit_(
+ current_corners[i], rot_noises[idx + i, j], rot_mat_T
+ )
+ current_corners[i] += (
+ current_box[0, :2] + loc_noises[i + idx, j, :2]
+ )
+ coll_mat = box_collision_test(
+ current_corners[:num].reshape(num, 4, 2), box_corners
+ )
+ for i in range(num): # remove self-coll
+ coll_mat[i, idx : idx + num] = False
+ if not coll_mat.any():
+ for i in range(num):
+ success_mask[i + idx] = j
+ box_corners[i + idx] = current_corners[i]
+ loc_noises[i + idx, j, :2] += dst_pos[i] - boxes[i + idx, :2]
+ rot_noises[i + idx, j] += dst_grot[i] - current_grot[i]
+ break
+ idx += num
+ return success_mask
+
+
+@numba.njit
+def noise_per_box_v2_(boxes, valid_mask, loc_noises, rot_noises, global_rot_noises):
+ # boxes: [N, 5]
+ # valid_mask: [N]
+ # loc_noises: [N, M, 3]
+ # rot_noises: [N, M]
+ num_boxes = boxes.shape[0]
+ num_tests = loc_noises.shape[1]
+ box_corners = box_np_ops.box2d_to_corner_jit(boxes)
+ current_corners = np.zeros((4, 2), dtype=boxes.dtype)
+ current_box = np.zeros((1, 5), dtype=boxes.dtype)
+ rot_mat_T = np.zeros((2, 2), dtype=boxes.dtype)
+ dst_pos = np.zeros((2,), dtype=boxes.dtype)
+ success_mask = -np.ones((num_boxes,), dtype=np.int64)
+ corners_norm = np.zeros((4, 2), dtype=boxes.dtype)
+ corners_norm[1, 1] = 1.0
+ corners_norm[2] = 1.0
+ corners_norm[3, 0] = 1.0
+ corners_norm -= np.array([0.5, 0.5], dtype=boxes.dtype)
+ corners_norm = corners_norm.reshape(4, 2)
+ for i in range(num_boxes):
+ if valid_mask[i]:
+ for j in range(num_tests):
+ current_box[0, :] = boxes[i]
+ current_radius = np.sqrt(boxes[i, 0] ** 2 + boxes[i, 1] ** 2)
+ current_grot = np.arctan2(boxes[i, 0], boxes[i, 1])
+ dst_grot = current_grot + global_rot_noises[i, j]
+ dst_pos[0] = current_radius * np.sin(dst_grot)
+ dst_pos[1] = current_radius * np.cos(dst_grot)
+ current_box[0, :2] = dst_pos
+ current_box[0, -1] += dst_grot - current_grot
+
+ rot_sin = np.sin(current_box[0, -1])
+ rot_cos = np.cos(current_box[0, -1])
+ rot_mat_T[0, 0] = rot_cos
+ rot_mat_T[0, 1] = -rot_sin
+ rot_mat_T[1, 0] = rot_sin
+ rot_mat_T[1, 1] = rot_cos
+ current_corners[:] = (
+ current_box[0, 2:4] * corners_norm @ rot_mat_T + current_box[0, :2]
+ )
+ current_corners -= current_box[0, :2]
+ _rotation_box2d_jit_(current_corners, rot_noises[i, j], rot_mat_T)
+ current_corners += current_box[0, :2] + loc_noises[i, j, :2]
+ coll_mat = box_collision_test(
+ current_corners.reshape(1, 4, 2), box_corners
+ )
+ coll_mat[0, i] = False
+ if not coll_mat.any():
+ success_mask[i] = j
+ box_corners[i] = current_corners
+ loc_noises[i, j, :2] += dst_pos - boxes[i, :2]
+ rot_noises[i, j] += dst_grot - current_grot
+ break
+ return success_mask
+
+
+@numba.njit
+def points_transform_(
+ points, centers, point_masks, loc_transform, rot_transform, valid_mask
+):
+ num_box = centers.shape[0]
+ num_points = points.shape[0]
+ rot_mat_T = np.zeros((num_box, 3, 3), dtype=points.dtype)
+ for i in range(num_box):
+ _rotation_matrix_3d_(rot_mat_T[i], rot_transform[i], 2)
+ for i in range(num_points):
+ for j in range(num_box):
+ if valid_mask[j]:
+ if point_masks[i, j] == 1:
+ points[i, :3] -= centers[j, :3]
+ points[i : i + 1, :3] = points[i : i + 1, :3] @ rot_mat_T[j]
+ points[i, :3] += centers[j, :3]
+ points[i, :3] += loc_transform[j]
+ break # only apply first box's transform
+
+
+@numba.njit
+def box3d_transform_(boxes, loc_transform, rot_transform, valid_mask):
+ num_box = boxes.shape[0]
+ for i in range(num_box):
+ if valid_mask[i]:
+ boxes[i, :3] += loc_transform[i]
+ boxes[i, 6] += rot_transform[i]
+
+
+def _select_transform(transform, indices):
+ result = np.zeros((transform.shape[0], *transform.shape[2:]), dtype=transform.dtype)
+ for i in range(transform.shape[0]):
+ if indices[i] != -1:
+ result[i] = transform[i, indices[i]]
+ return result
+
+
+@numba.njit
+def group_transform_(loc_noise, rot_noise, locs, rots, group_center, valid_mask):
+ # loc_noise: [N, M, 3], locs: [N, 3]
+ # rot_noise: [N, M]
+ # group_center: [N, 3]
+ num_try = loc_noise.shape[1]
+ r = 0.0
+ x = 0.0
+ y = 0.0
+ rot_center = 0.0
+ for i in range(loc_noise.shape[0]):
+ if valid_mask[i]:
+ x = locs[i, 0] - group_center[i, 0]
+ y = locs[i, 1] - group_center[i, 1]
+ r = np.sqrt(x ** 2 + y ** 2)
+ # calculate rots related to group center
+ rot_center = np.arctan2(x, y)
+ for j in range(num_try):
+ loc_noise[i, j, 0] += r * (
+ np.sin(rot_center + rot_noise[i, j]) - np.sin(rot_center)
+ )
+ loc_noise[i, j, 1] += r * (
+ np.cos(rot_center + rot_noise[i, j]) - np.cos(rot_center)
+ )
+
+
+@numba.njit
+def group_transform_v2_(
+ loc_noise, rot_noise, locs, rots, group_center, grot_noise, valid_mask
+):
+ # loc_noise: [N, M, 3], locs: [N, 3]
+ # rot_noise: [N, M]
+ # group_center: [N, 3]
+ num_try = loc_noise.shape[1]
+ r = 0.0
+ x = 0.0
+ y = 0.0
+ rot_center = 0.0
+ for i in range(loc_noise.shape[0]):
+ if valid_mask[i]:
+ x = locs[i, 0] - group_center[i, 0]
+ y = locs[i, 1] - group_center[i, 1]
+ r = np.sqrt(x ** 2 + y ** 2)
+ # calculate rots related to group center
+ rot_center = np.arctan2(x, y)
+ for j in range(num_try):
+ loc_noise[i, j, 0] += r * (
+ np.sin(rot_center + rot_noise[i, j] + grot_noise[i, j])
+ - np.sin(rot_center + grot_noise[i, j])
+ )
+ loc_noise[i, j, 1] += r * (
+ np.cos(rot_center + rot_noise[i, j] + grot_noise[i, j])
+ - np.cos(rot_center + grot_noise[i, j])
+ )
+
+
+def set_group_noise_same_(loc_noise, rot_noise, group_ids):
+ gid_to_index_dict = {}
+ for i, gid in enumerate(group_ids):
+ if gid not in gid_to_index_dict:
+ gid_to_index_dict[gid] = i
+ for i in range(loc_noise.shape[0]):
+ loc_noise[i] = loc_noise[gid_to_index_dict[group_ids[i]]]
+ rot_noise[i] = rot_noise[gid_to_index_dict[group_ids[i]]]
+
+
+def set_group_noise_same_v2_(loc_noise, rot_noise, grot_noise, group_ids):
+ gid_to_index_dict = {}
+ for i, gid in enumerate(group_ids):
+ if gid not in gid_to_index_dict:
+ gid_to_index_dict[gid] = i
+ for i in range(loc_noise.shape[0]):
+ loc_noise[i] = loc_noise[gid_to_index_dict[group_ids[i]]]
+ rot_noise[i] = rot_noise[gid_to_index_dict[group_ids[i]]]
+ grot_noise[i] = grot_noise[gid_to_index_dict[group_ids[i]]]
+
+
+def get_group_center(locs, group_ids):
+ num_groups = 0
+ group_centers = np.zeros_like(locs)
+ group_centers_ret = np.zeros_like(locs)
+ group_id_dict = {}
+ group_id_num_dict = OrderedDict()
+ for i, gid in enumerate(group_ids):
+ if gid >= 0:
+ if gid in group_id_dict:
+ group_centers[group_id_dict[gid]] += locs[i]
+ group_id_num_dict[gid] += 1
+ else:
+ group_id_dict[gid] = num_groups
+ num_groups += 1
+ group_id_num_dict[gid] = 1
+ group_centers[group_id_dict[gid]] = locs[i]
+ for i, gid in enumerate(group_ids):
+ group_centers_ret[i] = (
+ group_centers[group_id_dict[gid]] / group_id_num_dict[gid]
+ )
+ return group_centers_ret, group_id_num_dict
+
+
+def noise_per_object_v3_(
+ gt_boxes,
+ points=None,
+ valid_mask=None,
+ rotation_perturb=np.pi / 4,
+ center_noise_std=1.0,
+ global_random_rot_range=np.pi / 4,
+ num_try=5,
+ group_ids=None,
+):
+ """random rotate or remove each groundtrutn independently.
+ use kitti viewer to test this function points_transform_
+
+ Args:
+ gt_boxes: [N, 7], gt box in lidar.points_transform_
+ points: [M, 4], point cloud in lidar.
+ """
+ num_boxes = gt_boxes.shape[0]
+ if not isinstance(rotation_perturb, (list, tuple, np.ndarray)):
+ rotation_perturb = [-rotation_perturb, rotation_perturb]
+ if not isinstance(global_random_rot_range, (list, tuple, np.ndarray)):
+ global_random_rot_range = [-global_random_rot_range, global_random_rot_range]
+ enable_grot = (
+ np.abs(global_random_rot_range[0] - global_random_rot_range[1]) >= 1e-3
+ )
+ if not isinstance(center_noise_std, (list, tuple, np.ndarray)):
+ center_noise_std = [center_noise_std, center_noise_std, center_noise_std]
+ if valid_mask is None:
+ valid_mask = np.ones((num_boxes,), dtype=np.bool_)
+ center_noise_std = np.array(center_noise_std, dtype=gt_boxes.dtype)
+ loc_noises = np.random.normal(scale=center_noise_std, size=[num_boxes, num_try, 3])
+ # loc_noises = np.random.uniform(
+ # -center_noise_std, center_noise_std, size=[num_boxes, num_try, 3])
+ rot_noises = np.random.uniform(
+ rotation_perturb[0], rotation_perturb[1], size=[num_boxes, num_try]
+ )
+ gt_grots = np.arctan2(gt_boxes[:, 0], gt_boxes[:, 1])
+ grot_lowers = global_random_rot_range[0] - gt_grots
+ grot_uppers = global_random_rot_range[1] - gt_grots
+ global_rot_noises = np.random.uniform(
+ grot_lowers[..., np.newaxis],
+ grot_uppers[..., np.newaxis],
+ size=[num_boxes, num_try],
+ )
+ if group_ids is not None:
+ if enable_grot:
+ set_group_noise_same_v2_(
+ loc_noises, rot_noises, global_rot_noises, group_ids
+ )
+ else:
+ set_group_noise_same_(loc_noises, rot_noises, group_ids)
+ group_centers, group_id_num_dict = get_group_center(gt_boxes[:, :3], group_ids)
+ if enable_grot:
+ group_transform_v2_(
+ loc_noises,
+ rot_noises,
+ gt_boxes[:, :3],
+ gt_boxes[:, 6],
+ group_centers,
+ global_rot_noises,
+ valid_mask,
+ )
+ else:
+ group_transform_(
+ loc_noises,
+ rot_noises,
+ gt_boxes[:, :3],
+ gt_boxes[:, 6],
+ group_centers,
+ valid_mask,
+ )
+ group_nums = np.array(list(group_id_num_dict.values()), dtype=np.int64)
+
+ origin = [0.5, 0.5, 0.5]
+ gt_box_corners = box_np_ops.center_to_corner_box3d(
+ gt_boxes[:, :3], gt_boxes[:, 3:6], gt_boxes[:, 6], origin=origin, axis=2
+ )
+ if group_ids is not None:
+ if not enable_grot:
+ selected_noise = noise_per_box_group(
+ gt_boxes[:, [0, 1, 3, 4, 6]],
+ valid_mask,
+ loc_noises,
+ rot_noises,
+ group_nums,
+ )
+ else:
+ selected_noise = noise_per_box_group_v2_(
+ gt_boxes[:, [0, 1, 3, 4, 6]],
+ valid_mask,
+ loc_noises,
+ rot_noises,
+ group_nums,
+ global_rot_noises,
+ )
+ else:
+ if not enable_grot:
+ selected_noise = noise_per_box(
+ gt_boxes[:, [0, 1, 3, 4, 6]], valid_mask, loc_noises, rot_noises
+ )
+ else:
+ selected_noise = noise_per_box_v2_(
+ gt_boxes[:, [0, 1, 3, 4, 6]],
+ valid_mask,
+ loc_noises,
+ rot_noises,
+ global_rot_noises,
+ )
+ loc_transforms = _select_transform(loc_noises, selected_noise)
+ rot_transforms = _select_transform(rot_noises, selected_noise)
+ surfaces = box_np_ops.corner_to_surfaces_3d_jit(gt_box_corners)
+ if points is not None:
+ point_masks = points_in_convex_polygon_3d_jit(points[:, :3], surfaces)
+ points_transform_(
+ points,
+ gt_boxes[:, :3],
+ point_masks,
+ loc_transforms,
+ rot_transforms,
+ valid_mask,
+ )
+
+ box3d_transform_(gt_boxes, loc_transforms, rot_transforms, valid_mask)
+
+
+def noise_per_object_v2_(
+ gt_boxes,
+ points=None,
+ valid_mask=None,
+ rotation_perturb=np.pi / 4,
+ center_noise_std=1.0,
+ global_random_rot_range=np.pi / 4,
+ num_try=100,
+):
+ """random rotate or remove each groundtrutn independently.
+ use kitti viewer to test this function points_transform_
+
+ Args:
+ gt_boxes: [N, 7], gt box in lidar.points_transform_
+ points: [M, 4], point cloud in lidar.
+ """
+ num_boxes = gt_boxes.shape[0]
+ if not isinstance(rotation_perturb, (list, tuple, np.ndarray)):
+ rotation_perturb = [-rotation_perturb, rotation_perturb]
+ if not isinstance(global_random_rot_range, (list, tuple, np.ndarray)):
+ global_random_rot_range = [-global_random_rot_range, global_random_rot_range]
+
+ if not isinstance(center_noise_std, (list, tuple, np.ndarray)):
+ center_noise_std = [center_noise_std, center_noise_std, center_noise_std]
+ if valid_mask is None:
+ valid_mask = np.ones((num_boxes,), dtype=np.bool_)
+ center_noise_std = np.array(center_noise_std, dtype=gt_boxes.dtype)
+ loc_noises = np.random.normal(scale=center_noise_std, size=[num_boxes, num_try, 3])
+ # loc_noises = np.random.uniform(
+ # -center_noise_std, center_noise_std, size=[num_boxes, num_try, 3])
+ rot_noises = np.random.uniform(
+ rotation_perturb[0], rotation_perturb[1], size=[num_boxes, num_try]
+ )
+ gt_grots = np.arctan2(gt_boxes[:, 0], gt_boxes[:, 1])
+ grot_lowers = global_random_rot_range[0] - gt_grots
+ grot_uppers = global_random_rot_range[1] - gt_grots
+ global_rot_noises = np.random.uniform(
+ grot_lowers[..., np.newaxis],
+ grot_uppers[..., np.newaxis],
+ size=[num_boxes, num_try],
+ )
+
+ origin = [0.5, 0.5, 0]
+ gt_box_corners = box_np_ops.center_to_corner_box3d(
+ gt_boxes[:, :3], gt_boxes[:, 3:6], gt_boxes[:, 6], origin=origin, axis=2
+ )
+ if np.abs(global_random_rot_range[0] - global_random_rot_range[1]) < 1e-3:
+ selected_noise = noise_per_box(
+ gt_boxes[:, [0, 1, 3, 4, 6]], valid_mask, loc_noises, rot_noises
+ )
+ else:
+ selected_noise = noise_per_box_v2_(
+ gt_boxes[:, [0, 1, 3, 4, 6]],
+ valid_mask,
+ loc_noises,
+ rot_noises,
+ global_rot_noises,
+ )
+ loc_transforms = _select_transform(loc_noises, selected_noise)
+ rot_transforms = _select_transform(rot_noises, selected_noise)
+ if points is not None:
+ surfaces = box_np_ops.corner_to_surfaces_3d_jit(gt_box_corners)
+ point_masks = points_in_convex_polygon_3d_jit(points[:, :3], surfaces)
+ points_transform_(
+ points,
+ gt_boxes[:, :3],
+ point_masks,
+ loc_transforms,
+ rot_transforms,
+ valid_mask,
+ )
+
+ box3d_transform_(gt_boxes, loc_transforms, rot_transforms, valid_mask)
+
+
+def global_scaling(gt_boxes, points, scale=0.05):
+ if not isinstance(scale, list):
+ scale = [-scale, scale]
+ noise_scale = np.random.uniform(scale[0] + 1, scale[1] + 1)
+ points[:, :3] *= noise_scale
+ gt_boxes[:, :6] *= noise_scale
+ return gt_boxes, points
+
+
+def global_rotation(gt_boxes, points, rotation=np.pi / 4):
+ if not isinstance(rotation, list):
+ rotation = [-rotation, rotation]
+ noise_rotation = np.random.uniform(rotation[0], rotation[1])
+ points[:, :3] = box_np_ops.rotation_points_single_angle(
+ points[:, :3], noise_rotation, axis=2
+ )
+ gt_boxes[:, :3] = box_np_ops.rotation_points_single_angle(
+ gt_boxes[:, :3], noise_rotation, axis=2
+ )
+ if gt_boxes.shape[1] > 7:
+ gt_boxes[:, 6:8] = box_np_ops.rotation_points_single_angle(
+ np.hstack([gt_boxes[:, 6:8], np.zeros((gt_boxes.shape[0], 1))]),
+ noise_rotation,
+ axis=2,
+ )[:, :2]
+ gt_boxes[:, -1] += noise_rotation
+ return gt_boxes, points
+
+
+def random_flip(gt_boxes, points, probability=0.5):
+ enable = np.random.choice(
+ [False, True], replace=False, p=[1 - probability, probability]
+ )
+ if enable:
+ gt_boxes[:, 1] = -gt_boxes[:, 1]
+ gt_boxes[:, -1] = -gt_boxes[:, -1] + np.pi
+ points[:, 1] = -points[:, 1]
+ if gt_boxes.shape[1] > 7: # y axis: x, y, z, w, h, l, vx, vy, r
+ gt_boxes[:, 7] = -gt_boxes[:, 7]
+ return gt_boxes, points
+
+def random_flip_both(gt_boxes, points, probability=0.5, flip_coor=None):
+ # x flip
+ enable = np.random.choice(
+ [False, True], replace=False, p=[1 - probability, probability]
+ )
+ if enable:
+ gt_boxes[:, 1] = -gt_boxes[:, 1]
+ gt_boxes[:, -1] = -gt_boxes[:, -1] + np.pi
+ points[:, 1] = -points[:, 1]
+ if gt_boxes.shape[1] > 7: # y axis: x, y, z, w, h, l, vx, vy, r
+ gt_boxes[:, 7] = -gt_boxes[:, 7]
+
+ # y flip
+ enable = np.random.choice(
+ [False, True], replace=False, p=[1 - probability, probability]
+ )
+ if enable:
+ if flip_coor is None:
+ gt_boxes[:, 0] = -gt_boxes[:, 0]
+ points[:, 0] = -points[:, 0]
+ else:
+ gt_boxes[:, 0] = flip_coor * 2 - gt_boxes[:, 0]
+ points[:, 0] = flip_coor * 2 - points[:, 0]
+
+ gt_boxes[:, -1] = -gt_boxes[:, -1] + 2*np.pi # TODO: CHECK THIS
+
+ if gt_boxes.shape[1] > 7: # y axis: x, y, z, w, h, l, vx, vy, r
+ gt_boxes[:, 6] = -gt_boxes[:, 6]
+
+ return gt_boxes, points
+
+
+def global_scaling_v2(gt_boxes, points, min_scale=0.95, max_scale=1.05):
+ noise_scale = np.random.uniform(min_scale, max_scale)
+ points[:, :3] *= noise_scale
+ gt_boxes[:, :-1] *= noise_scale
+ return gt_boxes, points
+
+
+def global_rotation_v2(gt_boxes, points, min_rad=-np.pi / 4, max_rad=np.pi / 4):
+ noise_rotation = np.random.uniform(min_rad, max_rad)
+ points[:, :3] = box_np_ops.rotation_points_single_angle(
+ points[:, :3], noise_rotation, axis=2
+ )
+ gt_boxes[:, :3] = box_np_ops.rotation_points_single_angle(
+ gt_boxes[:, :3], noise_rotation, axis=2
+ )
+ gt_boxes[:, -1] += noise_rotation
+ return gt_boxes, points
+
+
+@numba.jit(nopython=True)
+def box_collision_test(boxes, qboxes, clockwise=True):
+ N = boxes.shape[0]
+ K = qboxes.shape[0]
+ ret = np.zeros((N, K), dtype=np.bool_)
+ slices = np.array([1, 2, 3, 0])
+ lines_boxes = np.stack(
+ (boxes, boxes[:, slices, :]), axis=2
+ ) # [N, 4, 2(line), 2(xy)]
+ lines_qboxes = np.stack((qboxes, qboxes[:, slices, :]), axis=2)
+ # vec = np.zeros((2,), dtype=boxes.dtype)
+ boxes_standup = box_np_ops.corner_to_standup_nd_jit(boxes)
+ qboxes_standup = box_np_ops.corner_to_standup_nd_jit(qboxes)
+ for i in range(N):
+ for j in range(K):
+ # calculate standup first
+ iw = min(boxes_standup[i, 2], qboxes_standup[j, 2]) - max(
+ boxes_standup[i, 0], qboxes_standup[j, 0]
+ )
+ if iw > 0:
+ ih = min(boxes_standup[i, 3], qboxes_standup[j, 3]) - max(
+ boxes_standup[i, 1], qboxes_standup[j, 1]
+ )
+ if ih > 0:
+ for k in range(4):
+ for l in range(4):
+ A = lines_boxes[i, k, 0]
+ B = lines_boxes[i, k, 1]
+ C = lines_qboxes[j, l, 0]
+ D = lines_qboxes[j, l, 1]
+ acd = (D[1] - A[1]) * (C[0] - A[0]) > (C[1] - A[1]) * (
+ D[0] - A[0]
+ )
+ bcd = (D[1] - B[1]) * (C[0] - B[0]) > (C[1] - B[1]) * (
+ D[0] - B[0]
+ )
+ if acd != bcd:
+ abc = (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (
+ C[0] - A[0]
+ )
+ abd = (D[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (
+ D[0] - A[0]
+ )
+ if abc != abd:
+ ret[i, j] = True # collision.
+ break
+ if ret[i, j] is True:
+ break
+ if ret[i, j] is False:
+ # now check complete overlap.
+ # box overlap qbox:
+ box_overlap_qbox = True
+ for l in range(4): # point l in qboxes
+ for k in range(4): # corner k in boxes
+ vec = boxes[i, k] - boxes[i, (k + 1) % 4]
+ if clockwise:
+ vec = -vec
+ cross = vec[1] * (boxes[i, k, 0] - qboxes[j, l, 0])
+ cross -= vec[0] * (boxes[i, k, 1] - qboxes[j, l, 1])
+ if cross >= 0:
+ box_overlap_qbox = False
+ break
+ if box_overlap_qbox is False:
+ break
+
+ if box_overlap_qbox is False:
+ qbox_overlap_box = True
+ for l in range(4): # point l in boxes
+ for k in range(4): # corner k in qboxes
+ vec = qboxes[j, k] - qboxes[j, (k + 1) % 4]
+ if clockwise:
+ vec = -vec
+ cross = vec[1] * (qboxes[j, k, 0] - boxes[i, l, 0])
+ cross -= vec[0] * (qboxes[j, k, 1] - boxes[i, l, 1])
+ if cross >= 0: #
+ qbox_overlap_box = False
+ break
+ if qbox_overlap_box is False:
+ break
+ if qbox_overlap_box:
+ ret[i, j] = True # collision.
+ else:
+ ret[i, j] = True # collision.
+ return ret
+
+
+def global_translate_(gt_boxes, points, noise_translate_std):
+ """
+ Apply global translation to gt_boxes and points.
+ """
+
+ if not isinstance(noise_translate_std, (list, tuple, np.ndarray)):
+ noise_translate_std = np.array(
+ [noise_translate_std, noise_translate_std, noise_translate_std]
+ )
+ if all([e == 0 for e in noise_translate_std]):
+ return gt_boxes, points
+ noise_translate = np.array(
+ [
+ np.random.normal(0, noise_translate_std[0], 1),
+ np.random.normal(0, noise_translate_std[1], 1),
+ np.random.normal(0, noise_translate_std[0], 1),
+ ]
+ ).T
+
+ points[:, :3] += noise_translate
+ gt_boxes[:, :3] += noise_translate
+
+ return gt_boxes, points
+
+def global_translate_v2(gt_boxes, points, noise_translate):
+ """
+ Apply global translation to gt_boxes and points.
+ """
+
+ if not isinstance(noise_translate, (list, tuple, np.ndarray)):
+ noise_translate = np.array(
+ [noise_translate, noise_translate]
+ )
+ if all([e == 0 for e in noise_translate]):
+ return gt_boxes, points
+ noise_translate = np.array(
+ [
+ np.random.uniform(-noise_translate[0],noise_translate[0]),
+ np.random.uniform(-noise_translate[1],noise_translate[1]),
+ ]
+ ).T
+
+ points[:, :2] += noise_translate
+ gt_boxes[:, :2] += noise_translate
+
+ return gt_boxes, points
+
+
+if __name__ == "__main__":
+ bboxes = np.array(
+ [
+ [0.0, 0.0, 0.5, 0.5],
+ [0.2, 0.2, 0.6, 0.6],
+ [0.7, 0.7, 0.9, 0.9],
+ [0.55, 0.55, 0.8, 0.8],
+ ]
+ )
+ bbox_corners = box_np_ops.minmax_to_corner_2d(bboxes)
+ print(bbox_corners.shape)
+ print(box_collision_test(bbox_corners, bbox_corners))
diff --git a/det3d/core/sampler/sample_ops.py b/det3d/core/sampler/sample_ops.py
new file mode 100644
index 0000000..d507460
--- /dev/null
+++ b/det3d/core/sampler/sample_ops.py
@@ -0,0 +1,369 @@
+import copy
+import pathlib
+import pickle
+import time
+from functools import partial, reduce
+
+import numpy as np
+from det3d.core.bbox import box_np_ops
+from det3d.core.sampler import preprocess as prep
+from det3d.utils.check import shape_mergeable
+
+
+class DataBaseSamplerV2:
+ def __init__(
+ self,
+ db_infos,
+ groups,
+ db_prepor=None,
+ rate=1.0,
+ global_rot_range=None,
+ logger=None,
+ ):
+ for k, v in db_infos.items():
+ logger.info(f"load {len(v)} {k} database infos")
+
+ if db_prepor is not None:
+ db_infos = db_prepor(db_infos)
+ logger.info("After filter database:")
+ for k, v in db_infos.items():
+ logger.info(f"load {len(v)} {k} database infos")
+
+ self.db_infos = db_infos
+ self._rate = rate
+ self._groups = groups
+ self._group_db_infos = {}
+ self._group_name_to_names = []
+ self._sample_classes = []
+ self._sample_max_nums = []
+ self._use_group_sampling = False # slower
+ if any([len(g) > 1 for g in groups]):
+ self._use_group_sampling = True
+ if not self._use_group_sampling:
+ self._group_db_infos = self.db_infos # just use db_infos
+ for group_info in groups:
+ group_names = list(group_info.keys())
+ self._sample_classes += group_names
+ self._sample_max_nums += list(group_info.values())
+ else:
+ for group_info in groups:
+ group_dict = {}
+ group_names = list(group_info.keys())
+ group_name = ", ".join(group_names)
+ self._sample_classes += group_names
+ self._sample_max_nums += list(group_info.values())
+ self._group_name_to_names.append((group_name, group_names))
+ # self._group_name_to_names[group_name] = group_names
+ for name in group_names:
+ for item in db_infos[name]:
+ gid = item["group_id"]
+ if gid not in group_dict:
+ group_dict[gid] = [item]
+ else:
+ group_dict[gid] += [item]
+ if group_name in self._group_db_infos:
+ raise ValueError("group must be unique")
+ group_data = list(group_dict.values())
+ self._group_db_infos[group_name] = group_data
+ info_dict = {}
+ if len(group_info) > 1:
+ for group in group_data:
+ names = [item["name"] for item in group]
+ names = sorted(names)
+ group_name = ", ".join(names)
+ if group_name in info_dict:
+ info_dict[group_name] += 1
+ else:
+ info_dict[group_name] = 1
+ print(info_dict)
+
+ self._sampler_dict = {}
+ for k, v in self._group_db_infos.items():
+ self._sampler_dict[k] = prep.BatchSampler(v, k)
+ self._enable_global_rot = False
+ if global_rot_range is not None:
+ if not isinstance(global_rot_range, (list, tuple, np.ndarray)):
+ global_rot_range = [-global_rot_range, global_rot_range]
+ else:
+ assert shape_mergeable(global_rot_range, [2])
+ if np.abs(global_rot_range[0] - global_rot_range[1]) >= 1e-3:
+ self._enable_global_rot = True
+ self._global_rot_range = global_rot_range
+
+ @property
+ def use_group_sampling(self):
+ return self._use_group_sampling
+
+ def sample_all(
+ self,
+ root_path,
+ gt_boxes,
+ gt_names,
+ num_point_features,
+ random_crop=False,
+ gt_group_ids=None,
+ calib=None,
+ road_planes=None,
+ ):
+ sampled_num_dict = {}
+ sample_num_per_class = []
+ for class_name, max_sample_num in zip(
+ self._sample_classes, self._sample_max_nums
+ ):
+ sampled_num = int(
+ max_sample_num - np.sum([n == class_name for n in gt_names])
+ )
+
+ sampled_num = np.round(self._rate * sampled_num).astype(np.int64)
+ sampled_num_dict[class_name] = sampled_num
+ sample_num_per_class.append(sampled_num)
+
+ sampled_groups = self._sample_classes
+ if self._use_group_sampling:
+ assert gt_group_ids is not None
+ sampled_groups = []
+ sample_num_per_class = []
+ for group_name, class_names in self._group_name_to_names:
+ sampled_nums_group = [sampled_num_dict[n] for n in class_names]
+ sampled_num = np.max(sampled_nums_group)
+ sample_num_per_class.append(sampled_num)
+ sampled_groups.append(group_name)
+ total_group_ids = gt_group_ids
+ sampled = []
+ sampled_gt_boxes = []
+ avoid_coll_boxes = gt_boxes
+
+ for class_name, sampled_num in zip(sampled_groups, sample_num_per_class):
+ if sampled_num > 0:
+ if self._use_group_sampling:
+ sampled_cls = self.sample_group(
+ class_name, sampled_num, avoid_coll_boxes, total_group_ids
+ )
+ else:
+ sampled_cls = self.sample_class_v2(
+ class_name, sampled_num, avoid_coll_boxes
+ )
+
+ sampled += sampled_cls
+ if len(sampled_cls) > 0:
+ if len(sampled_cls) == 1:
+ sampled_gt_box = sampled_cls[0]["box3d_lidar"][np.newaxis, ...]
+ else:
+ sampled_gt_box = np.stack(
+ [s["box3d_lidar"] for s in sampled_cls], axis=0
+ )
+
+ sampled_gt_boxes += [sampled_gt_box]
+ avoid_coll_boxes = np.concatenate(
+ [avoid_coll_boxes, sampled_gt_box], axis=0
+ )
+ if self._use_group_sampling:
+ if len(sampled_cls) == 1:
+ sampled_group_ids = np.array(sampled_cls[0]["group_id"])[
+ np.newaxis, ...
+ ]
+ else:
+ sampled_group_ids = np.stack(
+ [s["group_id"] for s in sampled_cls], axis=0
+ )
+ total_group_ids = np.concatenate(
+ [total_group_ids, sampled_group_ids], axis=0
+ )
+
+ if len(sampled) > 0:
+ sampled_gt_boxes = np.concatenate(sampled_gt_boxes, axis=0)
+
+ num_sampled = len(sampled)
+ s_points_list = []
+ for info in sampled:
+ try:
+ s_points = np.fromfile(
+ str(pathlib.Path(root_path) / info["path"]), dtype=np.float32
+ ).reshape(-1, num_point_features)
+
+ if "rot_transform" in info:
+ rot = info["rot_transform"]
+ s_points[:, :3] = box_np_ops.rotation_points_single_angle(
+ s_points[:, :4], rot, axis=2
+ )
+ s_points[:, :3] += info["box3d_lidar"][:3]
+ s_points_list.append(s_points)
+ # print(pathlib.Path(info["path"]).stem)
+ except Exception:
+ print(str(pathlib.Path(root_path) / info["path"]))
+ continue
+ if random_crop:
+ s_points_list_new = []
+ assert calib is not None
+ rect = calib["rect"]
+ Trv2c = calib["Trv2c"]
+ P2 = calib["P2"]
+ gt_bboxes = box_np_ops.box3d_to_bbox(sampled_gt_boxes, rect, Trv2c, P2)
+ crop_frustums = prep.random_crop_frustum(gt_bboxes, rect, Trv2c, P2)
+ for i in range(crop_frustums.shape[0]):
+ s_points = s_points_list[i]
+ mask = prep.mask_points_in_corners(
+ s_points, crop_frustums[i : i + 1]
+ ).reshape(-1)
+ num_remove = np.sum(mask)
+ if num_remove > 0 and (s_points.shape[0] - num_remove) > 15:
+ s_points = s_points[np.logical_not(mask)]
+ s_points_list_new.append(s_points)
+ s_points_list = s_points_list_new
+ ret = {
+ "gt_names": np.array([s["name"] for s in sampled]),
+ "difficulty": np.array([s["difficulty"] for s in sampled]),
+ "gt_boxes": sampled_gt_boxes,
+ "points": np.concatenate(s_points_list, axis=0),
+ "gt_masks": np.ones((num_sampled,), dtype=np.bool_),
+ }
+ if self._use_group_sampling:
+ ret["group_ids"] = np.array([s["group_id"] for s in sampled])
+ else:
+ ret["group_ids"] = np.arange(
+ gt_boxes.shape[0], gt_boxes.shape[0] + len(sampled)
+ )
+ else:
+ ret = None
+ return ret
+
+ def sample(self, name, num):
+ if self._use_group_sampling:
+ group_name = name
+ ret = self._sampler_dict[group_name].sample(num)
+ groups_num = [len(l) for l in ret]
+ return reduce(lambda x, y: x + y, ret), groups_num
+ else:
+ ret = self._sampler_dict[name].sample(num)
+ return ret, np.ones((len(ret),), dtype=np.int64)
+
+ def sample_v1(self, name, num):
+ if isinstance(name, (list, tuple)):
+ group_name = ", ".join(name)
+ ret = self._sampler_dict[group_name].sample(num)
+ groups_num = [len(l) for l in ret]
+ return reduce(lambda x, y: x + y, ret), groups_num
+ else:
+ ret = self._sampler_dict[name].sample(num)
+ return ret, np.ones((len(ret),), dtype=np.int64)
+
+ def sample_class_v2(self, name, num, gt_boxes):
+ sampled = self._sampler_dict[name].sample(num)
+ sampled = copy.deepcopy(sampled)
+ num_gt = gt_boxes.shape[0]
+ num_sampled = len(sampled)
+ gt_boxes_bv = box_np_ops.center_to_corner_box2d(
+ gt_boxes[:, 0:2], gt_boxes[:, 3:5], gt_boxes[:, -1]
+ )
+
+ sp_boxes = np.stack([i["box3d_lidar"] for i in sampled], axis=0)
+
+ valid_mask = np.zeros([gt_boxes.shape[0]], dtype=np.bool_)
+ valid_mask = np.concatenate(
+ [valid_mask, np.ones([sp_boxes.shape[0]], dtype=np.bool_)], axis=0
+ )
+ boxes = np.concatenate([gt_boxes, sp_boxes], axis=0).copy()
+ if self._enable_global_rot:
+ # place samples to any place in a circle.
+ prep.noise_per_object_v3_(
+ boxes, None, valid_mask, 0, 0, self._global_rot_range, num_try=100
+ )
+
+ sp_boxes_new = boxes[gt_boxes.shape[0] :]
+ sp_boxes_bv = box_np_ops.center_to_corner_box2d(
+ sp_boxes_new[:, 0:2], sp_boxes_new[:, 3:5], sp_boxes_new[:, -1]
+ )
+
+ total_bv = np.concatenate([gt_boxes_bv, sp_boxes_bv], axis=0)
+ # coll_mat = collision_test_allbox(total_bv)
+ coll_mat = prep.box_collision_test(total_bv, total_bv)
+ diag = np.arange(total_bv.shape[0])
+ coll_mat[diag, diag] = False
+
+ valid_samples = []
+ for i in range(num_gt, num_gt + num_sampled):
+ if coll_mat[i].any():
+ coll_mat[i] = False
+ coll_mat[:, i] = False
+ else:
+ if self._enable_global_rot:
+ sampled[i - num_gt]["box3d_lidar"][:2] = boxes[i, :2]
+ sampled[i - num_gt]["box3d_lidar"][-1] = boxes[i, -1]
+ sampled[i - num_gt]["rot_transform"] = (
+ boxes[i, -1] - sp_boxes[i - num_gt, -1]
+ )
+ valid_samples.append(sampled[i - num_gt])
+ return valid_samples
+
+ def sample_group(self, name, num, gt_boxes, gt_group_ids):
+ sampled, group_num = self.sample(name, num)
+ sampled = copy.deepcopy(sampled)
+ # rewrite sampled group id to avoid duplicated with gt group ids
+ gid_map = {}
+ max_gt_gid = np.max(gt_group_ids)
+ sampled_gid = max_gt_gid + 1
+ for s in sampled:
+ gid = s["group_id"]
+ if gid in gid_map:
+ s["group_id"] = gid_map[gid]
+ else:
+ gid_map[gid] = sampled_gid
+ s["group_id"] = sampled_gid
+ sampled_gid += 1
+
+ num_gt = gt_boxes.shape[0]
+ gt_boxes_bv = box_np_ops.center_to_corner_box2d(
+ gt_boxes[:, 0:2], gt_boxes[:, 3:5], gt_boxes[:, -1]
+ )
+
+ sp_boxes = np.stack([i["box3d_lidar"] for i in sampled], axis=0)
+ sp_group_ids = np.stack([i["group_id"] for i in sampled], axis=0)
+ valid_mask = np.zeros([gt_boxes.shape[0]], dtype=np.bool_)
+ valid_mask = np.concatenate(
+ [valid_mask, np.ones([sp_boxes.shape[0]], dtype=np.bool_)], axis=0
+ )
+ boxes = np.concatenate([gt_boxes, sp_boxes], axis=0).copy()
+ group_ids = np.concatenate([gt_group_ids, sp_group_ids], axis=0)
+ if self._enable_global_rot:
+ # place samples to any place in a circle.
+ prep.noise_per_object_v3_(
+ boxes,
+ None,
+ valid_mask,
+ 0,
+ 0,
+ self._global_rot_range,
+ group_ids=group_ids,
+ num_try=100,
+ )
+ sp_boxes_new = boxes[gt_boxes.shape[0] :]
+ sp_boxes_bv = box_np_ops.center_to_corner_box2d(
+ sp_boxes_new[:, 0:2], sp_boxes_new[:, 3:5], sp_boxes_new[:, -1]
+ )
+ total_bv = np.concatenate([gt_boxes_bv, sp_boxes_bv], axis=0)
+ # coll_mat = collision_test_allbox(total_bv)
+ coll_mat = prep.box_collision_test(total_bv, total_bv)
+ diag = np.arange(total_bv.shape[0])
+ coll_mat[diag, diag] = False
+ valid_samples = []
+ idx = num_gt
+ for num in group_num:
+ if coll_mat[idx : idx + num].any():
+ coll_mat[idx : idx + num] = False
+ coll_mat[:, idx : idx + num] = False
+ else:
+ for i in range(num):
+ if self._enable_global_rot:
+ sampled[idx - num_gt + i]["box3d_lidar"][:2] = boxes[
+ idx + i, :2
+ ]
+ sampled[idx - num_gt + i]["box3d_lidar"][-1] = boxes[
+ idx + i, -1
+ ]
+ sampled[idx - num_gt + i]["rot_transform"] = (
+ boxes[idx + i, -1] - sp_boxes[idx + i - num_gt, -1]
+ )
+
+ valid_samples.append(sampled[idx - num_gt + i])
+ idx += num
+ return valid_samples
diff --git a/det3d/core/utils/__init__.py b/det3d/core/utils/__init__.py
new file mode 100644
index 0000000..154357c
--- /dev/null
+++ b/det3d/core/utils/__init__.py
@@ -0,0 +1,4 @@
+from .dist_utils import *
+from .misc import *
+from .center_utils import *
+from .circle_nms_jit import *
\ No newline at end of file
diff --git a/det3d/core/utils/center_utils.py b/det3d/core/utils/center_utils.py
new file mode 100644
index 0000000..8edc421
--- /dev/null
+++ b/det3d/core/utils/center_utils.py
@@ -0,0 +1,121 @@
+# ------------------------------------------------------------------------------
+# Copyright (c) Microsoft
+# Licensed under the MIT License.
+# Written by Bin Xiao (Bin.Xiao@microsoft.com)
+# Modified by Xingyi Zhou and Tianwei Yin
+# ------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import torch
+from torch import nn
+from .circle_nms_jit import circle_nms
+
+def gaussian_radius(det_size, min_overlap=0.5):
+ height, width = det_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = np.sqrt(b1 ** 2 - 4 * a1 * c1)
+ r1 = (b1 + sq1) / 2
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = np.sqrt(b2 ** 2 - 4 * a2 * c2)
+ r2 = (b2 + sq2) / 2
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = np.sqrt(b3 ** 2 - 4 * a3 * c3)
+ r3 = (b3 + sq3) / 2
+ return min(r1, r2, r3)
+
+def gaussian2D(shape, sigma=1):
+ m, n = [(ss - 1.) / 2. for ss in shape]
+ y, x = np.ogrid[-m:m+1,-n:n+1]
+
+ h = np.exp(-(x * x + y * y) / (2 * sigma * sigma))
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
+ return h
+
+
+def draw_umich_gaussian(heatmap, center, radius, k=1):
+ diameter = 2 * radius + 1
+ gaussian = gaussian2D((diameter, diameter), sigma=diameter / 6)
+
+ x, y = int(center[0]), int(center[1])
+
+ height, width = heatmap.shape[0:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:radius + right]
+ if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0: # TODO debug
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
+ return heatmap
+
+def _gather_feat(feat, ind, mask=None):
+ dim = feat.size(2)
+ ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
+ feat = feat.gather(1, ind)
+ if mask is not None:
+ mask = mask.unsqueeze(2).expand_as(feat)
+ feat = feat[mask]
+ feat = feat.view(-1, dim)
+ return feat
+
+def _transpose_and_gather_feat(feat, ind):
+ feat = feat.permute(0, 2, 3, 1).contiguous()
+ feat = feat.view(feat.size(0), -1, feat.size(3))
+ feat = _gather_feat(feat, ind)
+ return feat
+
+def _circle_nms(boxes, min_radius, post_max_size=83):
+ """
+ NMS according to center distance
+ """
+ keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]
+
+ keep = torch.from_numpy(keep).long().to(boxes.device)
+
+ return keep
+
+
+def bilinear_interpolate_torch(im, x, y):
+ """
+ Args:
+ im: (H, W, C) [y, x]
+ x: (N)
+ y: (N)
+ Returns:
+ """
+ x0 = torch.floor(x).long()
+ x1 = x0 + 1
+
+ y0 = torch.floor(y).long()
+ y1 = y0 + 1
+
+ x0 = torch.clamp(x0, 0, im.shape[1] - 1)
+ x1 = torch.clamp(x1, 0, im.shape[1] - 1)
+ y0 = torch.clamp(y0, 0, im.shape[0] - 1)
+ y1 = torch.clamp(y1, 0, im.shape[0] - 1)
+
+ Ia = im[y0, x0]
+ Ib = im[y1, x0]
+ Ic = im[y0, x1]
+ Id = im[y1, x1]
+
+ wa = (x1.type_as(x) - x) * (y1.type_as(y) - y)
+ wb = (x1.type_as(x) - x) * (y - y0.type_as(y))
+ wc = (x - x0.type_as(x)) * (y1.type_as(y) - y)
+ wd = (x - x0.type_as(x)) * (y - y0.type_as(y))
+ ans = torch.t((torch.t(Ia) * wa)) + torch.t(torch.t(Ib) * wb) + torch.t(torch.t(Ic) * wc) + torch.t(torch.t(Id) * wd)
+ return ans
diff --git a/det3d/core/utils/circle_nms_jit.py b/det3d/core/utils/circle_nms_jit.py
new file mode 100644
index 0000000..d19cf1a
--- /dev/null
+++ b/det3d/core/utils/circle_nms_jit.py
@@ -0,0 +1,28 @@
+import numba
+import numpy as np
+
+@numba.jit(nopython=True)
+def circle_nms(dets, thresh):
+ x1 = dets[:, 0]
+ y1 = dets[:, 1]
+ scores = dets[:, 2]
+ order = scores.argsort()[::-1].astype(np.int32) # highest->lowest
+ ndets = dets.shape[0]
+ suppressed = np.zeros((ndets), dtype=np.int32)
+ keep = []
+ for _i in range(ndets):
+ i = order[_i] # start with highest score box
+ if suppressed[i] == 1: # if any box have enough iou with this, remove it
+ continue
+ keep.append(i)
+ for _j in range(_i + 1, ndets):
+ j = order[_j]
+ if suppressed[j] == 1:
+ continue
+ # calculate center distance between i and j box
+ dist = (x1[i]-x1[j])**2 + (y1[i]-y1[j])**2
+
+ # ovr = inter / areas[j]
+ if dist <= thresh:
+ suppressed[j] = 1
+ return keep
diff --git a/det3d/core/utils/dist_utils.py b/det3d/core/utils/dist_utils.py
new file mode 100644
index 0000000..68f6670
--- /dev/null
+++ b/det3d/core/utils/dist_utils.py
@@ -0,0 +1,57 @@
+from collections import OrderedDict
+
+import torch.distributed as dist
+from det3d.torchie.trainer import OptimizerHook
+from torch._utils import _flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)
+ ):
+ tensor.copy_(synced)
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ grads = [
+ param.grad.data
+ for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ world_size = dist.get_world_size()
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+class DistOptimizerHook(OptimizerHook):
+ def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+
+ def after_train_iter(self, runner):
+ runner.optimizer.zero_grad()
+ runner.outputs["loss"].backward()
+ allreduce_grads(runner.model.parameters(), self.coalesce, self.bucket_size_mb)
+ if self.grad_clip is not None:
+ self.clip_grads(runner.model.parameters())
+ runner.optimizer.step()
diff --git a/det3d/core/utils/misc.py b/det3d/core/utils/misc.py
new file mode 100644
index 0000000..f65f5c4
--- /dev/null
+++ b/det3d/core/utils/misc.py
@@ -0,0 +1,36 @@
+from functools import partial
+
+import numpy as np
+from det3d import torchie
+from six.moves import map, zip
+
+
+def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True):
+ num_imgs = tensor.size(0)
+ mean = np.array(mean, dtype=np.float32)
+ std = np.array(std, dtype=np.float32)
+ imgs = []
+ for img_id in range(num_imgs):
+ img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0)
+ img = torchie.imdenormalize(img, mean, std, to_bgr=to_rgb).astype(np.uint8)
+ imgs.append(np.ascontiguousarray(img))
+ return imgs
+
+
+def multi_apply(func, *args, **kwargs):
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+def unmap(data, count, inds, fill=0):
+ """ Unmap a subset of item (data) back to the original set of items (of
+ size count) """
+ if data.dim() == 1:
+ ret = data.new_full((count,), fill)
+ ret[inds] = data
+ else:
+ new_size = (count,) + data.size()[1:]
+ ret = data.new_full(new_size, fill)
+ ret[inds, :] = data
+ return ret
diff --git a/det3d/core/utils/scatter.py b/det3d/core/utils/scatter.py
new file mode 100644
index 0000000..6ee13b7
--- /dev/null
+++ b/det3d/core/utils/scatter.py
@@ -0,0 +1,60 @@
+# The following code are copied from pytorch_scatter https://github.com/rusty1s/pytorch_scatter
+# Copyright (c) 2020 Matthias Fey
+# MIT License
+from typing import Optional, Tuple
+import torch
+
+@torch.jit.script
+def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
+ if dim < 0:
+ dim = other.dim() + dim
+ if src.dim() == 1:
+ for _ in range(dim):
+ src = src.unsqueeze(0)
+ for _ in range(other.dim()-src.dim()):
+ src = src.unsqueeze(-1)
+ src = src.expand_as(other)
+ return src
+
+@torch.jit.script
+def scatter_sum(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+ index = broadcast(index, src, dim)
+ if out is None:
+ size = list(src.size())
+ if dim_size is not None:
+ size[dim] = dim_size
+ elif index.numel() == 0:
+ size[dim] = 0
+ else:
+ size[dim] = int(index.max()) + 1
+ out = torch.zeros(size, dtype=src.dtype, device=src.device)
+ return out.scatter_add_(dim, index, src)
+ else:
+ return out.scatter_add_(dim, index, src)
+
+@torch.jit.script
+def scatter_mean(src: torch.Tensor, index: torch.Tensor, dim: int = -1,
+ out: Optional[torch.Tensor] = None,
+ dim_size: Optional[int] = None) -> torch.Tensor:
+
+ out = scatter_sum(src, index, dim, out, dim_size)
+ dim_size = out.size(dim)
+
+ index_dim = dim
+ if index_dim < 0:
+ index_dim = index_dim + src.dim()
+ if index.dim() <= index_dim:
+ index_dim = index.dim() - 1
+
+ ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
+ count = scatter_sum(ones, index, index_dim, None, dim_size)
+ count.clamp_(1)
+ count = broadcast(count, out, dim)
+ if torch.is_floating_point(out):
+ out.div_(count)
+ else:
+ assert 0
+ # out.floor_divide_(count)
+ return out
\ No newline at end of file
diff --git a/det3d/datasets/__init__.py b/det3d/datasets/__init__.py
new file mode 100644
index 0000000..1c0f34a
--- /dev/null
+++ b/det3d/datasets/__init__.py
@@ -0,0 +1,27 @@
+from .builder import build_dataset
+
+# from .cityscapes import CityscapesDataset
+from .nuscenes import NuScenesDataset
+from .waymo import WaymoDataset
+
+# from .custom import CustomDataset
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+
+# from .extra_aug import ExtraAugmentation
+from .loader import DistributedGroupSampler, GroupSampler, build_dataloader
+from .registry import DATASETS
+
+# from .voc import VOCDataset
+# from .wider_face import WIDERFaceDataset
+# from .xml_style import XMLDataset
+#
+__all__ = [
+ "CustomDataset",
+ "GroupSampler",
+ "DistributedGroupSampler",
+ "build_dataloader",
+ "ConcatDataset",
+ "RepeatDataset",
+ "DATASETS",
+ "build_dataset",
+]
diff --git a/det3d/datasets/builder.py b/det3d/datasets/builder.py
new file mode 100644
index 0000000..9405d9f
--- /dev/null
+++ b/det3d/datasets/builder.py
@@ -0,0 +1,43 @@
+import copy
+
+from det3d.utils import build_from_cfg
+
+from .dataset_wrappers import ConcatDataset, RepeatDataset
+from .registry import DATASETS
+
+
+def _concat_dataset(cfg, default_args=None):
+ ann_files = cfg["ann_file"]
+ img_prefixes = cfg.get("img_prefix", None)
+ seg_prefixes = cfg.get("seg_prefixes", None)
+ proposal_files = cfg.get("proposal_file", None)
+
+ datasets = []
+ num_dset = len(ann_files)
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ data_cfg["ann_file"] = ann_files[i]
+ if isinstance(img_prefixes, (list, tuple)):
+ data_cfg["img_prefix"] = img_prefixes[i]
+ if isinstance(seg_prefixes, (list, tuple)):
+ data_cfg["seg_prefix"] = seg_prefixes[i]
+ if isinstance(proposal_files, (list, tuple)):
+ data_cfg["proposal_file"] = proposal_files[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets)
+
+
+def build_dataset(cfg, default_args=None):
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg["type"] == "RepeatDataset":
+ dataset = RepeatDataset(
+ build_dataset(cfg["dataset"], default_args), cfg["times"]
+ )
+ # elif isinstance(cfg['ann_file'], (list, tuple)):
+ # dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
diff --git a/det3d/datasets/custom.py b/det3d/datasets/custom.py
new file mode 100644
index 0000000..1df6ff3
--- /dev/null
+++ b/det3d/datasets/custom.py
@@ -0,0 +1,190 @@
+import os.path as osp
+from pathlib import Path
+
+import numpy as np
+from torch.utils.data import Dataset
+
+from .registry import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module
+class PointCloudDataset(Dataset):
+ """An abstract class representing a pytorch-like Dataset.
+ All other datasets should subclass it. All subclasses should override
+ ``__len__``, that provides the size of the dataset, and ``__getitem__``,
+ supporting integer indexing in range from 0 to len(self) exclusive.
+ """
+
+ NumPointFeatures = -1
+ CLASSES = None
+
+ def __init__(
+ self,
+ root_path,
+ info_path,
+ pipeline=None,
+ test_mode=False,
+ class_names=None,
+ **kwrags
+ ):
+ self._info_path = info_path
+ self._root_path = Path(root_path)
+ self._class_names = class_names
+
+ self.test_mode = test_mode
+
+ self._set_group_flag()
+
+ if pipeline is None:
+ self.pipeline = None
+ else:
+ self.pipeline = Compose(pipeline)
+
+ def __getitem__(self, index):
+ """This function is used for preprocess.
+ you need to create a input dict in this function for network inference.
+ format: {
+ anchors
+ voxels
+ num_points
+ coordinates
+ if training:
+ labels
+ reg_targets
+ [optional]anchors_mask, slow in SECOND v1.5, don't use this.
+ [optional]metadata, in kitti, image index is saved in metadata
+ }
+ """
+ raise NotImplementedError
+
+ def __len__(self):
+ raise NotImplementedError
+
+ def get_sensor_data(self, query):
+ """Dataset must provide a unified function to get data.
+ Args:
+ query: int or dict. this param must support int for training.
+ if dict, should have this format (no example yet):
+ {
+ sensor_name: {
+ sensor_meta
+ }
+ }
+ if int, will return all sensor data.
+ (TODO: how to deal with unsynchronized data?)
+ Returns:
+ sensor_data: dict.
+ if query is int (return all), return a dict with all sensors:
+ {
+ sensor_name: sensor_data
+ ...
+ metadata: ... (for kitti, contains image_idx)
+ }
+
+ if sensor is lidar (all lidar point cloud must be concatenated to one array):
+ e.g. If your dataset have two lidar sensor, you need to return a single dict:
+ {
+ "lidar": {
+ "points": ...
+ ...
+ }
+ }
+ sensor_data: {
+ points: [N, 3+]
+ [optional]annotations: {
+ "boxes": [N, 7] locs, dims, yaw, in lidar coord system. must tested
+ in provided visualization tools such as second.utils.simplevis
+ or web tool.
+ "names": array of string.
+ }
+ }
+ if sensor is camera (not used yet):
+ sensor_data: {
+ data: image string (array is too large)
+ [optional]annotations: {
+ "boxes": [N, 4] 2d bbox
+ "names": array of string.
+ }
+ }
+ metadata: {
+ # dataset-specific information.
+ # for kitti, must have image_idx for label file generation.
+ image_idx: ...
+ }
+ [optional]calib # only used for kitti
+ """
+ raise NotImplementedError
+
+ def evaluation(self, dt_annos, output_dir):
+ """Dataset must provide a evaluation function to evaluate model."""
+ raise NotImplementedError
+
+ @property
+ def ground_truth_annotations(self):
+ """
+ If you want to eval by my KITTI eval function, you must
+ provide the correct format annotations.
+ ground_truth_annotations format:
+ {
+ bbox: [N, 4], if you fill fake data, MUST HAVE >25 HEIGHT!!!!!!
+ alpha: [N], you can use -10 to ignore it.
+ occluded: [N], you can use zero.
+ truncated: [N], you can use zero.
+ name: [N]
+ location: [N, 3] center of 3d box.
+ dimensions: [N, 3] dim of 3d box.
+ rotation_y: [N] angle.
+ }
+ all fields must be filled, but some fields can fill
+ zero.
+ """
+ raise NotImplementedError
+
+ def pre_pipeline(self, results):
+ results["img_prefix"] = self.img_prefix
+ results["seg_prefix"] = self.seg_prefix
+ results["proposal_file"] = self.proposal_file
+ results["bbox_fields"] = []
+ results["mask_fields"] = []
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small."""
+ valid_inds = []
+ for i, img_info in enumerate(self.img_infos):
+ if min(img_info["width"], img_info["height"]) >= min_size:
+ valid_inds.append(i)
+ return valid_inds
+
+ def _set_group_flag(self):
+ """Set flag according to image aspect ratio.
+ Images with aspect ratio greater than 1 will be set as group 1,
+ otherwise group 0.
+ """
+ self.flag = np.ones(len(self), dtype=np.uint8)
+ # self.flag = np.zeros(len(self), dtype=np.uint8)
+ # for i in range(len(self)):
+ # img_info = self.img_infos[i]
+ # if img_info['width'] / img_info['height'] > 1:
+ # self.flag[i] = 1
+
+ def prepare_train_input(self, idx):
+ raise NotImplementedError
+
+ # img_info = self.img_infos[idx]
+ # ann_info = self.get_ann_info(idx)
+ # results = dict(img_info=img_info, ann_info=ann_info)
+ # if self.proposals is not None:
+ # results['proposals'] = self.proposals[idx]
+ # self.pre_pipeline(results)
+ # return self.pipeline(results)
+
+ def prepare_test_input(self, idx):
+ raise NotImplementedError
+
+ # img_info = self.img_infos[idx]
+ # results = dict(img_info=img_info)
+ # if self.proposals is not None:
+ # results['proposals'] = self.proposals[idx]
+ # self.pre_pipeline(results)
+ # return self.pipeline(results)
diff --git a/det3d/datasets/dataset_factory.py b/det3d/datasets/dataset_factory.py
new file mode 100644
index 0000000..225632f
--- /dev/null
+++ b/det3d/datasets/dataset_factory.py
@@ -0,0 +1,11 @@
+from .nuscenes import NuScenesDataset
+from .waymo import WaymoDataset
+
+dataset_factory = {
+ "NUSC": NuScenesDataset,
+ "WAYMO": WaymoDataset
+}
+
+
+def get_dataset(dataset_name):
+ return dataset_factory[dataset_name]
diff --git a/det3d/datasets/dataset_wrappers.py b/det3d/datasets/dataset_wrappers.py
new file mode 100644
index 0000000..2d7f17f
--- /dev/null
+++ b/det3d/datasets/dataset_wrappers.py
@@ -0,0 +1,55 @@
+import numpy as np
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .registry import DATASETS
+
+
+@DATASETS.register_module
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ """
+
+ def __init__(self, datasets):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ if hasattr(datasets[0], "flag"):
+ flags = []
+ for i in range(0, len(datasets)):
+ flags.append(datasets[i].flag)
+ self.flag = np.concatenate(flags)
+
+
+@DATASETS.register_module
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ if hasattr(self.dataset, "flag"):
+ self.flag = np.tile(self.dataset.flag, times)
+
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx):
+ return self.dataset[idx % self._ori_len]
+
+ def __len__(self):
+ return self.times * self._ori_len
diff --git a/det3d/datasets/loader/__init__.py b/det3d/datasets/loader/__init__.py
new file mode 100644
index 0000000..0d00da9
--- /dev/null
+++ b/det3d/datasets/loader/__init__.py
@@ -0,0 +1,4 @@
+from .build_loader import build_dataloader
+from .sampler import DistributedGroupSampler, GroupSampler
+
+__all__ = ["GroupSampler", "DistributedGroupSampler", "build_dataloader"]
diff --git a/det3d/datasets/loader/build_loader.py b/det3d/datasets/loader/build_loader.py
new file mode 100644
index 0000000..4e2ff9e
--- /dev/null
+++ b/det3d/datasets/loader/build_loader.py
@@ -0,0 +1,57 @@
+import platform
+from functools import partial
+
+from det3d.torchie.parallel import collate, collate_kitti
+from det3d.torchie.trainer import get_dist_info
+from torch.utils.data import DataLoader
+
+from .sampler import (
+ DistributedGroupSampler,
+ DistributedSampler,
+ DistributedSamplerV2,
+ GroupSampler,
+)
+
+if platform.system() != "Windows":
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+
+def build_dataloader(
+ dataset, batch_size, workers_per_gpu, num_gpus=1, dist=True, **kwargs
+):
+ shuffle = kwargs.get("shuffle", True)
+ if dist:
+ rank, world_size = get_dist_info()
+ # sampler = DistributedSamplerV2(dataset,
+ # num_replicas=world_size,
+ # rank=rank,
+ # shuffle=shuffle)
+ if shuffle:
+ sampler = DistributedGroupSampler(dataset, batch_size, world_size, rank)
+ else:
+ sampler = DistributedSampler(dataset, world_size, rank, shuffle=False)
+ batch_size = batch_size
+ num_workers = workers_per_gpu
+ else:
+ sampler = GroupSampler(dataset, batch_size) if shuffle else None
+ sampler = None
+ batch_size = num_gpus * batch_size
+ num_workers = num_gpus * workers_per_gpu
+
+ # TODO change pin_memory
+ data_loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ shuffle=(sampler is None),
+ num_workers=num_workers,
+ collate_fn=collate_kitti,
+ # pin_memory=True,
+ pin_memory=False,
+ )
+
+ return data_loader
diff --git a/det3d/datasets/loader/sampler.py b/det3d/datasets/loader/sampler.py
new file mode 100644
index 0000000..60ae2cf
--- /dev/null
+++ b/det3d/datasets/loader/sampler.py
@@ -0,0 +1,223 @@
+from __future__ import division
+import math
+
+import numpy as np
+import torch
+import math
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+
+from det3d.torchie.trainer import get_dist_info
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+# from torch.utils.data import Sampler
+
+
+class DistributedSamplerV2(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[: (self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
+
+
+class DistributedSampler(_DistributedSampler):
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank)
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[: (self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+
+class GroupSampler(Sampler):
+ def __init__(self, dataset, samples_per_gpu=1):
+ assert hasattr(dataset, "flag")
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.flag = dataset.flag.astype(np.int64)
+ self.group_sizes = np.bincount(self.flag)
+ self.num_samples = 0
+ for i, size in enumerate(self.group_sizes):
+ self.num_samples += (
+ int(np.ceil(size / self.samples_per_gpu)) * self.samples_per_gpu
+ )
+
+ def __iter__(self):
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size == 0:
+ continue
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ np.random.shuffle(indice)
+ num_extra = int(
+ np.ceil(size / self.samples_per_gpu)
+ ) * self.samples_per_gpu - len(indice)
+ indice = np.concatenate([indice, indice[:num_extra]])
+ indices.append(indice)
+ indices = np.concatenate(indices)
+ indices = [
+ indices[i * self.samples_per_gpu : (i + 1) * self.samples_per_gpu]
+ for i in np.random.permutation(range(len(indices) // self.samples_per_gpu))
+ ]
+ indices = np.concatenate(indices)
+ indices = indices.astype(np.int64).tolist()
+ assert len(indices) == self.num_samples
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+
+class DistributedGroupSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, samples_per_gpu=1, num_replicas=None, rank=None):
+ _rank, _num_replicas = get_dist_info()
+ if num_replicas is None:
+ num_replicas = _num_replicas
+ if rank is None:
+ rank = _rank
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+
+ assert hasattr(self.dataset, "flag")
+ self.flag = self.dataset.flag
+ self.group_sizes = np.bincount(self.flag)
+
+ self.num_samples = 0
+ for i, j in enumerate(self.group_sizes):
+ self.num_samples += (
+ int(
+ math.ceil(
+ self.group_sizes[i]
+ * 1.0
+ / self.samples_per_gpu
+ / self.num_replicas
+ )
+ )
+ * self.samples_per_gpu
+ )
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size > 0:
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ indice = indice[list(torch.randperm(int(size), generator=g))].tolist()
+ extra = int(
+ math.ceil(size * 1.0 / self.samples_per_gpu / self.num_replicas)
+ ) * self.samples_per_gpu * self.num_replicas - len(indice)
+ indice += indice[:extra]
+ indices += indice
+
+ assert len(indices) == self.total_size
+
+ indices = [
+ indices[j]
+ for i in list(
+ torch.randperm(len(indices) // self.samples_per_gpu, generator=g)
+ )
+ for j in range(i * self.samples_per_gpu, (i + 1) * self.samples_per_gpu)
+ ]
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset : offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/det3d/datasets/nuscenes/__init__.py b/det3d/datasets/nuscenes/__init__.py
new file mode 100644
index 0000000..02b035e
--- /dev/null
+++ b/det3d/datasets/nuscenes/__init__.py
@@ -0,0 +1,4 @@
+from .nuscenes import NuScenesDataset
+from .nusc_common import *
+
+__all__ = ["NuScenesDataset"]
diff --git a/det3d/datasets/nuscenes/nusc_common.py b/det3d/datasets/nuscenes/nusc_common.py
new file mode 100644
index 0000000..84e07c4
--- /dev/null
+++ b/det3d/datasets/nuscenes/nusc_common.py
@@ -0,0 +1,521 @@
+import numpy as np
+import pickle
+
+from pathlib import Path
+from functools import reduce
+from typing import List
+
+from tqdm import tqdm
+from pyquaternion import Quaternion
+
+try:
+ from nuscenes import NuScenes
+ from nuscenes.utils import splits
+ from nuscenes.utils.data_classes import Box
+ from nuscenes.eval.detection.config import config_factory
+ from nuscenes.eval.detection.evaluate import NuScenesEval
+except:
+ print("nuScenes devkit not Found!")
+
+general_to_detection = {
+ "human.pedestrian.adult": "pedestrian",
+ "human.pedestrian.child": "pedestrian",
+ "human.pedestrian.wheelchair": "ignore",
+ "human.pedestrian.stroller": "ignore",
+ "human.pedestrian.personal_mobility": "ignore",
+ "human.pedestrian.police_officer": "pedestrian",
+ "human.pedestrian.construction_worker": "pedestrian",
+ "animal": "ignore",
+ "vehicle.car": "car",
+ "vehicle.motorcycle": "motorcycle",
+ "vehicle.bicycle": "bicycle",
+ "vehicle.bus.bendy": "bus",
+ "vehicle.bus.rigid": "bus",
+ "vehicle.truck": "truck",
+ "vehicle.construction": "construction_vehicle",
+ "vehicle.emergency.ambulance": "ignore",
+ "vehicle.emergency.police": "ignore",
+ "vehicle.trailer": "trailer",
+ "movable_object.barrier": "barrier",
+ "movable_object.trafficcone": "traffic_cone",
+ "movable_object.pushable_pullable": "ignore",
+ "movable_object.debris": "ignore",
+ "static_object.bicycle_rack": "ignore",
+}
+
+cls_attr_dist = {
+ "barrier": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 0,
+ "vehicle.parked": 0,
+ "vehicle.stopped": 0,
+ },
+ "bicycle": {
+ "cycle.with_rider": 2791,
+ "cycle.without_rider": 8946,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 0,
+ "vehicle.parked": 0,
+ "vehicle.stopped": 0,
+ },
+ "bus": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 9092,
+ "vehicle.parked": 3294,
+ "vehicle.stopped": 3881,
+ },
+ "car": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 114304,
+ "vehicle.parked": 330133,
+ "vehicle.stopped": 46898,
+ },
+ "construction_vehicle": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 882,
+ "vehicle.parked": 11549,
+ "vehicle.stopped": 2102,
+ },
+ "ignore": {
+ "cycle.with_rider": 307,
+ "cycle.without_rider": 73,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 165,
+ "vehicle.parked": 400,
+ "vehicle.stopped": 102,
+ },
+ "motorcycle": {
+ "cycle.with_rider": 4233,
+ "cycle.without_rider": 8326,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 0,
+ "vehicle.parked": 0,
+ "vehicle.stopped": 0,
+ },
+ "pedestrian": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 157444,
+ "pedestrian.sitting_lying_down": 13939,
+ "pedestrian.standing": 46530,
+ "vehicle.moving": 0,
+ "vehicle.parked": 0,
+ "vehicle.stopped": 0,
+ },
+ "traffic_cone": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 0,
+ "vehicle.parked": 0,
+ "vehicle.stopped": 0,
+ },
+ "trailer": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 3421,
+ "vehicle.parked": 19224,
+ "vehicle.stopped": 1895,
+ },
+ "truck": {
+ "cycle.with_rider": 0,
+ "cycle.without_rider": 0,
+ "pedestrian.moving": 0,
+ "pedestrian.sitting_lying_down": 0,
+ "pedestrian.standing": 0,
+ "vehicle.moving": 21339,
+ "vehicle.parked": 55626,
+ "vehicle.stopped": 11097,
+ },
+}
+
+def _second_det_to_nusc_box(detection):
+ box3d = detection["box3d_lidar"].detach().cpu().numpy()
+ scores = detection["scores"].detach().cpu().numpy()
+ labels = detection["label_preds"].detach().cpu().numpy()
+ box3d[:, -1] = -box3d[:, -1] - np.pi / 2
+ box_list = []
+ for i in range(box3d.shape[0]):
+ quat = Quaternion(axis=[0, 0, 1], radians=box3d[i, -1])
+ velocity = (*box3d[i, 6:8], 0.0)
+ box = Box(
+ box3d[i, :3],
+ box3d[i, 3:6],
+ quat,
+ label=labels[i],
+ score=scores[i],
+ velocity=velocity,
+ )
+ box_list.append(box)
+ return box_list
+
+
+def _lidar_nusc_box_to_global(nusc, boxes, sample_token):
+ try:
+ s_record = nusc.get("sample", sample_token)
+ sample_data_token = s_record["data"]["LIDAR_TOP"]
+ except:
+ sample_data_token = sample_token
+
+ sd_record = nusc.get("sample_data", sample_data_token)
+ cs_record = nusc.get("calibrated_sensor", sd_record["calibrated_sensor_token"])
+ pose_record = nusc.get("ego_pose", sd_record["ego_pose_token"])
+
+ box_list = []
+ for box in boxes:
+ # Move box to ego vehicle coord system
+ box.rotate(Quaternion(cs_record["rotation"]))
+ box.translate(np.array(cs_record["translation"]))
+ # Move box to global coord system
+ box.rotate(Quaternion(pose_record["rotation"]))
+ box.translate(np.array(pose_record["translation"]))
+ box_list.append(box)
+ return box_list
+
+
+def _get_available_scenes(nusc):
+ available_scenes = []
+ print("total scene num:", len(nusc.scene))
+ for scene in nusc.scene:
+ scene_token = scene["token"]
+ scene_rec = nusc.get("scene", scene_token)
+ sample_rec = nusc.get("sample", scene_rec["first_sample_token"])
+ sd_rec = nusc.get("sample_data", sample_rec["data"]["LIDAR_TOP"])
+ has_more_frames = True
+ scene_not_exist = False
+ while has_more_frames:
+ lidar_path, boxes, _ = nusc.get_sample_data(sd_rec["token"])
+ if not Path(lidar_path).exists():
+ scene_not_exist = True
+ break
+ else:
+ break
+ if scene_not_exist:
+ continue
+ available_scenes.append(scene)
+ print("exist scene num:", len(available_scenes))
+ return available_scenes
+
+
+def get_sample_data(
+ nusc, sample_data_token: str, selected_anntokens: List[str] = None
+):
+ """
+ Returns the data path as well as all annotations related to that sample_data.
+ Note that the boxes are transformed into the current sensor's coordinate frame.
+ :param sample_data_token: Sample_data token.
+ :param selected_anntokens: If provided only return the selected annotation.
+ :return: (data_path, boxes, camera_intrinsic )
+ """
+
+ # Retrieve sensor & pose records
+ sd_record = nusc.get("sample_data", sample_data_token)
+ cs_record = nusc.get("calibrated_sensor", sd_record["calibrated_sensor_token"])
+ sensor_record = nusc.get("sensor", cs_record["sensor_token"])
+ pose_record = nusc.get("ego_pose", sd_record["ego_pose_token"])
+
+ data_path = nusc.get_sample_data_path(sample_data_token)
+
+ if sensor_record["modality"] == "camera":
+ cam_intrinsic = np.array(cs_record["camera_intrinsic"])
+ else:
+ cam_intrinsic = None
+
+ # Retrieve all sample annotations and map to sensor coordinate system.
+ if selected_anntokens is not None:
+ boxes = list(map(nusc.get_box, selected_anntokens))
+ else:
+ boxes = nusc.get_boxes(sample_data_token)
+
+ # Make list of Box objects including coord system transforms.
+ box_list = []
+ for box in boxes:
+ box.velocity = nusc.box_velocity(box.token)
+ # Move box to ego vehicle coord system
+ box.translate(-np.array(pose_record["translation"]))
+ box.rotate(Quaternion(pose_record["rotation"]).inverse)
+
+ # Move box to sensor coord system
+ box.translate(-np.array(cs_record["translation"]))
+ box.rotate(Quaternion(cs_record["rotation"]).inverse)
+
+ box_list.append(box)
+
+ return data_path, box_list, cam_intrinsic
+
+
+
+def _fill_trainval_infos(nusc, train_scenes, val_scenes, test=False, nsweeps=10, filter_zero=True):
+ from nuscenes.utils.geometry_utils import transform_matrix
+
+ train_nusc_infos = []
+ val_nusc_infos = []
+
+ ref_chan = "LIDAR_TOP" # The radar channel from which we track back n sweeps to aggregate the point cloud.
+ chan = "LIDAR_TOP" # The reference channel of the current sample_rec that the point clouds are mapped to.
+
+ for sample in tqdm(nusc.sample):
+ """ Manual save info["sweeps"] """
+ # Get reference pose and timestamp
+ # ref_chan == "LIDAR_TOP"
+ ref_sd_token = sample["data"][ref_chan]
+ ref_sd_rec = nusc.get("sample_data", ref_sd_token)
+ ref_cs_rec = nusc.get(
+ "calibrated_sensor", ref_sd_rec["calibrated_sensor_token"]
+ )
+ ref_pose_rec = nusc.get("ego_pose", ref_sd_rec["ego_pose_token"])
+ ref_time = 1e-6 * ref_sd_rec["timestamp"]
+
+ ref_lidar_path, ref_boxes, _ = get_sample_data(nusc, ref_sd_token)
+
+ ref_cam_front_token = sample["data"]["CAM_FRONT"]
+ ref_cam_path, _, ref_cam_intrinsic = nusc.get_sample_data(ref_cam_front_token)
+
+ # Homogeneous transform from ego car frame to reference frame
+ ref_from_car = transform_matrix(
+ ref_cs_rec["translation"], Quaternion(ref_cs_rec["rotation"]), inverse=True
+ )
+
+ # Homogeneous transformation matrix from global to _current_ ego car frame
+ car_from_global = transform_matrix(
+ ref_pose_rec["translation"],
+ Quaternion(ref_pose_rec["rotation"]),
+ inverse=True,
+ )
+
+ info = {
+ "lidar_path": ref_lidar_path,
+ "cam_front_path": ref_cam_path,
+ "cam_intrinsic": ref_cam_intrinsic,
+ "token": sample["token"],
+ "sweeps": [],
+ "ref_from_car": ref_from_car,
+ "car_from_global": car_from_global,
+ "timestamp": ref_time,
+ }
+
+ sample_data_token = sample["data"][chan]
+ curr_sd_rec = nusc.get("sample_data", sample_data_token)
+ sweeps = []
+ while len(sweeps) < nsweeps - 1:
+ if curr_sd_rec["prev"] == "":
+ if len(sweeps) == 0:
+ sweep = {
+ "lidar_path": ref_lidar_path,
+ "sample_data_token": curr_sd_rec["token"],
+ "transform_matrix": None,
+ "time_lag": curr_sd_rec["timestamp"] * 0,
+ # time_lag: 0,
+ }
+ sweeps.append(sweep)
+ else:
+ sweeps.append(sweeps[-1])
+ else:
+ curr_sd_rec = nusc.get("sample_data", curr_sd_rec["prev"])
+
+ # Get past pose
+ current_pose_rec = nusc.get("ego_pose", curr_sd_rec["ego_pose_token"])
+ global_from_car = transform_matrix(
+ current_pose_rec["translation"],
+ Quaternion(current_pose_rec["rotation"]),
+ inverse=False,
+ )
+
+ # Homogeneous transformation matrix from sensor coordinate frame to ego car frame.
+ current_cs_rec = nusc.get(
+ "calibrated_sensor", curr_sd_rec["calibrated_sensor_token"]
+ )
+ car_from_current = transform_matrix(
+ current_cs_rec["translation"],
+ Quaternion(current_cs_rec["rotation"]),
+ inverse=False,
+ )
+
+ tm = reduce(
+ np.dot,
+ [ref_from_car, car_from_global, global_from_car, car_from_current],
+ )
+
+ lidar_path = nusc.get_sample_data_path(curr_sd_rec["token"])
+
+ time_lag = ref_time - 1e-6 * curr_sd_rec["timestamp"]
+
+ sweep = {
+ "lidar_path": lidar_path,
+ "sample_data_token": curr_sd_rec["token"],
+ "transform_matrix": tm,
+ "global_from_car": global_from_car,
+ "car_from_current": car_from_current,
+ "time_lag": time_lag,
+ }
+ sweeps.append(sweep)
+
+ info["sweeps"] = sweeps
+
+ assert (
+ len(info["sweeps"]) == nsweeps - 1
+ ), f"sweep {curr_sd_rec['token']} only has {len(info['sweeps'])} sweeps, you should duplicate to sweep num {nsweeps-1}"
+ """ read from api """
+
+ if not test:
+ annotations = [
+ nusc.get("sample_annotation", token) for token in sample["anns"]
+ ]
+
+ mask = np.array([(anno['num_lidar_pts'] + anno['num_radar_pts'])>0 for anno in annotations], dtype=bool).reshape(-1)
+
+ locs = np.array([b.center for b in ref_boxes]).reshape(-1, 3)
+ dims = np.array([b.wlh for b in ref_boxes]).reshape(-1, 3)
+ # rots = np.array([b.orientation.yaw_pitch_roll[0] for b in ref_boxes]).reshape(-1, 1)
+ velocity = np.array([b.velocity for b in ref_boxes]).reshape(-1, 3)
+ rots = np.array([quaternion_yaw(b.orientation) for b in ref_boxes]).reshape(
+ -1, 1
+ )
+ names = np.array([b.name for b in ref_boxes])
+ tokens = np.array([b.token for b in ref_boxes])
+ gt_boxes = np.concatenate(
+ [locs, dims, velocity[:, :2], -rots - np.pi / 2], axis=1
+ )
+ # gt_boxes = np.concatenate([locs, dims, rots], axis=1)
+
+ assert len(annotations) == len(gt_boxes) == len(velocity)
+
+ if not filter_zero:
+ info["gt_boxes"] = gt_boxes
+ info["gt_boxes_velocity"] = velocity
+ info["gt_names"] = np.array([general_to_detection[name] for name in names])
+ info["gt_boxes_token"] = tokens
+ else:
+ info["gt_boxes"] = gt_boxes[mask, :]
+ info["gt_boxes_velocity"] = velocity[mask, :]
+ info["gt_names"] = np.array([general_to_detection[name] for name in names])[mask]
+ info["gt_boxes_token"] = tokens[mask]
+
+ if sample["scene_token"] in train_scenes:
+ train_nusc_infos.append(info)
+ else:
+ val_nusc_infos.append(info)
+
+ return train_nusc_infos, val_nusc_infos
+
+
+def quaternion_yaw(q: Quaternion) -> float:
+ """
+ Calculate the yaw angle from a quaternion.
+ Note that this only works for a quaternion that represents a box in lidar or global coordinate frame.
+ It does not work for a box in the camera frame.
+ :param q: Quaternion of interest.
+ :return: Yaw angle in radians.
+ """
+
+ # Project into xy plane.
+ v = np.dot(q.rotation_matrix, np.array([1, 0, 0]))
+
+ # Measure yaw using arctan.
+ yaw = np.arctan2(v[1], v[0])
+
+ return yaw
+
+
+def create_nuscenes_infos(root_path, version="v1.0-trainval", nsweeps=10, filter_zero=True):
+ nusc = NuScenes(version=version, dataroot=root_path, verbose=True)
+ available_vers = ["v1.0-trainval", "v1.0-test", "v1.0-mini"]
+ assert version in available_vers
+ if version == "v1.0-trainval":
+ train_scenes = splits.train
+ # random.shuffle(train_scenes)
+ # train_scenes = train_scenes[:int(len(train_scenes)*0.2)]
+ val_scenes = splits.val
+ elif version == "v1.0-test":
+ train_scenes = splits.test
+ val_scenes = []
+ elif version == "v1.0-mini":
+ train_scenes = splits.mini_train
+ val_scenes = splits.mini_val
+ else:
+ raise ValueError("unknown")
+ test = "test" in version
+ root_path = Path(root_path)
+ # filter exist scenes. you may only download part of dataset.
+ available_scenes = _get_available_scenes(nusc)
+ available_scene_names = [s["name"] for s in available_scenes]
+ train_scenes = list(filter(lambda x: x in available_scene_names, train_scenes))
+ val_scenes = list(filter(lambda x: x in available_scene_names, val_scenes))
+ train_scenes = set(
+ [
+ available_scenes[available_scene_names.index(s)]["token"]
+ for s in train_scenes
+ ]
+ )
+ val_scenes = set(
+ [available_scenes[available_scene_names.index(s)]["token"] for s in val_scenes]
+ )
+ if test:
+ print(f"test scene: {len(train_scenes)}")
+ else:
+ print(f"train scene: {len(train_scenes)}, val scene: {len(val_scenes)}")
+
+ train_nusc_infos, val_nusc_infos = _fill_trainval_infos(
+ nusc, train_scenes, val_scenes, test, nsweeps=nsweeps, filter_zero=filter_zero
+ )
+
+ if test:
+ print(f"test sample: {len(train_nusc_infos)}")
+ with open(
+ root_path / "infos_test_{:02d}sweeps_withvelo.pkl".format(nsweeps), "wb"
+ ) as f:
+ pickle.dump(train_nusc_infos, f)
+ else:
+ print(
+ f"train sample: {len(train_nusc_infos)}, val sample: {len(val_nusc_infos)}"
+ )
+ with open(
+ root_path / "infos_train_{:02d}sweeps_withvelo_filter_{}.pkl".format(nsweeps, filter_zero), "wb"
+ ) as f:
+ pickle.dump(train_nusc_infos, f)
+ with open(
+ root_path / "infos_val_{:02d}sweeps_withvelo_filter_{}.pkl".format(nsweeps, filter_zero), "wb"
+ ) as f:
+ pickle.dump(val_nusc_infos, f)
+
+
+def eval_main(nusc, eval_version, res_path, eval_set, output_dir):
+ # nusc = NuScenes(version=version, dataroot=str(root_path), verbose=True)
+ cfg = config_factory(eval_version)
+
+ nusc_eval = NuScenesEval(
+ nusc,
+ config=cfg,
+ result_path=res_path,
+ eval_set=eval_set,
+ output_dir=output_dir,
+ verbose=True,
+ )
+ metrics_summary = nusc_eval.main(plot_examples=10,)
diff --git a/det3d/datasets/nuscenes/nuscenes.py b/det3d/datasets/nuscenes/nuscenes.py
new file mode 100644
index 0000000..bad91b6
--- /dev/null
+++ b/det3d/datasets/nuscenes/nuscenes.py
@@ -0,0 +1,327 @@
+import sys
+import pickle
+import json
+import random
+import operator
+import numpy as np
+
+from functools import reduce
+from pathlib import Path
+from copy import deepcopy
+
+try:
+ from nuscenes.nuscenes import NuScenes
+ from nuscenes.eval.detection.config import config_factory
+except:
+ print("nuScenes devkit not found!")
+
+from det3d.datasets.custom import PointCloudDataset
+from det3d.datasets.nuscenes.nusc_common import (
+ general_to_detection,
+ cls_attr_dist,
+ _second_det_to_nusc_box,
+ _lidar_nusc_box_to_global,
+ eval_main
+)
+from det3d.datasets.registry import DATASETS
+
+
+@DATASETS.register_module
+class NuScenesDataset(PointCloudDataset):
+ NumPointFeatures = 5 # x, y, z, intensity, ring_index
+
+ def __init__(
+ self,
+ info_path,
+ root_path,
+ nsweeps=0, # here set to zero to catch unset nsweep
+ cfg=None,
+ pipeline=None,
+ class_names=None,
+ test_mode=False,
+ version="v1.0-trainval",
+ **kwargs,
+ ):
+ super(NuScenesDataset, self).__init__(
+ root_path, info_path, pipeline, test_mode=test_mode, class_names=class_names
+ )
+
+ self.nsweeps = nsweeps
+ assert self.nsweeps > 0, "At least input one sweep please!"
+ print(self.nsweeps)
+
+ self._info_path = info_path
+ self._class_names = class_names
+
+ if not hasattr(self, "_nusc_infos"):
+ self.load_infos(self._info_path)
+
+ self._num_point_features = NuScenesDataset.NumPointFeatures
+ self._name_mapping = general_to_detection
+
+ self.painted = kwargs.get('painted', False)
+ if self.painted:
+ self._num_point_features += 10
+
+ self.version = version
+ self.eval_version = "detection_cvpr_2019"
+
+ def reset(self):
+ self.logger.info(f"re-sample {self.frac} frames from full set")
+ random.shuffle(self._nusc_infos_all)
+ self._nusc_infos = self._nusc_infos_all[: self.frac]
+
+ def load_infos(self, info_path):
+
+ with open(self._info_path, "rb") as f:
+ _nusc_infos_all = pickle.load(f)
+
+ if not self.test_mode: # if training
+ self.frac = int(len(_nusc_infos_all) * 0.25)
+
+ _cls_infos = {name: [] for name in self._class_names}
+ for info in _nusc_infos_all:
+ for name in set(info["gt_names"]):
+ if name in self._class_names:
+ _cls_infos[name].append(info)
+
+ duplicated_samples = sum([len(v) for _, v in _cls_infos.items()])
+ _cls_dist = {k: len(v) / max(duplicated_samples, 1) for k, v in _cls_infos.items()}
+
+ self._nusc_infos = []
+
+ frac = 1.0 / len(self._class_names)
+ ratios = [frac / v for v in _cls_dist.values()]
+
+ for cls_infos, ratio in zip(list(_cls_infos.values()), ratios):
+ self._nusc_infos += np.random.choice(
+ cls_infos, int(len(cls_infos) * ratio)
+ ).tolist()
+
+ _cls_infos = {name: [] for name in self._class_names}
+ for info in self._nusc_infos:
+ for name in set(info["gt_names"]):
+ if name in self._class_names:
+ _cls_infos[name].append(info)
+
+ _cls_dist = {
+ k: len(v) / len(self._nusc_infos) for k, v in _cls_infos.items()
+ }
+ else:
+ if isinstance(_nusc_infos_all, dict):
+ self._nusc_infos = []
+ for v in _nusc_infos_all.values():
+ self._nusc_infos.extend(v)
+ else:
+ self._nusc_infos = _nusc_infos_all
+
+ def __len__(self):
+
+ if not hasattr(self, "_nusc_infos"):
+ self.load_infos(self._info_path)
+
+ return len(self._nusc_infos)
+
+ @property
+ def ground_truth_annotations(self):
+ if "gt_boxes" not in self._nusc_infos[0]:
+ return None
+ cls_range_map = config_factory(self.eval_version).serialize()['class_range']
+ gt_annos = []
+ for info in self._nusc_infos:
+ gt_names = np.array(info["gt_names"])
+ gt_boxes = info["gt_boxes"]
+ mask = np.array([n != "ignore" for n in gt_names], dtype=np.bool_)
+ gt_names = gt_names[mask]
+ gt_boxes = gt_boxes[mask]
+ # det_range = np.array([cls_range_map[n] for n in gt_names_mapped])
+ det_range = np.array([cls_range_map[n] for n in gt_names])
+ det_range = det_range[..., np.newaxis] @ np.array([[-1, -1, 1, 1]])
+ mask = (gt_boxes[:, :2] >= det_range[:, :2]).all(1)
+ mask &= (gt_boxes[:, :2] <= det_range[:, 2:]).all(1)
+ N = int(np.sum(mask))
+ gt_annos.append(
+ {
+ "bbox": np.tile(np.array([[0, 0, 50, 50]]), [N, 1]),
+ "alpha": np.full(N, -10),
+ "occluded": np.zeros(N),
+ "truncated": np.zeros(N),
+ "name": gt_names[mask],
+ "location": gt_boxes[mask][:, :3],
+ "dimensions": gt_boxes[mask][:, 3:6],
+ "rotation_y": gt_boxes[mask][:, 6],
+ "token": info["token"],
+ }
+ )
+ return gt_annos
+
+ def get_sensor_data(self, idx):
+
+ info = self._nusc_infos[idx]
+
+ res = {
+ "lidar": {
+ "type": "lidar",
+ "points": None,
+ "nsweeps": self.nsweeps,
+ # "ground_plane": -gp[-1] if with_gp else None,
+ "annotations": None,
+ },
+ "metadata": {
+ "image_prefix": self._root_path,
+ "num_point_features": self._num_point_features,
+ "token": info["token"],
+ },
+ "calib": None,
+ "cam": {},
+ "mode": "val" if self.test_mode else "train",
+ "painted": self.painted
+ }
+
+ data, _ = self.pipeline(res, info)
+
+ return data
+
+ def __getitem__(self, idx):
+ return self.get_sensor_data(idx)
+
+ def evaluation(self, detections, output_dir=None, testset=False):
+ version = self.version
+ eval_set_map = {
+ "v1.0-mini": "mini_val",
+ "v1.0-trainval": "val",
+ "v1.0-test": "test",
+ }
+
+ if not testset:
+ dets = []
+ gt_annos = self.ground_truth_annotations
+ assert gt_annos is not None
+
+ miss = 0
+ for gt in gt_annos:
+ try:
+ dets.append(detections[gt["token"]])
+ except Exception:
+ miss += 1
+
+ assert miss == 0
+ else:
+ dets = [v for _, v in detections.items()]
+ assert len(detections) == 6008
+
+ nusc_annos = {
+ "results": {},
+ "meta": None,
+ }
+
+ nusc = NuScenes(version=version, dataroot=str(self._root_path), verbose=True)
+
+ mapped_class_names = []
+ for n in self._class_names:
+ if n in self._name_mapping:
+ mapped_class_names.append(self._name_mapping[n])
+ else:
+ mapped_class_names.append(n)
+
+ for det in dets:
+ annos = []
+ boxes = _second_det_to_nusc_box(det)
+ boxes = _lidar_nusc_box_to_global(nusc, boxes, det["metadata"]["token"])
+ for i, box in enumerate(boxes):
+ name = mapped_class_names[box.label]
+ if np.sqrt(box.velocity[0] ** 2 + box.velocity[1] ** 2) > 0.2:
+ if name in [
+ "car",
+ "construction_vehicle",
+ "bus",
+ "truck",
+ "trailer",
+ ]:
+ attr = "vehicle.moving"
+ elif name in ["bicycle", "motorcycle"]:
+ attr = "cycle.with_rider"
+ else:
+ attr = None
+ else:
+ if name in ["pedestrian"]:
+ attr = "pedestrian.standing"
+ elif name in ["bus"]:
+ attr = "vehicle.stopped"
+ else:
+ attr = None
+
+ nusc_anno = {
+ "sample_token": det["metadata"]["token"],
+ "translation": box.center.tolist(),
+ "size": box.wlh.tolist(),
+ "rotation": box.orientation.elements.tolist(),
+ "velocity": box.velocity[:2].tolist(),
+ "detection_name": name,
+ "detection_score": box.score,
+ "attribute_name": attr
+ if attr is not None
+ else max(cls_attr_dist[name].items(), key=operator.itemgetter(1))[
+ 0
+ ],
+ }
+ annos.append(nusc_anno)
+ nusc_annos["results"].update({det["metadata"]["token"]: annos})
+
+ nusc_annos["meta"] = {
+ "use_camera": False,
+ "use_lidar": True,
+ "use_radar": False,
+ "use_map": False,
+ "use_external": False,
+ }
+
+ name = self._info_path.split("/")[-1].split(".")[0]
+ res_path = str(Path(output_dir) / Path(name + ".json"))
+ with open(res_path, "w") as f:
+ json.dump(nusc_annos, f)
+
+ print(f"Finish generate predictions for testset, save to {res_path}")
+
+ if not testset:
+ eval_main(
+ nusc,
+ self.eval_version,
+ res_path,
+ eval_set_map[self.version],
+ output_dir,
+ )
+
+ with open(Path(output_dir) / "metrics_summary.json", "r") as f:
+ metrics = json.load(f)
+
+ detail = {}
+ result = f"Nusc {version} Evaluation\n"
+ for name in mapped_class_names:
+ detail[name] = {}
+ for k, v in metrics["label_aps"][name].items():
+ detail[name][f"dist@{k}"] = v
+ threshs = ", ".join(list(metrics["label_aps"][name].keys()))
+ scores = list(metrics["label_aps"][name].values())
+ mean = sum(scores) / len(scores)
+ scores = ", ".join([f"{s * 100:.2f}" for s in scores])
+ result += f"{name} Nusc dist AP@{threshs}\n"
+ result += scores
+ result += f" mean AP: {mean}"
+ result += "\n"
+ res_nusc = {
+ "results": {"nusc": result},
+ "detail": {"nusc": detail},
+ }
+ else:
+ res_nusc = None
+
+ if res_nusc is not None:
+ res = {
+ "results": {"nusc": res_nusc["results"]["nusc"],},
+ "detail": {"eval.nusc": res_nusc["detail"]["nusc"],},
+ }
+ else:
+ res = None
+
+ return res, None
diff --git a/det3d/datasets/pipelines/__init__.py b/det3d/datasets/pipelines/__init__.py
new file mode 100644
index 0000000..c6c233b
--- /dev/null
+++ b/det3d/datasets/pipelines/__init__.py
@@ -0,0 +1,26 @@
+from .compose import Compose
+from .formating import Reformat
+
+# from .loading import LoadAnnotations, LoadImageFromFile, LoadProposals
+from .loading import *
+from .test_aug import DoubleFlip
+from .preprocess import Preprocess, Voxelization
+from .preprocess_multiframe import Preprocess_multiframe
+
+__all__ = [
+ "Compose",
+ "to_tensor",
+ "ToTensor",
+ "ImageToTensor",
+ "ToDataContainer",
+ "Transpose",
+ "Collect",
+ "LoadImageAnnotations",
+ "LoadImageFromFile",
+ "LoadProposals",
+ "PhotoMetricDistortion",
+ "Preprocess",
+ "Voxelization",
+ "AssignTarget",
+ "AssignLabel"
+]
diff --git a/det3d/datasets/pipelines/compose.py b/det3d/datasets/pipelines/compose.py
new file mode 100644
index 0000000..f9856ad
--- /dev/null
+++ b/det3d/datasets/pipelines/compose.py
@@ -0,0 +1,37 @@
+import collections
+
+from det3d.utils import build_from_cfg
+from ..registry import PIPELINES
+
+
+@PIPELINES.register_module
+class Compose(object):
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ if transform['type'] == 'Empty':
+ continue
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError("transform must be callable or a dict")
+
+ def __call__(self, res, info):
+ for t in self.transforms:
+ res, info = t(res, info)
+ if res is None:
+ return None
+ return res, info
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += "\n"
+ format_string += " {0}".format(t)
+ format_string += "\n)"
+ return format_string
+
diff --git a/det3d/datasets/pipelines/formating.py b/det3d/datasets/pipelines/formating.py
new file mode 100644
index 0000000..94f7cdc
--- /dev/null
+++ b/det3d/datasets/pipelines/formating.py
@@ -0,0 +1,105 @@
+from det3d import torchie
+import numpy as np
+import torch
+
+from ..registry import PIPELINES
+
+
+class DataBundle(object):
+ def __init__(self, data):
+ self.data = data
+
+
+@PIPELINES.register_module
+class Reformat(object):
+ def __init__(self, **kwargs):
+ double_flip = kwargs.get('double_flip', False)
+ self.double_flip = double_flip
+
+ def __call__(self, res, info):
+ meta = res["metadata"]
+
+ # voxels = res["lidar"]["voxels"]
+
+ data_bundle = dict(
+ metadata=meta,
+ points_num=res["lidar"]["points_num"]
+ # points=points,
+ # voxels=voxels["voxels"],
+ # shape=voxels["shape"],
+ # num_points=voxels["num_points"],
+ # num_voxels=voxels["num_voxels"],
+ # coordinates=voxels["coordinates"]
+ )
+
+ if 'multi_points' in res["lidar"]:
+ multi_points = res["lidar"]["multi_points"]
+ data_bundle.update(multi_points=multi_points)
+ data_bundle.update(times=res["lidar"]["times"])
+ else:
+ points = res["lidar"]["points"]
+ if points is not None:
+ data_bundle.update(points=points)
+
+ if 'voxels' in res["lidar"]:
+ voxels = res["lidar"]["voxels"]
+
+ data_bundle.update(
+ voxels=voxels["voxels"],
+ shape=voxels["shape"],
+ num_points=voxels["num_points"],
+ num_voxels=voxels["num_voxels"],
+ coordinates=voxels["coordinates"],
+ )
+
+ if res["mode"] == "train":
+ data_bundle.update(res["lidar"]["targets"])
+ elif res["mode"] == "val":
+ data_bundle.update(dict(metadata=meta, ))
+
+ if self.double_flip:
+ # y axis
+ yflip_points = res["lidar"]["yflip_points"]
+ yflip_voxels = res["lidar"]["yflip_voxels"]
+ yflip_data_bundle = dict(
+ metadata=meta,
+ points=yflip_points,
+ voxels=yflip_voxels["voxels"],
+ shape=yflip_voxels["shape"],
+ num_points=yflip_voxels["num_points"],
+ num_voxels=yflip_voxels["num_voxels"],
+ coordinates=yflip_voxels["coordinates"],
+ )
+
+ # x axis
+ xflip_points = res["lidar"]["xflip_points"]
+ xflip_voxels = res["lidar"]["xflip_voxels"]
+ xflip_data_bundle = dict(
+ metadata=meta,
+ points=xflip_points,
+ voxels=xflip_voxels["voxels"],
+ shape=xflip_voxels["shape"],
+ num_points=xflip_voxels["num_points"],
+ num_voxels=xflip_voxels["num_voxels"],
+ coordinates=xflip_voxels["coordinates"],
+ )
+ # double axis flip
+ double_flip_points = res["lidar"]["double_flip_points"]
+ double_flip_voxels = res["lidar"]["double_flip_voxels"]
+ double_flip_data_bundle = dict(
+ metadata=meta,
+ points=double_flip_points,
+ voxels=double_flip_voxels["voxels"],
+ shape=double_flip_voxels["shape"],
+ num_points=double_flip_voxels["num_points"],
+ num_voxels=double_flip_voxels["num_voxels"],
+ coordinates=double_flip_voxels["coordinates"],
+ )
+
+ return [data_bundle, yflip_data_bundle, xflip_data_bundle, double_flip_data_bundle], info
+
+
+ return data_bundle, info
+
+
+
diff --git a/det3d/datasets/pipelines/loading.py b/det3d/datasets/pipelines/loading.py
new file mode 100644
index 0000000..2704638
--- /dev/null
+++ b/det3d/datasets/pipelines/loading.py
@@ -0,0 +1,245 @@
+import os.path as osp
+import warnings
+import numpy as np
+from functools import reduce
+
+# import pycocotools.mask as maskUtils
+
+from pathlib import Path
+from copy import deepcopy
+from det3d import torchie
+from det3d.core import box_np_ops
+import pickle
+import os
+from ..registry import PIPELINES
+
+def _dict_select(dict_, inds):
+ for k, v in dict_.items():
+ if isinstance(v, dict):
+ _dict_select(v, inds)
+ else:
+ dict_[k] = v[inds]
+
+def read_file(path, tries=2, num_point_feature=4, painted=False):
+ if painted:
+ dir_path = os.path.join(*path.split('/')[:-2], 'painted_'+path.split('/')[-2])
+ painted_path = os.path.join(dir_path, path.split('/')[-1]+'.npy')
+ points = np.load(painted_path)
+ points = points[:, [0, 1, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]] # remove ring_index from features
+ else:
+ points = np.fromfile(path, dtype=np.float32).reshape(-1, 5)[:, :num_point_feature]
+
+ return points
+
+
+def remove_close(points, radius: float) -> None:
+ """
+ Removes point too close within a certain radius from origin.
+ :param radius: Radius below which points are removed.
+ """
+ x_filt = np.abs(points[0, :]) < radius
+ y_filt = np.abs(points[1, :]) < radius
+ not_close = np.logical_not(np.logical_and(x_filt, y_filt))
+ points = points[:, not_close]
+ return points
+
+
+def read_sweep(sweep, painted=False):
+ min_distance = 1.0
+ points_sweep = read_file(str(sweep["lidar_path"]), painted=painted).T
+ points_sweep = remove_close(points_sweep, min_distance)
+
+ nbr_points = points_sweep.shape[1]
+ if sweep["transform_matrix"] is not None:
+ points_sweep[:3, :] = sweep["transform_matrix"].dot(
+ np.vstack((points_sweep[:3, :], np.ones(nbr_points)))
+ )[:3, :]
+ curr_times = sweep["time_lag"] * np.ones((1, points_sweep.shape[1]))
+
+ return points_sweep.T, curr_times.T
+
+def read_single_waymo(obj):
+ points_xyz = obj["lidars"]["points_xyz"]
+ points_feature = obj["lidars"]["points_feature"]
+
+ # normalize intensity
+ points_feature[:, 0] = np.tanh(points_feature[:, 0])
+
+ points = np.concatenate([points_xyz, points_feature], axis=-1)
+
+ return points
+
+def read_single_waymo_sweep(sweep):
+ obj = get_obj(sweep['path'])
+
+ points_xyz = obj["lidars"]["points_xyz"]
+ points_feature = obj["lidars"]["points_feature"]
+
+ # normalize intensity
+ points_feature[:, 0] = np.tanh(points_feature[:, 0])
+ points_sweep = np.concatenate([points_xyz, points_feature], axis=-1).T # 5 x N
+
+ nbr_points = points_sweep.shape[1]
+
+ if sweep["transform_matrix"] is not None:
+ points_sweep[:3, :] = sweep["transform_matrix"].dot(
+ np.vstack((points_sweep[:3, :], np.ones(nbr_points)))
+ )[:3, :]
+
+ curr_times = sweep["time_lag"] * np.ones((1, points_sweep.shape[1]))
+
+ return points_sweep.T, curr_times.T
+
+
+def get_obj(path):
+ with open(path, 'rb') as f:
+ obj = pickle.load(f)
+ return obj
+
+
+@PIPELINES.register_module
+class LoadPointCloudFromFile(object):
+ def __init__(self, dataset="KittiDataset", **kwargs):
+ self.type = dataset
+ self.random_select = kwargs.get("random_select", False)
+ self.npoints = kwargs.get("npoints", 16834)
+ self.combine_frames = kwargs.get("combine", 1)
+
+ def __call__(self, res, info):
+
+ res["type"] = self.type
+
+ if self.type == "NuScenesDataset":
+
+ nsweeps = res["lidar"]["nsweeps"]
+
+ lidar_path = Path(info["lidar_path"])
+ points = read_file(str(lidar_path), painted=res["painted"])
+
+ sweep_points_list = [points]
+ sweep_times_list = [np.zeros((points.shape[0], 1))]
+
+ assert (nsweeps - 1) == len(
+ info["sweeps"]
+ ), "nsweeps {} should equal to list length {}.".format(
+ nsweeps, len(info["sweeps"])
+ )
+
+ for i in np.random.choice(len(info["sweeps"]), nsweeps - 1, replace=False):
+ sweep = info["sweeps"][i]
+ points_sweep, times_sweep = read_sweep(sweep, painted=res["painted"])
+ sweep_points_list.append(points_sweep)
+ sweep_times_list.append(times_sweep)
+
+ points = np.concatenate(sweep_points_list, axis=0)
+ times = np.concatenate(sweep_times_list, axis=0).astype(points.dtype)
+
+ res["lidar"]["points"] = points
+ res["lidar"]["times"] = times
+ res["lidar"]["combined"] = np.hstack([points, times])
+
+ elif self.type == "WaymoDataset":
+ path = info['path']
+ nsweeps = res["lidar"]["nsweeps"]
+ obj = get_obj(path)
+ points = read_single_waymo(obj)
+ res["lidar"]["points"] = points
+ res["lidar"]["points_num"] = points.shape[0]
+
+ if nsweeps > 1:
+ sweep_points_list = [points]
+ sweep_times_list = [np.zeros((points.shape[0], 1))]
+
+ assert (nsweeps - 1) == len(
+ info["sweeps"]
+ ), "nsweeps {} should be equal to the list length {}.".format(
+ nsweeps, len(info["sweeps"])
+ )
+
+ for i in range(nsweeps - 1):
+ sweep = info["sweeps"][i]
+ points_sweep, times_sweep = read_single_waymo_sweep(sweep)
+ sweep_points_list.append(points_sweep)
+ sweep_times_list.append(times_sweep)
+
+ points = np.concatenate(sweep_points_list, axis=0)
+ times = np.concatenate(sweep_times_list, axis=0).astype(points.dtype)
+
+ res["lidar"]["points"] = points
+ res["lidar"]["times"] = times
+ res["lidar"]["combined"] = np.hstack([points, times])
+ elif self.type == "WaymoDataset_multi_frame":
+ path = info['path']
+ nsweeps = res["lidar"]["nsweeps"]
+ obj = get_obj(path)
+ points = read_single_waymo(obj)
+ res["lidar"]["points"] = points
+ res["lidar"]["points_num"] = points.shape[0]
+ combine = self.combine_frames
+
+ c_frame = nsweeps//combine
+ if c_frame > 0:
+ sweep_points_list = []
+ sweep_times_list = []
+ sweep_combined_list = []
+
+ combine_points_list = [points]
+ combine_times_list = [np.zeros((points.shape[0], 1))]
+ for j in range(combine-1):
+ sweep = info["sweeps"][j]
+ points_sweep, times_sweep = read_single_waymo_sweep(sweep)
+ combine_points_list.append(points_sweep)
+ combine_times_list.append(times_sweep)
+
+ sweep_points_list.append(np.concatenate(combine_points_list, axis=0))
+ sweep_times_list.append(np.concatenate(combine_times_list, axis=0).astype(points.dtype))
+ sweep_combined_list.append(np.hstack([sweep_points_list[-1], sweep_times_list[-1]]))
+
+ for i in range(c_frame - 1):
+ combine_points_list = []
+ combine_times_list = []
+ for j in range(combine):
+ sweep = info["sweeps"][(i+1)*combine+j-1]
+ points_sweep, times_sweep = read_single_waymo_sweep(sweep)
+ combine_points_list.append(points_sweep)
+ combine_times_list.append(times_sweep)
+
+ sweep_points_list.append(np.concatenate(combine_points_list, axis=0))
+ sweep_times_list.append(np.concatenate(combine_times_list, axis=0).astype(points.dtype))
+ sweep_combined_list.append(np.hstack([sweep_points_list[-1], sweep_times_list[-1]]))
+
+ res["lidar"]["points"] = sweep_points_list
+ res["lidar"]["times"] = sweep_times_list
+ res["lidar"]["combined"] = sweep_combined_list
+
+ else:
+ raise NotImplementedError
+
+ return res, info
+
+
+@PIPELINES.register_module
+class LoadPointCloudAnnotations(object):
+ def __init__(self, with_bbox=True, **kwargs):
+ pass
+
+ def __call__(self, res, info):
+
+ if res["type"] in ["NuScenesDataset"] and "gt_boxes" in info:
+ gt_boxes = info["gt_boxes"].astype(np.float32)
+ gt_boxes[np.isnan(gt_boxes)] = 0
+ res["lidar"]["annotations"] = {
+ "boxes": gt_boxes,
+ "names": info["gt_names"],
+ "tokens": info["gt_boxes_token"],
+ "velocities": info["gt_boxes_velocity"].astype(np.float32),
+ }
+ elif res["type"] in ['WaymoDataset','WaymoDataset_multi_frame'] and "gt_boxes" in info:
+ res["lidar"]["annotations"] = {
+ "boxes": info["gt_boxes"].astype(np.float32),
+ "names": info["gt_names"],
+ }
+ else:
+ pass
+
+ return res, info
diff --git a/det3d/datasets/pipelines/preprocess.py b/det3d/datasets/pipelines/preprocess.py
new file mode 100644
index 0000000..23247d6
--- /dev/null
+++ b/det3d/datasets/pipelines/preprocess.py
@@ -0,0 +1,495 @@
+import numpy as np
+
+from det3d.core.bbox import box_np_ops
+from det3d.core.sampler import preprocess as prep
+from det3d.builder import build_dbsampler
+
+from det3d.core.input.voxel_generator import VoxelGenerator
+from det3d.core.utils.center_utils import (
+ draw_umich_gaussian, gaussian_radius
+)
+from ..registry import PIPELINES
+
+
+def _dict_select(dict_, inds):
+ for k, v in dict_.items():
+ if isinstance(v, dict):
+ _dict_select(v, inds)
+ else:
+ dict_[k] = v[inds]
+
+
+def drop_arrays_by_name(gt_names, used_classes):
+ inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
+ inds = np.array(inds, dtype=np.int64)
+ return inds
+
+@PIPELINES.register_module
+class Preprocess(object):
+ def __init__(self, cfg=None, **kwargs):
+ self.shuffle_points = cfg.shuffle_points
+ self.min_points_in_gt = cfg.get("min_points_in_gt", -1)
+
+ self.mode = cfg.mode
+ if self.mode == "train":
+ self.global_rotation_noise = cfg.global_rot_noise
+ self.global_scaling_noise = cfg.global_scale_noise
+ self.global_translate_noise = cfg.get('global_translate_noise', 0)
+ self.class_names = cfg.class_names
+ if cfg.db_sampler != None:
+ self.db_sampler = build_dbsampler(cfg.db_sampler)
+ else:
+ self.db_sampler = None
+
+ self.npoints = cfg.get("npoints", -1)
+
+ self.no_augmentation = cfg.get('no_augmentation', False)
+
+ def __call__(self, res, info):
+
+ res["mode"] = self.mode
+
+ if res["type"] in ["WaymoDataset"]:
+ if "combined" in res["lidar"]:
+ points = res["lidar"]["combined"]
+ else:
+ points = res["lidar"]["points"]
+ elif res["type"] in ["NuScenesDataset"]:
+ points = res["lidar"]["combined"]
+ else:
+ raise NotImplementedError
+
+ if self.mode == "train":
+ anno_dict = res["lidar"]["annotations"]
+
+ gt_dict = {
+ "gt_boxes": anno_dict["boxes"],
+ "gt_names": np.array(anno_dict["names"]).reshape(-1),
+ }
+
+ if self.mode == "train" and not self.no_augmentation:
+ selected = drop_arrays_by_name(
+ gt_dict["gt_names"], ["DontCare", "ignore", "UNKNOWN"]
+ )
+
+ _dict_select(gt_dict, selected)
+
+ if self.min_points_in_gt > 0:
+ point_counts = box_np_ops.points_count_rbbox(
+ points, gt_dict["gt_boxes"]
+ )
+ mask = point_counts >= min_points_in_gt
+ _dict_select(gt_dict, mask)
+
+ gt_boxes_mask = np.array(
+ [n in self.class_names for n in gt_dict["gt_names"]], dtype=np.bool_
+ )
+
+ if self.db_sampler:
+ sampled_dict = self.db_sampler.sample_all(
+ res["metadata"]["image_prefix"],
+ gt_dict["gt_boxes"],
+ gt_dict["gt_names"],
+ res["metadata"]["num_point_features"],
+ False,
+ gt_group_ids=None,
+ calib=None,
+ road_planes=None
+ )
+
+ if sampled_dict is not None:
+ sampled_gt_names = sampled_dict["gt_names"]
+ sampled_gt_boxes = sampled_dict["gt_boxes"]
+ sampled_points = sampled_dict["points"]
+ sampled_gt_masks = sampled_dict["gt_masks"]
+ gt_dict["gt_names"] = np.concatenate(
+ [gt_dict["gt_names"], sampled_gt_names], axis=0
+ )
+ gt_dict["gt_boxes"] = np.concatenate(
+ [gt_dict["gt_boxes"], sampled_gt_boxes]
+ )
+ gt_boxes_mask = np.concatenate(
+ [gt_boxes_mask, sampled_gt_masks], axis=0
+ )
+
+
+ points = np.concatenate([sampled_points, points], axis=0)
+
+ _dict_select(gt_dict, gt_boxes_mask)
+
+ gt_classes = np.array(
+ [self.class_names.index(n) + 1 for n in gt_dict["gt_names"]],
+ dtype=np.int32,
+ )
+ gt_dict["gt_classes"] = gt_classes
+
+ gt_dict["gt_boxes"], points = prep.random_flip_both(gt_dict["gt_boxes"], points)
+
+ gt_dict["gt_boxes"], points = prep.global_rotation(
+ gt_dict["gt_boxes"], points, rotation=self.global_rotation_noise
+ )
+ gt_dict["gt_boxes"], points = prep.global_scaling_v2(
+ gt_dict["gt_boxes"], points, *self.global_scaling_noise
+ )
+ gt_dict["gt_boxes"], points = prep.global_translate_v2(
+ gt_dict["gt_boxes"], points, noise_translate=self.global_translate_noise
+ )
+ elif self.no_augmentation:
+ gt_boxes_mask = np.array(
+ [n in self.class_names for n in gt_dict["gt_names"]], dtype=np.bool_
+ )
+ _dict_select(gt_dict, gt_boxes_mask)
+
+ gt_classes = np.array(
+ [self.class_names.index(n) + 1 for n in gt_dict["gt_names"]],
+ dtype=np.int32,
+ )
+ gt_dict["gt_classes"] = gt_classes
+
+
+ if self.shuffle_points:
+ np.random.shuffle(points)
+
+ res["lidar"]["points"] = points
+
+ if self.mode == "train":
+ res["lidar"]["annotations"] = gt_dict
+
+ return res, info
+
+
+@PIPELINES.register_module
+class Voxelization(object):
+ def __init__(self, **kwargs):
+ cfg = kwargs.get("cfg", None)
+ self.range = cfg.range
+ self.voxel_size = cfg.voxel_size
+ self.max_points_in_voxel = cfg.max_points_in_voxel
+ self.max_voxel_num = [cfg.max_voxel_num, cfg.max_voxel_num] if isinstance(cfg.max_voxel_num, int) else cfg.max_voxel_num
+
+ self.double_flip = cfg.get('double_flip', False)
+
+ self.voxel_generator = VoxelGenerator(
+ voxel_size=self.voxel_size,
+ point_cloud_range=self.range,
+ max_num_points=self.max_points_in_voxel,
+ max_voxels=self.max_voxel_num[0],
+ )
+
+ def __call__(self, res, info):
+ voxel_size = self.voxel_generator.voxel_size
+ pc_range = self.voxel_generator.point_cloud_range
+ grid_size = self.voxel_generator.grid_size
+
+ if res["mode"] == "train":
+ gt_dict = res["lidar"]["annotations"]
+ bv_range = pc_range[[0, 1, 3, 4]]
+ mask = prep.filter_gt_box_outside_range(gt_dict["gt_boxes"], bv_range)
+ _dict_select(gt_dict, mask)
+
+ res["lidar"]["annotations"] = gt_dict
+ max_voxels = self.max_voxel_num[0]
+ else:
+ max_voxels = self.max_voxel_num[1]
+ # max_voxels = self.max_voxel_num[1]
+
+ voxels, coordinates, num_points = self.voxel_generator.generate(
+ res["lidar"]["points"], max_voxels=max_voxels
+ )
+ num_voxels = np.array([voxels.shape[0]], dtype=np.int64)
+
+ res["lidar"]["voxels"] = dict(
+ voxels=voxels,
+ coordinates=coordinates,
+ num_points=num_points,
+ num_voxels=num_voxels,
+ shape=grid_size,
+ range=pc_range,
+ size=voxel_size
+ )
+
+ double_flip = self.double_flip and (res["mode"] != 'train')
+
+ if double_flip:
+ flip_voxels, flip_coordinates, flip_num_points = self.voxel_generator.generate(
+ res["lidar"]["yflip_points"]
+ )
+ flip_num_voxels = np.array([flip_voxels.shape[0]], dtype=np.int64)
+
+ res["lidar"]["yflip_voxels"] = dict(
+ voxels=flip_voxels,
+ coordinates=flip_coordinates,
+ num_points=flip_num_points,
+ num_voxels=flip_num_voxels,
+ shape=grid_size,
+ range=pc_range,
+ size=voxel_size
+ )
+
+ flip_voxels, flip_coordinates, flip_num_points = self.voxel_generator.generate(
+ res["lidar"]["xflip_points"]
+ )
+ flip_num_voxels = np.array([flip_voxels.shape[0]], dtype=np.int64)
+
+ res["lidar"]["xflip_voxels"] = dict(
+ voxels=flip_voxels,
+ coordinates=flip_coordinates,
+ num_points=flip_num_points,
+ num_voxels=flip_num_voxels,
+ shape=grid_size,
+ range=pc_range,
+ size=voxel_size
+ )
+
+ flip_voxels, flip_coordinates, flip_num_points = self.voxel_generator.generate(
+ res["lidar"]["double_flip_points"]
+ )
+ flip_num_voxels = np.array([flip_voxels.shape[0]], dtype=np.int64)
+
+ res["lidar"]["double_flip_voxels"] = dict(
+ voxels=flip_voxels,
+ coordinates=flip_coordinates,
+ num_points=flip_num_points,
+ num_voxels=flip_num_voxels,
+ shape=grid_size,
+ range=pc_range,
+ size=voxel_size
+ )
+
+ return res, info
+
+def flatten(box):
+ return np.concatenate(box, axis=0)
+
+def merge_multi_group_label(gt_classes, num_classes_by_task):
+ num_task = len(gt_classes)
+ flag = 0
+
+ for i in range(num_task):
+ gt_classes[i] += flag
+ flag += num_classes_by_task[i]
+
+ return flatten(gt_classes)
+
+
+@PIPELINES.register_module
+class AssignLabel(object):
+ def __init__(self, **kwargs):
+ """Return CenterNet training labels like heatmap, height, offset"""
+ assigner_cfg = kwargs["cfg"]
+ self.out_size_factor = assigner_cfg.out_size_factor
+ self.tasks = assigner_cfg.target_assigner.tasks
+ self.gaussian_overlap = assigner_cfg.gaussian_overlap
+ self._max_objs = assigner_cfg.max_objs
+ self._min_radius = assigner_cfg.min_radius
+ self.corner_prediction = assigner_cfg.get('corner_prediction', False)
+ self.gt_kernel_size = assigner_cfg.get('gt_kernel_size', 1)
+ print('use gt label assigning kernel size ', self.gt_kernel_size)
+ self.cfg = assigner_cfg
+
+ def __call__(self, res, info):
+ max_objs = self._max_objs
+ gt_kernel_size = self.gt_kernel_size
+ window_size = gt_kernel_size**2
+ class_names_by_task = [t.class_names for t in self.tasks]
+ num_classes_by_task = [t.num_class for t in self.tasks]
+
+ example = {}
+
+ if res["mode"] == "train":
+ if 'pc_range' in self.cfg:
+ pc_range = np.array(self.cfg['pc_range'], dtype=np.float32)
+ voxel_size = np.array(self.cfg['voxel_size'], dtype=np.float32)
+ grid_size = (pc_range[3:] - pc_range[:3]) / voxel_size
+ grid_size = np.round(grid_size).astype(np.int64)
+ elif 'voxels' in res['lidar']:
+ # Calculate output featuremap size
+ grid_size = res["lidar"]["voxels"]["shape"]
+ pc_range = res["lidar"]["voxels"]["range"]
+ voxel_size = res["lidar"]["voxels"]["size"]
+ else:
+ raise NotImplementedError("range and size configuration are missing in the config!")
+ # BEV map down sample scale
+ ds_factor=self.out_size_factor
+ # get width and height
+ W,H=(pc_range[3] - pc_range[0]) / voxel_size[0]/ ds_factor, (pc_range[4] - pc_range[1]) / voxel_size[1]/ ds_factor
+ W,H=np.round(W).astype(int),np.round(H).astype(int)
+ feature_map_size = grid_size[:2] // self.out_size_factor
+
+ gt_dict = res["lidar"]["annotations"]
+
+ # reorganize the gt_dict by tasks
+ task_masks = []
+ flag = 0
+ for class_name in class_names_by_task:
+ task_masks.append(
+ [
+ np.where(
+ gt_dict["gt_classes"] == class_name.index(i) + 1 + flag
+ )
+ for i in class_name
+ ]
+ )
+ flag += len(class_name)
+
+ task_boxes = []
+ task_classes = []
+ task_names = []
+ flag2 = 0
+ for idx, mask in enumerate(task_masks):
+ task_box = []
+ task_class = []
+ task_name = []
+ for m in mask:
+ task_box.append(gt_dict["gt_boxes"][m])
+ task_class.append(gt_dict["gt_classes"][m] - flag2)
+ task_name.append(gt_dict["gt_names"][m])
+ task_boxes.append(np.concatenate(task_box, axis=0))
+ task_classes.append(np.concatenate(task_class))
+ task_names.append(np.concatenate(task_name))
+ flag2 += len(mask)
+
+ for task_box in task_boxes:
+ # limit rad to [-pi, pi]
+ task_box[:, -1] = box_np_ops.limit_period(
+ task_box[:, -1], offset=0.5, period=np.pi * 2
+ )
+
+ # print(gt_dict.keys())
+ gt_dict["gt_classes"] = task_classes
+ gt_dict["gt_names"] = task_names
+ gt_dict["gt_boxes"] = task_boxes
+
+ res["lidar"]["annotations"] = gt_dict
+
+ draw_gaussian = draw_umich_gaussian
+
+ hms, anno_boxs, inds, masks, cats = [], [], [], [], []
+ if self.corner_prediction:
+ corners = []
+
+ for idx, task in enumerate(self.tasks):
+ hm = np.zeros((len(class_names_by_task[idx]), feature_map_size[1], feature_map_size[0]),
+ dtype=np.float32)
+
+ if self.corner_prediction:
+ corner = np.zeros((1, feature_map_size[1], feature_map_size[0]), dtype=np.float32)
+
+ if res['type'] == 'NuScenesDataset':
+ # [reg, hei, dim, vx, vy, rots, rotc]
+ anno_box = np.zeros((max_objs*window_size, 10), dtype=np.float32)
+ elif res['type'] in ['WaymoDataset','WaymoDataset_multi_frame']:
+ anno_box = np.zeros((max_objs*window_size, 10), dtype=np.float32)
+ else:
+ raise NotImplementedError("Only Support nuScene for Now!")
+
+ ind = np.zeros((max_objs*window_size), dtype=np.int64)
+ mask = np.zeros((max_objs*window_size), dtype=np.uint8)
+ cat = np.zeros((max_objs*window_size), dtype=np.int64)
+
+ num_objs = min(gt_dict['gt_boxes'][idx].shape[0], max_objs)
+
+ for k in range(num_objs):
+ cls_id = gt_dict['gt_classes'][idx][k] - 1
+
+ w, l, h = gt_dict['gt_boxes'][idx][k][3], gt_dict['gt_boxes'][idx][k][4], \
+ gt_dict['gt_boxes'][idx][k][5]
+ w, l = w / voxel_size[0] / self.out_size_factor, l / voxel_size[1] / self.out_size_factor
+ if w > 0 and l > 0:
+ radius = gaussian_radius((l, w), min_overlap=self.gaussian_overlap)
+ radius = max(self._min_radius, int(radius))
+
+ # be really careful for the coordinate system of your box annotation.
+ x, y, z = gt_dict['gt_boxes'][idx][k][0], gt_dict['gt_boxes'][idx][k][1], \
+ gt_dict['gt_boxes'][idx][k][2]
+
+ coor_x, coor_y = (x - pc_range[0]) / voxel_size[0] / self.out_size_factor, \
+ (y - pc_range[1]) / voxel_size[1] / self.out_size_factor
+
+ ct = np.array(
+ [coor_x, coor_y], dtype=np.float32)
+ ct_int = ct.astype(np.int32)
+
+ # throw out not in range objects to avoid out of array area when creating the heatmap
+ if not (0 <= ct_int[0] < feature_map_size[0] and 0 <= ct_int[1] < feature_map_size[1]):
+ continue
+
+ draw_gaussian(hm[cls_id], ct, radius)
+ if self.corner_prediction:
+ radius = radius//2
+ # draw four corner and center
+ dim = np.array([w, l], dtype=np.float32)
+ rot = np.array([gt_dict['gt_boxes'][idx][k][8]], dtype=np.float32)
+ corner_keypoints = box_np_ops.center_to_corner_box2d(ct[np.newaxis,:],dim[np.newaxis,:],rot)
+ draw_gaussian(corner[0], ct, radius)
+ draw_gaussian(corner[0], (corner_keypoints[0, 0] + corner_keypoints[0, 1])/2, radius)
+ draw_gaussian(corner[0], (corner_keypoints[0, 2] + corner_keypoints[0, 3])/2, radius)
+ draw_gaussian(corner[0], (corner_keypoints[0, 0] + corner_keypoints[0, 3])/2, radius)
+ draw_gaussian(corner[0], (corner_keypoints[0, 1] + corner_keypoints[0, 2])/2, radius)
+
+ new_idx = k
+ x, y = np.arange(ct_int[0]-gt_kernel_size//2,ct_int[0]+1+gt_kernel_size//2), np.arange(ct_int[1]-gt_kernel_size//2,ct_int[1]+1+gt_kernel_size//2)
+ x, y = np.meshgrid(x, y)
+ x = x.reshape(-1)
+ y = y.reshape(-1)
+
+ for j in range(window_size):
+
+ cat[new_idx*window_size+j] = cls_id
+ ind[new_idx*window_size+j] = y[j] * feature_map_size[0] + x[j]
+ mask[new_idx*window_size+j] = 1
+
+ if res['type'] == 'NuScenesDataset':
+ vx, vy = gt_dict['gt_boxes'][idx][k][6:8]
+ rot = gt_dict['gt_boxes'][idx][k][8]
+ anno_box[new_idx*window_size+j] = np.concatenate(
+ (ct - (x[j], y[j]), z, np.log(gt_dict['gt_boxes'][idx][k][3:6]),
+ np.array(vx), np.array(vy), np.sin(rot), np.cos(rot)), axis=None)
+ elif res['type'] in ['WaymoDataset','WaymoDataset_multi_frame']:
+ vx, vy = gt_dict['gt_boxes'][idx][k][6:8]
+ rot = gt_dict['gt_boxes'][idx][k][-1]
+ anno_box[new_idx*window_size+j] = np.concatenate(
+ (ct - (x[j], y[j]), z, np.log(gt_dict['gt_boxes'][idx][k][3:6]),
+ np.array(vx), np.array(vy), np.sin(rot), np.cos(rot)), axis=None)
+ else:
+ raise NotImplementedError("Only Support Waymo and nuScene for Now")
+
+ hms.append(hm)
+ anno_boxs.append(anno_box)
+ masks.append(mask)
+ inds.append(ind)
+ cats.append(cat)
+ if self.corner_prediction:
+ corners.append(corner)
+
+ # used for two stage code
+ boxes = flatten(gt_dict['gt_boxes'])
+ classes = merge_multi_group_label(gt_dict['gt_classes'], num_classes_by_task)
+
+ if res["type"] == "NuScenesDataset":
+ gt_boxes_and_cls = np.zeros((max_objs, 10), dtype=np.float32)
+ elif res['type'] in ['WaymoDataset','WaymoDataset_multi_frame']:
+ gt_boxes_and_cls = np.zeros((max_objs, 10), dtype=np.float32)
+ else:
+ raise NotImplementedError()
+
+ boxes_and_cls = np.concatenate((boxes,
+ classes.reshape(-1, 1).astype(np.float32)), axis=1)
+ num_obj = len(boxes_and_cls)
+ assert num_obj <= max_objs
+ # x, y, z, w, l, h, rotation_y, velocity_x, velocity_y, class_name
+ boxes_and_cls = boxes_and_cls[:, [0, 1, 2, 3, 4, 5, 8, 6, 7, 9]]
+ gt_boxes_and_cls[:num_obj] = boxes_and_cls
+
+ example.update({'gt_boxes_and_cls': gt_boxes_and_cls})
+
+ example.update({'hm': hms, 'anno_box': anno_boxs, 'ind': inds, 'mask': masks, 'cat': cats})
+ if self.corner_prediction:
+ example.update({'corners': corners})
+ else:
+ pass
+
+ res["lidar"]["targets"] = example
+
+ return res, info
+
diff --git a/det3d/datasets/pipelines/preprocess_multiframe.py b/det3d/datasets/pipelines/preprocess_multiframe.py
new file mode 100644
index 0000000..f25dee1
--- /dev/null
+++ b/det3d/datasets/pipelines/preprocess_multiframe.py
@@ -0,0 +1,227 @@
+import numpy as np
+
+from det3d.core.bbox import box_np_ops
+from det3d.core.sampler import preprocess as prep
+from det3d.builder import build_dbsampler
+
+from det3d.core.input.voxel_generator import VoxelGenerator
+from det3d.core.utils.center_utils import (
+ draw_umich_gaussian, gaussian_radius
+)
+from ..registry import PIPELINES
+
+
+def _dict_select(dict_, inds):
+ for k, v in dict_.items():
+ if isinstance(v, dict):
+ _dict_select(v, inds)
+ else:
+ dict_[k] = v[inds]
+
+
+def drop_arrays_by_name(gt_names, used_classes):
+ inds = [i for i, x in enumerate(gt_names) if x not in used_classes]
+ inds = np.array(inds, dtype=np.int64)
+ return inds
+
+@PIPELINES.register_module
+class Preprocess_multiframe(object):
+ def __init__(self, cfg=None, **kwargs):
+ self.shuffle_points = cfg.shuffle_points
+ self.min_points_in_gt = cfg.get("min_points_in_gt", -1)
+ self.combine_frame = cfg.get("combine_frame", False)
+
+ self.mode = cfg.mode
+ if self.mode == "train":
+ self.global_rotation_noise = cfg.global_rot_noise
+ self.global_scaling_noise = cfg.global_scale_noise
+ self.global_translate_noise = cfg.get('global_translate_noise', 0)
+ self.class_names = cfg.class_names
+ if cfg.db_sampler != None:
+ self.db_sampler = build_dbsampler(cfg.db_sampler)
+ else:
+ self.db_sampler = None
+
+ self.npoints = cfg.get("npoints", -1)
+
+ self.no_augmentation = cfg.get('no_augmentation', False)
+
+ def __call__(self, res, info):
+
+ res["mode"] = self.mode
+
+ if res["type"] in ["WaymoDataset"]:
+ if "combined" in res["lidar"]:
+ points = res["lidar"]["combined"]
+ else:
+ points = res["lidar"]["points"]
+ elif res["type"] in ["WaymoDataset_multi_frame"]:
+ if self.combine_frame:
+ points = res["lidar"]["combined"][0]
+ previous_frame = res["lidar"]["combined"][1:]
+ time_frame = [time[0,0] for time in res["lidar"]["times"]]
+ else:
+ points = res["lidar"]["points"][0]
+ previous_frame = res["lidar"]["points"][1:]
+ time_frame = [time[0,0] for time in res["lidar"]["times"]]
+ elif res["type"] in ["NuScenesDataset"]:
+ points = res["lidar"]["combined"]
+ else:
+ raise NotImplementedError
+
+ if self.mode == "train":
+ anno_dict = res["lidar"]["annotations"]
+
+ gt_dict = {
+ "gt_boxes": anno_dict["boxes"],
+ "gt_names": np.array(anno_dict["names"]).reshape(-1),
+ }
+
+ points_num = []
+ points_timeframe = []
+
+ if self.mode == "train" and not self.no_augmentation:
+ selected = drop_arrays_by_name(
+ gt_dict["gt_names"], ["DontCare", "ignore", "UNKNOWN"]
+ )
+
+ _dict_select(gt_dict, selected)
+
+ if self.min_points_in_gt > 0:
+ point_counts = box_np_ops.points_count_rbbox(
+ points, gt_dict["gt_boxes"]
+ )
+ mask = point_counts >= min_points_in_gt
+ _dict_select(gt_dict, mask)
+
+ gt_boxes_mask = np.array(
+ [n in self.class_names for n in gt_dict["gt_names"]], dtype=np.bool_
+ )
+
+
+
+ if self.db_sampler:
+ sampled_dict = self.db_sampler.sample_all(
+ res["metadata"]["image_prefix"],
+ gt_dict["gt_boxes"],
+ gt_dict["gt_names"],
+ # res["metadata"]["num_point_features"],
+ points.shape[1],
+ False,
+ gt_group_ids=None,
+ calib=None,
+ road_planes=None
+ )
+
+ if sampled_dict is not None:
+ sampled_gt_names = sampled_dict["gt_names"]
+ sampled_gt_boxes = sampled_dict["gt_boxes"]
+ sampled_points = sampled_dict["points"]
+ sampled_gt_masks = sampled_dict["gt_masks"]
+ gt_dict["gt_names"] = np.concatenate(
+ [gt_dict["gt_names"], sampled_gt_names], axis=0
+ )
+ gt_dict["gt_boxes"] = np.concatenate(
+ [gt_dict["gt_boxes"], sampled_gt_boxes]
+ )
+ gt_boxes_mask = np.concatenate(
+ [gt_boxes_mask, sampled_gt_masks], axis=0
+ )
+
+
+ points = np.concatenate([sampled_points, points], axis=0)
+ points_num.append(points.shape[0])
+ points_timeframe.append(0.)
+
+ if res["type"] in ["WaymoDataset_multi_frame"]:
+ for idx, pre_points in enumerate(previous_frame):
+ pre_points = np.concatenate([sampled_points, pre_points], axis=0)
+ points_num.append(pre_points.shape[0])
+ points = np.concatenate([points, pre_points], axis=0)
+ points_timeframe.append(time_frame[idx+1])
+ else:
+ points_num.append(points.shape[0])
+ points_timeframe.append(0.)
+ if res["type"] in ["WaymoDataset_multi_frame"]:
+ for idx, pre_points in enumerate(previous_frame):
+ points_num.append(pre_points.shape[0])
+ points = np.concatenate([points, pre_points], axis=0)
+ points_timeframe.append(time_frame[idx+1])
+ else:
+ points_num.append(points.shape[0])
+ points_timeframe.append(0.)
+ if res["type"] in ["WaymoDataset_multi_frame"]:
+ for idx, pre_points in enumerate(previous_frame):
+ points_num.append(pre_points.shape[0])
+ points = np.concatenate([points, pre_points], axis=0)
+ points_timeframe.append(time_frame[idx+1])
+
+ _dict_select(gt_dict, gt_boxes_mask)
+
+ gt_classes = np.array(
+ [self.class_names.index(n) + 1 for n in gt_dict["gt_names"]],
+ dtype=np.int32,
+ )
+ gt_dict["gt_classes"] = gt_classes
+
+ gt_dict["gt_boxes"], points = prep.random_flip_both(gt_dict["gt_boxes"], points)
+
+ gt_dict["gt_boxes"], points = prep.global_rotation(
+ gt_dict["gt_boxes"], points, rotation=self.global_rotation_noise
+ )
+ gt_dict["gt_boxes"], points = prep.global_scaling_v2(
+ gt_dict["gt_boxes"], points, *self.global_scaling_noise
+ )
+ gt_dict["gt_boxes"], points = prep.global_translate_v2(
+ gt_dict["gt_boxes"], points, noise_translate=self.global_translate_noise
+ )
+ elif self.no_augmentation:
+ gt_boxes_mask = np.array(
+ [n in self.class_names for n in gt_dict["gt_names"]], dtype=np.bool_
+ )
+ _dict_select(gt_dict, gt_boxes_mask)
+
+ gt_classes = np.array(
+ [self.class_names.index(n) + 1 for n in gt_dict["gt_names"]],
+ dtype=np.int32,
+ )
+ gt_dict["gt_classes"] = gt_classes
+
+ points_num.append(points.shape[0])
+ points_timeframe.append(0.)
+ if res["type"] in ["WaymoDataset_multi_frame"]:
+ for idx, pre_points in enumerate(previous_frame):
+ points_num.append(pre_points.shape[0])
+ points = np.concatenate([points, pre_points], axis=0)
+ points_timeframe.append(time_frame[idx+1])
+ else:
+ points_num.append(points.shape[0])
+ points_timeframe.append(0.)
+ if res["type"] in ["WaymoDataset_multi_frame"]:
+ for idx, pre_points in enumerate(previous_frame):
+ points_num.append(pre_points.shape[0])
+ points = np.concatenate([points, pre_points], axis=0)
+ points_timeframe.append(time_frame[idx+1])
+
+ #disengage points in multi-frame
+ if len(points_num)>1:
+ previous_frames = []
+ counts = 0
+ for n in range(len(points_num)):
+ previous_frames.append(points[counts:counts+points_num[n]])
+ if self.shuffle_points:
+ np.random.shuffle(previous_frames[-1])
+ counts += points_num[n]
+ # points = points[:points_num[0]]
+ res["lidar"]["multi_points"] = previous_frames
+ else:
+ if self.shuffle_points:
+ np.random.shuffle(points)
+ res["lidar"]["points"] = points
+
+ res["lidar"]["times"] = np.asarray(points_timeframe)
+
+ if self.mode == "train":
+ res["lidar"]["annotations"] = gt_dict
+
+ return res, info
\ No newline at end of file
diff --git a/det3d/datasets/pipelines/test_aug.py b/det3d/datasets/pipelines/test_aug.py
new file mode 100644
index 0000000..9a34bd0
--- /dev/null
+++ b/det3d/datasets/pipelines/test_aug.py
@@ -0,0 +1,35 @@
+from det3d import torchie
+
+from ..registry import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module
+class DoubleFlip(object):
+ def __init__(self):
+ pass
+
+ def __call__(self, res, info):
+ # y flip
+ points = res["lidar"]["points"].copy()
+ points[:, 1] = -points[:, 1]
+
+ res["lidar"]['yflip_points'] = points
+
+ # x flip
+ points = res["lidar"]["points"].copy()
+ points[:, 0] = -points[:, 0]
+
+ res["lidar"]['xflip_points'] = points
+
+ # x y flip
+ points = res["lidar"]["points"].copy()
+ points[:, 0] = -points[:, 0]
+ points[:, 1] = -points[:, 1]
+
+ res["lidar"]["double_flip_points"] = points
+
+ return res, info
+
+
+
diff --git a/det3d/datasets/registry.py b/det3d/datasets/registry.py
new file mode 100644
index 0000000..3045ee2
--- /dev/null
+++ b/det3d/datasets/registry.py
@@ -0,0 +1,4 @@
+from det3d.utils.registry import Registry
+
+DATASETS = Registry("dataset")
+PIPELINES = Registry("pipeline")
diff --git a/det3d/datasets/utils/create_gt_database.py b/det3d/datasets/utils/create_gt_database.py
new file mode 100644
index 0000000..88670cc
--- /dev/null
+++ b/det3d/datasets/utils/create_gt_database.py
@@ -0,0 +1,170 @@
+import pickle
+from pathlib import Path
+import os
+import numpy as np
+
+from det3d.core import box_np_ops
+from det3d.datasets.dataset_factory import get_dataset
+from tqdm import tqdm
+
+dataset_name_map = {
+ "NUSC": "NuScenesDataset",
+ "WAYMO": "WaymoDataset"
+}
+
+
+def create_groundtruth_database(
+ dataset_class_name,
+ data_path,
+ info_path=None,
+ used_classes=None,
+ db_path=None,
+ dbinfo_path=None,
+ relative_path=True,
+ **kwargs,
+):
+ pipeline = [
+ {
+ "type": "LoadPointCloudFromFile",
+ "dataset": dataset_name_map[dataset_class_name],
+ },
+ {"type": "LoadPointCloudAnnotations", "with_bbox": True},
+ ]
+
+ if "nsweeps" in kwargs:
+ dataset = get_dataset(dataset_class_name)(
+ info_path=info_path,
+ root_path=data_path,
+ pipeline=pipeline,
+ test_mode=True,
+ nsweeps=kwargs["nsweeps"],
+ )
+ nsweeps = dataset.nsweeps
+ else:
+ dataset = get_dataset(dataset_class_name)(
+ info_path=info_path, root_path=data_path, test_mode=True, pipeline=pipeline
+ )
+ nsweeps = 1
+
+ root_path = Path(data_path)
+
+ if dataset_class_name in ["WAYMO", "NUSC"]:
+ if db_path is None:
+ db_path = root_path / f"gt_database_{nsweeps}sweeps_withvelo"
+ if dbinfo_path is None:
+ dbinfo_path = root_path / f"dbinfos_train_{nsweeps}sweeps_withvelo.pkl"
+ else:
+ raise NotImplementedError()
+
+ if dataset_class_name == "NUSC":
+ point_features = 5
+ elif dataset_class_name == "WAYMO":
+ point_features = 5 if nsweeps == 1 else 6
+ else:
+ raise NotImplementedError()
+
+ db_path.mkdir(parents=True, exist_ok=True)
+
+ all_db_infos = {}
+ group_counter = 0
+
+ for index in tqdm(range(len(dataset))):
+ image_idx = index
+ # modified to nuscenes
+ sensor_data = dataset.get_sensor_data(index)
+ if "image_idx" in sensor_data["metadata"]:
+ image_idx = sensor_data["metadata"]["image_idx"]
+
+ if nsweeps > 1:
+ points = sensor_data["lidar"]["combined"]
+ else:
+ points = sensor_data["lidar"]["points"]
+
+ annos = sensor_data["lidar"]["annotations"]
+ gt_boxes = annos["boxes"]
+ names = annos["names"]
+
+ if dataset_class_name == 'WAYMO':
+ # waymo dataset contains millions of objects and it is not possible to store
+ # all of them into a single folder
+ # we randomly sample a few objects for gt augmentation
+ # We keep all cyclist as they are rare
+ if index % 4 != 0:
+ mask = (names == 'VEHICLE')
+ mask = np.logical_not(mask)
+ names = names[mask]
+ gt_boxes = gt_boxes[mask]
+
+ if index % 2 != 0:
+ mask = (names == 'PEDESTRIAN')
+ mask = np.logical_not(mask)
+ names = names[mask]
+ gt_boxes = gt_boxes[mask]
+
+
+ group_dict = {}
+ group_ids = np.full([gt_boxes.shape[0]], -1, dtype=np.int64)
+ if "group_ids" in annos:
+ group_ids = annos["group_ids"]
+ else:
+ group_ids = np.arange(gt_boxes.shape[0], dtype=np.int64)
+ difficulty = np.zeros(gt_boxes.shape[0], dtype=np.int32)
+ if "difficulty" in annos:
+ difficulty = annos["difficulty"]
+
+ num_obj = gt_boxes.shape[0]
+ if num_obj == 0:
+ continue
+ point_indices = box_np_ops.points_in_rbbox(points, gt_boxes)
+ for i in range(num_obj):
+ if (used_classes is None) or names[i] in used_classes:
+ filename = f"{image_idx}_{names[i]}_{i}.bin"
+ dirpath = os.path.join(str(db_path), names[i])
+ os.makedirs(dirpath, exist_ok=True)
+
+ filepath = os.path.join(str(db_path), names[i], filename)
+ gt_points = points[point_indices[:, i]]
+ gt_points[:, :3] -= gt_boxes[i, :3]
+ with open(filepath, "w") as f:
+ try:
+ gt_points[:, :point_features].tofile(f)
+ except:
+ print("process {} files".format(index))
+ break
+
+ if (used_classes is None) or names[i] in used_classes:
+ if relative_path:
+ db_dump_path = os.path.join(db_path.stem, names[i], filename)
+ else:
+ db_dump_path = str(filepath)
+
+ db_info = {
+ "name": names[i],
+ "path": db_dump_path,
+ "image_idx": image_idx,
+ "gt_idx": i,
+ "box3d_lidar": gt_boxes[i],
+ "num_points_in_gt": gt_points.shape[0],
+ "difficulty": difficulty[i],
+ # "group_id": -1,
+ # "bbox": bboxes[i],
+ }
+ local_group_id = group_ids[i]
+ # if local_group_id >= 0:
+ if local_group_id not in group_dict:
+ group_dict[local_group_id] = group_counter
+ group_counter += 1
+ db_info["group_id"] = group_dict[local_group_id]
+ if "score" in annos:
+ db_info["score"] = annos["score"][i]
+ if names[i] in all_db_infos:
+ all_db_infos[names[i]].append(db_info)
+ else:
+ all_db_infos[names[i]] = [db_info]
+
+ print("dataset length: ", len(dataset))
+ for k, v in all_db_infos.items():
+ print(f"load {len(v)} {k} database infos")
+
+ with open(dbinfo_path, "wb") as f:
+ pickle.dump(all_db_infos, f)
diff --git a/det3d/datasets/utils/distributed.py b/det3d/datasets/utils/distributed.py
new file mode 100644
index 0000000..bd04f9a
--- /dev/null
+++ b/det3d/datasets/utils/distributed.py
@@ -0,0 +1,62 @@
+import math
+import torch
+import torch.distributed as dist
+from torch.utils.data.sampler import Sampler
+
+
+class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+ .. note::
+ Dataset is assumed to be of constant size.
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ """
+
+ def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True):
+ if num_replicas is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ num_replicas = dist.get_world_size()
+ if rank is None:
+ if not dist.is_available():
+ raise RuntimeError("Requires distributed package to be available")
+ rank = dist.get_rank()
+ self.dataset = dataset
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
+ self.total_size = self.num_samples * self.num_replicas
+ self.shuffle = shuffle
+
+ def __iter__(self):
+ if self.shuffle:
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ indices += indices[: (self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank : self.total_size : self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/det3d/datasets/utils/eval.py b/det3d/datasets/utils/eval.py
new file mode 100644
index 0000000..8154894
--- /dev/null
+++ b/det3d/datasets/utils/eval.py
@@ -0,0 +1,367 @@
+import numpy as np
+import numba
+
+from det3d.ops.nms.nms_gpu import rotate_iou_gpu_eval
+from det3d.ops.nms.nms_gpu import inter
+from det3d.core import box_np_ops
+
+
+def get_split_parts(num, num_part):
+ same_part = num // num_part
+ remain_num = num % num_part
+ if remain_num == 0:
+ return [same_part] * num_part
+ else:
+ return [same_part] * num_part + [remain_num]
+
+
+def prepare_data(gt_annos, dt_annos, current_class, difficulty=None, clean_data=None):
+ gt_datas_list = []
+ dt_datas_list = []
+ total_dc_num = []
+ ignored_gts, ignored_dets, dontcares = [], [], []
+ total_num_valid_gt = 0
+ for i in range(len(gt_annos)):
+ rets = clean_data(gt_annos[i], dt_annos[i], current_class, difficulty)
+ num_valid_gt, ignored_gt, ignored_det, dc_bboxes = rets
+ ignored_gts.append(np.array(ignored_gt, dtype=np.int64))
+ ignored_dets.append(np.array(ignored_det, dtype=np.int64))
+ if len(dc_bboxes) == 0:
+ dc_bboxes = np.zeros((0, 4)).astype(np.float64)
+ else:
+ dc_bboxes = np.stack(dc_bboxes, 0).astype(np.float64)
+ total_dc_num.append(dc_bboxes.shape[0])
+ dontcares.append(dc_bboxes)
+ total_num_valid_gt += num_valid_gt
+ gt_datas = np.concatenate(
+ [gt_annos[i]["bbox"], gt_annos[i]["alpha"][..., np.newaxis]], 1
+ )
+ dt_datas = np.concatenate(
+ [
+ dt_annos[i]["bbox"],
+ dt_annos[i]["alpha"][..., np.newaxis],
+ dt_annos[i]["score"][..., np.newaxis],
+ ],
+ 1,
+ )
+ gt_datas_list.append(gt_datas)
+ dt_datas_list.append(dt_datas)
+ total_dc_num = np.stack(total_dc_num, axis=0)
+ return (
+ gt_datas_list,
+ dt_datas_list,
+ ignored_gts,
+ ignored_dets,
+ dontcares,
+ total_dc_num,
+ total_num_valid_gt,
+ )
+
+
+def calculate_iou_partly(
+ gt_annos, dt_annos, metric, num_parts=50, z_axis=1, z_center=1.0
+):
+ """fast iou algorithm. this function can be used independently to
+ do result analysis.
+ Args:
+ gt_annos: dict, must from get_label_annos() in kitti_common.py
+ dt_annos: dict, must from get_label_annos() in kitti_common.py
+ metric: eval type. 0: bbox, 1: bev, 2: 3d
+ num_parts: int. a parameter for fast calculate algorithm
+ z_axis: height axis. kitti camera use 1, lidar use 2.
+ """
+ assert len(gt_annos) == len(dt_annos)
+ total_dt_num = np.stack([len(a["name"]) for a in dt_annos], 0)
+ total_gt_num = np.stack([len(a["name"]) for a in gt_annos], 0)
+ num_examples = len(gt_annos)
+ split_parts = get_split_parts(num_examples, num_parts)
+ parted_overlaps = []
+ example_idx = 0
+ bev_axes = list(range(3))
+ bev_axes.pop(z_axis)
+ split_parts = [i for i in split_parts if i != 0]
+ for num_part in split_parts:
+ gt_annos_part = gt_annos[example_idx : example_idx + num_part]
+ dt_annos_part = dt_annos[example_idx : example_idx + num_part]
+ if metric == 0:
+ gt_boxes = np.concatenate([a["bbox"] for a in gt_annos_part], 0)
+ dt_boxes = np.concatenate([a["bbox"] for a in dt_annos_part], 0)
+ overlap_part = image_box_overlap(gt_boxes, dt_boxes)
+ elif metric == 1:
+ loc = np.concatenate([a["location"][:, bev_axes] for a in gt_annos_part], 0)
+ dims = np.concatenate(
+ [a["dimensions"][:, bev_axes] for a in gt_annos_part], 0
+ )
+ rots = np.concatenate([a["rotation_y"] for a in gt_annos_part], 0)
+ gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1)
+ loc = np.concatenate([a["location"][:, bev_axes] for a in dt_annos_part], 0)
+ dims = np.concatenate(
+ [a["dimensions"][:, bev_axes] for a in dt_annos_part], 0
+ )
+ rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0)
+ dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1)
+ overlap_part = bev_box_overlap(gt_boxes, dt_boxes).astype(np.float64)
+ elif metric == 2:
+ loc = np.concatenate([a["location"] for a in gt_annos_part], 0)
+ dims = np.concatenate([a["dimensions"] for a in gt_annos_part], 0)
+ rots = np.concatenate([a["rotation_y"] for a in gt_annos_part], 0)
+ gt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1)
+ loc = np.concatenate([a["location"] for a in dt_annos_part], 0)
+ dims = np.concatenate([a["dimensions"] for a in dt_annos_part], 0)
+ rots = np.concatenate([a["rotation_y"] for a in dt_annos_part], 0)
+ dt_boxes = np.concatenate([loc, dims, rots[..., np.newaxis]], axis=1)
+ overlap_part = box3d_overlap(
+ gt_boxes, dt_boxes, z_axis=z_axis, z_center=z_center
+ ).astype(np.float64)
+ else:
+ raise ValueError("unknown metric")
+ parted_overlaps.append(overlap_part)
+ example_idx += num_part
+
+ overlaps = []
+ example_idx = 0
+ for j, num_part in enumerate(split_parts):
+ gt_annos_part = gt_annos[example_idx : example_idx + num_part]
+ dt_annos_part = dt_annos[example_idx : example_idx + num_part]
+ gt_num_idx, dt_num_idx = 0, 0
+ for i in range(num_part):
+ gt_box_num = total_gt_num[example_idx + i]
+ dt_box_num = total_dt_num[example_idx + i]
+ overlaps.append(
+ parted_overlaps[j][
+ gt_num_idx : gt_num_idx + gt_box_num,
+ dt_num_idx : dt_num_idx + dt_box_num,
+ ]
+ )
+ gt_num_idx += gt_box_num
+ dt_num_idx += dt_box_num
+ example_idx += num_part
+
+ return overlaps, parted_overlaps, total_gt_num, total_dt_num
+
+
+@numba.jit(nopython=True)
+def compute_statistics_jit(
+ overlaps,
+ gt_datas,
+ dt_datas,
+ ignored_gt,
+ ignored_det,
+ dc_bboxes,
+ metric,
+ min_overlap,
+ thresh=0,
+ compute_fp=False,
+ compute_aos=False,
+):
+
+ det_size = dt_datas.shape[0]
+ gt_size = gt_datas.shape[0]
+ dt_scores = dt_datas[:, -1]
+ dt_alphas = dt_datas[:, 4]
+ gt_alphas = gt_datas[:, 4]
+ dt_bboxes = dt_datas[:, :4]
+ # gt_bboxes = gt_datas[:, :4]
+
+ assigned_detection = [False] * det_size
+ ignored_threshold = [False] * det_size
+ if compute_fp:
+ for i in range(det_size):
+ if dt_scores[i] < thresh:
+ ignored_threshold[i] = True
+ NO_DETECTION = -10000000
+ tp, fp, fn, similarity = 0, 0, 0, 0
+ # thresholds = [0.0]
+ # delta = [0.0]
+ thresholds = np.zeros((gt_size,))
+ thresh_idx = 0
+ delta = np.zeros((gt_size,))
+ delta_idx = 0
+ for i in range(gt_size):
+ if ignored_gt[i] == -1:
+ continue
+ det_idx = -1
+ valid_detection = NO_DETECTION
+ max_overlap = 0
+ assigned_ignored_det = False
+
+ for j in range(det_size):
+ if ignored_det[j] == -1:
+ continue
+ if assigned_detection[j]:
+ continue
+ if ignored_threshold[j]:
+ continue
+ overlap = overlaps[j, i]
+ dt_score = dt_scores[j]
+ if (
+ not compute_fp
+ and (overlap > min_overlap)
+ and dt_score > valid_detection
+ ):
+ det_idx = j
+ valid_detection = dt_score
+ elif (
+ compute_fp
+ and (overlap > min_overlap)
+ and (overlap > max_overlap or assigned_ignored_det)
+ and ignored_det[j] == 0
+ ):
+ max_overlap = overlap
+ det_idx = j
+ valid_detection = 1
+ assigned_ignored_det = False
+ elif (
+ compute_fp
+ and (overlap > min_overlap)
+ and (valid_detection == NO_DETECTION)
+ and ignored_det[j] == 1
+ ):
+ det_idx = j
+ valid_detection = 1
+ assigned_ignored_det = True
+
+ if (valid_detection == NO_DETECTION) and ignored_gt[i] == 0:
+ fn += 1
+ elif (valid_detection != NO_DETECTION) and (
+ ignored_gt[i] == 1 or ignored_det[det_idx] == 1
+ ):
+ assigned_detection[det_idx] = True
+ elif valid_detection != NO_DETECTION:
+ # only a tp add a threshold.
+ tp += 1
+ # thresholds.append(dt_scores[det_idx])
+ thresholds[thresh_idx] = dt_scores[det_idx]
+ thresh_idx += 1
+ if compute_aos:
+ # delta.append(gt_alphas[i] - dt_alphas[det_idx])
+ delta[delta_idx] = gt_alphas[i] - dt_alphas[det_idx]
+ delta_idx += 1
+
+ assigned_detection[det_idx] = True
+ if compute_fp:
+ for i in range(det_size):
+ if not (
+ assigned_detection[i]
+ or ignored_det[i] == -1
+ or ignored_det[i] == 1
+ or ignored_threshold[i]
+ ):
+ fp += 1
+ nstuff = 0
+ if metric == 0:
+ overlaps_dt_dc = image_box_overlap(dt_bboxes, dc_bboxes, 0)
+ for i in range(dc_bboxes.shape[0]):
+ for j in range(det_size):
+ if assigned_detection[j]:
+ continue
+ if ignored_det[j] == -1 or ignored_det[j] == 1:
+ continue
+ if ignored_threshold[j]:
+ continue
+ if overlaps_dt_dc[j, i] > min_overlap:
+ assigned_detection[j] = True
+ nstuff += 1
+ fp -= nstuff
+ if compute_aos:
+ tmp = np.zeros((fp + delta_idx,))
+ # tmp = [0] * fp
+ for i in range(delta_idx):
+ tmp[i + fp] = (1.0 + np.cos(delta[i])) / 2.0
+ # tmp.append((1.0 + np.cos(delta[i])) / 2.0)
+ # assert len(tmp) == fp + tp
+ # assert len(delta) == tp
+ if tp > 0 or fp > 0:
+ similarity = np.sum(tmp)
+ else:
+ similarity = -1
+ return tp, fp, fn, similarity, thresholds[:thresh_idx]
+
+
+@numba.jit(nopython=True)
+def image_box_overlap(boxes, query_boxes, criterion=-1):
+ N = boxes.shape[0]
+ K = query_boxes.shape[0]
+ overlaps = np.zeros((N, K), dtype=boxes.dtype)
+ for k in range(K):
+ qbox_area = (query_boxes[k, 2] - query_boxes[k, 0]) * (
+ query_boxes[k, 3] - query_boxes[k, 1]
+ )
+ for n in range(N):
+ iw = min(boxes[n, 2], query_boxes[k, 2]) - max(
+ boxes[n, 0], query_boxes[k, 0]
+ )
+ if iw > 0:
+ ih = min(boxes[n, 3], query_boxes[k, 3]) - max(
+ boxes[n, 1], query_boxes[k, 1]
+ )
+ if ih > 0:
+ if criterion == -1:
+ ua = (
+ (boxes[n, 2] - boxes[n, 0]) * (boxes[n, 3] - boxes[n, 1])
+ + qbox_area
+ - iw * ih
+ )
+ elif criterion == 0:
+ ua = (boxes[n, 2] - boxes[n, 0]) * (boxes[n, 3] - boxes[n, 1])
+ elif criterion == 1:
+ ua = qbox_area
+ else:
+ ua = 1.0
+ overlaps[n, k] = iw * ih / ua
+ return overlaps
+
+
+def bev_box_overlap(boxes, qboxes, criterion=-1, stable=False):
+ if stable:
+ riou = box_np_ops.riou_cc(boxes, qboxes)
+ else:
+ riou = rotate_iou_gpu_eval(boxes, qboxes, criterion)
+ return riou
+
+
+@numba.jit(nopython=True, parallel=True)
+def box3d_overlap_kernel(boxes, qboxes, rinc, criterion=-1, z_axis=1, z_center=1.0):
+ """
+ z_axis: the z (height) axis.
+ z_center: unified z (height) center of box.
+ """
+ N, K = boxes.shape[0], qboxes.shape[0]
+ for i in range(N):
+ for j in range(K):
+ if rinc[i, j] > 0:
+ min_z = min(
+ boxes[i, z_axis] + boxes[i, z_axis + 3] * (1 - z_center),
+ qboxes[j, z_axis] + qboxes[j, z_axis + 3] * (1 - z_center),
+ )
+ max_z = max(
+ boxes[i, z_axis] - boxes[i, z_axis + 3] * z_center,
+ qboxes[j, z_axis] - qboxes[j, z_axis + 3] * z_center,
+ )
+ iw = min_z - max_z
+ if iw > 0:
+ area1 = boxes[i, 3] * boxes[i, 4] * boxes[i, 5]
+ area2 = qboxes[j, 3] * qboxes[j, 4] * qboxes[j, 5]
+ inc = iw * rinc[i, j]
+ if criterion == -1:
+ ua = area1 + area2 - inc
+ elif criterion == 0:
+ ua = area1
+ elif criterion == 1:
+ ua = area2
+ else:
+ ua = 1.0
+ rinc[i, j] = inc / ua
+ else:
+ rinc[i, j] = 0.0
+
+
+def box3d_overlap(boxes, qboxes, criterion=-1, z_axis=1, z_center=1.0):
+ """kitti camera format z_axis=1.
+ """
+ bev_axes = list(range(7))
+ bev_axes.pop(z_axis + 3)
+ bev_axes.pop(z_axis)
+ rinc = rotate_iou_gpu_eval(boxes[:, bev_axes], qboxes[:, bev_axes], 2)
+ box3d_overlap_kernel(boxes, qboxes, rinc, criterion, z_axis, z_center)
+ return rinc
diff --git a/det3d/datasets/utils/oss.py b/det3d/datasets/utils/oss.py
new file mode 100644
index 0000000..66773b7
--- /dev/null
+++ b/det3d/datasets/utils/oss.py
@@ -0,0 +1,575 @@
+"""\
+This module offser helpers for OSS operation.
+
+Basic Use
+----------
+Create an :class:`OSSPath` object::
+
+ >>> p = OSSPath('s3://mybucket/myprefix/mykey.bin')
+ OSSPath('s3://mybucket/myprefix/mykey.bin')
+ >>> OSSPath() / "mybucket" / "myprefix" / "mykey.bin"
+ OSSPath('s3://mybucket/myprefix/mykey.bin')
+
+
+Querying object properies::
+
+ >>> p.exists()
+ True
+ >>> p.is_dir()
+ False
+ >>> p.is_file()
+ True
+ >>> p.get_size()
+ 256
+
+Access path properties::
+
+ >>> p.bucket
+ "mybucket"
+ >>> p.key
+ "myprefix/mykey.bin"
+ >>> p.name
+ "mykey.bin"
+ >>> p.stem
+ "mykey"
+ >> p.suffix
+ ".bin"
+ >> p.suffixes
+ [".bin"]
+ >>> p.parent
+ OSSPath('s3://mybucket/myprefix')
+ >>> p.root
+ OSSPath('s3://mybucket')
+
+Uploading content to an object::
+
+ >>> p.put(b"some bytes\n")
+ True
+
+Uploading file to an object::
+
+ >>> p.put(open('/path/some/image.jpg', 'rb'))
+
+Reading an object::
+
+ >>> f = p.download()
+ >>> f.read()
+ b"some bytes"
+ >>> p.download(encoding='utf-8')
+ >>> f.read()
+ "some bytes"
+
+Deleting an object::
+
+ >>> p.delete()
+ True
+
+Path manipulations::
+
+ >>> p = OSSPath('s3://mybucket/myprefix/mykey.bin')
+ >>> p.with_name('mykey2.bin')
+ OSSPath("s3://mybucket/myprefix/mykey2.bin")
+ >>> p.with_suffix('.txt')
+ OSSPath("s3://mybucket/myprefix/mykey.txt")
+ >>> p.with_bucket('some_bucket')
+ OSSPath("s3://some_bucket/myprefix/mykey.txt")
+
+ >>> q = p.parent
+ >>> q
+ OSSPath('s3://mybucket/myprefix')
+ >>> q / "subfile.txt"
+ OSSPath("s3://mybucket/myprefix/subfile.txt")
+ >>> q / "subdir" / "subfile.txt"
+ OSSPath("s3://mybucket/myprefix/subdir/subfile.txt")
+ >>> q.joinpath("a", "b", "c")
+ OSSPath('s3://mybucket/myprefix/a/b/c')
+
+Directory-level operations::
+
+ >>> list(q.list_all()) # list all subfiles in all levels
+ >>> list(q.iter_dir()) # list subdirs and subfiles in one-level
+ >>> for root, dirs, files in q.walk(): print(files) # recursively walk through directory
+ >>> q.rmtree() # remove all subkeys of p
+
+
+"""
+import os
+import io
+import codecs
+from typing import Tuple, Iterable, Optional, List
+from pathlib import PosixPath
+from urllib.parse import urlparse, urlunparse
+import re
+import socket
+import boto3
+from botocore.errorfactory import ClientError
+
+
+def get_site():
+ m = re.search(r"([^.]+)\.brainpp\.cn$", socket.getfqdn())
+ if m:
+ return m.group(1)
+
+
+OSS_ENDPOINT = os.getenv(
+ "OSS_ENDPOINT", default="http://oss.{}.brainpp.cn".format(get_site()),
+)
+
+
+class OSSPath:
+
+ __slots__ = ("_client", "bucket", "_key_parts")
+
+ def __new__(cls, s3url: Optional[str] = None, endpoint_url=OSS_ENDPOINT):
+ _client = boto3.client("s3", endpoint_url=endpoint_url)
+ bucket, parts = cls._parse_s3url(s3url)
+ return cls._create(_client, bucket, parts)
+
+ @classmethod
+ def _parse_s3url(cls, s3url: Optional[str] = None):
+ if s3url is None:
+ return "", ()
+
+ if not s3url.startswith("s3://"):
+ raise ValueError(
+ "s3url must be formated as 's3:///path/to/object'"
+ )
+
+ r = urlparse(s3url)
+ assert r.scheme == "s3"
+
+ key = r.path.lstrip("/") # remove the leading /
+
+ parts = PosixPath(key).parts
+ return r.netloc, parts
+
+ @classmethod
+ def _create(cls, client, bucket: str, key_parts: Tuple[str]):
+ assert isinstance(key_parts, tuple)
+ self = object.__new__(cls)
+ self._client = client
+ self.bucket = bucket
+ self._key_parts = key_parts
+ return self
+
+ @property
+ def key(self) -> str:
+ return "/".join(self._key_parts)
+
+ @property
+ def parent(self):
+ """The logical parent of the path."""
+
+ if not len(self._key_parts):
+ return self
+
+ return self._create(self._client, self.bucket, self._key_parts[:-1])
+
+ @property
+ def root(self):
+ return self._create(self._client, self.bucket, key_parts=())
+
+ @property
+ def name(self):
+ if len(self._key_parts) < 1:
+ return ""
+ return self._key_parts[-1]
+
+ @property
+ def suffix(self):
+ """The final component's last suffix, if any."""
+ name = self.name
+ i = name.rfind(".")
+ if 0 < i < len(name) - 1:
+ return name[i:]
+ else:
+ return ""
+
+ @property
+ def suffixes(self):
+ """A list of the final component's suffixes, if any."""
+ name = self.name
+ if name.endswith("."):
+ return []
+ name = name.lstrip(".")
+ return ["." + suffix for suffix in name.split(".")[1:]]
+
+ @property
+ def stem(self):
+ """The final path component, minus its last suffix."""
+ name = self.name
+ i = name.rfind(".")
+ if 0 < i < len(name) - 1:
+ return name[:i]
+ else:
+ return name
+
+ @property
+ def parts(self):
+ """An object providing sequence-like access to the
+ components in the filesystem path."""
+
+ return self._key_parts
+
+ def __str__(self) -> str:
+ return "s3://{}/{}".format(self.bucket, self.key)
+
+ def __eq__(self, other):
+ if not isinstance(other, OSSPath):
+ return False
+ return self.bucket == other.bucket and self.key == other.key
+
+ def __hash__(self):
+ return hash(str(self))
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, str(self))
+
+ def __lt__(self, other):
+ if not isinstance(other, OSSPath):
+ raise NotImplementedError()
+ return str(self) < str(other)
+
+ def __le__(self, other):
+ if not isinstance(other, OSSPath):
+ raise NotImplementedError()
+ return str(self) <= str(other)
+
+ def __gt__(self, other):
+ if not isinstance(other, OSSPath):
+ raise NotImplementedError()
+ return str(self) > str(other)
+
+ def __ge__(self, other):
+ if not isinstance(other, OSSPath):
+ raise NotImplementedError()
+ return str(self) >= str(other)
+
+ def with_name(self, name):
+ """Return a new path with the file name changed."""
+ if not self.name:
+ raise ValueError("%r has an empty name" % (self,))
+
+ r = urlparse(name)
+ if not (r.scheme == "" and r.netloc == "" or "/" in name):
+ raise ValueError("invalid name %r" % (name))
+
+ return self._create(self._client, self.bucket, self._key_parts[:-1] + (name,))
+
+ def with_suffix(self, suffix):
+ """Return a new path with the file suffix changed. If the path
+ has no suffix, add given suffix. If the given suffix is an empty
+ string, remove the suffix from the path.
+ """
+ if "/" in suffix:
+ raise ValueError("Invalid suffix %r" % (suffix,))
+ if suffix and not suffix.startswith(".") or suffix == ".":
+ raise ValueError("Invalid suffix %r" % (suffix))
+ name = self.name
+ if not name:
+ raise ValueError("%r has an empty name" % (self,))
+ old_suffix = self.suffix
+ if not old_suffix:
+ name = name + suffix
+ else:
+ name = name[: -len(old_suffix)] + suffix
+ return self._create(self._client, self.bucket, self._key_parts[:-1] + (name,))
+
+ def with_bucket(self, bucket):
+ if not isinstance(bucket, str):
+ raise ValueError("bucket be string")
+
+ bucket = bucket.strip("/")
+ if not bucket:
+ raise ValueError("bucket must not be empty")
+ if "/" in bucket:
+ raise ValueError("bucket_name must not contain '/'")
+ return self._create(self._client, bucket, self._key_parts)
+
+ def _make_child(self, args: Iterable[str]):
+
+ if not self.bucket:
+ bucket, *rest_args = args
+ bucket = bucket.lstrip("/")
+ bucket, *rest_parts = PosixPath(bucket).parts
+ return self.with_bucket(bucket)._make_child(rest_parts + rest_args)
+
+ parts = [p for p in self._key_parts]
+ for item in args:
+ if not isinstance(item, str):
+ raise ValueError("child must be string")
+ item = item.lstrip("/") # remove leading '/'
+ if not item:
+ raise ValueError("child must not be empty")
+ for p in PosixPath(item).parts:
+ parts.append(p)
+
+ return self._create(self._client, self.bucket, tuple(parts))
+
+ def joinpath(self, *args):
+ """Combine this path with one or several arguments, and return a
+ new path representing either a subpath (if all arguments are relative
+ paths) or a totally different path (if one of the arguments is
+ anchored).
+ """
+ return self._make_child(args)
+
+ def __truediv__(self, key):
+ return self._make_child((key,))
+
+ def __rtruediv__(self, key):
+ raise NotImplemented
+
+ def is_dir(self):
+ if not self.bucket:
+ return False
+
+ if not self.key:
+ # key empty, return whether bucket exists
+ try:
+ self._client.head_bucket(Bucket=self.bucket)
+ return True
+ except ClientError as e:
+ if e.response["Error"]["Code"] == "404":
+ return False
+
+ prefix = self.key
+ if prefix[-1] != "/":
+ prefix = prefix + "/"
+ resp = self._client.list_objects(
+ Bucket=self.bucket, Delimiter="/", Prefix=prefix
+ )
+ return "CommonPrefixes" in resp or "Contents" in resp
+
+ def is_file(self):
+ if not self.bucket:
+ return False
+ if not self.key:
+ return False
+ try:
+ self._client.head_object(Bucket=self.bucket, Key=self.key)
+ return True
+ except ClientError as e:
+ if e.response["Error"]["Code"] == "404":
+ return False
+
+ def exists(self):
+ if not self.bucket:
+ return False
+ if self.is_dir():
+ return True
+ elif self.is_file():
+ return True
+ return False
+
+ def get_size(self):
+ if not self.bucket:
+ return -1
+ if self.is_dir():
+ return 0
+ if not self.is_file():
+ return -1
+
+ key = self.key.lstrip("/")
+ return self._client.head_object(Bucket=self.bucket, Key=key)["ContentLength"]
+
+ def list_all(self, batch_size=1000):
+ """\
+ List all subkeys
+ :returns: Iterator[OSSPath]
+ """
+ if not self.is_dir():
+ return
+
+ if batch_size > 1000:
+ print(
+ "At most 1000 keys can be operated at once. Clipping batch_size to 1000."
+ )
+ batch_size = 1000
+
+ prefix = self.key
+ if prefix[-1] != "/":
+ prefix = prefix + "/"
+
+ marker = None
+ while True:
+ request = dict(
+ Bucket=self.bucket, Delimiter="", Prefix=prefix, MaxKeys=batch_size,
+ )
+ if marker:
+ request["Marker"] = marker
+
+ resp = self._client.list_objects(**request)
+
+ for p in resp.get("Contents", []):
+ yield self.root / p["Key"]
+
+ if not resp["IsTruncated"]:
+ break
+
+ print(
+ "More than {} objects are found under {}, you should avoid putting too many small objects!".format(
+ batch_size, self
+ )
+ )
+ marker = resp["NextMarker"]
+
+ def walk(self, topdown=True, recursive=True, batch_size=1000):
+ """\
+ Generate path tree by walking either top-down or bottom-up just like :func:`os.walk`.
+ For each prefix in the tree, it yields a 3-tuple (subtree-root, subdirs, subfiles).
+
+ If optional argument *topdown* is True or not specified, the triple for a directory
+ is generated before the triples for any subdirectories. If *topdown* is False,
+ the triple for a directory is generated after its subdirectries.
+
+ If *recurisve* is set to False, it only yields the top level subdirectries and subfiles.
+
+ *batch_size* is the maximum keys that OSS returns in one request-response,
+ and it cannot be set larger than 1000.
+ """
+ if not self.is_dir():
+ return
+
+ if batch_size > 1000:
+ print(
+ "At most 1000 keys can be operated at once. Clipping batch_size to 1000."
+ )
+ batch_size = 1000
+
+ prefix = self.key
+ if prefix[-1] != "/":
+ prefix = prefix + "/"
+
+ dirs, files = [], []
+ marker = None
+ while True:
+ request = dict(
+ Bucket=self.bucket, Delimiter="/", Prefix=prefix, MaxKeys=batch_size,
+ )
+ if marker:
+ request["Marker"] = marker
+
+ resp = self._client.list_objects(**request)
+
+ dirs += [self.root / p["Prefix"] for p in resp.get("CommonPrefixes", [])]
+
+ files += [self.root / p["Key"] for p in resp.get("Contents", [])]
+
+ if not resp["IsTruncated"]:
+ break
+
+ print(
+ "More than {} objects are found under {}, you should avoid putting too many small objects!".format(
+ batch_size, self
+ )
+ )
+ marker = resp["NextMarker"]
+
+ if topdown:
+ yield self, dirs, files
+
+ if recursive:
+ for subdir in dirs:
+ yield from subdir.walk(
+ recursive=True, topdown=topdown, batch_size=batch_size
+ )
+
+ if not topdown:
+ yield self, dirs, files
+
+ def iterdir(self, batch_size=1000):
+ """
+ Iterates over self directory, yields subdirs and subfiles.
+ :returns: Iterator[OSSPath]
+ """
+ for root, dirs, files in self.walk(batch_size=batch_size, recursive=False):
+ yield from dirs
+ yield from files
+
+ def download(self, encoding=None) -> Optional[io.IOBase]:
+ """
+ :param encoding: if None, it returns bytes io;
+ if an encoding (such as 'utf-8') is specified, it returns text io
+
+ :returns: file-like object which can be read out
+ """
+
+ if not self.is_file():
+ raise FileNotFoundError("{!r} is not an existing object.".format(self))
+
+ r = self._client.get_object(Bucket=self.bucket, Key=self.key)
+ b = r["Body"]
+ if encoding is not None:
+ b = codecs.getreader(encoding)(b)
+
+ return b
+
+ def put(self, bytes_or_file) -> bool:
+ """
+ :param bytes_or_file: bytes or file-like object to be uploaded to OSS
+ :returns: wheter successfully uploaded
+ """
+ if not self.bucket or not self.key:
+ raise ValueError("Invalid path to put object: {!r}".format(self))
+ if self.key.endswith("/"):
+ raise ValueError('Object key cannot endswith "/": {}'.format(self.key))
+
+ r = self._client.put_object(
+ Body=bytes_or_file, Bucket=self.bucket, Key=self.key,
+ )
+ return r["ResponseMetadata"]["HTTPStatusCode"] == 200
+
+ def delete(self) -> bool:
+ """
+ :returns: whether this object is deleted
+ """
+ if not self.is_file():
+ return True
+ r = self._client.delete_object(Bucket=self.bucket, Key=self.key)
+
+ return r["ResponseMetadata"]["HTTPStatusCode"] == 204
+
+ def rmtree(self, batch_size=1000) -> List[str]:
+ """
+ :returns: list of deleted objects
+ """
+ if not self.is_dir():
+ if self.is_file():
+ raise ValueError("{!r} is not a directory".format(self))
+ return True
+
+ if batch_size > 1000:
+ print(
+ "At most 1000 keys can be operated at once. Clipping batch_size to 1000."
+ )
+ batch_size = 1000
+
+ prefix = self.key
+ if prefix[-1] != "/":
+ prefix = prefix + "/"
+
+ ret = []
+ while True:
+ lr = self._client.list_objects(
+ Bucket=self.bucket, Delimiter="", Prefix=prefix, MaxKeys=batch_size,
+ )
+
+ dr = self._client.delete_objects(
+ Bucket=self.bucket,
+ Delete={"Objects": [{"Key": i["Key"]} for i in lr.get("Contents", [])]},
+ )
+
+ for i in dr["Deleted"]:
+ ret.append("s3://{}/{}".format(self.bucket, i["Key"]))
+
+ if not lr["IsTruncated"]:
+ break
+
+ print(
+ "More than {} objects are found under {}, you should avoid putting too many small objects!".format(
+ batch_size, self
+ )
+ )
+
+ return ret
diff --git a/det3d/datasets/waymo/__init__.py b/det3d/datasets/waymo/__init__.py
new file mode 100644
index 0000000..f710797
--- /dev/null
+++ b/det3d/datasets/waymo/__init__.py
@@ -0,0 +1,4 @@
+from .waymo import WaymoDataset
+from .waymo_common import *
+
+__all__ = ["WaymoDataset"]
diff --git a/det3d/datasets/waymo/waymo.py b/det3d/datasets/waymo/waymo.py
new file mode 100644
index 0000000..f659dba
--- /dev/null
+++ b/det3d/datasets/waymo/waymo.py
@@ -0,0 +1,105 @@
+import sys
+import pickle
+import json
+import random
+import operator
+from numba.cuda.simulator.api import detect
+import numpy as np
+
+from functools import reduce
+from pathlib import Path
+from copy import deepcopy
+
+from det3d.datasets.custom import PointCloudDataset
+
+from det3d.datasets.registry import DATASETS
+
+
+@DATASETS.register_module
+class WaymoDataset(PointCloudDataset):
+ NumPointFeatures = 5 # x, y, z, intensity, elongation
+
+ def __init__(
+ self,
+ info_path,
+ root_path,
+ cfg=None,
+ pipeline=None,
+ class_names=None,
+ test_mode=False,
+ sample=False,
+ nsweeps=1,
+ load_interval=1,
+ **kwargs,
+ ):
+ self.load_interval = load_interval
+ self.sample = sample
+ self.nsweeps = nsweeps
+ print("Using {} sweeps".format(nsweeps))
+ super(WaymoDataset, self).__init__(
+ root_path, info_path, pipeline, test_mode=test_mode, class_names=class_names
+ )
+
+ self._info_path = info_path
+ self._class_names = class_names
+ self._num_point_features = WaymoDataset.NumPointFeatures if nsweeps == 1 else WaymoDataset.NumPointFeatures+1
+
+ def reset(self):
+ assert False
+
+ def load_infos(self, info_path):
+
+ with open(self._info_path, "rb") as f:
+ _waymo_infos_all = pickle.load(f)
+
+ self._waymo_infos = _waymo_infos_all[::self.load_interval]
+
+ print("Using {} Frames".format(len(self._waymo_infos)))
+
+ def __len__(self):
+
+ if not hasattr(self, "_waymo_infos"):
+ self.load_infos(self._info_path)
+
+ return len(self._waymo_infos)
+
+ def get_sensor_data(self, idx):
+ info = self._waymo_infos[idx]
+
+ res = {
+ "lidar": {
+ "type": "lidar",
+ "points": None,
+ "annotations": None,
+ "nsweeps": self.nsweeps,
+ },
+ "metadata": {
+ "image_prefix": self._root_path,
+ "num_point_features": self._num_point_features,
+ "token": info["token"],
+ },
+ "calib": None,
+ "cam": {},
+ "mode": "val" if self.test_mode else "train",
+ "type": "WaymoDataset",
+ }
+
+ data, _ = self.pipeline(res, info)
+
+ return data
+
+ def __getitem__(self, idx):
+ return self.get_sensor_data(idx)
+
+ def evaluation(self, detections, output_dir=None, testset=False):
+ from .waymo_common import _create_pd_detection, reorganize_info
+
+ infos = self._waymo_infos
+ infos = reorganize_info(infos)
+
+ _create_pd_detection(detections, infos, output_dir)
+
+ print("use waymo devkit tool for evaluation")
+
+ return None, None
+
diff --git a/det3d/datasets/waymo/waymo_common.py b/det3d/datasets/waymo/waymo_common.py
new file mode 100644
index 0000000..13bfdd5
--- /dev/null
+++ b/det3d/datasets/waymo/waymo_common.py
@@ -0,0 +1,355 @@
+import os.path as osp
+import numpy as np
+import pickle
+import random
+
+from pathlib import Path
+from functools import reduce
+from typing import Tuple, List
+import os
+import json
+from tqdm import tqdm
+import argparse
+
+from tqdm import tqdm
+try:
+ import tensorflow as tf
+ tf.enable_eager_execution()
+except:
+ print("No Tensorflow")
+
+from nuscenes.utils.geometry_utils import transform_matrix
+from pyquaternion import Quaternion
+
+
+CAT_NAME_TO_ID = {
+ 'VEHICLE': 1,
+ 'PEDESTRIAN': 2,
+ 'SIGN': 3,
+ 'CYCLIST': 4,
+}
+TYPE_LIST = ['UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST']
+
+def get_obj(path):
+ with open(path, 'rb') as f:
+ obj = pickle.load(f)
+ return obj
+
+# ignore sign class
+LABEL_TO_TYPE = {0: 1, 1:2, 2:4}
+
+import uuid
+
+class UUIDGeneration():
+ def __init__(self):
+ self.mapping = {}
+ def get_uuid(self,seed):
+ if seed not in self.mapping:
+ self.mapping[seed] = uuid.uuid4().hex
+ return self.mapping[seed]
+uuid_gen = UUIDGeneration()
+
+def _create_pd_detection(detections, infos, result_path, tracking=False):
+ """Creates a prediction objects file."""
+ from waymo_open_dataset import label_pb2
+ from waymo_open_dataset.protos import metrics_pb2
+
+ objects = metrics_pb2.Objects()
+
+ for token, detection in tqdm(detections.items()):
+ info = infos[token]
+ obj = get_obj(info['anno_path'])
+
+ box3d = detection["box3d_lidar"].detach().cpu().numpy()
+ scores = detection["scores"].detach().cpu().numpy()
+ labels = detection["label_preds"].detach().cpu().numpy()
+
+ # transform back to Waymo coordinate
+ # x,y,z,w,l,h,r2
+ # x,y,z,l,w,h,r1
+ # r2 = -pi/2 - r1
+ box3d[:, -1] = -box3d[:, -1] - np.pi / 2
+ box3d = box3d[:, [0, 1, 2, 4, 3, 5, -1]]
+
+ if tracking:
+ tracking_ids = detection['tracking_ids']
+
+ for i in range(box3d.shape[0]):
+ det = box3d[i]
+ score = scores[i]
+
+ label = labels[i]
+
+ o = metrics_pb2.Object()
+ o.context_name = obj['scene_name']
+ o.frame_timestamp_micros = int(obj['frame_name'].split("_")[-1])
+
+ # Populating box and score.
+ box = label_pb2.Label.Box()
+ box.center_x = det[0]
+ box.center_y = det[1]
+ box.center_z = det[2]
+ box.length = det[3]
+ box.width = det[4]
+ box.height = det[5]
+ box.heading = det[-1]
+ o.object.box.CopyFrom(box)
+ o.score = score
+ # Use correct type.
+ o.object.type = LABEL_TO_TYPE[label]
+
+ if tracking:
+ o.object.id = uuid_gen.get_uuid(int(tracking_ids[i]))
+
+ objects.objects.append(o)
+
+ # Write objects to a file.
+ if tracking:
+ path = os.path.join(result_path, 'tracking_pred.bin')
+ else:
+ path = os.path.join(result_path, 'detection_pred.bin')
+
+ print("results saved to {}".format(path))
+ f = open(path, 'wb')
+ f.write(objects.SerializeToString())
+ f.close()
+
+def _create_gt_detection(infos, tracking=True):
+ """Creates a gt prediction object file for local evaluation."""
+ from waymo_open_dataset import label_pb2
+ from waymo_open_dataset.protos import metrics_pb2
+
+ objects = metrics_pb2.Objects()
+
+ for idx in tqdm(range(len(infos))):
+ info = infos[idx]
+
+ obj = get_obj(info['anno_path'])
+ annos = obj['objects']
+ num_points_in_gt = np.array([ann['num_points'] for ann in annos])
+ box3d = np.array([ann['box'] for ann in annos])
+
+ if len(box3d) == 0:
+ continue
+
+ names = np.array([TYPE_LIST[ann['label']] for ann in annos])
+
+ box3d = box3d[:, [0, 1, 2, 3, 4, 5, -1]]
+
+ for i in range(box3d.shape[0]):
+ if num_points_in_gt[i] == 0:
+ continue
+ if names[i] == 'UNKNOWN':
+ continue
+
+ det = box3d[i]
+ score = 1.0
+ label = names[i]
+
+ o = metrics_pb2.Object()
+ o.context_name = obj['scene_name']
+ o.frame_timestamp_micros = int(obj['frame_name'].split("_")[-1])
+
+ # Populating box and score.
+ box = label_pb2.Label.Box()
+ box.center_x = det[0]
+ box.center_y = det[1]
+ box.center_z = det[2]
+ box.length = det[3]
+ box.width = det[4]
+ box.height = det[5]
+ box.heading = det[-1]
+ o.object.box.CopyFrom(box)
+ o.score = score
+ # Use correct type.
+ o.object.type = CAT_NAME_TO_ID[label]
+ o.object.num_lidar_points_in_box = num_points_in_gt[i]
+ o.object.id = annos[i]['name']
+
+ objects.objects.append(o)
+
+ # Write objects to a file.
+ f = open(os.path.join(args.result_path, 'gt_preds.bin'), 'wb')
+ f.write(objects.SerializeToString())
+ f.close()
+
+def veh_pos_to_transform(veh_pos):
+ "convert vehicle pose to two transformation matrix"
+ rotation = veh_pos[:3, :3]
+ tran = veh_pos[:3, 3]
+
+ global_from_car = transform_matrix(
+ tran, Quaternion(matrix=rotation), inverse=False
+ )
+
+ car_from_global = transform_matrix(
+ tran, Quaternion(matrix=rotation), inverse=True
+ )
+
+ return global_from_car, car_from_global
+
+def _fill_infos(root_path, frames, split='train', nsweeps=1):
+ # load all train infos
+ infos = []
+ for frame_name in tqdm(frames): # global id
+ lidar_path = os.path.join(root_path, split, 'lidar', frame_name)
+ ref_path = os.path.join(root_path, split, 'annos', frame_name)
+
+ ref_obj = get_obj(ref_path)
+ ref_time = 1e-6 * int(ref_obj['frame_name'].split("_")[-1])
+
+ ref_pose = np.reshape(ref_obj['veh_to_global'], [4, 4])
+ _, ref_from_global = veh_pos_to_transform(ref_pose)
+
+ info = {
+ "path": lidar_path,
+ "anno_path": ref_path,
+ "token": frame_name,
+ "timestamp": ref_time,
+ "sweeps": []
+ }
+
+ sequence_id = int(frame_name.split("_")[1])
+ frame_id = int(frame_name.split("_")[3][:-4]) # remove .pkl
+
+ prev_id = frame_id
+ sweeps = []
+ while len(sweeps) < nsweeps - 1:
+ if prev_id <= 0:
+ if len(sweeps) == 0:
+ sweep = {
+ "path": lidar_path,
+ "token": frame_name,
+ "transform_matrix": None,
+ "time_lag": 0
+ }
+ sweeps.append(sweep)
+ else:
+ sweeps.append(sweeps[-1])
+ else:
+ prev_id = prev_id - 1
+ # global identifier
+
+ curr_name = 'seq_{}_frame_{}.pkl'.format(sequence_id, prev_id)
+ curr_lidar_path = os.path.join(root_path, split, 'lidar', curr_name)
+ curr_label_path = os.path.join(root_path, split, 'annos', curr_name)
+
+ curr_obj = get_obj(curr_label_path)
+ curr_pose = np.reshape(curr_obj['veh_to_global'], [4, 4])
+ global_from_car, _ = veh_pos_to_transform(curr_pose)
+
+ tm = reduce(
+ np.dot,
+ [ref_from_global, global_from_car],
+ )
+
+ curr_time = int(curr_obj['frame_name'].split("_")[-1])
+ time_lag = ref_time - 1e-6 * curr_time
+
+ sweep = {
+ "path": curr_lidar_path,
+ "transform_matrix": tm,
+ "time_lag": time_lag,
+ }
+ sweeps.append(sweep)
+
+ info["sweeps"] = sweeps
+
+ if split != 'test':
+ # read boxes
+ TYPE_LIST = ['UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST']
+ annos = ref_obj['objects']
+ num_points_in_gt = np.array([ann['num_points'] for ann in annos])
+ gt_boxes = np.array([ann['box'] for ann in annos]).reshape(-1, 9)
+
+ if len(gt_boxes) != 0:
+ # transform from Waymo to KITTI coordinate
+ # Waymo: x, y, z, length, width, height, rotation from positive x axis clockwisely
+ # KITTI: x, y, z, width, length, height, rotation from negative y axis counterclockwisely
+ gt_boxes[:, -1] = -np.pi / 2 - gt_boxes[:, -1]
+ gt_boxes[:, [3, 4]] = gt_boxes[:, [4, 3]]
+
+ gt_names = np.array([TYPE_LIST[ann['label']] for ann in annos])
+ mask_not_zero = (num_points_in_gt > 0).reshape(-1)
+
+ # filter boxes without lidar points
+ info['gt_boxes'] = gt_boxes[mask_not_zero, :].astype(np.float32)
+ info['gt_names'] = gt_names[mask_not_zero].astype(str)
+
+ infos.append(info)
+ return infos
+
+def sort_frame(frames):
+ indices = []
+
+ for f in frames:
+ seq_id = int(f.split("_")[1])
+ frame_id= int(f.split("_")[3][:-4])
+
+ idx = seq_id * 1000 + frame_id
+ indices.append(idx)
+
+ rank = list(np.argsort(np.array(indices)))
+
+ frames = [frames[r] for r in rank]
+ return frames
+
+def get_available_frames(root, split):
+ dir_path = os.path.join(root, split, 'lidar')
+ available_frames = list(os.listdir(dir_path))
+
+ sorted_frames = sort_frame(available_frames)
+
+ print(split, " split ", "exist frame num:", len(available_frames))
+ return sorted_frames
+
+
+def create_waymo_infos(root_path, split='train', nsweeps=1):
+ frames = get_available_frames(root_path, split)
+
+ waymo_infos = _fill_infos(
+ root_path, frames, split, nsweeps
+ )
+
+ print(
+ f"sample: {len(waymo_infos)}"
+ )
+ with open(
+ os.path.join(root_path, "infos_"+split+"_{:02d}sweeps_filter_zero_gt.pkl".format(nsweeps)), "wb"
+ ) as f:
+ pickle.dump(waymo_infos, f)
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Waymo 3D Extractor")
+ parser.add_argument("--path", type=str, default="data/Waymo/tfrecord_training")
+ parser.add_argument("--info_path", type=str)
+ parser.add_argument("--result_path", type=str)
+ parser.add_argument("--gt", action='store_true' )
+ parser.add_argument("--tracking", action='store_true')
+ args = parser.parse_args()
+ return args
+
+
+def reorganize_info(infos):
+ new_info = {}
+
+ for info in infos:
+ token = info['token']
+ new_info[token] = info
+
+ return new_info
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ with open(args.info_path, 'rb') as f:
+ infos = pickle.load(f)
+
+ if args.gt:
+ _create_gt_detection(infos, tracking=args.tracking)
+ exit()
+
+ infos = reorganize_info(infos)
+ with open(args.path, 'rb') as f:
+ preds = pickle.load(f)
+ _create_pd_detection(preds, infos, args.result_path, tracking=args.tracking)
diff --git a/det3d/datasets/waymo/waymo_converter.py b/det3d/datasets/waymo/waymo_converter.py
new file mode 100644
index 0000000..c239384
--- /dev/null
+++ b/det3d/datasets/waymo/waymo_converter.py
@@ -0,0 +1,71 @@
+"""Tool to convert Waymo Open Dataset to pickle files.
+ Adapted from https://github.com/WangYueFt/pillar-od
+ # Copyright (c) Massachusetts Institute of Technology and its affiliates.
+ # Licensed under MIT License
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import glob, argparse, tqdm, pickle, os
+
+import waymo_decoder
+import tensorflow.compat.v2 as tf
+from waymo_open_dataset import dataset_pb2
+
+from multiprocessing import Pool
+
+tf.enable_v2_behavior()
+
+fnames = None
+LIDAR_PATH = None
+ANNO_PATH = None
+
+def convert(idx):
+ global fnames
+ fname = fnames[idx]
+ dataset = tf.data.TFRecordDataset(fname, compression_type='')
+ for frame_id, data in enumerate(dataset):
+ frame = dataset_pb2.Frame()
+ frame.ParseFromString(bytearray(data.numpy()))
+ decoded_frame = waymo_decoder.decode_frame(frame, frame_id)
+ decoded_annos = waymo_decoder.decode_annos(frame, frame_id)
+
+ with open(os.path.join(LIDAR_PATH, 'seq_{}_frame_{}.pkl'.format(idx, frame_id)), 'wb') as f:
+ pickle.dump(decoded_frame, f)
+
+ with open(os.path.join(ANNO_PATH, 'seq_{}_frame_{}.pkl'.format(idx, frame_id)), 'wb') as f:
+ pickle.dump(decoded_annos, f)
+
+
+def main(args):
+ global fnames
+ fnames = sorted(list(glob.glob(args.record_path)))
+
+ print("Number of files {}".format(len(fnames)))
+
+ with Pool(64) as p: # change according to your cpu
+ r = list(tqdm.tqdm(p.imap(convert, range(len(fnames))), total=len(fnames)))
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description='Waymo Data Converter')
+ parser.add_argument('--root_path', type=str, required=True)
+ parser.add_argument('--record_path', type=str, required=True)
+
+ args = parser.parse_args()
+
+ if not os.path.isdir(args.root_path):
+ os.mkdir(args.root_path)
+
+ LIDAR_PATH = os.path.join(args.root_path, 'lidar')
+ ANNO_PATH = os.path.join(args.root_path, 'annos')
+
+ if not os.path.isdir(LIDAR_PATH):
+ os.mkdir(LIDAR_PATH)
+
+ if not os.path.isdir(ANNO_PATH):
+ os.mkdir(ANNO_PATH)
+
+ main(args)
diff --git a/det3d/datasets/waymo/waymo_decoder.py b/det3d/datasets/waymo/waymo_decoder.py
new file mode 100644
index 0000000..3255546
--- /dev/null
+++ b/det3d/datasets/waymo/waymo_decoder.py
@@ -0,0 +1,207 @@
+"""Waymo open dataset decoder.
+ Taken from https://github.com/WangYueFt/pillar-od
+ # Copyright (c) Massachusetts Institute of Technology and its affiliates.
+ # Licensed under MIT License
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import zlib
+import numpy as np
+
+import tensorflow.compat.v2 as tf
+from pyquaternion import Quaternion
+
+from waymo_open_dataset import dataset_pb2
+from waymo_open_dataset.utils import range_image_utils
+from waymo_open_dataset.utils import transform_utils
+tf.enable_v2_behavior()
+
+def decode_frame(frame, frame_id):
+ """Decodes native waymo Frame proto to tf.Examples."""
+
+ lidars = extract_points(frame.lasers,
+ frame.context.laser_calibrations,
+ frame.pose)
+
+ frame_name = '{scene_name}_{location}_{time_of_day}_{timestamp}'.format(
+ scene_name=frame.context.name,
+ location=frame.context.stats.location,
+ time_of_day=frame.context.stats.time_of_day,
+ timestamp=frame.timestamp_micros)
+
+ example_data = {
+ 'scene_name': frame.context.name,
+ 'frame_name': frame_name,
+ 'frame_id': frame_id,
+ 'lidars': lidars,
+ }
+
+ return example_data
+ # return encode_tf_example(example_data, FEATURE_SPEC)
+
+def decode_annos(frame, frame_id):
+ """Decodes some meta data (e.g. calibration matrices, frame matrices)."""
+
+ veh_to_global = np.array(frame.pose.transform)
+
+ ref_pose = np.reshape(np.array(frame.pose.transform), [4, 4])
+ global_from_ref_rotation = ref_pose[:3, :3]
+ objects = extract_objects(frame.laser_labels, global_from_ref_rotation)
+
+ frame_name = '{scene_name}_{location}_{time_of_day}_{timestamp}'.format(
+ scene_name=frame.context.name,
+ location=frame.context.stats.location,
+ time_of_day=frame.context.stats.time_of_day,
+ timestamp=frame.timestamp_micros)
+
+ annos = {
+ 'scene_name': frame.context.name,
+ 'frame_name': frame_name,
+ 'frame_id': frame_id,
+ 'veh_to_global': veh_to_global,
+ 'objects': objects,
+ }
+
+ return annos
+
+
+def extract_points_from_range_image(laser, calibration, frame_pose):
+ """Decode points from lidar."""
+ if laser.name != calibration.name:
+ raise ValueError('Laser and calibration do not match')
+ if laser.name == dataset_pb2.LaserName.TOP:
+ frame_pose = tf.convert_to_tensor(
+ np.reshape(np.array(frame_pose.transform), [4, 4]))
+ range_image_top_pose = dataset_pb2.MatrixFloat.FromString(
+ zlib.decompress(laser.ri_return1.range_image_pose_compressed))
+ # [H, W, 6]
+ range_image_top_pose_tensor = tf.reshape(
+ tf.convert_to_tensor(range_image_top_pose.data),
+ range_image_top_pose.shape.dims)
+ # [H, W, 3, 3]
+ range_image_top_pose_tensor_rotation = transform_utils.get_rotation_matrix(
+ range_image_top_pose_tensor[..., 0],
+ range_image_top_pose_tensor[..., 1], range_image_top_pose_tensor[...,
+ 2])
+ range_image_top_pose_tensor_translation = range_image_top_pose_tensor[...,
+ 3:]
+ range_image_top_pose_tensor = transform_utils.get_transform(
+ range_image_top_pose_tensor_rotation,
+ range_image_top_pose_tensor_translation)
+ frame_pose = tf.expand_dims(frame_pose, axis=0)
+ pixel_pose = tf.expand_dims(range_image_top_pose_tensor, axis=0)
+ else:
+ pixel_pose = None
+ frame_pose = None
+ first_return = zlib.decompress(
+ laser.ri_return1.range_image_compressed)
+ second_return = zlib.decompress(
+ laser.ri_return2.range_image_compressed)
+ points_list = []
+ for range_image_str in [first_return, second_return]:
+ range_image = dataset_pb2.MatrixFloat.FromString(range_image_str)
+ if not calibration.beam_inclinations:
+ beam_inclinations = range_image_utils.compute_inclination(
+ tf.constant([
+ calibration.beam_inclination_min, calibration.beam_inclination_max
+ ]),
+ height=range_image.shape.dims[0])
+ else:
+ beam_inclinations = tf.constant(calibration.beam_inclinations)
+ beam_inclinations = tf.reverse(beam_inclinations, axis=[-1])
+ extrinsic = np.reshape(np.array(calibration.extrinsic.transform), [4, 4])
+ range_image_tensor = tf.reshape(
+ tf.convert_to_tensor(range_image.data), range_image.shape.dims)
+ range_image_mask = range_image_tensor[..., 0] > 0
+ range_image_cartesian = (
+ range_image_utils.extract_point_cloud_from_range_image(
+ tf.expand_dims(range_image_tensor[..., 0], axis=0),
+ tf.expand_dims(extrinsic, axis=0),
+ tf.expand_dims(tf.convert_to_tensor(beam_inclinations), axis=0),
+ pixel_pose=pixel_pose,
+ frame_pose=frame_pose))
+ range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0)
+ points_tensor = tf.gather_nd(
+ tf.concat([range_image_cartesian, range_image_tensor[..., 1:4]],
+ axis=-1),
+ tf.where(range_image_mask))
+ points_list.append(points_tensor.numpy())
+ return points_list
+
+
+def extract_points(lasers, laser_calibrations, frame_pose):
+ """Extract point clouds."""
+ sort_lambda = lambda x: x.name
+ lasers_with_calibration = zip(
+ sorted(lasers, key=sort_lambda),
+ sorted(laser_calibrations, key=sort_lambda))
+ points_xyz = []
+ points_feature = []
+ points_nlz = []
+ for laser, calibration in lasers_with_calibration:
+ points_list = extract_points_from_range_image(laser, calibration,
+ frame_pose)
+ points = np.concatenate(points_list, axis=0)
+ points_xyz.extend(points[..., :3].astype(np.float32))
+ points_feature.extend(points[..., 3:5].astype(np.float32))
+ points_nlz.extend(points[..., 5].astype(np.float32))
+ return {
+ 'points_xyz': np.asarray(points_xyz),
+ 'points_feature': np.asarray(points_feature),
+ }
+
+def global_vel_to_ref(vel, global_from_ref_rotation):
+ # inverse means ref_from_global, rotation_matrix for normalization
+ vel = [vel[0], vel[1], 0]
+ ref = np.dot(Quaternion(matrix=global_from_ref_rotation).inverse.rotation_matrix, vel)
+ ref = [ref[0], ref[1], 0.0]
+
+ return ref
+
+def extract_objects(laser_labels, global_from_ref_rotation):
+ """Extract objects."""
+ objects = []
+ for object_id, label in enumerate(laser_labels):
+ category_label = label.type
+ box = label.box
+
+ speed = [label.metadata.speed_x, label.metadata.speed_y]
+ accel = [label.metadata.accel_x, label.metadata.accel_y]
+ num_lidar_points_in_box = label.num_lidar_points_in_box
+ # Difficulty level is 0 if labeler did not say this was LEVEL_2.
+ # Set difficulty level of "999" for boxes with no points in box.
+ if num_lidar_points_in_box <= 0:
+ combined_difficulty_level = 999
+ if label.detection_difficulty_level == 0:
+ # Use points in box to compute difficulty level.
+ if num_lidar_points_in_box >= 5:
+ combined_difficulty_level = 1
+ else:
+ combined_difficulty_level = 2
+ else:
+ combined_difficulty_level = label.detection_difficulty_level
+
+ ref_velocity = global_vel_to_ref(speed, global_from_ref_rotation)
+
+ objects.append({
+ 'id': object_id,
+ 'name': label.id,
+ 'label': category_label,
+ 'box': np.array([box.center_x, box.center_y, box.center_z,
+ box.length, box.width, box.height, ref_velocity[0],
+ ref_velocity[1], box.heading], dtype=np.float32),
+ 'num_points':
+ num_lidar_points_in_box,
+ 'detection_difficulty_level':
+ label.detection_difficulty_level,
+ 'combined_difficulty_level':
+ combined_difficulty_level,
+ 'global_speed':
+ np.array(speed, dtype=np.float32),
+ 'global_accel':
+ np.array(accel, dtype=np.float32),
+ })
+ return objects
diff --git a/det3d/models/__init__.py b/det3d/models/__init__.py
new file mode 100644
index 0000000..d24d502
--- /dev/null
+++ b/det3d/models/__init__.py
@@ -0,0 +1,43 @@
+import importlib
+spconv_spec = importlib.util.find_spec("spconv")
+found = spconv_spec is not None
+if found:
+ from .backbones import * # noqa: F401,F403
+else:
+ print("No spconv, sparse convolution disabled!")
+from .bbox_heads import * # noqa: F401,F403
+from .builder import (
+ build_backbone,
+ build_detector,
+ build_head,
+ build_loss,
+ build_neck,
+ build_roi_head
+)
+from .detectors import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .readers import *
+from .registry import (
+ BACKBONES,
+ DETECTORS,
+ HEADS,
+ LOSSES,
+ NECKS,
+ READERS,
+)
+from .second_stage import *
+from .roi_heads import *
+
+__all__ = [
+ "READERS",
+ "BACKBONES",
+ "NECKS",
+ "HEADS",
+ "LOSSES",
+ "DETECTORS",
+ "build_backbone",
+ "build_neck",
+ "build_head",
+ "build_loss",
+ "build_detector",
+]
diff --git a/det3d/models/backbones/__init__.py b/det3d/models/backbones/__init__.py
new file mode 100644
index 0000000..b50cbd9
--- /dev/null
+++ b/det3d/models/backbones/__init__.py
@@ -0,0 +1,9 @@
+import importlib
+spconv_spec = importlib.util.find_spec("spconv")
+found = spconv_spec is not None
+
+if found:
+ from .scn import SpMiddleResNetFHD
+else:
+ print("No spconv, sparse convolution disabled!")
+
diff --git a/det3d/models/backbones/scn.py b/det3d/models/backbones/scn.py
new file mode 100644
index 0000000..e5b0716
--- /dev/null
+++ b/det3d/models/backbones/scn.py
@@ -0,0 +1,258 @@
+import numpy as np
+import spconv
+from spconv import SparseConv3d, SubMConv3d
+from torch import nn
+from torch.nn import functional as F
+
+from ..registry import BACKBONES
+from ..utils import build_norm_layer
+
+
+def conv3x3(in_planes, out_planes, stride=1, indice_key=None, bias=True):
+ """3x3 convolution with padding"""
+ return spconv.SubMConv3d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=1,
+ bias=bias,
+ indice_key=indice_key,
+ )
+
+
+def conv1x1(in_planes, out_planes, stride=1, indice_key=None, bias=True):
+ """1x1 convolution"""
+ return spconv.SubMConv3d(
+ in_planes,
+ out_planes,
+ kernel_size=1,
+ stride=stride,
+ padding=1,
+ bias=bias,
+ indice_key=indice_key,
+ )
+
+
+class SparseBasicBlock(spconv.SparseModule):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ norm_cfg=None,
+ downsample=None,
+ indice_key=None,
+ ):
+ super(SparseBasicBlock, self).__init__()
+
+ if norm_cfg is None:
+ norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
+
+ bias = norm_cfg is not None
+
+ self.conv1 = conv3x3(inplanes, planes, stride, indice_key=indice_key, bias=bias)
+ self.bn1 = build_norm_layer(norm_cfg, planes)[1]
+ self.relu = nn.ReLU()
+ self.conv2 = conv3x3(planes, planes, indice_key=indice_key, bias=bias)
+ self.bn2 = build_norm_layer(norm_cfg, planes)[1]
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ identity = x
+
+ out = self.conv1(x)
+ out.features = self.bn1(out.features)
+ out.features = self.relu(out.features)
+
+ out = self.conv2(out)
+ out.features = self.bn2(out.features)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out.features += identity.features
+ out.features = self.relu(out.features)
+
+ return out
+
+
+@BACKBONES.register_module
+class SpMiddleResNetFHD(nn.Module):
+ def __init__(
+ self, num_input_features=128, norm_cfg=None, name="SpMiddleResNetFHD", init_channel=16, **kwargs
+ ):
+ super(SpMiddleResNetFHD, self).__init__()
+ self.name = name
+
+ self.dcn = None
+ self.zero_init_residual = False
+
+ if norm_cfg is None:
+ norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
+
+ # input: # [1600, 1200, 41]
+ self.conv_input = spconv.SparseSequential(
+ SubMConv3d(num_input_features, init_channel, 3, bias=False, indice_key="res0"),
+ build_norm_layer(norm_cfg, init_channel)[1],
+ nn.ReLU(inplace=True)
+ )
+
+ self.conv1 = spconv.SparseSequential(
+ SparseBasicBlock(init_channel, init_channel, norm_cfg=norm_cfg, indice_key="res0"),
+ SparseBasicBlock(init_channel, init_channel, norm_cfg=norm_cfg, indice_key="res0"),
+ )
+
+ self.conv2 = spconv.SparseSequential(
+ SparseConv3d(
+ init_channel, 2*init_channel, 3, 2, padding=1, bias=False
+ ), # [1600, 1200, 41] -> [800, 600, 21]
+ build_norm_layer(norm_cfg, 2*init_channel)[1],
+ nn.ReLU(inplace=True),
+ SparseBasicBlock(2*init_channel, 2*init_channel, norm_cfg=norm_cfg, indice_key="res1"),
+ SparseBasicBlock(2*init_channel, 2*init_channel, norm_cfg=norm_cfg, indice_key="res1"),
+ )
+
+ self.conv3 = spconv.SparseSequential(
+ SparseConv3d(
+ 2*init_channel, 4*init_channel, 3, 2, padding=1, bias=False
+ ), # [800, 600, 21] -> [400, 300, 11]
+ build_norm_layer(norm_cfg, 4*init_channel)[1],
+ nn.ReLU(inplace=True),
+ SparseBasicBlock(4*init_channel, 4*init_channel, norm_cfg=norm_cfg, indice_key="res2"),
+ SparseBasicBlock(4*init_channel, 4*init_channel, norm_cfg=norm_cfg, indice_key="res2"),
+ )
+
+ self.conv4 = spconv.SparseSequential(
+ SparseConv3d(
+ 4*init_channel, 8*init_channel, 3, 2, padding=[0, 1, 1], bias=False
+ ), # [400, 300, 11] -> [200, 150, 5]
+ build_norm_layer(norm_cfg, 8*init_channel)[1],
+ nn.ReLU(inplace=True),
+ SparseBasicBlock(8*init_channel, 8*init_channel, norm_cfg=norm_cfg, indice_key="res3"),
+ SparseBasicBlock(8*init_channel, 8*init_channel, norm_cfg=norm_cfg, indice_key="res3"),
+ )
+
+
+ self.extra_conv = spconv.SparseSequential(
+ SparseConv3d(
+ 8*init_channel, 8*init_channel, (3, 1, 1), (2, 1, 1), bias=False
+ ), # [200, 150, 5] -> [200, 150, 2]
+ build_norm_layer(norm_cfg, 8*init_channel)[1],
+ nn.ReLU(),
+ )
+
+ def forward(self, voxel_features, coors, batch_size, input_shape):
+
+ # input: # [41, 1600, 1408]
+ sparse_shape = np.array(input_shape[::-1]) + [1, 0, 0]
+
+ coors = coors.int()
+ ret = spconv.SparseConvTensor(voxel_features, coors, sparse_shape, batch_size)
+
+ x = self.conv_input(ret)
+
+ x_conv1 = self.conv1(x)
+ x_conv2 = self.conv2(x_conv1)
+ x_conv3 = self.conv3(x_conv2)
+ x_conv4 = self.conv4(x_conv3)
+
+ ret = self.extra_conv(x_conv4)
+
+ ret = ret.dense()
+
+ N, C, D, H, W = ret.shape
+ ret = ret.view(N, C * D, H, W)
+
+ multi_scale_voxel_features = {
+ 'conv1': x_conv1,
+ 'conv2': x_conv2,
+ 'conv3': x_conv3,
+ 'conv4': x_conv4,
+ }
+
+ return ret, multi_scale_voxel_features
+
+@BACKBONES.register_module
+class SpMiddleFHD(nn.Module):
+ def __init__(
+ self, num_input_features=128, norm_cfg=None, name="SpMiddleFHD", **kwargs
+ ):
+ super(SpMiddleFHD, self).__init__()
+ self.name = name
+
+ self.dcn = None
+ self.zero_init_residual = False
+
+ if norm_cfg is None:
+ norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
+
+ self.middle_conv = spconv.SparseSequential(
+ SubMConv3d(num_input_features, 16, 3, bias=False, indice_key="subm0"),
+ build_norm_layer(norm_cfg, 16)[1],
+ nn.ReLU(),
+ SubMConv3d(16, 16, 3, bias=False, indice_key="subm0"),
+ build_norm_layer(norm_cfg, 16)[1],
+ nn.ReLU(),
+ SparseConv3d(
+ 16, 32, 3, 2, padding=1, bias=False
+ ), # [1600, 1200, 41] -> [800, 600, 21]
+ build_norm_layer(norm_cfg, 32)[1],
+ nn.ReLU(),
+ SubMConv3d(32, 32, 3, indice_key="subm1", bias=False),
+ build_norm_layer(norm_cfg, 32)[1],
+ nn.ReLU(),
+ SubMConv3d(32, 32, 3, indice_key="subm1", bias=False),
+ build_norm_layer(norm_cfg, 32)[1],
+ nn.ReLU(),
+ SparseConv3d(
+ 32, 64, 3, 2, padding=1, bias=False
+ ), # [800, 600, 21] -> [400, 300, 11]
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SubMConv3d(64, 64, 3, indice_key="subm2", bias=False),
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SubMConv3d(64, 64, 3, indice_key="subm2", bias=False),
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SubMConv3d(64, 64, 3, indice_key="subm2", bias=False),
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SparseConv3d(
+ 64, 64, 3, 2, padding=[0, 1, 1], bias=False
+ ), # [400, 300, 11] -> [200, 150, 5]
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SubMConv3d(64, 64, 3, indice_key="subm3", bias=False),
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SubMConv3d(64, 64, 3, indice_key="subm3", bias=False),
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SubMConv3d(64, 64, 3, indice_key="subm3", bias=False),
+ build_norm_layer(norm_cfg, 64)[1],
+ nn.ReLU(),
+ SparseConv3d(
+ 64, 128, (3, 1, 1), (2, 1, 1), bias=False
+ ), # [200, 150, 5] -> [200, 150, 2]
+ build_norm_layer(norm_cfg, 128)[1],
+ nn.ReLU(),
+ )
+
+ def forward(self, voxel_features, coors, batch_size, input_shape):
+ # input: # [41, 1600, 1408]
+ sparse_shape = np.array(input_shape[::-1]) + [1, 0, 0]
+ coors = coors.int()
+
+ ret = spconv.SparseConvTensor(voxel_features, coors, sparse_shape, batch_size)
+ ret = self.middle_conv(ret)
+ ret = ret.dense()
+
+ N, C, D, H, W = ret.shape
+ ret = ret.view(N, C * D, H, W)
+
+ return ret, None
\ No newline at end of file
diff --git a/det3d/models/bbox_heads/__init__.py b/det3d/models/bbox_heads/__init__.py
new file mode 100644
index 0000000..fc2fc32
--- /dev/null
+++ b/det3d/models/bbox_heads/__init__.py
@@ -0,0 +1,8 @@
+from .center_head import CenterHead
+from .center_head_iou import CenterHeadIoU
+from .center_head_iou_1d import CenterHeadIoU_1d
+
+__all__ = [
+ "CenterHead",
+ "CenterHeadIoU",
+ "CenterHeadIoU_1d",]
diff --git a/det3d/models/bbox_heads/center_head.py b/det3d/models/bbox_heads/center_head.py
new file mode 100644
index 0000000..e379c85
--- /dev/null
+++ b/det3d/models/bbox_heads/center_head.py
@@ -0,0 +1,544 @@
+# ------------------------------------------------------------------------------
+# Portions of this code are from
+# det3d (https://github.com/poodarchu/Det3D/tree/56402d4761a5b73acd23080f537599b0888cce07)
+# Copyright (c) 2019 朱本金
+# Licensed under the MIT License
+# ------------------------------------------------------------------------------
+
+import logging
+from collections import defaultdict
+from det3d.core import box_torch_ops
+import torch
+from det3d.torchie.cnn import kaiming_init
+from torch import double, nn
+from det3d.models.losses.centernet_loss import FastFocalLoss, RegLoss
+from det3d.models.utils import Sequential
+from ..registry import HEADS
+from ...ops.iou3d_nms.iou3d_nms_utils import boxes_iou3d_gpu
+import copy
+try:
+ from det3d.ops.dcn import DeformConv
+except:
+ print("Deformable Convolution not built!")
+
+from det3d.core.utils.circle_nms_jit import circle_nms
+
+class FeatureAdaption(nn.Module):
+ """Feature Adaption Module.
+
+ Feature Adaption Module is implemented based on DCN v1.
+ It uses anchor shape prediction rather than feature map to
+ predict offsets of deformable conv layer.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ out_channels (int): Number of channels in the output feature map.
+ kernel_size (int): Deformable conv kernel size.
+ deformable_groups (int): Deformable conv group size.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ deformable_groups=4):
+ super(FeatureAdaption, self).__init__()
+ offset_channels = kernel_size * kernel_size * 2
+ self.conv_offset = nn.Conv2d(
+ in_channels, deformable_groups * offset_channels, 1, bias=True)
+ self.conv_adaption = DeformConv(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ deformable_groups=deformable_groups)
+ self.relu = nn.ReLU(inplace=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+
+ def forward(self, x,):
+ offset = self.conv_offset(x)
+ x = self.relu(self.conv_adaption(x, offset))
+ return x
+
+class SepHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ heads,
+ head_conv=64,
+ final_kernel=1,
+ bn=False,
+ init_bias=-2.19,
+ **kwargs,
+ ):
+ super(SepHead, self).__init__(**kwargs)
+
+ self.heads = heads
+ for head in self.heads:
+ classes, num_conv = self.heads[head]
+
+ fc = Sequential()
+ for i in range(num_conv-1):
+ fc.add(nn.Conv2d(in_channels, head_conv,
+ kernel_size=final_kernel, stride=1,
+ padding=final_kernel // 2, bias=True))
+ if bn:
+ fc.add(nn.BatchNorm2d(head_conv))
+ fc.add(nn.ReLU())
+
+ fc.add(nn.Conv2d(head_conv, classes,
+ kernel_size=final_kernel, stride=1,
+ padding=final_kernel // 2, bias=True))
+
+ if 'hm' in head:
+ fc[-1].bias.data.fill_(init_bias)
+ else:
+ for m in fc.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+
+ self.__setattr__(head, fc)
+
+
+ def forward(self, x):
+ ret_dict = dict()
+ for head in self.heads:
+ ret_dict[head] = self.__getattr__(head)(x)
+
+ return ret_dict
+
+class DCNSepHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_cls,
+ heads,
+ head_conv=64,
+ final_kernel=1,
+ bn=False,
+ init_bias=-2.19,
+ **kwargs,
+ ):
+ super(DCNSepHead, self).__init__(**kwargs)
+
+ # feature adaptation with dcn
+ # use separate features for classification / regression
+ self.feature_adapt_cls = FeatureAdaption(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ deformable_groups=4)
+
+ self.feature_adapt_reg = FeatureAdaption(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ deformable_groups=4)
+
+ # heatmap prediction head
+ self.cls_head = Sequential(
+ nn.Conv2d(in_channels, head_conv,
+ kernel_size=3, padding=1, bias=True),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_conv, num_cls,
+ kernel_size=3, stride=1,
+ padding=1, bias=True)
+ )
+ self.cls_head[-1].bias.data.fill_(init_bias)
+
+ # other regression target
+ self.task_head = SepHead(in_channels, heads, head_conv=head_conv, bn=bn, final_kernel=final_kernel)
+
+
+ def forward(self, x):
+ center_feat = self.feature_adapt_cls(x)
+ reg_feat = self.feature_adapt_reg(x)
+
+ cls_score = self.cls_head(center_feat)
+ ret = self.task_head(reg_feat)
+ ret['hm'] = cls_score
+
+ return ret
+
+
+@HEADS.register_module
+class CenterHead(nn.Module):
+ def __init__(
+ self,
+ in_channels=[128,],
+ tasks=[],
+ dataset='nuscenes',
+ weight=0.25,
+ code_weights=[],
+ common_heads=dict(),
+ logger=None,
+ init_bias=-2.19,
+ share_conv_channel=64,
+ num_hm_conv=2,
+ dcn_head=False,
+ gt_nms=False,
+ ):
+ super(CenterHead, self).__init__()
+
+ num_classes = [len(t["class_names"]) for t in tasks]
+ self.class_names = [t["class_names"] for t in tasks]
+ self.code_weights = code_weights
+ self.weight = weight # weight between hm loss and loc loss
+ self.dataset = dataset
+
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+
+ self.crit = FastFocalLoss()
+ self.crit_reg = RegLoss()
+
+ self.box_n_dim = 9 if 'vel' in common_heads else 7
+ self.use_direction_classifier = False
+
+ self.gt_nms = gt_nms
+ if self.gt_nms:
+ print("Use gt nms!")
+
+ if not logger:
+ logger = logging.getLogger("CenterHead")
+ self.logger = logger
+
+ logger.info(
+ f"num_classes: {num_classes}"
+ )
+
+ # a shared convolution
+ self.shared_conv = nn.Sequential(
+ nn.Conv2d(in_channels, share_conv_channel,
+ kernel_size=3, padding=1, bias=True),
+ nn.BatchNorm2d(share_conv_channel),
+ nn.ReLU(inplace=True)
+ )
+
+ self.tasks = nn.ModuleList()
+ print("Use HM Bias: ", init_bias)
+
+ if dcn_head:
+ print("Use Deformable Convolution in the CenterHead!")
+
+ for num_cls in num_classes:
+ heads = copy.deepcopy(common_heads)
+ if not dcn_head:
+ heads.update(dict(hm=(num_cls, num_hm_conv)))
+ self.tasks.append(
+ SepHead(share_conv_channel, heads, bn=True, init_bias=init_bias, final_kernel=3)
+ )
+ else:
+ self.tasks.append(
+ DCNSepHead(share_conv_channel, num_cls, heads, bn=True, init_bias=init_bias, final_kernel=3)
+ )
+
+ logger.info("Finish CenterHead Initialization")
+
+ def forward(self, x, *kwargs):
+ ret_dicts = []
+
+ x = self.shared_conv(x.float())
+
+ for task in self.tasks:
+ ret_dicts.append(task(x))
+
+ return ret_dicts
+
+ def _sigmoid(self, x):
+ y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
+ return y
+
+ def loss(self, example, preds_dicts, test_cfg, **kwargs):
+ rets = []
+ for task_id, preds_dict in enumerate(preds_dicts):
+ # heatmap focal loss
+ preds_dict['hm'] = self._sigmoid(preds_dict['hm'])
+
+ hm_loss = self.crit(preds_dict['hm'], example['hm'][task_id], example['ind'][task_id], example['mask'][task_id], example['cat'][task_id])
+
+ target_box = example['anno_box'][task_id]
+ # reconstruct the anno_box from multiple reg heads
+ if self.dataset in ['waymo', 'nuscenes']:
+ if 'vel' in preds_dict:
+ preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
+ preds_dict['vel'], preds_dict['rot']), dim=1)
+ else:
+ preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
+ preds_dict['rot']), dim=1)
+ target_box = target_box[..., [0, 1, 2, 3, 4, 5, -2, -1]] # remove vel target
+ else:
+ raise NotImplementedError()
+
+ ret = {}
+
+ # Regression loss for dimension, offset, height, rotation
+ box_loss = self.crit_reg(preds_dict['anno_box'], example['mask'][task_id], example['ind'][task_id], target_box)
+
+ loc_loss = (box_loss*box_loss.new_tensor(self.code_weights)).sum()
+
+ loss = hm_loss + self.weight*loc_loss
+
+ ret.update({'loss': loss, 'hm_loss': hm_loss.detach().cpu(), 'loc_loss':loc_loss, 'loc_loss_elem': box_loss.detach().cpu(), 'num_positive': example['mask'][task_id].float().sum()})
+
+ rets.append(ret)
+
+ """convert batch-key to key-batch
+ """
+ rets_merged = defaultdict(list)
+ for ret in rets:
+ for k, v in ret.items():
+ rets_merged[k].append(v)
+
+ return rets_merged
+
+ @torch.no_grad()
+ def predict(self, example, preds_dicts, test_cfg, **kwargs):
+ """decode, nms, then return the detection result. Additionaly support double flip testing
+ """
+ # get loss info
+ rets = []
+ metas = []
+
+ double_flip = test_cfg.get('double_flip', False)
+
+ post_center_range = test_cfg.post_center_limit_range
+ if len(post_center_range) > 0:
+ post_center_range = torch.tensor(
+ post_center_range,
+ dtype=preds_dicts[0]['hm'].dtype,
+ device=preds_dicts[0]['hm'].device,
+ )
+
+ for task_id, preds_dict in enumerate(preds_dicts):
+ # convert N C H W to N H W C
+ for key, val in preds_dict.items():
+ preds_dict[key] = val.permute(0, 2, 3, 1).contiguous()
+
+ batch_size = preds_dict['hm'].shape[0]
+
+ if double_flip:
+ assert batch_size % 4 == 0, print(batch_size)
+ batch_size = int(batch_size / 4)
+ for k in preds_dict.keys():
+ # transform the prediction map back to their original coordinate befor flipping
+ # the flipped predictions are ordered in a group of 4. The first one is the original pointcloud
+ # the second one is X flip pointcloud(y=-y), the third one is Y flip pointcloud(x=-x), and the last one is
+ # X and Y flip pointcloud(x=-x, y=-y).
+ # Also please note that pytorch's flip function is defined on higher dimensional space, so dims=[2] means that
+ # it is flipping along the axis with H length(which is normaly the Y axis), however in our traditional word, it is flipping along
+ # the X axis. The below flip follows pytorch's definition yflip(y=-y) xflip(x=-x)
+ _, H, W, C = preds_dict[k].shape
+ preds_dict[k] = preds_dict[k].reshape(int(batch_size), 4, H, W, C)
+ preds_dict[k][:, 1] = torch.flip(preds_dict[k][:, 1], dims=[1])
+ preds_dict[k][:, 2] = torch.flip(preds_dict[k][:, 2], dims=[2])
+ preds_dict[k][:, 3] = torch.flip(preds_dict[k][:, 3], dims=[1, 2])
+
+ if "metadata" not in example or len(example["metadata"]) == 0:
+ meta_list = [None] * batch_size
+ else:
+ meta_list = example["metadata"]
+ if double_flip:
+ meta_list = meta_list[:4*int(batch_size):4]
+
+ batch_hm = torch.sigmoid(preds_dict['hm'])
+
+ batch_dim = torch.exp(preds_dict['dim'])
+
+ batch_rots = preds_dict['rot'][..., 0:1]
+ batch_rotc = preds_dict['rot'][..., 1:2]
+ batch_reg = preds_dict['reg']
+ batch_hei = preds_dict['height']
+
+ if double_flip:
+ batch_hm = batch_hm.mean(dim=1)
+ batch_hei = batch_hei.mean(dim=1)
+ batch_dim = batch_dim.mean(dim=1)
+
+ # y = -y reg_y = 1-reg_y
+ batch_reg[:, 1, ..., 1] = 1 - batch_reg[:, 1, ..., 1]
+ batch_reg[:, 2, ..., 0] = 1 - batch_reg[:, 2, ..., 0]
+
+ batch_reg[:, 3, ..., 0] = 1 - batch_reg[:, 3, ..., 0]
+ batch_reg[:, 3, ..., 1] = 1 - batch_reg[:, 3, ..., 1]
+ batch_reg = batch_reg.mean(dim=1)
+
+ # first yflip
+ # y = -y theta = pi -theta
+ # sin(pi-theta) = sin(theta) cos(pi-theta) = -cos(theta)
+ # batch_rots[:, 1] the same
+ batch_rotc[:, 1] *= -1
+
+ # then xflip x = -x theta = 2pi - theta
+ # sin(2pi - theta) = -sin(theta) cos(2pi - theta) = cos(theta)
+ # batch_rots[:, 2] the same
+ batch_rots[:, 2] *= -1
+
+ # double flip
+ batch_rots[:, 3] *= -1
+ batch_rotc[:, 3] *= -1
+
+ batch_rotc = batch_rotc.mean(dim=1)
+ batch_rots = batch_rots.mean(dim=1)
+
+ batch_rot = torch.atan2(batch_rots, batch_rotc)
+
+ batch, H, W, num_cls = batch_hm.size()
+
+ batch_reg = batch_reg.reshape(batch, H*W, 2)
+ batch_hei = batch_hei.reshape(batch, H*W, 1)
+
+ batch_rot = batch_rot.reshape(batch, H*W, 1)
+ batch_dim = batch_dim.reshape(batch, H*W, 3)
+ batch_hm = batch_hm.reshape(batch, H*W, num_cls)
+
+ ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
+ ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_hm)
+ xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_hm)
+
+ xs = xs.view(batch, -1, 1) + batch_reg[:, :, 0:1]
+ ys = ys.view(batch, -1, 1) + batch_reg[:, :, 1:2]
+
+ xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[0] + test_cfg.pc_range[0]
+ ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[1] + test_cfg.pc_range[1]
+
+ if 'vel' in preds_dict:
+ batch_vel = preds_dict['vel']
+
+ if double_flip:
+ # flip vy
+ batch_vel[:, 1, ..., 1] *= -1
+ # flip vx
+ batch_vel[:, 2, ..., 0] *= -1
+
+ batch_vel[:, 3] *= -1
+
+ batch_vel = batch_vel.mean(dim=1)
+
+ batch_vel = batch_vel.reshape(batch, H*W, 2)
+ batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2)
+ else:
+ batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_rot], dim=2)
+
+ metas.append(meta_list)
+
+ if test_cfg.get('per_class_nms', False):
+ pass
+ else:
+ rets.append(self.post_processing(example, batch_box_preds, batch_hm, test_cfg, post_center_range, task_id))
+
+ # Merge branches results
+ ret_list = []
+ num_samples = len(rets[0])
+
+ ret_list = []
+ for i in range(num_samples):
+ ret = {}
+ for k in rets[0][i].keys():
+ if k in ["box3d_lidar", "scores","selected_box_mask"]:
+ ret[k] = torch.cat([ret[i][k] for ret in rets])
+ elif k in ["label_preds"]:
+ flag = 0
+ for j, num_class in enumerate(self.num_classes):
+ rets[j][i][k] += flag
+ flag += num_class
+ ret[k] = torch.cat([ret[i][k] for ret in rets])
+
+ ret['metadata'] = metas[0][i]
+ ret_list.append(ret)
+
+ return ret_list
+
+ @torch.no_grad()
+ def post_processing(self, example, batch_box_preds, batch_hm, test_cfg, post_center_range, task_id):
+ batch_size = len(batch_hm)
+
+ prediction_dicts = []
+ for i in range(batch_size):
+ box_preds = batch_box_preds[i]
+ hm_preds = batch_hm[i]
+
+ scores, labels = torch.max(hm_preds, dim=-1)
+
+ score_mask = scores > test_cfg.score_threshold
+ distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) \
+ & (box_preds[..., :3] <= post_center_range[3:]).all(1)
+
+ mask = distance_mask & score_mask
+
+ # # gt test
+ # cur_gt = example['gt_boxes_and_cls'][i].to(box_preds)
+ # cur_gt = cur_gt[cur_gt[:,-1]>0]
+ # if cur_gt.shape[0]>0:
+ # iou3d = boxes_iou3d_gpu(box_preds, cur_gt[:, :7]) # (M, N)
+ # match_gt_mask = torch.any(iou3d>0,dim=1)
+ # gt_iou, gt_idx= torch.max(iou3d, dim=1)
+ # gt_label = cur_gt[gt_idx,-1]
+ # gt_box = cur_gt[gt_idx,:7]
+ # mask = mask & match_gt_mask
+
+ # # use gt as the prediction
+ # scores = gt_iou
+ # labels = (gt_label-1).to(labels)
+ # box_preds = gt_box
+
+ box_preds = box_preds[mask]
+ scores = scores[mask]
+ labels = labels[mask]
+
+ boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]
+
+ if test_cfg.get('circular_nms', False):
+ centers = boxes_for_nms[:, [0, 1]]
+ boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
+ selected = _circle_nms(boxes, min_radius=test_cfg.min_radius[task_id], post_max_size=test_cfg.nms.nms_post_max_size)
+ elif self.gt_nms:
+ order = scores.sort(0, descending=True)[1]
+ order = order[:500]
+ selected = order
+
+ if 'gt_boxes_and_cls' in example:
+ # use best match for each gt box
+ cur_gt = example['gt_boxes_and_cls'][i].to(boxes_for_nms)
+ iou3d = boxes_iou3d_gpu(boxes_for_nms[order], cur_gt[:, :7]) # (M, N)
+ # max_overlaps, gt_assignment = torch.max(iou3d, dim=1)
+ match_gt_mask = torch.any(iou3d>0,dim=0)
+ max_overlaps, best_matched_pred_id = torch.max(iou3d, dim=0)
+ selected_box_mask = torch.zeros_like(order)
+ selected_box_mask[best_matched_pred_id[match_gt_mask]] = 1
+ else:
+ selected = box_torch_ops.rotate_nms_pcdet(boxes_for_nms.float(), scores.float(),
+ thresh=test_cfg.nms.nms_iou_threshold,
+ pre_maxsize=test_cfg.nms.nms_pre_max_size,
+ post_max_size=test_cfg.nms.nms_post_max_size)
+
+ selected_boxes = box_preds[selected]
+ selected_scores = scores[selected]
+ selected_labels = labels[selected]
+
+ prediction_dict = {
+ 'box3d_lidar': selected_boxes,
+ 'scores': selected_scores,
+ 'label_preds': selected_labels
+ }
+ if self.gt_nms and 'gt_boxes_and_cls' in example:
+ prediction_dict['selected_box_mask']=selected_box_mask
+
+ prediction_dicts.append(prediction_dict)
+
+ return prediction_dicts
+
+import numpy as np
+def _circle_nms(boxes, min_radius, post_max_size=83):
+ """
+ NMS according to center distance
+ """
+ keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]
+
+ keep = torch.from_numpy(keep).long().to(boxes.device)
+
+ return keep
\ No newline at end of file
diff --git a/det3d/models/bbox_heads/center_head_iou.py b/det3d/models/bbox_heads/center_head_iou.py
new file mode 100644
index 0000000..5f1b91d
--- /dev/null
+++ b/det3d/models/bbox_heads/center_head_iou.py
@@ -0,0 +1,626 @@
+# ------------------------------------------------------------------------------
+# Portions of this code are from
+# det3d (https://github.com/poodarchu/Det3D/tree/56402d4761a5b73acd23080f537599b0888cce07)
+# Copyright (c) 2019 朱本金
+# Licensed under the MIT License
+# ------------------------------------------------------------------------------
+
+import logging
+from collections import defaultdict
+from det3d.core import box_torch_ops
+import torch
+from det3d.torchie.cnn import kaiming_init
+from torch import double, nn
+from det3d.models.losses.centernet_loss import FastFocalLoss, RegLoss
+from det3d.models.utils import Sequential
+from ..registry import HEADS
+from ...ops.iou3d_nms.iou3d_nms_utils import boxes_iou3d_gpu
+from det3d.core.utils.center_utils import _transpose_and_gather_feat
+import copy
+try:
+ from det3d.ops.dcn import DeformConv
+except:
+ print("Deformable Convolution not built!")
+
+from det3d.core.utils.circle_nms_jit import circle_nms
+
+class FeatureAdaption(nn.Module):
+ """Feature Adaption Module.
+
+ Feature Adaption Module is implemented based on DCN v1.
+ It uses anchor shape prediction rather than feature map to
+ predict offsets of deformable conv layer.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ out_channels (int): Number of channels in the output feature map.
+ kernel_size (int): Deformable conv kernel size.
+ deformable_groups (int): Deformable conv group size.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ deformable_groups=4):
+ super(FeatureAdaption, self).__init__()
+ offset_channels = kernel_size * kernel_size * 2
+ self.conv_offset = nn.Conv2d(
+ in_channels, deformable_groups * offset_channels, 1, bias=True)
+ self.conv_adaption = DeformConv(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ deformable_groups=deformable_groups)
+ self.relu = nn.ReLU(inplace=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+
+ def forward(self, x,):
+ offset = self.conv_offset(x)
+ x = self.relu(self.conv_adaption(x, offset))
+ return x
+
+class SepHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ heads,
+ head_conv=64,
+ final_kernel=1,
+ bn=False,
+ init_bias=-2.19,
+ **kwargs,
+ ):
+ super(SepHead, self).__init__(**kwargs)
+
+ self.heads = heads
+ for head in self.heads:
+ classes, num_conv = self.heads[head]
+
+ fc = Sequential()
+ for i in range(num_conv-1):
+ fc.add(nn.Conv2d(in_channels, head_conv,
+ kernel_size=final_kernel, stride=1,
+ padding=final_kernel // 2, bias=True))
+ if bn:
+ fc.add(nn.BatchNorm2d(head_conv))
+ fc.add(nn.ReLU())
+
+ fc.add(nn.Conv2d(head_conv, classes,
+ kernel_size=final_kernel, stride=1,
+ padding=final_kernel // 2, bias=True))
+
+ if 'hm' in head:
+ fc[-1].bias.data.fill_(init_bias)
+ else:
+ for m in fc.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+
+ self.__setattr__(head, fc)
+
+
+ def forward(self, x):
+ ret_dict = dict()
+ for head in self.heads:
+ ret_dict[head] = self.__getattr__(head)(x)
+
+ return ret_dict
+
+class DCNSepHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ num_cls,
+ heads,
+ head_conv=64,
+ final_kernel=1,
+ bn=False,
+ init_bias=-2.19,
+ **kwargs,
+ ):
+ super(DCNSepHead, self).__init__(**kwargs)
+
+ # feature adaptation with dcn
+ # use separate features for classification / regression
+ self.feature_adapt_cls = FeatureAdaption(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ deformable_groups=4)
+
+ self.feature_adapt_reg = FeatureAdaption(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ deformable_groups=4)
+
+ # heatmap prediction head
+ self.cls_head = Sequential(
+ nn.Conv2d(in_channels, head_conv,
+ kernel_size=3, padding=1, bias=True),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(head_conv, num_cls,
+ kernel_size=3, stride=1,
+ padding=1, bias=True)
+ )
+ self.cls_head[-1].bias.data.fill_(init_bias)
+
+ # other regression target
+ self.task_head = SepHead(in_channels, heads, head_conv=head_conv, bn=bn, final_kernel=final_kernel)
+
+
+ def forward(self, x):
+ center_feat = self.feature_adapt_cls(x)
+ reg_feat = self.feature_adapt_reg(x)
+
+ cls_score = self.cls_head(center_feat)
+ ret = self.task_head(reg_feat)
+ ret['hm'] = cls_score
+
+ return ret
+
+
+@HEADS.register_module
+class CenterHeadIoU(nn.Module):
+ def __init__(
+ self,
+ in_channels=[128,],
+ tasks=[],
+ dataset='nuscenes',
+ weight=0.25,
+ code_weights=[],
+ common_heads=dict(),
+ logger=None,
+ init_bias=-2.19,
+ share_conv_channel=64,
+ num_hm_conv=2,
+ dcn_head=False,
+ corner_loss=False,
+ iou_loss=False,
+ gt_nms=False,
+ iou_factor=[1,1,4]
+ ):
+ super(CenterHeadIoU, self).__init__()
+
+ num_classes = [len(t["class_names"]) for t in tasks]
+ self.class_names = [t["class_names"] for t in tasks]
+ self.code_weights = code_weights
+ self.weight = weight # weight between hm loss and loc loss
+ self.dataset = dataset
+
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.use_corner_loss = corner_loss
+ self.use_iou_loss = iou_loss
+ self.iou_factor = iou_factor
+
+ self.crit = FastFocalLoss()
+ self.crit_reg = RegLoss()
+ if self.use_iou_loss:
+ self.crit_iou = torch.nn.SmoothL1Loss(reduction = 'none')
+ if self.use_corner_loss:
+ self.corner_crit = torch.nn.MSELoss(reduction = 'none')
+
+ self.box_n_dim = 9 if 'vel' in common_heads else 7
+ self.use_direction_classifier = False
+
+ self.gt_nms = gt_nms
+ if self.gt_nms:
+ print("Use gt nms!")
+
+ if not logger:
+ logger = logging.getLogger("CenterHeadIoU")
+ self.logger = logger
+
+ logger.info(
+ f"num_classes: {num_classes}"
+ )
+
+ # a shared convolution
+ self.shared_conv = nn.Sequential(
+ nn.Conv2d(in_channels, share_conv_channel,
+ kernel_size=3, padding=1, bias=True),
+ nn.BatchNorm2d(share_conv_channel),
+ nn.ReLU(inplace=True)
+ )
+
+ self.tasks = nn.ModuleList()
+ print("Use HM Bias: ", init_bias)
+
+ if dcn_head:
+ print("Use Deformable Convolution in the CenterHead!")
+
+ for num_cls in num_classes:
+ heads = copy.deepcopy(common_heads)
+ if not dcn_head:
+ heads.update(dict(hm=(num_cls, num_hm_conv)))
+ if self.use_corner_loss:
+ heads.update(dict(corner=(1, num_hm_conv)))
+ self.tasks.append(
+ SepHead(share_conv_channel, heads, bn=True, init_bias=init_bias, final_kernel=3)
+ )
+ else:
+ self.tasks.append(
+ DCNSepHead(share_conv_channel, num_cls, heads, bn=True, init_bias=init_bias, final_kernel=3)
+ )
+
+ logger.info("Finish CenterHeadIoU Initialization")
+
+ def forward(self, x, *kwargs):
+ ret_dicts = []
+
+ x = self.shared_conv(x.float())
+
+ for task in self.tasks:
+ ret_dicts.append(task(x))
+
+ return ret_dicts
+
+ def _sigmoid(self, x):
+ y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4)
+ return y
+
+ def loss(self, example, preds_dicts, test_cfg, **kwargs):
+ rets = []
+ for task_id, preds_dict in enumerate(preds_dicts):
+ # heatmap focal loss
+ preds_dict['hm'] = self._sigmoid(preds_dict['hm'])
+
+ hm_loss = self.crit(preds_dict['hm'], example['hm'][task_id], example['ind'][task_id], example['mask'][task_id], example['cat'][task_id])
+
+ if self.use_corner_loss:
+ corner_loss = self.corner_crit(preds_dict['corner'], example['corners'][task_id])
+ corner_mask = (example['corners'][task_id]>0).to(corner_loss)
+ corner_loss = (corner_loss * corner_mask).sum()/(corner_mask.sum() + 1e-4)
+
+ target_box = example['anno_box'][task_id]
+ # reconstruct the anno_box from multiple reg heads
+ if self.dataset in ['waymo', 'nuscenes']:
+ if 'vel' in preds_dict:
+ preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
+ preds_dict['vel'], preds_dict['rot']), dim=1)
+ else:
+ preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
+ preds_dict['rot']), dim=1)
+ target_box = target_box[..., [0, 1, 2, 3, 4, 5, -2, -1]] # remove vel target
+ else:
+ raise NotImplementedError()
+
+ # IoU loss
+ if self.use_iou_loss:
+ with torch.no_grad():
+ preds_dict['iou'] = _transpose_and_gather_feat(preds_dict['iou'],example['ind'][task_id]) # B M 1
+ preds_box = get_box(preds_dict['anno_box'],example['ind'][task_id],test_cfg)
+ cur_gt = example['gt_boxes_and_cls'].to(preds_box)
+ gt_iou = []
+ for i in range(preds_dict['iou'].shape[0]):
+ iou3d = boxes_iou3d_gpu(preds_box[i], cur_gt[i, :, :7]) # (M, N)
+ gt_iou.append(torch.max(iou3d, dim=1)[0])
+ gt_iou = torch.stack(gt_iou,dim=0)
+ gt_iou = 2 * gt_iou - 1
+
+ iou_loss = self.crit_iou(preds_dict['iou'].squeeze(2),gt_iou) * example['mask'][task_id].float()
+ iou_loss = iou_loss.sum()/(example['mask'][task_id].float().sum()+1e-4)
+
+ ret = {}
+
+ # Regression loss for dimension, offset, height, rotation
+ box_loss = self.crit_reg(preds_dict['anno_box'], example['mask'][task_id], example['ind'][task_id], target_box)
+
+ loc_loss = (box_loss*box_loss.new_tensor(self.code_weights)).sum()
+
+ loss = hm_loss + self.weight*loc_loss
+
+ if self.use_iou_loss:
+ loss = loss + iou_loss
+ if self.use_corner_loss:
+ loss = loss + corner_loss
+
+ ret.update({'loss': loss, 'hm_loss': hm_loss.detach().cpu(),'loc_loss':loc_loss, 'loc_loss_elem': box_loss.detach().cpu(), 'num_positive': example['mask'][task_id].float().sum()})
+ if self.use_iou_loss:
+ ret.update({'iou_loss':iou_loss.detach().cpu()})
+ if self.use_corner_loss:
+ ret.update({'corner_loss':corner_loss.detach().cpu()})
+
+ rets.append(ret)
+
+ """convert batch-key to key-batch
+ """
+ rets_merged = defaultdict(list)
+ for ret in rets:
+ for k, v in ret.items():
+ rets_merged[k].append(v)
+
+ return rets_merged
+
+ @torch.no_grad()
+ def predict(self, example, preds_dicts, test_cfg, **kwargs):
+ """decode, nms, then return the detection result. Additionaly support double flip testing
+ """
+ # get loss info
+ rets = []
+ metas = []
+
+ double_flip = test_cfg.get('double_flip', False)
+
+ post_center_range = test_cfg.post_center_limit_range
+ if len(post_center_range) > 0:
+ post_center_range = torch.tensor(
+ post_center_range,
+ dtype=preds_dicts[0]['hm'].dtype,
+ device=preds_dicts[0]['hm'].device,
+ )
+
+ for task_id, preds_dict in enumerate(preds_dicts):
+
+ # preds_dict['anno_box'] = torch.cat((preds_dict['reg'], preds_dict['height'], preds_dict['dim'],
+ # preds_dict['rot']), dim=1)
+ # preds_box = get_box(preds_dict['anno_box'],example['ind'][task_id],test_cfg)
+ # cur_gt = example['gt_boxes_and_cls'].to(preds_box)
+ # gt_iou = []
+ # for i in range(preds_dict['iou'].shape[0]):
+ # iou3d = boxes_iou3d_gpu(preds_box[i], cur_gt[i, :, :7]) # (M, N)
+ # gt_iou.append(torch.max(iou3d, dim=1)[0])
+ # gt_iou = torch.stack(gt_iou,dim=0)
+
+ # convert N C H W to N H W C
+ for key, val in preds_dict.items():
+ preds_dict[key] = val.permute(0, 2, 3, 1).contiguous()
+
+ batch_size = preds_dict['hm'].shape[0]
+
+ if double_flip:
+ assert batch_size % 4 == 0, print(batch_size)
+ batch_size = int(batch_size / 4)
+ for k in preds_dict.keys():
+ # transform the prediction map back to their original coordinate befor flipping
+ # the flipped predictions are ordered in a group of 4. The first one is the original pointcloud
+ # the second one is X flip pointcloud(y=-y), the third one is Y flip pointcloud(x=-x), and the last one is
+ # X and Y flip pointcloud(x=-x, y=-y).
+ # Also please note that pytorch's flip function is defined on higher dimensional space, so dims=[2] means that
+ # it is flipping along the axis with H length(which is normaly the Y axis), however in our traditional word, it is flipping along
+ # the X axis. The below flip follows pytorch's definition yflip(y=-y) xflip(x=-x)
+ _, H, W, C = preds_dict[k].shape
+ preds_dict[k] = preds_dict[k].reshape(int(batch_size), 4, H, W, C)
+ preds_dict[k][:, 1] = torch.flip(preds_dict[k][:, 1], dims=[1])
+ preds_dict[k][:, 2] = torch.flip(preds_dict[k][:, 2], dims=[2])
+ preds_dict[k][:, 3] = torch.flip(preds_dict[k][:, 3], dims=[1, 2])
+
+ if "metadata" not in example or len(example["metadata"]) == 0:
+ meta_list = [None] * batch_size
+ else:
+ meta_list = example["metadata"]
+ if double_flip:
+ meta_list = meta_list[:4*int(batch_size):4]
+
+ batch_hm = torch.sigmoid(preds_dict['hm'])
+ if self.use_iou_loss:
+ batch_iou = preds_dict['iou']
+ else:
+ batch_iou = None
+
+ batch_dim = torch.exp(preds_dict['dim'])
+
+ batch_rots = preds_dict['rot'][..., 0:1]
+ batch_rotc = preds_dict['rot'][..., 1:2]
+ batch_reg = preds_dict['reg']
+ batch_hei = preds_dict['height']
+
+ if double_flip:
+ batch_hm = batch_hm.mean(dim=1)
+ batch_hei = batch_hei.mean(dim=1)
+ batch_dim = batch_dim.mean(dim=1)
+
+ # y = -y reg_y = 1-reg_y
+ batch_reg[:, 1, ..., 1] = 1 - batch_reg[:, 1, ..., 1]
+ batch_reg[:, 2, ..., 0] = 1 - batch_reg[:, 2, ..., 0]
+
+ batch_reg[:, 3, ..., 0] = 1 - batch_reg[:, 3, ..., 0]
+ batch_reg[:, 3, ..., 1] = 1 - batch_reg[:, 3, ..., 1]
+ batch_reg = batch_reg.mean(dim=1)
+
+ # first yflip
+ # y = -y theta = pi -theta
+ # sin(pi-theta) = sin(theta) cos(pi-theta) = -cos(theta)
+ # batch_rots[:, 1] the same
+ batch_rotc[:, 1] *= -1
+
+ # then xflip x = -x theta = 2pi - theta
+ # sin(2pi - theta) = -sin(theta) cos(2pi - theta) = cos(theta)
+ # batch_rots[:, 2] the same
+ batch_rots[:, 2] *= -1
+
+ # double flip
+ batch_rots[:, 3] *= -1
+ batch_rotc[:, 3] *= -1
+
+ batch_rotc = batch_rotc.mean(dim=1)
+ batch_rots = batch_rots.mean(dim=1)
+
+ batch_rot = torch.atan2(batch_rots, batch_rotc)
+
+ batch, H, W, num_cls = batch_hm.size()
+
+ batch_reg = batch_reg.reshape(batch, H*W, 2)
+ batch_hei = batch_hei.reshape(batch, H*W, 1)
+
+ batch_rot = batch_rot.reshape(batch, H*W, 1)
+ batch_dim = batch_dim.reshape(batch, H*W, 3)
+ batch_hm = batch_hm.reshape(batch, H*W, num_cls)
+ if self.use_iou_loss:
+ batch_iou = (batch_iou.reshape(batch, H*W) + 1) * 0.5
+ batch_iou = torch.clamp(batch_iou,min=0.,max=1.)
+
+ ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
+ ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_hm)
+ xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_hm)
+
+ xs = xs.view(batch, -1, 1) + batch_reg[:, :, 0:1]
+ ys = ys.view(batch, -1, 1) + batch_reg[:, :, 1:2]
+
+ xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[0] + test_cfg.pc_range[0]
+ ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[1] + test_cfg.pc_range[1]
+
+ if 'vel' in preds_dict:
+ batch_vel = preds_dict['vel']
+
+ if double_flip:
+ # flip vy
+ batch_vel[:, 1, ..., 1] *= -1
+ # flip vx
+ batch_vel[:, 2, ..., 0] *= -1
+
+ batch_vel[:, 3] *= -1
+
+ batch_vel = batch_vel.mean(dim=1)
+
+ batch_vel = batch_vel.reshape(batch, H*W, 2)
+ batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2)
+ else:
+ batch_box_preds = torch.cat([xs, ys, batch_hei, batch_dim, batch_rot], dim=2)
+
+ metas.append(meta_list)
+
+ if test_cfg.get('per_class_nms', False):
+ pass
+ else:
+ rets.append(self.post_processing(example, batch_box_preds, batch_hm, batch_iou, test_cfg, post_center_range, task_id))
+
+ # Merge branches results
+ ret_list = []
+ num_samples = len(rets[0])
+
+ ret_list = []
+ for i in range(num_samples):
+ ret = {}
+ for k in rets[0][i].keys():
+ if k in ["box3d_lidar", "scores","selected_box_mask","gt_scores","selected"]:
+ ret[k] = torch.cat([ret[i][k] for ret in rets])
+ elif k in ["label_preds"]:
+ flag = 0
+ for j, num_class in enumerate(self.num_classes):
+ rets[j][i][k] += flag
+ flag += num_class
+ ret[k] = torch.cat([ret[i][k] for ret in rets])
+
+ ret['metadata'] = metas[0][i]
+ ret_list.append(ret)
+
+ return ret_list
+
+ @torch.no_grad()
+ def post_processing(self, example, batch_box_preds, batch_hm, batch_iou, test_cfg, post_center_range, task_id):
+ batch_size = len(batch_hm)
+
+ prediction_dicts = []
+ for i in range(batch_size):
+ box_preds = batch_box_preds[i]
+ hm_preds = batch_hm[i]
+
+ scores, labels = torch.max(hm_preds, dim=-1)
+
+ score_mask = scores > test_cfg.score_threshold
+ distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) \
+ & (box_preds[..., :3] <= post_center_range[3:]).all(1)
+
+ mask = distance_mask & score_mask
+
+ box_preds = box_preds[mask]
+ scores = scores[mask]
+ # ious = batch_iou[i][mask]
+ labels = labels[mask]
+ # print('ious', ious.cpu().numpy())
+
+ if self.use_iou_loss:
+ iou_factor = torch.LongTensor(self.iou_factor).to(labels)
+ ious = batch_iou[i][mask]
+
+ ious = torch.pow(ious, iou_factor[labels])
+ scores = scores * ious
+
+ boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]
+
+ # cur_gt = example['gt_boxes_and_cls'][i].to(boxes_for_nms)
+ # iou3d = boxes_iou3d_gpu(boxes_for_nms, cur_gt[:, :7]) # (M, N)
+ # gt_scores, _ = torch.max(iou3d, dim=1)
+
+ if test_cfg.get('circular_nms', False):
+ centers = boxes_for_nms[:, [0, 1]]
+ boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
+ selected = _circle_nms(boxes, min_radius=test_cfg.min_radius[task_id], post_max_size=test_cfg.nms.nms_post_max_size)
+ elif self.gt_nms:
+ order = scores.sort(0, descending=True)[1]
+ order = order[:500]
+ selected = order
+
+ if 'gt_boxes_and_cls' in example:
+ # use best match for each gt box
+ cur_gt = example['gt_boxes_and_cls'][i].to(boxes_for_nms)
+ iou3d = boxes_iou3d_gpu(boxes_for_nms[order], cur_gt[:, :7]) # (M, N)
+ # max_overlaps, gt_assignment = torch.max(iou3d, dim=1)
+ match_gt_mask = torch.any(iou3d>0,dim=0)
+ max_overlaps, best_matched_pred_id = torch.max(iou3d, dim=0)
+ selected_box_mask = torch.zeros_like(order)
+ selected_box_mask[best_matched_pred_id[match_gt_mask]] = 1
+ else:
+ selected = box_torch_ops.rotate_nms_pcdet(boxes_for_nms.float(), scores.float(),
+ thresh=test_cfg.nms.nms_iou_threshold,
+ pre_maxsize=test_cfg.nms.nms_pre_max_size,
+ post_max_size=test_cfg.nms.nms_post_max_size)
+
+ selected_boxes = box_preds[selected]
+ selected_scores = scores[selected]
+ selected_labels = labels[selected]
+
+ prediction_dict = {
+ 'box3d_lidar': selected_boxes,
+ 'scores': selected_scores,
+ 'label_preds': selected_labels,
+ # 'selected': selected,
+ # 'gt_scores': gt_scores
+ }
+ if self.gt_nms and 'gt_boxes_and_cls' in example:
+ prediction_dict['selected_box_mask']=selected_box_mask
+
+ prediction_dicts.append(prediction_dict)
+
+ return prediction_dicts
+
+import numpy as np
+def _circle_nms(boxes, min_radius, post_max_size=83):
+ """
+ NMS according to center distance
+ """
+ keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]
+
+ keep = torch.from_numpy(keep).long().to(boxes.device)
+
+ return keep
+
+def get_box(pred_boxs, mask, test_cfg):
+ batch,_ , H, W= pred_boxs.size()
+ ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
+ ys = ys.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs)
+ xs = xs.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs)
+
+ xs = xs.view(batch, 1, H, W) + pred_boxs[:, 0:1]
+ ys = ys.view(batch, 1, H, W) + pred_boxs[:, 1:2]
+
+ xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[0] + test_cfg.pc_range[0]
+ ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[1] + test_cfg.pc_range[1]
+
+ rot = torch.atan2(pred_boxs[:, 6:7], pred_boxs[:, 7:8])
+ pred = torch.cat([xs, ys, pred_boxs[:,2:3], torch.exp(pred_boxs[:,3:6]), rot], dim=1)
+
+ return _transpose_and_gather_feat(pred,mask) # B M 7
+
+
+
+
+
diff --git a/det3d/models/bbox_heads/center_head_iou_1d.py b/det3d/models/bbox_heads/center_head_iou_1d.py
new file mode 100644
index 0000000..b3d67e5
--- /dev/null
+++ b/det3d/models/bbox_heads/center_head_iou_1d.py
@@ -0,0 +1,618 @@
+# ------------------------------------------------------------------------------
+# Portions of this code are from
+# det3d (https://github.com/poodarchu/Det3D/tree/56402d4761a5b73acd23080f537599b0888cce07)
+# Copyright (c) 2019 朱本金
+# Licensed under the MIT License
+# ------------------------------------------------------------------------------
+
+import logging
+from collections import defaultdict
+from det3d.core import box_torch_ops
+import torch
+from det3d.torchie.cnn import kaiming_init
+from torch import double, nn
+from det3d.models.losses.centernet_loss import FastFocalLoss, RegLoss, SegLoss
+from det3d.models.utils import Sequential
+from ..registry import HEADS
+from ...ops.iou3d_nms.iou3d_nms_utils import boxes_iou3d_gpu
+from ..utils import build_norm_layer
+from det3d.core.utils.center_utils import _transpose_and_gather_feat
+import copy
+
+from det3d.core.utils.circle_nms_jit import circle_nms
+
+
+class SepHead(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ heads,
+ head_conv=64,
+ final_kernel=1,
+ bn=False,
+ init_bias=-2.19,
+ **kwargs,
+ ):
+ super(SepHead, self).__init__(**kwargs)
+
+ self.heads = heads
+ for head in self.heads:
+ classes, num_conv = self.heads[head]
+
+ fc = Sequential()
+ for i in range(num_conv - 1):
+ fc.add(
+ nn.Conv1d(
+ in_channels,
+ head_conv,
+ kernel_size=final_kernel,
+ stride=1,
+ padding=final_kernel // 2,
+ bias=True,
+ )
+ )
+ if bn:
+ fc.add(build_norm_layer(dict(type="BN1d"), head_conv)[1])
+ fc.add(nn.ReLU())
+
+ fc.add(
+ nn.Conv1d(
+ head_conv,
+ classes,
+ kernel_size=final_kernel,
+ stride=1,
+ padding=final_kernel // 2,
+ bias=True,
+ )
+ )
+
+ if "hm" in head:
+ fc[-1].bias.data.fill_(init_bias)
+ else:
+ for m in fc.modules():
+ if isinstance(m, nn.Conv1d):
+ kaiming_init(m)
+
+ self.__setattr__(head, fc)
+
+ def forward(self, x, y):
+ for head in self.heads:
+ x[head] = self.__getattr__(head)(y)
+
+ return x
+
+
+@HEADS.register_module
+class CenterHeadIoU_1d(nn.Module):
+ def __init__(
+ self,
+ in_channels=[128,],
+ tasks=[],
+ dataset="nuscenes",
+ weight=0.25,
+ iou_weight=1,
+ corner_weight=1,
+ code_weights=[],
+ common_heads=dict(),
+ logger=None,
+ init_bias=-2.19,
+ share_conv_channel=64,
+ assign_label_window_size=1,
+ iou_loss=False,
+ corner_loss=False,
+ iou_factor=[1, 1, 4],
+ ):
+ super(CenterHeadIoU_1d, self).__init__()
+
+ num_classes = [len(t["class_names"]) for t in tasks]
+ self.class_names = [t["class_names"] for t in tasks]
+ self.code_weights = code_weights
+ self.weight = weight # weight between hm loss and loc loss
+ self.iou_weight = iou_weight
+ self.corner_weight = corner_weight
+ self.dataset = dataset
+ self.iou_factor = iou_factor
+
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+
+ self.crit = FastFocalLoss(assign_label_window_size)
+ self.crit_reg = torch.nn.L1Loss(reduction="none")
+ self.use_iou_loss = iou_loss
+ if self.use_iou_loss:
+ self.crit_iou = torch.nn.SmoothL1Loss(reduction="none")
+ self.corner_loss = corner_loss
+ if self.corner_loss:
+ self.corner_crit = torch.nn.MSELoss(reduction="none")
+
+ self.box_n_dim = 9 if "vel" in common_heads else 7
+ self.use_direction_classifier = False
+
+ if not logger:
+ logger = logging.getLogger("CenterHeadIoU_1d")
+ self.logger = logger
+
+ logger.info(f"num_classes: {num_classes}")
+
+ # a shared convolution
+ self.shared_conv = nn.Sequential(
+ nn.Conv1d(in_channels, share_conv_channel, kernel_size=1, bias=True),
+ build_norm_layer(dict(type="BN1d"), share_conv_channel)[1],
+ nn.ReLU(inplace=True),
+ )
+
+ self.tasks = nn.ModuleList()
+ print("Use HM Bias: ", init_bias)
+
+ for num_cls in num_classes:
+ heads = copy.deepcopy(common_heads)
+ self.tasks.append(
+ SepHead(
+ share_conv_channel,
+ heads,
+ bn=True,
+ init_bias=init_bias,
+ final_kernel=1,
+ )
+ )
+
+ logger.info("Finish CenterHeadIoU Initialization")
+
+ def forward(self, x, *kwargs):
+ ret_dicts = []
+
+ y = self.shared_conv(x["ct_feat"].float())
+
+ for task in self.tasks:
+ ret_dicts.append(task(x, y))
+
+ return ret_dicts
+
+ def _sigmoid(self, x):
+ y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
+ return y
+
+ def loss(self, example, preds_dicts, test_cfg, **kwargs):
+ rets = []
+ for task_id, preds_dict in enumerate(preds_dicts):
+ # heatmap focal loss
+ hm_loss = self.crit(
+ preds_dict["hm"],
+ example["hm"][task_id],
+ example["ind"][task_id],
+ example["mask"][task_id],
+ example["cat"][task_id],
+ )
+
+ target_box = example["anno_box"][task_id]
+
+ if self.corner_loss:
+ corner_loss = self.corner_crit(
+ preds_dict["corner_hm"], example["corners"][task_id]
+ )
+ corner_mask = (example["corners"][task_id] > 0).to(corner_loss)
+ corner_loss = (corner_loss * corner_mask).sum() / (
+ corner_mask.sum() + 1e-4
+ )
+
+ # reconstruct the anno_box from multiple reg heads
+ if self.dataset in ["waymo", "nuscenes"]:
+ if "vel" in preds_dict:
+ preds_dict["anno_box"] = torch.cat(
+ (
+ preds_dict["reg"],
+ preds_dict["height"],
+ preds_dict["dim"],
+ preds_dict["vel"],
+ preds_dict["rot"],
+ ),
+ dim=1,
+ )
+ else:
+ preds_dict["anno_box"] = torch.cat(
+ (
+ preds_dict["reg"],
+ preds_dict["height"],
+ preds_dict["dim"],
+ preds_dict["rot"],
+ ),
+ dim=1,
+ )
+ target_box = target_box[
+ ..., [0, 1, 2, 3, 4, 5, -2, -1]
+ ] # remove vel target
+ else:
+ raise NotImplementedError()
+
+ ret = {}
+
+ # Regression loss for dimension, offset, height, rotation
+ # get corresponding gt box # B, 500
+ target_box, selected_mask, selected_cls = get_corresponding_box(
+ preds_dict["order"],
+ example["ind"][task_id],
+ example["mask"][task_id],
+ example["cat"][task_id],
+ target_box,
+ )
+ mask = selected_mask.float().unsqueeze(2)
+
+ weights = self.code_weights
+
+ box_loss = self.crit_reg(
+ preds_dict["anno_box"].transpose(1, 2) * mask, target_box * mask
+ )
+ box_loss = box_loss / (mask.sum() + 1e-4)
+ box_loss = box_loss.transpose(2, 0).sum(dim=2).sum(dim=1)
+
+ loc_loss = (box_loss * box_loss.new_tensor(weights)).sum()
+
+ if self.use_iou_loss:
+ with torch.no_grad():
+ preds_box = get_box(
+ preds_dict["anno_box"],
+ preds_dict["order"],
+ test_cfg,
+ preds_dict["hm"].shape[2],
+ preds_dict["hm"].shape[3],
+ )
+ cur_gt = get_box_gt(
+ target_box,
+ preds_dict["order"],
+ test_cfg,
+ preds_dict["hm"].shape[2],
+ preds_dict["hm"].shape[3],
+ )
+ iou_targets = boxes_iou3d_gpu(
+ preds_box.reshape(-1, 7), cur_gt.reshape(-1, 7)
+ )[
+ range(preds_box.reshape(-1, 7).shape[0]),
+ range(cur_gt.reshape(-1, 7).shape[0]),
+ ]
+ iou_targets[torch.isnan(iou_targets)] = 0
+ iou_targets = 2 * iou_targets - 1
+ iou_loss = self.crit_iou(
+ preds_dict["iou"].reshape(-1), iou_targets
+ ) * mask.reshape(-1)
+ iou_loss = iou_loss.sum() / (mask.sum() + 1e-4)
+
+ loss = hm_loss + self.weight * loc_loss
+ if self.use_iou_loss:
+ loss = loss + self.iou_weight * iou_loss
+ if self.corner_loss:
+ loss = loss + self.corner_weight * corner_loss
+ ret.update(
+ {
+ "loss": loss,
+ "hm_loss": hm_loss.detach().cpu(),
+ "loc_loss": loc_loss,
+ "loc_loss_elem": box_loss.detach().cpu(),
+ "num_positive": example["mask"][task_id].float().sum(),
+ }
+ )
+ if self.use_iou_loss:
+ ret.update({"iou_loss": iou_loss.detach().cpu()})
+ if self.corner_loss:
+ ret.update({"corner_loss": corner_loss.detach().cpu()})
+
+ rets.append(ret)
+
+ """convert batch-key to key-batch
+ """
+ rets_merged = defaultdict(list)
+ for ret in rets:
+ for k, v in ret.items():
+ rets_merged[k].append(v)
+
+ return rets_merged
+
+ @torch.no_grad()
+ def predict(self, example, preds_dicts, test_cfg, **kwargs):
+ """decode, nms, then return the detection result. Additionaly support double flip testing"""
+ # get loss info
+ rets = []
+ metas = []
+
+ post_center_range = test_cfg.post_center_limit_range
+ if len(post_center_range) > 0:
+ post_center_range = torch.tensor(
+ post_center_range,
+ dtype=preds_dicts[0]["scores"].dtype,
+ device=preds_dicts[0]["scores"].device,
+ )
+
+ for task_id, preds_dict in enumerate(preds_dicts):
+ # convert B C N to B N C
+ for key, val in preds_dict.items():
+ if torch.is_tensor(preds_dict[key]):
+ if len(preds_dict[key].shape) == 3:
+ preds_dict[key] = val.permute(0, 2, 1).contiguous()
+
+ batch_size = preds_dict["scores"].shape[0]
+
+ if "metadata" not in example or len(example["metadata"]) == 0:
+ meta_list = [None] * batch_size
+ else:
+ meta_list = example["metadata"]
+
+ batch_score = preds_dict["scores"]
+ batch_label = preds_dict["labels"]
+ batch_mask = preds_dict["mask"]
+ if self.use_iou_loss:
+ batch_iou = preds_dict["iou"].squeeze(2)
+ else:
+ batch_iou = None
+ if "corner_hm" in preds_dict:
+ batch_corner_hm = preds_dict["corner_hm"]
+ else:
+ batch_corner_hm = None
+
+ batch_dim = torch.exp(preds_dict["dim"])
+
+ batch_rots = preds_dict["rot"][..., 0:1]
+ batch_rotc = preds_dict["rot"][..., 1:2]
+
+ batch_reg = preds_dict["reg"]
+ batch_hei = preds_dict["height"]
+ batch_rot = torch.atan2(batch_rots, batch_rotc)
+ if self.use_iou_loss:
+ batch_iou = (batch_iou + 1) * 0.5
+ batch_iou = torch.clamp(batch_iou, min=0.0, max=1.0)
+
+ batch, _, H, W = preds_dict["hm"].size()
+
+ ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
+ ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_score)
+ xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_score)
+
+ obj_num = preds_dict["order"].shape[1]
+ batch_id = np.indices((batch, obj_num))[0]
+ batch_id = torch.from_numpy(batch_id).to(preds_dict["order"])
+
+ xs = (
+ xs.view(batch, -1, 1)[batch_id, preds_dict["order"]]
+ + batch_reg[:, :, 0:1]
+ )
+ ys = (
+ ys.view(batch, -1, 1)[batch_id, preds_dict["order"]]
+ + batch_reg[:, :, 1:2]
+ )
+
+ xs = (
+ xs * test_cfg.out_size_factor * test_cfg.voxel_size[0]
+ + test_cfg.pc_range[0]
+ )
+ ys = (
+ ys * test_cfg.out_size_factor * test_cfg.voxel_size[1]
+ + test_cfg.pc_range[1]
+ )
+
+ if "vel" in preds_dict:
+ batch_vel = preds_dict["vel"]
+ batch_box_preds = torch.cat(
+ [xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], dim=2
+ )
+ else:
+ batch_box_preds = torch.cat(
+ [xs, ys, batch_hei, batch_dim, batch_rot], dim=2
+ )
+
+ metas.append(meta_list)
+
+ if test_cfg.get("per_class_nms", False):
+ pass
+ else:
+ rets.append(
+ self.post_processing(
+ example,
+ batch_box_preds,
+ batch_score,
+ batch_label,
+ test_cfg,
+ post_center_range,
+ task_id,
+ batch_mask,
+ batch_iou,
+ )
+ )
+
+ # Merge branches results
+ ret_list = []
+ num_samples = len(rets[0])
+
+ ret_list = []
+ for i in range(num_samples):
+ ret = {}
+ for k in rets[0][i].keys():
+ if k in [
+ "box3d_lidar",
+ "scores",
+ "selected_box_mask",
+ "gt_scores",
+ "selected",
+ "selected_feat_ids",
+ ]:
+ ret[k] = torch.cat([ret[i][k] for ret in rets])
+ elif k in ["label_preds"]:
+ flag = 0
+ for j, num_class in enumerate(self.num_classes):
+ rets[j][i][k] += flag
+ flag += num_class
+ ret[k] = torch.cat([ret[i][k] for ret in rets])
+
+ ret["metadata"] = metas[0][i]
+ ret_list.append(ret)
+
+ return ret_list
+
+ @torch.no_grad()
+ def post_processing(
+ self,
+ example,
+ batch_box_preds,
+ batch_score,
+ batch_label,
+ test_cfg,
+ post_center_range,
+ task_id,
+ batch_mask,
+ batch_iou,
+ ):
+ batch_size = len(batch_score)
+
+ prediction_dicts = []
+ for i in range(batch_size):
+ box_preds = batch_box_preds[i]
+ scores = batch_score[i]
+ labels = batch_label[i]
+ mask = batch_mask[i]
+
+ distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all(1) & (
+ box_preds[..., :3] <= post_center_range[3:]
+ ).all(1)
+
+ mask = mask & distance_mask
+
+ box_preds = box_preds[mask]
+ scores = scores[mask]
+ labels = labels[mask]
+
+ if self.use_iou_loss:
+ iou_factor = torch.LongTensor(self.iou_factor).to(labels)
+ ious = batch_iou[i][mask]
+ ious = torch.pow(ious, iou_factor[labels])
+ scores = scores * ious
+
+ boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]]
+
+ if test_cfg.get("circular_nms", False):
+ centers = boxes_for_nms[:, [0, 1]]
+ boxes = torch.cat([centers, scores.view(-1, 1)], dim=1)
+ selected = _circle_nms(
+ boxes,
+ min_radius=test_cfg.min_radius[task_id],
+ post_max_size=test_cfg.nms.nms_post_max_size,
+ )
+ elif test_cfg.nms.get("use_multi_class_nms", False):
+ # multi class nms
+ selected = []
+ for c in range(3):
+ class_mask = labels == c
+ if class_mask.sum() > 0:
+ class_idx = class_mask.nonzero()
+ select = box_torch_ops.rotate_nms_pcdet(
+ boxes_for_nms[class_mask].float(),
+ scores[class_mask].float(),
+ thresh=test_cfg.nms.nms_iou_threshold[c],
+ pre_maxsize=test_cfg.nms.nms_pre_max_size[c],
+ post_max_size=test_cfg.nms.nms_post_max_size[c],
+ )
+ selected.append(class_idx[select, 0])
+ if len(selected) > 0:
+ selected = torch.cat(selected, dim=0)
+ else:
+ selected = box_torch_ops.rotate_nms_pcdet(
+ boxes_for_nms.float(),
+ scores.float(),
+ thresh=test_cfg.nms.nms_iou_threshold,
+ pre_maxsize=test_cfg.nms.nms_pre_max_size,
+ post_max_size=test_cfg.nms.nms_post_max_size,
+ )
+
+ selected_boxes = box_preds[selected]
+ selected_scores = scores[selected]
+ selected_labels = labels[selected]
+
+ prediction_dict = {
+ "box3d_lidar": selected_boxes,
+ "scores": selected_scores,
+ "label_preds": selected_labels,
+ }
+
+ prediction_dicts.append(prediction_dict)
+
+ return prediction_dicts
+
+
+import numpy as np
+
+
+def _circle_nms(boxes, min_radius, post_max_size=83):
+ """
+ NMS according to center distance
+ """
+ keep = np.array(circle_nms(boxes.cpu().numpy(), thresh=min_radius))[:post_max_size]
+
+ keep = torch.from_numpy(keep).long().to(boxes.device)
+
+ return keep
+
+
+def get_box(pred_boxs, order, test_cfg, H, W):
+ batch = pred_boxs.shape[0]
+ obj_num = order.shape[1]
+ ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
+ ys = ys.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs)
+ xs = xs.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs)
+
+ batch_id = np.indices((batch, obj_num))[0]
+ batch_id = torch.from_numpy(batch_id).to(order)
+ xs = xs.view(batch, H * W)[batch_id, order].unsqueeze(1) + pred_boxs[:, 0:1]
+ ys = ys.view(batch, H * W)[batch_id, order].unsqueeze(1) + pred_boxs[:, 1:2]
+
+ xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[0] + test_cfg.pc_range[0]
+ ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[1] + test_cfg.pc_range[1]
+
+ rot = torch.atan2(pred_boxs[:, 6:7], pred_boxs[:, 7:8])
+ pred = torch.cat(
+ [xs, ys, pred_boxs[:, 2:3], torch.exp(pred_boxs[:, 3:6]), rot], dim=1
+ )
+
+ return torch.transpose(pred, 1, 2).contiguous() # B M 7
+
+
+def get_box_gt(gt_boxs, order, test_cfg, H, W):
+ batch = gt_boxs.shape[0]
+ obj_num = order.shape[1]
+ ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)])
+ ys = ys.view(1, H, W).repeat(batch, 1, 1).to(gt_boxs)
+ xs = xs.view(1, H, W).repeat(batch, 1, 1).to(gt_boxs)
+
+ batch_id = np.indices((batch, obj_num))[0]
+ batch_id = torch.from_numpy(batch_id).to(order)
+
+ batch_gt_dim = torch.exp(gt_boxs[..., 3:6])
+ batch_gt_hei = gt_boxs[..., 2:3]
+ batch_gt_rot = torch.atan2(gt_boxs[..., -2:-1], gt_boxs[..., -1:])
+ xs = xs.view(batch, H * W)[batch_id, order].unsqueeze(2) + gt_boxs[..., 0:1]
+ ys = ys.view(batch, H * W)[batch_id, order].unsqueeze(2) + gt_boxs[..., 1:2]
+
+ xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[0] + test_cfg.pc_range[0]
+ ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[1] + test_cfg.pc_range[1]
+
+ batch_box_targets = torch.cat(
+ [xs, ys, batch_gt_hei, batch_gt_dim, batch_gt_rot], dim=-1
+ )
+
+ return batch_box_targets # B M 7
+
+
+def get_corresponding_box(x_ind, y_ind, y_mask, y_cls, target_box):
+ # find the id in y which has the same ind in x
+ select_target = torch.zeros(x_ind.shape[0], x_ind.shape[1], target_box.shape[2]).to(
+ target_box
+ )
+ select_mask = torch.zeros_like(x_ind).to(y_mask)
+ select_cls = torch.zeros_like(x_ind).to(y_cls)
+
+ for i in range(x_ind.shape[0]):
+ idx = torch.arange(y_ind[i].shape[-1]).to(x_ind)
+ idx = idx[y_mask[i]]
+ box_cls = y_cls[i][y_mask[i]]
+ valid_y_ind = y_ind[i][y_mask[i]]
+ match = (x_ind[i].unsqueeze(1) == valid_y_ind.unsqueeze(0)).nonzero()
+ select_target[i, match[:, 0]] = target_box[i, idx[match[:, 1]]]
+ select_mask[i, match[:, 0]] = 1
+ select_cls[i, match[:, 0]] = box_cls[match[:, 1]]
+
+ return select_target, select_mask, select_cls
diff --git a/det3d/models/builder.py b/det3d/models/builder.py
new file mode 100644
index 0000000..0b15789
--- /dev/null
+++ b/det3d/models/builder.py
@@ -0,0 +1,50 @@
+from det3d.utils import build_from_cfg
+from torch import nn
+
+from .registry import (
+ BACKBONES,
+ DETECTORS,
+ HEADS,
+ LOSSES,
+ NECKS,
+ READERS,
+ SECOND_STAGE,
+ ROI_HEAD
+)
+
+
+def build(cfg, registry, default_args=None):
+ if isinstance(cfg, list):
+ modules = [build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg]
+ return nn.Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+def build_second_stage_module(cfg):
+ return build(cfg, SECOND_STAGE)
+
+def build_roi_head(cfg):
+ return build(cfg, ROI_HEAD)
+
+
+def build_reader(cfg):
+ return build(cfg, READERS)
+
+
+def build_backbone(cfg):
+ return build(cfg, BACKBONES)
+
+
+def build_neck(cfg):
+ return build(cfg, NECKS)
+
+def build_head(cfg):
+ return build(cfg, HEADS)
+
+
+def build_loss(cfg):
+ return build(cfg, LOSSES)
+
+
+def build_detector(cfg, train_cfg=None, test_cfg=None):
+ return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/det3d/models/detectors/__init__.py b/det3d/models/detectors/__init__.py
new file mode 100644
index 0000000..1af185d
--- /dev/null
+++ b/det3d/models/detectors/__init__.py
@@ -0,0 +1,14 @@
+from .base import BaseDetector
+from .point_pillars import PointPillars
+from .single_stage import SingleStageDetector
+from .voxelnet import VoxelNet
+from .two_stage import TwoStageDetector
+from .voxelnet_dynamic import VoxelNet_dynamic
+
+__all__ = [
+ "BaseDetector",
+ "SingleStageDetector",
+ "VoxelNet",
+ "PointPillars",
+ 'VoxelNet_dynamic',
+]
\ No newline at end of file
diff --git a/det3d/models/detectors/base.py b/det3d/models/detectors/base.py
new file mode 100644
index 0000000..ba3109a
--- /dev/null
+++ b/det3d/models/detectors/base.py
@@ -0,0 +1,70 @@
+import logging
+from abc import ABCMeta, abstractmethod
+
+import numpy as np
+# import pycocotools.mask as maskUtils
+import torch.nn as nn
+from det3d import torchie
+
+
+class BaseDetector(nn.Module):
+ """Base class for detectors"""
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self):
+ super(BaseDetector, self).__init__()
+ self.fp16_enabled = False
+
+ @property
+ def with_reader(self):
+ # Whether input data need to be processed by Input Feature Extractor
+ return hasattr(self, "reader") and self.reader is not None
+
+ @property
+ def with_neck(self):
+ return hasattr(self, "neck") and self.neck is not None
+
+ @property
+ def with_shared_head(self):
+ return hasattr(self, "shared_head") and self.shared_head is not None
+
+ @property
+ def with_bbox(self):
+ return hasattr(self, "bbox_head") and self.bbox_head is not None
+
+ @property
+ def with_mask(self):
+ return hasattr(self, "mask_head") and self.mask_head is not None
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ pass
+
+ def extract_feats(self, imgs):
+ assert isinstance(imgs, list)
+ for img in imgs:
+ yield self.extract_feat(img)
+
+ @abstractmethod
+ def forward_train(self, imgs, **kwargs):
+ pass
+
+ @abstractmethod
+ def simple_test(self, img, **kwargs):
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, **kwargs):
+ pass
+
+ def init_weights(self, pretrained=None):
+ if pretrained is not None:
+ logger = logging.getLogger()
+ logger.info("load model from: {}".format(pretrained))
+
+ def forward_test(self, imgs, **kwargs):
+ pass
+
+ def forward(self, example, return_loss=True, **kwargs):
+ pass
diff --git a/det3d/models/detectors/point_pillars.py b/det3d/models/detectors/point_pillars.py
new file mode 100644
index 0000000..a1e0c90
--- /dev/null
+++ b/det3d/models/detectors/point_pillars.py
@@ -0,0 +1,90 @@
+from ..registry import DETECTORS
+from .single_stage import SingleStageDetector
+from copy import deepcopy
+
+@DETECTORS.register_module
+class PointPillars(SingleStageDetector):
+ def __init__(
+ self,
+ reader,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ ):
+ super(PointPillars, self).__init__(
+ reader, backbone, neck, bbox_head, train_cfg, test_cfg, pretrained
+ )
+
+ def extract_feat(self, data):
+ input_features = self.reader(
+ data["features"], data["num_voxels"], data["coors"]
+ )
+ x = self.backbone(
+ input_features, data["coors"], data["batch_size"], data["input_shape"]
+ )
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward(self, example, return_loss=True, **kwargs):
+ voxels = example["voxels"]
+ coordinates = example["coordinates"]
+ num_points_in_voxel = example["num_points"]
+ num_voxels = example["num_voxels"]
+
+ batch_size = len(num_voxels)
+
+ data = dict(
+ features=voxels,
+ num_voxels=num_points_in_voxel,
+ coors=coordinates,
+ batch_size=batch_size,
+ input_shape=example["shape"][0],
+ )
+
+ x = self.extract_feat(data)
+ preds = self.bbox_head(x)
+
+ if return_loss:
+ return self.bbox_head.loss(example, preds)
+ else:
+ return self.bbox_head.predict(example, preds, self.test_cfg)
+
+ def forward_two_stage(self, example, return_loss=True, **kwargs):
+ voxels = example["voxels"]
+ coordinates = example["coordinates"]
+ num_points_in_voxel = example["num_points"]
+ num_voxels = example["num_voxels"]
+
+ batch_size = len(num_voxels)
+
+ data = dict(
+ features=voxels,
+ num_voxels=num_points_in_voxel,
+ coors=coordinates,
+ batch_size=batch_size,
+ input_shape=example["shape"][0],
+ )
+
+ x = self.extract_feat(data)
+ bev_feature = x
+ preds = self.bbox_head(x)
+
+ # manual deepcopy ...
+ new_preds = []
+ for pred in preds:
+ new_pred = {}
+ for k, v in pred.items():
+ new_pred[k] = v.detach()
+
+ new_preds.append(new_pred)
+
+ boxes = self.bbox_head.predict(example, new_preds, self.test_cfg)
+
+ if return_loss:
+ return boxes, bev_feature, self.bbox_head.loss(example, preds)
+ else:
+ return boxes, bev_feature, None
\ No newline at end of file
diff --git a/det3d/models/detectors/single_stage.py b/det3d/models/detectors/single_stage.py
new file mode 100644
index 0000000..67bb8c9
--- /dev/null
+++ b/det3d/models/detectors/single_stage.py
@@ -0,0 +1,62 @@
+import torch.nn as nn
+
+from .. import builder
+from ..registry import DETECTORS
+from .base import BaseDetector
+from ..utils.finetune_utils import FrozenBatchNorm2d
+from det3d.torchie.trainer import load_checkpoint
+
+
+@DETECTORS.register_module
+class SingleStageDetector(BaseDetector):
+ def __init__(
+ self,
+ reader,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ ):
+ super(SingleStageDetector, self).__init__()
+ self.reader = builder.build_reader(reader)
+ self.backbone = builder.build_backbone(backbone)
+ if neck is not None:
+ self.neck = builder.build_neck(neck)
+ self.bbox_head = builder.build_head(bbox_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.init_weights(pretrained=pretrained)
+
+ def init_weights(self, pretrained=None):
+ if pretrained is None:
+ return
+ try:
+ load_checkpoint(self, pretrained, map_location="cpu", strict=False)
+ print("init weight from {}".format(pretrained))
+ except:
+ print("no pretrained model at {}".format(pretrained))
+
+ def extract_feat(self, data):
+ input_features = self.reader(data)
+ x = self.backbone(input_features)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def aug_test(self, example, rescale=False):
+ raise NotImplementedError
+
+ def forward(self, example, return_loss=True, **kwargs):
+ pass
+
+ def predict(self, example, preds_dicts):
+ pass
+
+ def freeze(self):
+ for p in self.parameters():
+ p.requires_grad = False
+ FrozenBatchNorm2d.convert_frozen_batchnorm(self)
+ return self
\ No newline at end of file
diff --git a/det3d/models/detectors/two_stage.py b/det3d/models/detectors/two_stage.py
new file mode 100644
index 0000000..2c58df5
--- /dev/null
+++ b/det3d/models/detectors/two_stage.py
@@ -0,0 +1,215 @@
+from det3d.core.bbox import box_torch_ops
+from ...ops.iou3d_nms.iou3d_nms_utils import boxes_iou3d_gpu
+from det3d.core import box_torch_ops
+from ..registry import DETECTORS
+from .base import BaseDetector
+from .. import builder
+import torch
+from torch import nn
+import numpy as np
+
+@DETECTORS.register_module
+class TwoStageDetector(BaseDetector):
+ def __init__(
+ self,
+ first_stage_cfg,
+ second_stage_modules,
+ roi_head,
+ NMS_POST_MAXSIZE,
+ num_point=1,
+ freeze=False,
+ **kwargs
+ ):
+ super(TwoStageDetector, self).__init__()
+ self.single_det = builder.build_detector(first_stage_cfg, **kwargs)
+ self.NMS_POST_MAXSIZE = NMS_POST_MAXSIZE
+
+ if freeze:
+ print("Freeze First Stage Network")
+ # we train the model in two steps
+ self.single_det = self.single_det.freeze()
+ self.bbox_head = self.single_det.bbox_head
+
+ self.second_stage = nn.ModuleList()
+ # can be any number of modules
+ # bird eye view, cylindrical view, image, multiple timesteps, etc..
+ for module in second_stage_modules:
+ self.second_stage.append(builder.build_second_stage_module(module))
+
+ self.roi_head = builder.build_roi_head(roi_head)
+
+ self.num_point = num_point
+
+ def combine_loss(self, one_stage_loss, roi_loss, tb_dict):
+ one_stage_loss['loss'][0] += (roi_loss)
+
+ for i in range(len(one_stage_loss['loss'])):
+ one_stage_loss['roi_reg_loss'].append(tb_dict['rcnn_loss_reg'])
+ one_stage_loss['roi_cls_loss'].append(tb_dict['rcnn_loss_cls'])
+
+ return one_stage_loss
+
+ def get_box_center(self, boxes, example):
+ # box [List]
+ centers = []
+ for batch_id,box in enumerate(boxes):
+ if self.num_point == 1 or len(box['box3d_lidar']) == 0:
+ centers.append(box['box3d_lidar'][:, :3])
+ elif self.num_point == 5:
+ center2d = box['box3d_lidar'][:, :2]
+ height = box['box3d_lidar'][:, 2:3]
+ dim2d = box['box3d_lidar'][:, 3:5]
+ rotation_y = box['box3d_lidar'][:, -1]
+
+ corners = box_torch_ops.center_to_corner_box2d(center2d, dim2d, rotation_y)
+
+ front_middle = torch.cat([(corners[:, 0] + corners[:, 1])/2, height], dim=-1)
+ back_middle = torch.cat([(corners[:, 2] + corners[:, 3])/2, height], dim=-1)
+ left_middle = torch.cat([(corners[:, 0] + corners[:, 3])/2, height], dim=-1)
+ right_middle = torch.cat([(corners[:, 1] + corners[:, 2])/2, height], dim=-1)
+
+ points = torch.cat([box['box3d_lidar'][:, :3], front_middle, back_middle, left_middle, \
+ right_middle], dim=0)
+
+ centers.append(points)
+ elif self.num_point == 9:
+ center2d = box['box3d_lidar'][:, :2]
+ height = box['box3d_lidar'][:, 2:3]
+ dim2d = box['box3d_lidar'][:, 3:5]
+ rotation_y = box['box3d_lidar'][:, -1]
+
+ corners = box_torch_ops.center_to_corner_box2d(center2d, dim2d, rotation_y)
+ front_left = torch.cat([(center2d + corners[:, 0])/2, height], dim=-1)
+ front_middle = torch.cat([(corners[:, 0] + corners[:, 1] + center2d)/3, height], dim=-1)
+ back_left = torch.cat([(center2d + corners[:, 1])/2, height], dim=-1)
+ back_middle = torch.cat([(corners[:, 1] + corners[:, 2] + center2d)/3, height], dim=-1)
+ left_left = torch.cat([(center2d + corners[:, 2])/2, height], dim=-1)
+ left_middle = torch.cat([(corners[:, 2] + corners[:, 3] + center2d)/3, height], dim=-1)
+ right_left = torch.cat([(center2d + corners[:, 3])/2, height], dim=-1)
+ right_middle = torch.cat([(corners[:, 3] + corners[:, 1] + center2d)/3, height], dim=-1)
+
+ points = torch.cat([box['box3d_lidar'][:, :3], front_left, front_middle, back_left, back_middle, left_left, left_middle, \
+ right_left, right_middle], dim=0)
+ centers.append(points)
+ else:
+ raise NotImplementedError()
+
+ return centers
+
+ def reorder_first_stage_pred_and_feature(self, first_pred, example, features, point_locs):
+ batch_size = len(first_pred)
+ box_length = first_pred[0]['box3d_lidar'].shape[1]
+ feature_vector_length = sum([feat[0].shape[-1] for feat in features])
+
+ rois = first_pred[0]['box3d_lidar'].new_zeros((batch_size,
+ self.NMS_POST_MAXSIZE, box_length
+ ))
+ roi_scores = first_pred[0]['scores'].new_zeros((batch_size,
+ self.NMS_POST_MAXSIZE
+ ))
+ roi_labels = first_pred[0]['label_preds'].new_zeros((batch_size,
+ self.NMS_POST_MAXSIZE), dtype=torch.long
+ )
+ roi_features = features[0][0].new_zeros((batch_size,
+ self.NMS_POST_MAXSIZE, feature_vector_length
+ ))
+
+
+ for i in range(batch_size):
+ num_obj = features[0][i].shape[0]
+ # basically move rotation to position 6, so now the box is 7 + C . C is 2 for nuscenes to
+ # include velocity target
+
+ box_preds = first_pred[i]['box3d_lidar']
+
+ if self.roi_head.code_size == 9:
+ # x, y, z, w, l, h, rotation_y, velocity_x, velocity_y
+ box_preds = box_preds[:, [0, 1, 2, 3, 4, 5, 8, 6, 7]]
+
+ rois[i, :num_obj] = box_preds
+ roi_labels[i, :num_obj] = first_pred[i]['label_preds'] + 1
+ roi_scores[i, :num_obj] = first_pred[i]['scores']
+ roi_features[i, :num_obj] = torch.cat([feat[i] for feat in features], dim=-1)
+
+ example['rois'] = rois
+ example['roi_labels'] = roi_labels
+ example['roi_scores'] = roi_scores
+ example['roi_features'] = roi_features
+
+ example['has_class_labels']= True
+
+ return example
+
+ def post_process(self, batch_dict):
+ batch_size = batch_dict['batch_size']
+ pred_dicts = []
+
+ for index in range(batch_size):
+ box_preds = batch_dict['batch_box_preds'][index]
+ cls_preds = batch_dict['batch_cls_preds'][index] # this is the predicted iou
+ label_preds = batch_dict['roi_labels'][index]
+
+ if box_preds.shape[-1] == 9:
+ # move rotation to the end (the create submission file will take elements from 0:6 and -1)
+ box_preds = box_preds[:, [0, 1, 2, 3, 4, 5, 7, 8, 6]]
+
+ scores = torch.sqrt(torch.sigmoid(cls_preds).reshape(-1) * batch_dict['roi_scores'][index].reshape(-1))
+ mask = (label_preds != 0).reshape(-1)
+
+ box_preds = box_preds[mask, :]
+ scores = scores[mask]
+ labels = label_preds[mask]-1
+
+ # currently don't need nms
+ pred_dict = {
+ 'box3d_lidar': box_preds,
+ 'scores': scores,
+ 'label_preds': labels,
+ "metadata": batch_dict["metadata"][index],
+ }
+
+ pred_dicts.append(pred_dict)
+
+ return pred_dicts
+
+
+ def forward(self, example, return_loss=True, **kwargs):
+ out = self.single_det.forward_two_stage(example,
+ return_loss, **kwargs)
+ if len(out) == 4:
+ one_stage_pred, bev_feature, voxel_feature, one_stage_loss = out
+ example['voxel_feature'] = voxel_feature
+ elif len(out) == 3:
+ one_stage_pred, bev_feature, one_stage_loss = out
+ else:
+ raise NotImplementedError
+
+ # N C H W -> N H W C
+ example['bev_feature'] = bev_feature.permute(0, 2, 3, 1).contiguous()
+
+ if self.roi_head.code_size == 7 and return_loss is True:
+ # drop velocity
+ example['gt_boxes_and_cls'] = example['gt_boxes_and_cls'][:, :, [0, 1, 2, 3, 4, 5, 6, -1]]
+
+ centers_vehicle_frame = self.get_box_center(one_stage_pred,example)
+ features = []
+
+ for module in self.second_stage:
+ feature = module.forward(example, centers_vehicle_frame, self.num_point)
+ features.append(feature)
+ # feature is two level list
+ # first level is number of two stage information streams
+ # second level is batch
+
+ example = self.reorder_first_stage_pred_and_feature(first_pred=one_stage_pred, example=example, features=features, point_locs = centers_vehicle_frame)
+
+ # final classification / regression
+ batch_dict = self.roi_head(example, training=return_loss)
+
+
+ if return_loss:
+ roi_loss, tb_dict = self.roi_head.get_loss()
+
+ return self.combine_loss(one_stage_loss, roi_loss, tb_dict)
+ else:
+ return self.post_process(batch_dict)
diff --git a/det3d/models/detectors/voxelnet.py b/det3d/models/detectors/voxelnet.py
new file mode 100644
index 0000000..0c18f65
--- /dev/null
+++ b/det3d/models/detectors/voxelnet.py
@@ -0,0 +1,97 @@
+from ..registry import DETECTORS
+from .single_stage import SingleStageDetector
+from det3d.torchie.trainer import load_checkpoint
+import torch
+from copy import deepcopy
+from torch.cuda.amp import autocast as autocast
+
+@DETECTORS.register_module
+class VoxelNet(SingleStageDetector):
+ def __init__(
+ self,
+ reader,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ ):
+ super(VoxelNet, self).__init__(
+ reader, backbone, neck, bbox_head, train_cfg, test_cfg, pretrained
+ )
+
+ def extract_feat(self, example, data):
+ input_features = self.reader(data["features"], data["num_voxels"])
+ x, voxel_feature = self.backbone(
+ input_features, data["coors"], data["batch_size"], data["input_shape"]
+ )
+ if self.with_neck:
+ x = self.neck(x, example)
+
+ return x, voxel_feature
+
+ def forward(self, example, return_loss=True, **kwargs):
+ voxels = example["voxels"]
+ coordinates = example["coordinates"]
+ num_points_in_voxel = example["num_points"]
+ num_voxels = example["num_voxels"]
+
+ batch_size = len(num_voxels)
+
+ data = dict(
+ features=voxels,
+ num_voxels=num_points_in_voxel,
+ coors=coordinates,
+ batch_size=batch_size,
+ input_shape=example["shape"][0],
+ )
+
+ # if self.bbox_head.training:
+ # x, _ = self.extract_feat(example, data)
+ # else:
+ # with autocast():
+ # x, _ = self.extract_feat(example, data)
+ x, _ = self.extract_feat(example, data)
+ preds = self.bbox_head(x)
+
+ if return_loss:
+ return self.bbox_head.loss(example, preds, self.test_cfg)
+ else:
+ return self.bbox_head.predict(example, preds, self.test_cfg)
+
+ def forward_two_stage(self, example, return_loss=True, **kwargs):
+ voxels = example["voxels"]
+ coordinates = example["coordinates"]
+ num_points_in_voxel = example["num_points"]
+ num_voxels = example["num_voxels"]
+
+ batch_size = len(num_voxels)
+
+ data = dict(
+ features=voxels,
+ num_voxels=num_points_in_voxel,
+ coors=coordinates,
+ batch_size=batch_size,
+ input_shape=example["shape"][0],
+ )
+
+ x, _ = self.extract_feat(example, data)
+ bev_feature = x['BEV_feat']
+ preds = self.bbox_head(x)
+
+ # manual deepcopy ...
+ new_preds = []
+ for pred in preds:
+ new_pred = {}
+ for k, v in pred.items():
+ new_pred[k] = v.detach()
+
+ new_preds.append(new_pred)
+
+ boxes = self.bbox_head.predict(example, new_preds, self.test_cfg)
+
+ if return_loss:
+ return boxes, bev_feature, self.bbox_head.loss(example, preds, self.test_cfg)
+ else:
+ return boxes, bev_feature, None
diff --git a/det3d/models/detectors/voxelnet_dynamic.py b/det3d/models/detectors/voxelnet_dynamic.py
new file mode 100644
index 0000000..2848f56
--- /dev/null
+++ b/det3d/models/detectors/voxelnet_dynamic.py
@@ -0,0 +1,94 @@
+from ..registry import DETECTORS
+from .single_stage import SingleStageDetector
+from det3d.torchie.trainer import load_checkpoint
+import torch
+from copy import deepcopy
+from torch.cuda.amp import autocast as autocast
+
+@DETECTORS.register_module
+class VoxelNet_dynamic(SingleStageDetector):
+ def __init__(
+ self,
+ reader,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None,
+ ):
+ super(VoxelNet_dynamic, self).__init__(
+ reader, backbone, neck, bbox_head, train_cfg, test_cfg, pretrained
+ )
+
+ def extract_feat(self, example):
+ if 'voxels' not in example:
+ output = self.reader(example['points'])
+ voxels, coors, shape = output
+
+ data = dict(
+ features=voxels,
+ coors=coors,
+ batch_size=len(example['points']),
+ input_shape=shape,
+ voxels=voxels
+ )
+
+ x, voxel_feature = self.backbone(
+ data['voxels'], data["coors"], data["batch_size"], data["input_shape"]
+ )
+
+ if self.with_neck:
+ x = self.neck(x, example)
+
+ return x, voxel_feature
+
+ def forward(self, example, return_loss=True, **kwargs):
+ # if self.bbox_head.training:
+ # x, _ = self.extract_feat(example)
+ # else:
+ # with autocast():
+ # x, _ = self.extract_feat(example)
+ x, _ = self.extract_feat(example)
+ preds = self.bbox_head(x)
+
+ if return_loss:
+ return self.bbox_head.loss(example, preds, self.test_cfg)
+ else:
+ return self.bbox_head.predict(example, preds, self.test_cfg)
+
+ def forward_two_stage(self, example, return_loss=True, **kwargs):
+ voxels = example["voxels"]
+ coordinates = example["coordinates"]
+ num_points_in_voxel = example["num_points"]
+ num_voxels = example["num_voxels"]
+
+ batch_size = len(num_voxels)
+
+ data = dict(
+ features=voxels,
+ num_voxels=num_points_in_voxel,
+ coors=coordinates,
+ batch_size=batch_size,
+ input_shape=example["shape"][0],
+ )
+
+ x, _ = self.extract_feat(example, data)
+ bev_feature = x['BEV_feat']
+ preds = self.bbox_head(x)
+
+ # manual deepcopy ...
+ new_preds = []
+ for pred in preds:
+ new_pred = {}
+ for k, v in pred.items():
+ new_pred[k] = v.detach()
+
+ new_preds.append(new_pred)
+
+ boxes = self.bbox_head.predict(example, new_preds, self.test_cfg)
+
+ if return_loss:
+ return boxes, bev_feature, self.bbox_head.loss(example, preds, self.test_cfg)
+ else:
+ return boxes, bev_feature, None
diff --git a/det3d/models/losses/__init__.py b/det3d/models/losses/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/det3d/models/losses/centernet_loss.py b/det3d/models/losses/centernet_loss.py
new file mode 100644
index 0000000..1fc8856
--- /dev/null
+++ b/det3d/models/losses/centernet_loss.py
@@ -0,0 +1,109 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from det3d.core.utils.center_utils import _transpose_and_gather_feat
+
+class RegLoss(nn.Module):
+ '''Regression loss for an output tensor
+ Arguments:
+ output (batch x dim x h x w)
+ mask (batch x max_objects)
+ ind (batch x max_objects)
+ target (batch x max_objects x dim)
+ '''
+ def __init__(self):
+ super(RegLoss, self).__init__()
+
+ def forward(self, output, mask, ind, target):
+ pred = _transpose_and_gather_feat(output, ind)
+ mask = mask.float().unsqueeze(2)
+
+ loss = F.l1_loss(pred*mask, target*mask, reduction='none')
+ loss = loss / (mask.sum() + 1e-4)
+ loss = loss.transpose(2 ,0).sum(dim=2).sum(dim=1)
+ return loss
+
+class SegLoss(nn.Module):
+ '''segmentation loss for an output tensor
+ Arguments:
+ mask (batch x dim x h x w)
+ offset (batch x dim x h x w)
+ gt_mask (batch x dim x h x w)
+ gt_offset (batch x dim x h x w)
+ '''
+ def __init__(self, offset_weight =0.1):
+ super(SegLoss, self).__init__()
+ self.offset_weight = offset_weight
+
+ def forward(self, mask, offset, gt_mask, gt_offset):
+ loss = F.binary_cross_entropy(torch.sigmoid(mask), gt_mask)
+ offset_loss = F.l1_loss(offset*gt_mask, gt_offset*gt_mask, reduction='none')
+ offset_loss = offset_loss.sum() / (gt_mask.sum() + 1e-4)
+ loss += self.offset_weight * offset_loss
+ return loss
+
+class SegLossV2(nn.Module):
+ '''segmentation loss for an output tensor
+ Arguments:
+ mask (batch x dim x h x w)
+ offset (batch x dim x h x w)
+ grid_offset (batch x dim x h x w)
+ gt_mask (batch x dim x h x w)
+ gt_offset (batch x dim x h x w)
+ gt_grid_offset (batch x dim x h x w)
+ '''
+ def __init__(self):
+ super(SegLossV2, self).__init__()
+
+ def forward(self, mask, offset, grid_offset, gt_mask, gt_offset, gt_grid_offset):
+ loss = F.cross_entropy(mask, gt_mask.squeeze(1))
+ offset_mask = (gt_mask>0).to(gt_offset)
+ offset_loss = F.l1_loss(offset, gt_offset, reduction='none')*offset_mask
+ offset_loss = offset_loss.sum() / (offset_mask.sum() + 1e-4)
+ loss += offset_loss
+ grid_offset_mask = (gt_mask==1).to(gt_offset)
+ grid_offset_loss = F.l1_loss(F.sigmoid(grid_offset), gt_grid_offset, reduction='none')*grid_offset_mask
+ grid_offset_loss = grid_offset_loss.sum() / (grid_offset_mask.sum() + 1e-4)
+ loss += grid_offset_loss
+ return loss
+
+class FastFocalLoss(nn.Module):
+ '''
+ Reimplemented focal loss, exactly the same as the CornerNet version.
+ Faster and costs much less memory.
+ '''
+ def __init__(self, window_size=1, focal_factor=2):
+ super(FastFocalLoss, self).__init__()
+ self.window_size = window_size**2
+ self.focal_factor = focal_factor
+
+ def forward(self, out, target, ind, mask, cat):
+ '''
+ Arguments:
+ out, target: B x C x H x W
+ ind, mask: B x M
+ cat (category id for peaks): B x M
+ '''
+ mask = mask.float()
+ gt = torch.pow(1 - target, 4)
+ neg_loss = torch.log(1 - out) * torch.pow(out, self.focal_factor) * gt
+ neg_loss = neg_loss.sum()
+
+ if self.window_size>1:
+ ct_ind = ind[:,(self.window_size//2)::self.window_size]
+ ct_mask = mask[:,(self.window_size//2)::self.window_size]
+ ct_cat = cat[:,(self.window_size//2)::self.window_size]
+ else:
+ ct_ind = ind
+ ct_mask = mask
+ ct_cat = cat
+
+ pos_pred_pix = _transpose_and_gather_feat(out, ct_ind) # B x M x C
+ pos_pred = pos_pred_pix.gather(2, ct_cat.unsqueeze(2)) # B x M
+ num_pos = ct_mask.sum()
+ pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, self.focal_factor) * \
+ ct_mask.unsqueeze(2)
+ pos_loss = pos_loss.sum()
+ if num_pos == 0:
+ return - neg_loss
+ return - (pos_loss + neg_loss) / num_pos
diff --git a/det3d/models/necks/__init__.py b/det3d/models/necks/__init__.py
new file mode 100644
index 0000000..74ee9d0
--- /dev/null
+++ b/det3d/models/necks/__init__.py
@@ -0,0 +1,8 @@
+from .rpn import RPN
+from .rpn_transformer import RPN_transformer, RPN_transformer_deformable, RPN_transformer_multiframe, RPN_transformer_deformable_mtf
+
+__all__ = ["RPN",
+ "RPN_transformer",
+ "RPN_transformer_deformable",
+ "RPN_transformer_multiframe",
+ "RPN_transformer_deformable_mtf",]
diff --git a/det3d/models/necks/rpn.py b/det3d/models/necks/rpn.py
new file mode 100644
index 0000000..2b98089
--- /dev/null
+++ b/det3d/models/necks/rpn.py
@@ -0,0 +1,160 @@
+import time
+import numpy as np
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+from torchvision.models import resnet
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from det3d.torchie.cnn import constant_init, kaiming_init, xavier_init
+from det3d.torchie.trainer import load_checkpoint
+from det3d.models.utils import Empty, GroupNorm, Sequential
+from det3d.models.utils import change_default_args
+
+from .. import builder
+from ..registry import NECKS
+from ..utils import build_norm_layer
+
+
+@NECKS.register_module
+class RPN(nn.Module):
+ def __init__(
+ self,
+ layer_nums,
+ ds_layer_strides,
+ ds_num_filters,
+ us_layer_strides,
+ us_num_filters,
+ num_input_features,
+ norm_cfg=None,
+ name="rpn",
+ logger=None,
+ **kwargs
+ ):
+ super(RPN, self).__init__()
+ self._layer_strides = ds_layer_strides
+ self._num_filters = ds_num_filters
+ self._layer_nums = layer_nums
+ self._upsample_strides = us_layer_strides
+ self._num_upsample_filters = us_num_filters
+ self._num_input_features = num_input_features
+
+ if norm_cfg is None:
+ norm_cfg = dict(type="BN", eps=1e-3, momentum=0.01)
+ self._norm_cfg = norm_cfg
+
+ assert len(self._layer_strides) == len(self._layer_nums)
+ assert len(self._num_filters) == len(self._layer_nums)
+ assert len(self._num_upsample_filters) == len(self._upsample_strides)
+
+ self._upsample_start_idx = len(self._layer_nums) - len(self._upsample_strides)
+
+ must_equal_list = []
+ for i in range(len(self._upsample_strides)):
+ # print(upsample_strides[i])
+ must_equal_list.append(
+ self._upsample_strides[i]
+ / np.prod(self._layer_strides[: i + self._upsample_start_idx + 1])
+ )
+
+ for val in must_equal_list:
+ assert val == must_equal_list[0]
+
+ in_filters = [self._num_input_features, *self._num_filters[:-1]]
+ blocks = []
+ deblocks = []
+
+ for i, layer_num in enumerate(self._layer_nums):
+ block, num_out_filters = self._make_layer(
+ in_filters[i],
+ self._num_filters[i],
+ layer_num,
+ stride=self._layer_strides[i],
+ )
+ blocks.append(block)
+ if i - self._upsample_start_idx >= 0:
+ stride = (self._upsample_strides[i - self._upsample_start_idx])
+ if stride > 1:
+ deblock = Sequential(
+ nn.ConvTranspose2d(
+ num_out_filters,
+ self._num_upsample_filters[i - self._upsample_start_idx],
+ stride,
+ stride=stride,
+ bias=False,
+ ),
+ build_norm_layer(
+ self._norm_cfg,
+ self._num_upsample_filters[i - self._upsample_start_idx],
+ )[1],
+ nn.ReLU(),
+ )
+ else:
+ stride = np.round(1 / stride).astype(np.int64)
+ deblock = Sequential(
+ nn.Conv2d(
+ num_out_filters,
+ self._num_upsample_filters[i - self._upsample_start_idx],
+ stride,
+ stride=stride,
+ bias=False,
+ ),
+ build_norm_layer(
+ self._norm_cfg,
+ self._num_upsample_filters[i - self._upsample_start_idx],
+ )[1],
+ nn.ReLU(),
+ )
+ deblocks.append(deblock)
+ self.blocks = nn.ModuleList(blocks)
+ self.deblocks = nn.ModuleList(deblocks)
+
+ logger.info("Finish RPN Initialization")
+
+ @property
+ def downsample_factor(self):
+ factor = np.prod(self._layer_strides)
+ if len(self._upsample_strides) > 0:
+ factor /= self._upsample_strides[-1]
+ return factor
+
+ def _make_layer(self, inplanes, planes, num_blocks, stride=1):
+
+ block = Sequential(
+ nn.ZeroPad2d(1),
+ nn.Conv2d(inplanes, planes, 3, stride=stride, bias=False),
+ build_norm_layer(self._norm_cfg, planes)[1],
+ # nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01),
+ nn.ReLU(),
+ )
+
+ for j in range(num_blocks):
+ block.add(nn.Conv2d(planes, planes, 3, padding=1, bias=False))
+ block.add(
+ build_norm_layer(self._norm_cfg, planes)[1],
+ # nn.BatchNorm2d(planes, eps=1e-3, momentum=0.01)
+ )
+ block.add(nn.ReLU())
+
+ return block, planes
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution="uniform")
+
+ def forward(self, x, example=None):
+ ups = []
+ for i in range(len(self.blocks)):
+ x = self.blocks[i](x)
+ if i - self._upsample_start_idx >= 0:
+ ups.append(self.deblocks[i - self._upsample_start_idx](x))
+ if len(ups) > 0:
+ x = torch.cat(ups, dim=1)
+
+ return x
+
diff --git a/det3d/models/necks/rpn_transformer.py b/det3d/models/necks/rpn_transformer.py
new file mode 100644
index 0000000..1c84c64
--- /dev/null
+++ b/det3d/models/necks/rpn_transformer.py
@@ -0,0 +1,1092 @@
+import time
+import numpy as np
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+
+from det3d.torchie.cnn import xavier_init
+from det3d.models.utils import Sequential
+from det3d.models.utils import Transformer, Deform_Transformer
+
+from .. import builder
+from ..registry import NECKS
+from ..utils import build_norm_layer
+
+
+class ChannelAttention(nn.Module):
+ def __init__(self, in_planes, ratio=16):
+ super(ChannelAttention, self).__init__()
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
+
+ self.fc = nn.Sequential(
+ nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
+ nn.ReLU(),
+ nn.Conv2d(in_planes // 16, in_planes, 1, bias=False),
+ )
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = self.fc(self.avg_pool(x))
+ max_out = self.fc(self.max_pool(x))
+ out = avg_out + max_out
+ return self.sigmoid(out) * x
+
+
+class SpatialAttention(nn.Module):
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention, self).__init__()
+
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ avg_out = torch.mean(x, dim=1, keepdim=True)
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
+ y = torch.cat([avg_out, max_out], dim=1)
+ y = self.conv1(y)
+ return self.sigmoid(y) * x
+
+
+class SpatialAttention_mtf(nn.Module):
+ def __init__(self, kernel_size=7):
+ super(SpatialAttention_mtf, self).__init__()
+
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, curr, prev):
+ avg_out = torch.mean(curr, dim=1, keepdim=True)
+ max_out, _ = torch.max(curr, dim=1, keepdim=True)
+ y = torch.cat([avg_out, max_out], dim=1)
+ y = self.conv1(y)
+ return self.sigmoid(y) * prev
+
+
+@NECKS.register_module
+class RPN_transformer_base(nn.Module):
+ def __init__(
+ self,
+ layer_nums, # [2,2,2]
+ ds_num_filters, # [128,256,64]
+ num_input_features, # 256
+ transformer_config=None,
+ hm_head_layer=2,
+ corner_head_layer=2,
+ corner=False,
+ assign_label_window_size=1,
+ classes=3,
+ use_gt_training=False,
+ norm_cfg=None,
+ logger=None,
+ init_bias=-2.19,
+ score_threshold=0.1,
+ obj_num=500,
+ **kwargs
+ ):
+ super(RPN_transformer_base, self).__init__()
+ self._layer_strides = [1, 2, -4]
+ self._num_filters = ds_num_filters
+ self._layer_nums = layer_nums
+ self._num_input_features = num_input_features
+ self.score_threshold = score_threshold
+ self.transformer_config = transformer_config
+ self.corner = corner
+ self.obj_num = obj_num
+ self.use_gt_training = use_gt_training
+ self.window_size = assign_label_window_size**2
+ self.cross_attention_kernel_size = [3, 3, 3]
+ self.batch_id = None
+
+ if norm_cfg is None:
+ norm_cfg = dict(type="BN", eps=1e-3, momentum=0.01)
+ self._norm_cfg = norm_cfg
+
+ assert len(self._layer_strides) == len(self._layer_nums)
+ assert len(self._num_filters) == len(self._layer_nums)
+ assert self.transformer_config is not None
+
+ in_filters = [
+ self._num_input_features,
+ self._num_filters[0],
+ self._num_filters[1],
+ ]
+ blocks = []
+
+ for i, layer_num in enumerate(self._layer_nums):
+ block, num_out_filters = self._make_layer(
+ in_filters[i],
+ self._num_filters[i],
+ layer_num,
+ stride=self._layer_strides[i],
+ )
+ blocks.append(block)
+ self.blocks = nn.ModuleList(blocks)
+ self.up = Sequential(
+ nn.ConvTranspose2d(
+ self._num_filters[0], self._num_filters[2], 2, stride=2, bias=False
+ ),
+ build_norm_layer(self._norm_cfg, self._num_filters[2])[1],
+ nn.ReLU(),
+ )
+ # heatmap prediction
+ self.hm_head = Sequential()
+ for i in range(hm_head_layer - 1):
+ self.hm_head.add(
+ nn.Conv2d(
+ self._num_filters[-1] * 2,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+ )
+ self.hm_head.add(build_norm_layer(self._norm_cfg, 64)[1])
+ self.hm_head.add(nn.ReLU())
+
+ self.hm_head.add(
+ nn.Conv2d(64, classes, kernel_size=3, stride=1, padding=1, bias=True)
+ )
+ self.hm_head[-1].bias.data.fill_(init_bias)
+
+ if self.corner:
+ self.corner_head = Sequential()
+ for i in range(corner_head_layer - 1):
+ self.corner_head.add(
+ nn.Conv2d(
+ self._num_filters[-1] * 2,
+ 64,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=True,
+ )
+ )
+ self.corner_head.add(build_norm_layer(self._norm_cfg, 64)[1])
+ self.corner_head.add(nn.ReLU())
+
+ self.corner_head.add(
+ nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1, bias=True)
+ )
+ self.corner_head[-1].bias.data.fill_(init_bias)
+
+ def _make_layer(self, inplanes, planes, num_blocks, stride=1):
+
+ if stride > 0:
+ block = Sequential(
+ nn.ZeroPad2d(1),
+ nn.Conv2d(inplanes, planes, 3, stride=stride, bias=False),
+ build_norm_layer(self._norm_cfg, planes)[1],
+ nn.ReLU(),
+ )
+ else:
+ block = Sequential(
+ nn.ConvTranspose2d(
+ inplanes, planes, -stride, stride=-stride, bias=False
+ ),
+ build_norm_layer(self._norm_cfg, planes)[1],
+ nn.ReLU(),
+ )
+
+ for j in range(num_blocks):
+ block.add(nn.Conv2d(planes, planes, 3, padding=1, bias=False))
+ block.add(
+ build_norm_layer(self._norm_cfg, planes)[1],
+ )
+ block.add(nn.ReLU())
+
+ block.add(ChannelAttention(planes))
+ block.add(SpatialAttention())
+
+ return block, planes
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution="uniform")
+
+ def forward(self, x, example=None):
+ pass
+
+ def get_multi_scale_feature(self, center_pos, feats):
+ """
+ Args:
+ center_pos: center coor at the lowest scale feature map [B 500 2]
+ feats: multi scale BEV feature 3*[B C H W]
+ Returns:
+ neighbor_feat: [B 500 K C]
+ neighbor_pos: [B 500 K 2]
+ """
+ kernel_size = self.cross_attention_kernel_size
+ batch, num_cls, H, W = feats[0].size()
+
+ center_num = center_pos.shape[1]
+
+ relative_pos_list = []
+ neighbor_feat_list = []
+ for i, k in enumerate(kernel_size):
+ neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1)
+ neighbor_coords = torch.flatten(
+ torch.stack(torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0),
+ 1,
+ ) # [2, k]
+ neighbor_coords = (
+ neighbor_coords.permute(1, 0).contiguous().to(center_pos)
+ ) # relative coordinate [k, 2]
+ neighbor_coords = (
+ center_pos[:, :, None, :] // (2**i)
+ + neighbor_coords[None, None, :, :]
+ ) # coordinates [B, 500, k, 2]
+ neighbor_coords = torch.clamp(
+ neighbor_coords, min=0, max=H // (2**i) - 1
+ ) # prevent out of bound
+ feat_id = (
+ neighbor_coords[:, :, :, 1] * (W // (2**i))
+ + neighbor_coords[:, :, :, 0]
+ ) # pixel id [B, 500, k]
+ feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k]
+ # selected_feat = torch.gather(feats[i].reshape(batch, num_cls,(H*W)//(4**i)).permute(0, 2, 1).contiguous(),1,feat_id)
+ selected_feat = (
+ feats[i]
+ .reshape(batch, num_cls, (H * W) // (4**i))
+ .permute(0, 2, 1)
+ .contiguous()[self.batch_id.repeat(1, k**2), feat_id]
+ ) # B, 500*k, C
+ neighbor_feat_list.append(
+ selected_feat.reshape(batch, center_num, -1, num_cls)
+ ) # B, 500, k, C
+ relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2
+ # relative_pos_list.append(F.pad(neighbor_coords*(2**i), (0,1), "constant", i)) # B, 500, k, 3
+
+ neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3
+ neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C
+ return neighbor_feats, neighbor_pos
+
+ def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe):
+ """
+ Args:
+ center_pos: center coor at the lowest scale feature map [B 500 2]
+ feats: multi scale BEV feature (3+k)*[B C H W]
+ timeframe: timeframe [B,k]
+ Returns:
+ neighbor_feat: [B 500 K C]
+ neighbor_pos: [B 500 K 2]
+ neighbor_time: [B 500 K 1]
+ """
+ kernel_size = self.cross_attention_kernel_size
+ batch, num_cls, H, W = feats[0].size()
+
+ center_num = center_pos.shape[1]
+
+ relative_pos_list = []
+ neighbor_feat_list = []
+ timeframe_list = []
+ for i, k in enumerate(kernel_size):
+ neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1)
+ neighbor_coords = torch.flatten(
+ torch.stack(torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0),
+ 1,
+ ) # [2, k]
+ neighbor_coords = (
+ neighbor_coords.permute(1, 0).contiguous().to(center_pos)
+ ) # relative coordinate [k, 2]
+ neighbor_coords = (
+ center_pos[:, :, None, :] // (2**i)
+ + neighbor_coords[None, None, :, :]
+ ) # coordinates [B, 500, k, 2]
+ neighbor_coords = torch.clamp(
+ neighbor_coords, min=0, max=H // (2**i) - 1
+ ) # prevent out of bound
+ feat_id = (
+ neighbor_coords[:, :, :, 1] * (W // (2**i))
+ + neighbor_coords[:, :, :, 0]
+ ) # pixel id [B, 500, k]
+ feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k]
+ selected_feat = (
+ feats[i]
+ .reshape(batch, num_cls, (H * W) // (4**i))
+ .permute(0, 2, 1)
+ .contiguous()[self.batch_id.repeat(1, k**2), feat_id]
+ ) # B, 500*k, C
+ neighbor_feat_list.append(
+ selected_feat.reshape(batch, center_num, -1, num_cls)
+ ) # B, 500, k, C
+ relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2
+ timeframe_list.append(
+ torch.full_like(neighbor_coords[:, :, :, 0:1], 0)
+ ) # B, 500, k
+ if i == 0:
+ # add previous frame feature
+ for frame_num in range(feats[-1].shape[1]):
+ selected_feat = (
+ feats[-1][:, frame_num, :, :, :]
+ .reshape(batch, num_cls, (H * W) // (4**i))
+ .permute(0, 2, 1)
+ .contiguous()[self.batch_id.repeat(1, k**2), feat_id]
+ ) # B, 500*k, C
+ neighbor_feat_list.append(
+ selected_feat.reshape(batch, center_num, -1, num_cls)
+ )
+ relative_pos_list.append(neighbor_coords * (2**i))
+ time = timeframe[:, frame_num + 1].to(selected_feat) # B
+ timeframe_list.append(
+ time[:, None, None, None]
+ * torch.full_like(neighbor_coords[:, :, :, 0:1], 1)
+ ) # B, 500, k
+
+ neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3
+ neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C
+ neighbor_time = torch.cat(timeframe_list, dim=2) # B, 500, K, 1
+
+ return neighbor_feats, neighbor_pos, neighbor_time
+
+
+@NECKS.register_module
+class RPN_transformer(RPN_transformer_base):
+ def __init__(
+ self,
+ layer_nums, # [2,2,2]
+ ds_num_filters, # [128,256,64]
+ num_input_features, # 256
+ transformer_config=None,
+ hm_head_layer=2,
+ corner_head_layer=2,
+ corner=False,
+ assign_label_window_size=1,
+ classes=3,
+ use_gt_training=False,
+ norm_cfg=None,
+ name="rpn_transformer",
+ logger=None,
+ init_bias=-2.19,
+ score_threshold=0.1,
+ obj_num=500,
+ parametric_embedding=False,
+ **kwargs
+ ):
+ super(RPN_transformer, self).__init__(
+ layer_nums,
+ ds_num_filters,
+ num_input_features,
+ transformer_config,
+ hm_head_layer,
+ corner_head_layer,
+ corner,
+ assign_label_window_size,
+ classes,
+ use_gt_training,
+ norm_cfg,
+ logger,
+ init_bias,
+ score_threshold,
+ obj_num,
+ )
+
+ self.transformer_layer = Transformer(
+ self._num_filters[-1] * 2,
+ depth=transformer_config.depth,
+ heads=transformer_config.heads,
+ dim_head=transformer_config.dim_head,
+ mlp_dim=transformer_config.MLP_dim,
+ dropout=transformer_config.DP_rate,
+ out_attention=transformer_config.out_att,
+ )
+ self.pos_embedding_type = transformer_config.get(
+ "pos_embedding_type", "linear"
+ )
+ if self.pos_embedding_type == "linear":
+ self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
+ elif self.pos_embedding_type == "none":
+ self.pos_embedding = None
+ else:
+ raise NotImplementedError()
+ self.cross_attention_kernel_size = transformer_config.cross_attention_kernel_size
+ self.parametric_embedding = parametric_embedding
+ if self.parametric_embedding:
+ self.query_embed = nn.Embedding(self.obj_num, self._num_filters[-1] * 2)
+ nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
+
+ logger.info("Finish RPN_transformer Initialization")
+
+ def forward(self, x, example=None):
+
+ # FPN
+ x = self.blocks[0](x)
+ x_down = self.blocks[1](x)
+ x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
+
+ # heatmap head
+ hm = self.hm_head(x_up)
+
+ if self.corner and self.corner_head.training:
+ corner_hm = self.corner_head(x_up)
+ corner_hm = torch.sigmoid(corner_hm)
+
+ # find top K center location
+ hm = torch.sigmoid(hm)
+ batch, num_cls, H, W = hm.size()
+
+ scores, labels = torch.max(hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
+
+ self.batch_id = torch.from_numpy(np.indices((batch, self.obj_num))[0]).to(
+ labels
+ )
+
+ if self.use_gt_training and self.hm_head.training:
+ gt_inds = example["ind"][0][:, (self.window_size // 2) :: self.window_size]
+ gt_masks = example["mask"][0][
+ :, (self.window_size // 2) :: self.window_size
+ ]
+ batch_id_gt = torch.from_numpy(np.indices((batch, gt_inds.shape[1]))[0]).to(
+ labels
+ )
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks
+ else:
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+
+ scores = torch.gather(scores, 1, order)
+ labels = torch.gather(labels, 1, order)
+ mask = scores > self.score_threshold
+
+ ct_feat = (
+ x_up.reshape(batch, -1, H * W)
+ .transpose(2, 1)
+ .contiguous()[self.batch_id, order]
+ ) # B, 500, C
+
+ # create position embedding for each center
+ y_coor = order // W
+ # x_coor = order - y_coor*W
+ x_coor = order % W
+ pos_features = torch.stack([x_coor, y_coor], dim=2)
+
+ if self.parametric_embedding:
+ ct_feat = self.query_embed.weight
+ ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
+
+ # run transformer
+ neighbor_feat, neighbor_pos = self.get_multi_scale_feature(
+ pos_features, [x_up, x, x_down]
+ )
+
+ transformer_out = self.transformer_layer(
+ ct_feat,
+ pos_embedding=self.pos_embedding,
+ center_pos=pos_features.to(ct_feat),
+ y=neighbor_feat,
+ neighbor_pos=neighbor_pos.to(ct_feat),
+ ) # (B,N,C)
+
+ ct_feat = (
+ transformer_out["ct_feat"].transpose(2, 1).contiguous()
+ ) # B, C, 500
+
+ out_dict = {}
+ out_dict.update(
+ {
+ "hm": hm,
+ "scores": scores,
+ "labels": labels,
+ "order": order,
+ "ct_feat": ct_feat,
+ "mask": mask,
+ "BEV_feat": x_up,
+ "H": H,
+ "W": W,
+ }
+ )
+ if self.corner and self.corner_head.training:
+ out_dict.update({"corner_hm": corner_hm})
+
+ return out_dict
+
+
+@NECKS.register_module
+class RPN_transformer_multiframe(RPN_transformer_base):
+ def __init__(
+ self,
+ layer_nums, # [2,2,2]
+ ds_num_filters, # [128,256,64]
+ num_input_features, # 256
+ transformer_config=None,
+ hm_head_layer=2,
+ corner_head_layer=2,
+ corner=False,
+ assign_label_window_size=1,
+ classes=3,
+ use_gt_training=False,
+ norm_cfg=None,
+ name="rpn_transformer_multiframe",
+ logger=None,
+ init_bias=-2.19,
+ score_threshold=0.1,
+ obj_num=500,
+ frame=1,
+ **kwargs
+ ):
+ super(RPN_transformer_multiframe, self).__init__(
+ layer_nums,
+ ds_num_filters,
+ num_input_features,
+ transformer_config,
+ hm_head_layer,
+ corner_head_layer,
+ corner,
+ assign_label_window_size,
+ classes,
+ use_gt_training,
+ norm_cfg,
+ logger,
+ init_bias,
+ score_threshold,
+ obj_num,
+ )
+ self.frame = frame
+
+ self.out = Sequential(
+ nn.Conv2d(
+ self._num_filters[-1] * 2 * frame,
+ self._num_filters[-1] * 2,
+ 3,
+ padding=1,
+ bias=False,
+ ),
+ build_norm_layer(self._norm_cfg, self._num_filters[-1] * 2)[1],
+ nn.ReLU(),
+ )
+ self.mtf_attention = SpatialAttention_mtf()
+ self.time_embedding = nn.Linear(1, self._num_filters[-1] * 2)
+
+ self.transformer_layer = Transformer(
+ self._num_filters[-1] * 2,
+ depth=transformer_config.depth,
+ heads=transformer_config.heads,
+ dim_head=transformer_config.dim_head,
+ mlp_dim=transformer_config.MLP_dim,
+ dropout=transformer_config.DP_rate,
+ out_attention=transformer_config.out_att,
+ )
+ self.pos_embedding_type = transformer_config.get(
+ "pos_embedding_type", "linear"
+ )
+ if self.pos_embedding_type == "linear":
+ self.pos_embedding = nn.Linear(3, self._num_filters[-1] * 2)
+ else:
+ raise NotImplementedError()
+ self.cross_attention_kernel_size = transformer_config.cross_attention_kernel_size
+
+ logger.info("Finish RPN_transformer Initialization")
+
+ def forward(self, x, example=None):
+ # FPN
+ x = self.blocks[0](x)
+ x_down = self.blocks[1](x)
+ x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
+
+ # take out the BEV feature on current frame
+ x = torch.split(x, self.frame)
+ x_up = torch.split(x_up, self.frame)
+ x_down = torch.split(x_down, self.frame)
+ x_prev = torch.stack([t[1:] for t in x_up], dim=0) # B,K,C,H,W
+ x = torch.stack([t[0] for t in x], dim=0)
+ x_down = torch.stack([t[0] for t in x_down], dim=0)
+
+ x_up = torch.stack([t[0] for t in x_up], dim=0) # B,C,H,W
+ # use spatial attention in current frame on previous feature
+ x_prev_cat = self.mtf_attention(
+ x_up, x_prev.reshape(x_up.shape[0], -1, x_up.shape[2], x_up.shape[3])
+ ) # B,K*C,H,W
+ # time embedding
+ x_up_fuse = torch.cat((x_up, x_prev_cat), dim=1) + self.time_embedding(
+ example["times"][:, :, None].to(x_up)
+ ).reshape(x_up.shape[0], -1, 1, 1)
+ # fuse mtf feature
+ x_up_fuse = self.out(x_up_fuse)
+
+ # heatmap head
+ hm = self.hm_head(x_up_fuse)
+
+ if self.corner and self.corner_head.training:
+ corner_hm = self.corner_head(x_up_fuse)
+ corner_hm = torch.sigmoid(corner_hm)
+
+ # find top K center location
+ hm = torch.sigmoid(hm)
+ batch, num_cls, H, W = hm.size()
+
+ scores, labels = torch.max(hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
+
+ self.batch_id = torch.from_numpy(np.indices((batch, self.obj_num))[0]).to(
+ labels
+ )
+
+ if self.use_gt_training and self.hm_head.training:
+ gt_inds = example["ind"][0][:, (self.window_size // 2) :: self.window_size]
+ gt_masks = example["mask"][0][
+ :, (self.window_size // 2) :: self.window_size
+ ]
+ batch_id_gt = torch.from_numpy(np.indices((batch, gt_inds.shape[1]))[0]).to(
+ labels
+ )
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks
+ else:
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+
+ scores = torch.gather(scores, 1, order)
+ labels = torch.gather(labels, 1, order)
+ mask = scores > self.score_threshold
+
+ ct_feat = (
+ x_up.reshape(batch, -1, H * W)
+ .transpose(2, 1)
+ .contiguous()[self.batch_id, order]
+ ) # B, 500, C
+
+ # create position embedding for each center
+ y_coor = order // W
+ x_coor = order % W
+ pos_features = torch.stack([x_coor, y_coor], dim=2)
+
+ # run transformer
+ neighbor_feat, neighbor_pos, neighbor_time = self.get_multi_scale_feature_multiframe(
+ pos_features, [x_up, x, x_down, x_prev], example["times"]
+ )
+ neighbor_pos = torch.cat((neighbor_pos,neighbor_time),dim=3)
+ pos_features = F.pad(pos_features,(0,1),"constant", 0)
+
+ transformer_out = self.transformer_layer(
+ ct_feat,
+ pos_embedding=self.pos_embedding,
+ center_pos=pos_features.to(ct_feat),
+ y=neighbor_feat,
+ neighbor_pos=neighbor_pos.to(ct_feat),
+ ) # (B,N,C)
+
+ ct_feat = (
+ transformer_out["ct_feat"].transpose(2, 1).contiguous()
+ ) # B, C, 500
+
+ out_dict = {}
+ out_dict.update(
+ {
+ "hm": hm,
+ "scores": scores,
+ "labels": labels,
+ "order": order,
+ "ct_feat": ct_feat,
+ "mask": mask,
+ "BEV_feat": x_up,
+ }
+ )
+ if self.corner and self.corner_head.training:
+ out_dict.update({"corner_hm": corner_hm})
+
+ return out_dict
+
+
+@NECKS.register_module
+class RPN_transformer_deformable(RPN_transformer_base):
+ def __init__(
+ self,
+ layer_nums, # [2,2,2]
+ ds_num_filters, # [128,256,64]
+ num_input_features, # 256
+ transformer_config=None,
+ hm_head_layer=2,
+ corner_head_layer=2,
+ corner=False,
+ parametric_embedding=False,
+ assign_label_window_size=1,
+ classes=3,
+ use_gt_training=False,
+ norm_cfg=None,
+ name="rpn_transformer_deformable",
+ logger=None,
+ init_bias=-2.19,
+ score_threshold=0.1,
+ obj_num=500,
+ **kwargs
+ ):
+ super(RPN_transformer_deformable, self).__init__(
+ layer_nums,
+ ds_num_filters,
+ num_input_features,
+ transformer_config,
+ hm_head_layer,
+ corner_head_layer,
+ corner,
+ assign_label_window_size,
+ classes,
+ use_gt_training,
+ norm_cfg,
+ logger,
+ init_bias,
+ score_threshold,
+ obj_num,
+ )
+
+ self.transformer_layer = Deform_Transformer(
+ self._num_filters[-1] * 2,
+ depth=transformer_config.depth,
+ heads=transformer_config.heads,
+ dim_head=transformer_config.dim_head,
+ mlp_dim=transformer_config.MLP_dim,
+ dropout=transformer_config.DP_rate,
+ out_attention=transformer_config.out_att,
+ n_points=transformer_config.get("n_points", 9),
+ )
+ self.pos_embedding_type = transformer_config.get(
+ "pos_embedding_type", "linear"
+ )
+ if self.pos_embedding_type == "linear":
+ self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
+ else:
+ raise NotImplementedError()
+ self.parametric_embedding = parametric_embedding
+ if self.parametric_embedding:
+ self.query_embed = nn.Embedding(self.obj_num, self._num_filters[-1] * 2)
+ nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
+
+ logger.info("Finish RPN_transformer_deformable Initialization")
+
+ def forward(self, x, example=None):
+
+ # FPN
+ x = self.blocks[0](x)
+ x_down = self.blocks[1](x)
+ x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
+
+ # heatmap head
+ hm = self.hm_head(x_up)
+
+ if self.corner and self.corner_head.training:
+ corner_hm = self.corner_head(x_up)
+ corner_hm = torch.sigmoid(corner_hm)
+
+ # find top K center location
+ hm = torch.sigmoid(hm)
+ batch, num_cls, H, W = hm.size()
+
+ scores, labels = torch.max(hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
+ self.batch_id = torch.from_numpy(np.indices((batch, self.obj_num))[0]).to(
+ labels
+ )
+
+ if self.use_gt_training and self.hm_head.training:
+ gt_inds = example["ind"][0][:, (self.window_size // 2) :: self.window_size]
+ gt_masks = example["mask"][0][
+ :, (self.window_size // 2) :: self.window_size
+ ]
+ batch_id_gt = torch.from_numpy(np.indices((batch, gt_inds.shape[1]))[0]).to(
+ labels
+ )
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks
+ else:
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+
+ scores = torch.gather(scores, 1, order)
+ labels = torch.gather(labels, 1, order)
+ mask = scores > self.score_threshold
+
+ ct_feat = (
+ x_up.reshape(batch, -1, H * W)
+ .transpose(2, 1)
+ .contiguous()[self.batch_id, order]
+ ) # B, 500, C
+
+ # create position embedding for each center
+ y_coor = order // W
+ x_coor = order - y_coor * W
+ y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat)
+ y_coor, x_coor = y_coor / H, x_coor / W
+ pos_features = torch.stack([x_coor, y_coor], dim=2)
+
+ if self.parametric_embedding:
+ ct_feat = self.query_embed.weight
+ ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
+
+ # run transformer
+ src = torch.cat(
+ (
+ x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(),
+ x.reshape(batch, -1, (H * W) // 4).transpose(2, 1).contiguous(),
+ x_down.reshape(batch, -1, (H * W) // 16)
+ .transpose(2, 1)
+ .contiguous(),
+ ),
+ dim=1,
+ ) # B ,sum(H*W), C
+ spatial_shapes = torch.as_tensor(
+ [(H, W), (H // 2, W // 2), (H // 4, W // 4)],
+ dtype=torch.long,
+ device=ct_feat.device,
+ )
+ level_start_index = torch.cat(
+ (
+ spatial_shapes.new_zeros((1,)),
+ spatial_shapes.prod(1).cumsum(0)[:-1],
+ )
+ )
+
+ transformer_out = self.transformer_layer(
+ ct_feat,
+ self.pos_embedding,
+ src,
+ spatial_shapes,
+ level_start_index,
+ center_pos=pos_features,
+ ) # (B,N,C)
+
+ ct_feat = (
+ transformer_out["ct_feat"].transpose(2, 1).contiguous()
+ ) # B, C, 500
+
+ out_dict = {
+ "hm": hm,
+ "scores": scores,
+ "labels": labels,
+ "order": order,
+ "ct_feat": ct_feat,
+ "mask": mask,
+ }
+ if "out_attention" in transformer_out:
+ out_dict.update({"out_attention": transformer_out["out_attention"]})
+ if self.corner and self.corner_head.training:
+ out_dict.update({"corner_hm": corner_hm})
+
+ return out_dict
+
+
+@NECKS.register_module
+class RPN_transformer_deformable_mtf(RPN_transformer_base):
+ def __init__(
+ self,
+ layer_nums, # [2,2,2]
+ ds_num_filters, # [128,256,64]
+ num_input_features, # 256
+ transformer_config=None,
+ hm_head_layer=2,
+ corner_head_layer=2,
+ corner=False,
+ parametric_embedding=False,
+ assign_label_window_size=1,
+ classes=3,
+ use_gt_training=False,
+ norm_cfg=None,
+ name="rpn_transformer_deformable_mtf",
+ logger=None,
+ init_bias=-2.19,
+ score_threshold=0.1,
+ obj_num=500,
+ frame=1,
+ **kwargs
+ ):
+ super(RPN_transformer_deformable_mtf, self).__init__(
+ layer_nums,
+ ds_num_filters,
+ num_input_features,
+ transformer_config,
+ hm_head_layer,
+ corner_head_layer,
+ corner,
+ assign_label_window_size,
+ classes,
+ use_gt_training,
+ norm_cfg,
+ logger,
+ init_bias,
+ score_threshold,
+ obj_num,
+ )
+ self.frame = frame
+
+ self.out = Sequential(
+ nn.Conv2d(
+ self._num_filters[0] * frame,
+ self._num_filters[0],
+ 3,
+ padding=1,
+ bias=False,
+ ),
+ build_norm_layer(self._norm_cfg, self._num_filters[0])[1],
+ nn.ReLU(),
+ )
+ self.mtf_attention = SpatialAttention_mtf()
+ self.time_embedding = nn.Linear(1, self._num_filters[0])
+
+ self.transformer_layer = Deform_Transformer(
+ self._num_filters[-1] * 2,
+ depth=transformer_config.depth,
+ heads=transformer_config.heads,
+ levels=2 + self.frame,
+ dim_head=transformer_config.dim_head,
+ mlp_dim=transformer_config.MLP_dim,
+ dropout=transformer_config.DP_rate,
+ out_attention=transformer_config.out_att,
+ n_points=transformer_config.get("n_points", 9),
+ )
+ self.pos_embedding_type = transformer_config.get(
+ "pos_embedding_type", "linear"
+ )
+ if self.pos_embedding_type == "linear":
+ self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2)
+ else:
+ raise NotImplementedError()
+ self.parametric_embedding = parametric_embedding
+ if self.parametric_embedding:
+ self.query_embed = nn.Embedding(self.obj_num, self._num_filters[-1] * 2)
+ nn.init.uniform_(self.query_embed.weight, -1.0, 1.0)
+
+ logger.info("Finish RPN_transformer_deformable Initialization")
+
+ def forward(self, x, example=None):
+
+ # FPN
+ x = self.blocks[0](x)
+ x_down = self.blocks[1](x)
+ x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1)
+
+ # take out the BEV feature on current frame
+ x = torch.split(x, self.frame)
+ x_up = torch.split(x_up, self.frame)
+ x_down = torch.split(x_down, self.frame)
+ x_prev = torch.stack([t[1:] for t in x_up], dim=0) # B,K,C,H,W
+ x = torch.stack([t[0] for t in x], dim=0)
+ x_down = torch.stack([t[0] for t in x_down], dim=0)
+
+ x_up = torch.stack([t[0] for t in x_up], dim=0) # B,C,H,W
+ # use spatial attention in current frame on previous feature
+ x_prev_cat = self.mtf_attention(
+ x_up, x_prev.reshape(x_up.shape[0], -1, x_up.shape[2], x_up.shape[3])
+ ) # B,K*C,H,W
+ # time embedding
+ x_up_fuse = torch.cat((x_up, x_prev_cat), dim=1) + self.time_embedding(
+ example["times"][:, :, None].to(x_up)
+ ).reshape(x_up.shape[0], -1, 1, 1)
+ # fuse mtf feature
+ x_up_fuse = self.out(x_up_fuse)
+
+ # heatmap head
+ hm = self.hm_head(x_up_fuse)
+
+ if self.corner and self.corner_head.training:
+ corner_hm = self.corner_head(x_up_fuse)
+ corner_hm = torch.sigmoid(corner_hm)
+
+ # find top K center location
+ hm = torch.sigmoid(hm)
+ batch, num_cls, H, W = hm.size()
+
+ scores, labels = torch.max(hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W
+ self.batch_id = torch.from_numpy(np.indices((batch, self.obj_num))[0]).to(
+ labels
+ )
+
+ if self.use_gt_training and self.hm_head.training:
+ gt_inds = example["ind"][0][:, (self.window_size // 2) :: self.window_size]
+ gt_masks = example["mask"][0][
+ :, (self.window_size // 2) :: self.window_size
+ ]
+ batch_id_gt = torch.from_numpy(np.indices((batch, gt_inds.shape[1]))[0]).to(
+ labels
+ )
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+ scores[batch_id_gt, gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks
+ else:
+ order = scores.sort(1, descending=True)[1]
+ order = order[:, : self.obj_num]
+
+ scores = torch.gather(scores, 1, order)
+ labels = torch.gather(labels, 1, order)
+ mask = scores > self.score_threshold
+
+ ct_feat = (
+ x_up.reshape(batch, -1, H * W)
+ .transpose(2, 1)
+ .contiguous()[self.batch_id, order]
+ ) # B, 500, C
+
+ # create position embedding for each center
+ y_coor = order // W
+ x_coor = order - y_coor * W
+ y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat)
+ y_coor, x_coor = y_coor / H, x_coor / W
+ pos_features = torch.stack([x_coor, y_coor], dim=2)
+
+ if self.parametric_embedding:
+ ct_feat = self.query_embed.weight
+ ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1)
+
+ # run transformer
+ src_list = [
+ x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(),
+ x.reshape(batch, -1, (H * W) // 4).transpose(2, 1).contiguous(),
+ x_down.reshape(batch, -1, (H * W) // 16)
+ .transpose(2, 1)
+ .contiguous(),
+ ]
+ for frame in range(x_prev.shape[1]):
+ src_list.append(
+ x_prev[:, frame]
+ .reshape(batch, -1, (H * W))
+ .transpose(2, 1)
+ .contiguous()
+ )
+ src = torch.cat(src_list, dim=1) # B ,sum(H*W), C
+ spatial_list = [(H, W), (H // 2, W // 2), (H // 4, W // 4)]
+ spatial_list += [(H, W) for frame in range(x_prev.shape[1])]
+ spatial_shapes = torch.as_tensor(
+ spatial_list, dtype=torch.long, device=ct_feat.device
+ )
+ level_start_index = torch.cat(
+ (
+ spatial_shapes.new_zeros((1,)),
+ spatial_shapes.prod(1).cumsum(0)[:-1],
+ )
+ )
+
+ transformer_out = self.transformer_layer(
+ ct_feat,
+ self.pos_embedding,
+ src,
+ spatial_shapes,
+ level_start_index,
+ center_pos=pos_features,
+ ) # (B,N,C)
+
+ ct_feat = (
+ transformer_out["ct_feat"].transpose(2, 1).contiguous()
+ ) # B, C, 500
+
+ out_dict = {
+ "hm": hm,
+ "scores": scores,
+ "labels": labels,
+ "order": order,
+ "ct_feat": ct_feat,
+ "mask": mask,
+ }
+ if "out_attention" in transformer_out:
+ out_dict.update({"out_attention": transformer_out["out_attention"]})
+ if self.corner and self.corner_head.training:
+ out_dict.update({"corner_hm": corner_hm})
+
+ return out_dict
diff --git a/det3d/models/ops/functions/__init__.py b/det3d/models/ops/functions/__init__.py
new file mode 100644
index 0000000..8a2197b
--- /dev/null
+++ b/det3d/models/ops/functions/__init__.py
@@ -0,0 +1,10 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn_func import MSDeformAttnFunction
+
diff --git a/det3d/models/ops/functions/ms_deform_attn_func.py b/det3d/models/ops/functions/ms_deform_attn_func.py
new file mode 100644
index 0000000..59d7f1c
--- /dev/null
+++ b/det3d/models/ops/functions/ms_deform_attn_func.py
@@ -0,0 +1,61 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import torch
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+import MultiScaleDeformableAttention as MSDA
+
+
+class MSDeformAttnFunction(Function):
+ @staticmethod
+ def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
+ ctx.im2col_step = im2col_step
+ output = MSDA.ms_deform_attn_forward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
+ ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
+ grad_value, grad_sampling_loc, grad_attn_weight = \
+ MSDA.ms_deform_attn_backward(
+ value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
+
+ return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
+
+
+def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
+ # for debug and test only,
+ # need to use cuda version instead
+ N_, S_, M_, D_ = value.shape
+ _, Lq_, M_, L_, P_, _ = sampling_locations.shape
+ value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+ sampling_grids = 2 * sampling_locations - 1
+ sampling_value_list = []
+ for lid_, (H_, W_) in enumerate(value_spatial_shapes):
+ # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+ value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
+ # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+ sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
+ # N_*M_, D_, Lq_, P_
+ sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
+ mode='bilinear', padding_mode='zeros')
+ sampling_value_list.append(sampling_value_l_)
+ # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
+ attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
+ output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
+ return output.transpose(1, 2).contiguous()
diff --git a/det3d/models/ops/make.sh b/det3d/models/ops/make.sh
new file mode 100644
index 0000000..106b685
--- /dev/null
+++ b/det3d/models/ops/make.sh
@@ -0,0 +1,10 @@
+#!/usr/bin/env bash
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+python setup.py build install
diff --git a/det3d/models/ops/modules/__init__.py b/det3d/models/ops/modules/__init__.py
new file mode 100644
index 0000000..f82cb1a
--- /dev/null
+++ b/det3d/models/ops/modules/__init__.py
@@ -0,0 +1,9 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from .ms_deform_attn import MSDeformAttn
diff --git a/det3d/models/ops/modules/ms_deform_attn.py b/det3d/models/ops/modules/ms_deform_attn.py
new file mode 100644
index 0000000..112ecfa
--- /dev/null
+++ b/det3d/models/ops/modules/ms_deform_attn.py
@@ -0,0 +1,122 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import warnings
+import math
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_uniform_, constant_
+
+from ..functions import MSDeformAttnFunction
+
+
+def _is_power_of_2(n):
+ if (not isinstance(n, int)) or (n < 0):
+ raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
+ return (n & (n-1) == 0) and n != 0
+
+
+class MSDeformAttn(nn.Module):
+ def __init__(self, d_model=256, d_head = 64, n_levels=4, n_heads=8, n_points=4, out_sample_loc=False):
+ """
+ Multi-Scale Deformable Attention Module
+ :param d_model hidden dimension
+ :param n_levels number of feature levels
+ :param n_heads number of attention heads
+ :param n_points number of sampling points per attention head per feature level
+ """
+ super().__init__()
+ # if d_model % n_heads != 0:
+ # raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
+ # _d_per_head = d_model // n_heads
+ # # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
+ # if not _is_power_of_2(_d_per_head):
+ # warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
+ # "which is more efficient in our CUDA implementation.")
+
+ self.im2col_step = 64
+
+ self.d_model = d_model
+ self.d_head = d_head
+ self.n_levels = n_levels
+ self.n_heads = n_heads
+ self.n_points = n_points
+
+ self.out_sample_loc = out_sample_loc
+
+ self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
+ self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
+ self.value_proj = nn.Linear(d_model, d_head*n_heads)
+ self.output_proj = nn.Linear(d_head*n_heads, d_model)
+
+ self._reset_parameters()
+
+ def _reset_parameters(self):
+ constant_(self.sampling_offsets.weight.data, 0.)
+ thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+ grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
+ for i in range(self.n_points):
+ grid_init[:, :, i, :] *= i + 1
+ with torch.no_grad():
+ self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+ constant_(self.attention_weights.weight.data, 0.)
+ constant_(self.attention_weights.bias.data, 0.)
+ xavier_uniform_(self.value_proj.weight.data)
+ constant_(self.value_proj.bias.data, 0.)
+ xavier_uniform_(self.output_proj.weight.data)
+ constant_(self.output_proj.bias.data, 0.)
+
+ def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
+ """
+ :param query (N, Length_{query}, C)
+ :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
+ or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
+ :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
+ :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+ :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
+ :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
+
+ :return output (N, Length_{query}, C)
+ """
+ N, Len_q, _ = query.shape
+ N, Len_in, _ = input_flatten.shape
+ assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
+
+ value = self.value_proj(input_flatten)
+ if input_padding_mask is not None:
+ value = value.masked_fill(input_padding_mask[..., None], float(0))
+ value = value.view(N, Len_in, self.n_heads, self.d_head)
+ sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
+ attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
+ attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
+ # N, Len_q, n_heads, n_levels, n_points, 2
+ if reference_points.shape[-1] == 2:
+ offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1).to(sampling_offsets)
+
+ sampling_locations = reference_points[:, :, None, :, None, :] \
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
+ elif reference_points.shape[-1] == 4:
+ sampling_locations = reference_points[:, :, None, :, None, :2] \
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
+ else:
+ raise ValueError(
+ 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
+ output = MSDeformAttnFunction.apply(
+ value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
+ output = self.output_proj(output)
+ if self.out_sample_loc:
+ return output, torch.cat((sampling_locations,attention_weights[:,:,:,:,:,None]),dim=-1)
+ else:
+ return output, None
diff --git a/det3d/models/ops/setup.py b/det3d/models/ops/setup.py
new file mode 100644
index 0000000..a0131bc
--- /dev/null
+++ b/det3d/models/ops/setup.py
@@ -0,0 +1,71 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+import os
+import glob
+
+import torch
+
+from torch.utils.cpp_extension import CUDA_HOME
+from torch.utils.cpp_extension import CppExtension
+from torch.utils.cpp_extension import CUDAExtension
+
+from setuptools import find_packages
+from setuptools import setup
+
+requirements = ["torch", "torchvision"]
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": []}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-DCUDA_HAS_FP16=1",
+ "-D__CUDA_NO_HALF_OPERATORS__",
+ "-D__CUDA_NO_HALF_CONVERSIONS__",
+ "-D__CUDA_NO_HALF2_OPERATORS__",
+ ]
+ else:
+ raise NotImplementedError('Cuda is not availabel')
+
+ sources = [os.path.join(extensions_dir, s) for s in sources]
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "MultiScaleDeformableAttention",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+ return ext_modules
+
+setup(
+ name="MultiScaleDeformableAttention",
+ version="1.0",
+ author="Weijie Su",
+ url="https://github.com/fundamentalvision/Deformable-DETR",
+ description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
+ packages=find_packages(exclude=("configs", "tests",)),
+ ext_modules=get_extensions(),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/det3d/models/ops/src/cpu/ms_deform_attn_cpu.cpp b/det3d/models/ops/src/cpu/ms_deform_attn_cpu.cpp
new file mode 100644
index 0000000..e1bf854
--- /dev/null
+++ b/det3d/models/ops/src/cpu/ms_deform_attn_cpu.cpp
@@ -0,0 +1,41 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+
+#include
+#include
+
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ AT_ERROR("Not implement on cpu");
+}
+
diff --git a/det3d/models/ops/src/cpu/ms_deform_attn_cpu.h b/det3d/models/ops/src/cpu/ms_deform_attn_cpu.h
new file mode 100644
index 0000000..81b7b58
--- /dev/null
+++ b/det3d/models/ops/src/cpu/ms_deform_attn_cpu.h
@@ -0,0 +1,33 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor
+ms_deform_attn_cpu_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector
+ms_deform_attn_cpu_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
+
diff --git a/det3d/models/ops/src/cuda/ms_deform_attn_cuda.cu b/det3d/models/ops/src/cuda/ms_deform_attn_cuda.cu
new file mode 100644
index 0000000..d6d5836
--- /dev/null
+++ b/det3d/models/ops/src/cuda/ms_deform_attn_cuda.cu
@@ -0,0 +1,153 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include
+#include "cuda/ms_deform_im2col_cuda.cuh"
+
+#include
+#include
+#include
+#include
+
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
+
+ const int batch_n = im2col_step_;
+ auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto columns = output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
+ ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ columns.data());
+
+ }));
+ }
+
+ output = output.view({batch, num_query, num_heads*channels});
+
+ return output;
+}
+
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+
+ AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
+ AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
+ AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
+ AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
+ AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
+ AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
+
+ AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
+ AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
+ AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
+ AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
+ AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
+ AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
+
+ const int batch = value.size(0);
+ const int spatial_size = value.size(1);
+ const int num_heads = value.size(2);
+ const int channels = value.size(3);
+
+ const int num_levels = spatial_shapes.size(0);
+
+ const int num_query = sampling_loc.size(1);
+ const int num_point = sampling_loc.size(4);
+
+ const int im2col_step_ = std::min(batch, im2col_step);
+
+ AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
+
+ auto grad_value = at::zeros_like(value);
+ auto grad_sampling_loc = at::zeros_like(sampling_loc);
+ auto grad_attn_weight = at::zeros_like(attn_weight);
+
+ const int batch_n = im2col_step_;
+ auto per_value_size = spatial_size * num_heads * channels;
+ auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
+ auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
+ auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
+
+ for (int n = 0; n < batch/im2col_step_; ++n)
+ {
+ auto grad_output_g = grad_output_n.select(0, n);
+ AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
+ ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
+ grad_output_g.data(),
+ value.data() + n * im2col_step_ * per_value_size,
+ spatial_shapes.data(),
+ level_start_index.data(),
+ sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
+ batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
+ grad_value.data() + n * im2col_step_ * per_value_size,
+ grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
+ grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
+
+ }));
+ }
+
+ return {
+ grad_value, grad_sampling_loc, grad_attn_weight
+ };
+}
\ No newline at end of file
diff --git a/det3d/models/ops/src/cuda/ms_deform_attn_cuda.h b/det3d/models/ops/src/cuda/ms_deform_attn_cuda.h
new file mode 100644
index 0000000..c7ae53f
--- /dev/null
+++ b/det3d/models/ops/src/cuda/ms_deform_attn_cuda.h
@@ -0,0 +1,30 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+#include
+
+at::Tensor ms_deform_attn_cuda_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step);
+
+std::vector ms_deform_attn_cuda_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step);
+
diff --git a/det3d/models/ops/src/cuda/ms_deform_im2col_cuda.cuh b/det3d/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
new file mode 100644
index 0000000..6bc2acb
--- /dev/null
+++ b/det3d/models/ops/src/cuda/ms_deform_im2col_cuda.cuh
@@ -0,0 +1,1327 @@
+/*!
+**************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************
+* Modified from DCN (https://github.com/msracver/Deformable-ConvNets)
+* Copyright (c) 2018 Microsoft
+**************************************************************************
+*/
+
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads)
+{
+ return (N + num_threads - 1) / num_threads;
+}
+
+
+template
+__device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ }
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ *grad_attn_weight = top_grad * val;
+ *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+ *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+
+template
+__device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data,
+ const int &height, const int &width, const int &nheads, const int &channels,
+ const scalar_t &h, const scalar_t &w, const int &m, const int &c,
+ const scalar_t &top_grad,
+ const scalar_t &attn_weight,
+ scalar_t* &grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int h_low = floor(h);
+ const int w_low = floor(w);
+ const int h_high = h_low + 1;
+ const int w_high = w_low + 1;
+
+ const scalar_t lh = h - h_low;
+ const scalar_t lw = w - w_low;
+ const scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ const int w_stride = nheads * channels;
+ const int h_stride = width * w_stride;
+ const int h_low_ptr_offset = h_low * h_stride;
+ const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+ const int w_low_ptr_offset = w_low * w_stride;
+ const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+ const int base_ptr = m * channels + c;
+
+ const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+ const scalar_t top_grad_value = top_grad * attn_weight;
+ scalar_t grad_h_weight = 0, grad_w_weight = 0;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ {
+ const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+ v1 = bottom_data[ptr1];
+ grad_h_weight -= hw * v1;
+ grad_w_weight -= hh * v1;
+ atomicAdd(grad_value+ptr1, w1*top_grad_value);
+ }
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ {
+ const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+ v2 = bottom_data[ptr2];
+ grad_h_weight -= lw * v2;
+ grad_w_weight += hh * v2;
+ atomicAdd(grad_value+ptr2, w2*top_grad_value);
+ }
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ {
+ const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+ v3 = bottom_data[ptr3];
+ grad_h_weight += hw * v3;
+ grad_w_weight -= lh * v3;
+ atomicAdd(grad_value+ptr3, w3*top_grad_value);
+ }
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ {
+ const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+ v4 = bottom_data[ptr4];
+ grad_h_weight += lw * v4;
+ grad_w_weight += lh * v4;
+ atomicAdd(grad_value+ptr4, w4*top_grad_value);
+ }
+
+ const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ atomicAdd(grad_attn_weight, top_grad * val);
+ atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+ atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+
+template
+__global__ void ms_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ scalar_t *data_col_ptr = data_col + index;
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+ scalar_t col = 0;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride);
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight;
+ }
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ }
+ }
+ *data_col_ptr = col;
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockSize; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2];
+ __shared__ scalar_t cache_grad_attn_weight[blockSize];
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockSize/2; s>0; s>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+ if (tid == 0)
+ {
+ scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0];
+ int sid=2;
+ for (unsigned int tid = 1; tid < blockDim.x; ++tid)
+ {
+ _grad_w += cache_grad_sampling_loc[sid];
+ _grad_h += cache_grad_sampling_loc[sid + 1];
+ _grad_a += cache_grad_attn_weight[tid];
+ sid += 2;
+ }
+
+
+ *grad_sampling_loc = _grad_w;
+ *(grad_sampling_loc + 1) = _grad_h;
+ *grad_attn_weight = _grad_a;
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ *grad_sampling_loc = cache_grad_sampling_loc[0];
+ *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+ *grad_attn_weight = cache_grad_attn_weight[0];
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ extern __shared__ int _s[];
+ scalar_t* cache_grad_sampling_loc = (scalar_t*)_s;
+ scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+ unsigned int tid = threadIdx.x;
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0;
+ *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0;
+ *(cache_grad_attn_weight+threadIdx.x)=0;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x);
+ }
+
+ __syncthreads();
+
+ for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1)
+ {
+ if (tid < s) {
+ const unsigned int xid1 = tid << 1;
+ const unsigned int xid2 = (tid + s) << 1;
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1];
+ if (tid + (s << 1) < spre)
+ {
+ cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)];
+ cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)];
+ cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+ }
+ }
+ __syncthreads();
+ }
+
+ if (tid == 0)
+ {
+ atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+ atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+ atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+ }
+ __syncthreads();
+
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+__global__ void ms_deformable_col2im_gpu_kernel_gm(const int n,
+ const scalar_t *grad_col,
+ const scalar_t *data_value,
+ const int64_t *data_spatial_shapes,
+ const int64_t *data_level_start_index,
+ const scalar_t *data_sampling_loc,
+ const scalar_t *data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t *grad_value,
+ scalar_t *grad_sampling_loc,
+ scalar_t *grad_attn_weight)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ int _temp = index;
+ const int c_col = _temp % channels;
+ _temp /= channels;
+ const int sampling_index = _temp;
+ const int m_col = _temp % num_heads;
+ _temp /= num_heads;
+ const int q_col = _temp % num_query;
+ _temp /= num_query;
+ const int b_col = _temp;
+
+ const scalar_t top_grad = grad_col[index];
+
+ int data_weight_ptr = sampling_index * num_levels * num_point;
+ int data_loc_w_ptr = data_weight_ptr << 1;
+ const int grad_sampling_ptr = data_weight_ptr;
+ grad_sampling_loc += grad_sampling_ptr << 1;
+ grad_attn_weight += grad_sampling_ptr;
+ const int grad_weight_stride = 1;
+ const int grad_loc_stride = 2;
+ const int qid_stride = num_heads * channels;
+ const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride;
+
+ for (int l_col=0; l_col < num_levels; ++l_col)
+ {
+ const int level_start_id = data_level_start_index[l_col];
+ const int spatial_h_ptr = l_col << 1;
+ const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+ const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+ const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride;
+ const scalar_t *data_value_ptr = data_value + value_ptr_offset;
+ scalar_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+ for (int p_col=0; p_col < num_point; ++p_col)
+ {
+ const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr];
+ const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+ const scalar_t weight = data_attn_weight[data_weight_ptr];
+
+ const scalar_t h_im = loc_h * spatial_h - 0.5;
+ const scalar_t w_im = loc_w * spatial_w - 0.5;
+ if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w)
+ {
+ ms_deform_attn_col2im_bilinear_gm(
+ data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col,
+ top_grad, weight, grad_value_ptr,
+ grad_sampling_loc, grad_attn_weight);
+ }
+ data_weight_ptr += 1;
+ data_loc_w_ptr += 2;
+ grad_attn_weight += grad_weight_stride;
+ grad_sampling_loc += grad_loc_stride;
+ }
+ }
+ }
+}
+
+
+template
+void ms_deformable_im2col_cuda(cudaStream_t stream,
+ const scalar_t* data_value,
+ const int64_t* data_spatial_shapes,
+ const int64_t* data_level_start_index,
+ const scalar_t* data_sampling_loc,
+ const scalar_t* data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* data_col)
+{
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ const int num_threads = CUDA_NUM_THREADS;
+ ms_deformable_im2col_gpu_kernel
+ <<>>(
+ num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight,
+ batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col);
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
+
+template
+void ms_deformable_col2im_cuda(cudaStream_t stream,
+ const scalar_t* grad_col,
+ const scalar_t* data_value,
+ const int64_t * data_spatial_shapes,
+ const int64_t * data_level_start_index,
+ const scalar_t * data_sampling_loc,
+ const scalar_t * data_attn_weight,
+ const int batch_size,
+ const int spatial_size,
+ const int num_heads,
+ const int channels,
+ const int num_levels,
+ const int num_query,
+ const int num_point,
+ scalar_t* grad_value,
+ scalar_t* grad_sampling_loc,
+ scalar_t* grad_attn_weight)
+{
+ const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels;
+ const int num_kernels = batch_size * num_query * num_heads * channels;
+ const int num_actual_kernels = batch_size * num_query * num_heads * channels;
+ if (channels > 1024)
+ {
+ if ((channels & 1023) == 0)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_gm
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ else{
+ switch(channels)
+ {
+ case 1:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 2:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 4:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 8:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 16:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 32:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 64:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 128:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 256:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 512:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ case 1024:
+ ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ break;
+ default:
+ if (channels < 64)
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v1
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ else
+ {
+ ms_deformable_col2im_gpu_kernel_shm_reduce_v2
+ <<>>(
+ num_kernels,
+ grad_col,
+ data_value,
+ data_spatial_shapes,
+ data_level_start_index,
+ data_sampling_loc,
+ data_attn_weight,
+ batch_size,
+ spatial_size,
+ num_heads,
+ channels,
+ num_levels,
+ num_query,
+ num_point,
+ grad_value,
+ grad_sampling_loc,
+ grad_attn_weight);
+ }
+ }
+ }
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+
+}
\ No newline at end of file
diff --git a/det3d/models/ops/src/ms_deform_attn.h b/det3d/models/ops/src/ms_deform_attn.h
new file mode 100644
index 0000000..ac0ef2e
--- /dev/null
+++ b/det3d/models/ops/src/ms_deform_attn.h
@@ -0,0 +1,62 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#pragma once
+
+#include "cpu/ms_deform_attn_cpu.h"
+
+#ifdef WITH_CUDA
+#include "cuda/ms_deform_attn_cuda.h"
+#endif
+
+
+at::Tensor
+ms_deform_attn_forward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_forward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
+std::vector
+ms_deform_attn_backward(
+ const at::Tensor &value,
+ const at::Tensor &spatial_shapes,
+ const at::Tensor &level_start_index,
+ const at::Tensor &sampling_loc,
+ const at::Tensor &attn_weight,
+ const at::Tensor &grad_output,
+ const int im2col_step)
+{
+ if (value.type().is_cuda())
+ {
+#ifdef WITH_CUDA
+ return ms_deform_attn_cuda_backward(
+ value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
+#else
+ AT_ERROR("Not compiled with GPU support");
+#endif
+ }
+ AT_ERROR("Not implemented on the CPU");
+}
+
diff --git a/det3d/models/ops/src/vision.cpp b/det3d/models/ops/src/vision.cpp
new file mode 100644
index 0000000..2201f63
--- /dev/null
+++ b/det3d/models/ops/src/vision.cpp
@@ -0,0 +1,16 @@
+/*!
+**************************************************************************************************
+* Deformable DETR
+* Copyright (c) 2020 SenseTime. All Rights Reserved.
+* Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+**************************************************************************************************
+* Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+**************************************************************************************************
+*/
+
+#include "ms_deform_attn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
+ m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
+}
diff --git a/det3d/models/ops/test.py b/det3d/models/ops/test.py
new file mode 100644
index 0000000..363a0c1
--- /dev/null
+++ b/det3d/models/ops/test.py
@@ -0,0 +1,89 @@
+# ------------------------------------------------------------------------------------------------
+# Deformable DETR
+# Copyright (c) 2020 SenseTime. All Rights Reserved.
+# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
+# ------------------------------------------------------------------------------------------------
+# Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
+# ------------------------------------------------------------------------------------------------
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import time
+import torch
+import torch.nn as nn
+from torch.autograd import gradcheck
+
+from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
+
+
+N, M, D = 1, 2, 2
+Lq, L, P = 2, 2, 2
+shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
+level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
+S = sum([(H*W).item() for H, W in shapes])
+
+
+torch.manual_seed(3)
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_double():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+@torch.no_grad()
+def check_forward_equal_with_pytorch_float():
+ value = torch.rand(N, S, M, D).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
+ output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
+ fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
+ max_abs_err = (output_cuda - output_pytorch).abs().max()
+ max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
+
+ print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
+
+
+def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
+
+ value = torch.rand(N, S, M, channels).cuda() * 0.01
+ sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
+ attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
+ attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
+ im2col_step = 2
+ func = MSDeformAttnFunction.apply
+
+ value.requires_grad = grad_value
+ sampling_locations.requires_grad = grad_sampling_loc
+ attention_weights.requires_grad = grad_attn_weight
+
+ gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
+
+ print(f'* {gradok} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+ check_forward_equal_with_pytorch_double()
+ check_forward_equal_with_pytorch_float()
+
+ for channels in [30, 32, 64, 71]:
+ check_gradient_numerical(channels, True, True, True)
+
+
+
diff --git a/det3d/models/readers/__init__.py b/det3d/models/readers/__init__.py
new file mode 100644
index 0000000..225998e
--- /dev/null
+++ b/det3d/models/readers/__init__.py
@@ -0,0 +1,10 @@
+from .pillar_encoder import PillarFeatureNet, PointPillarsScatter
+from .voxel_encoder import VoxelFeatureExtractorV3
+from .dynamic_voxel_encoder import DynamicVoxelEncoder
+
+__all__ = [
+ "VoxelFeatureExtractorV3",
+ "PillarFeatureNet",
+ "PointPillarsScatter",
+ 'DynamicVoxelEncoder',
+]
diff --git a/det3d/models/readers/dynamic_voxel_encoder.py b/det3d/models/readers/dynamic_voxel_encoder.py
new file mode 100644
index 0000000..ecd910b
--- /dev/null
+++ b/det3d/models/readers/dynamic_voxel_encoder.py
@@ -0,0 +1,47 @@
+from det3d.core.utils.scatter import scatter_mean
+from torch.nn import functional as F
+from ..registry import READERS
+from torch import nn
+import numpy as np
+import torch
+
+def voxelization(points, pc_range, voxel_size):
+ keep = (points[:, 0] >= pc_range[0]) & (points[:, 0] <= pc_range[3]) & \
+ (points[:, 1] >= pc_range[1]) & (points[:, 1] <= pc_range[4]) & \
+ (points[:, 2] >= pc_range[2]) & (points[:, 2] <= pc_range[5])
+ points = points[keep, :]
+ coords = ((points[:, [2, 1, 0]] - pc_range[[2, 1, 0]]) / voxel_size[[2, 1, 0]]).to(torch.int64)
+ unique_coords, inverse_indices = coords.unique(return_inverse=True, dim=0)
+
+ voxels = scatter_mean(points, inverse_indices, dim=0)
+ return voxels, unique_coords
+
+@READERS.register_module
+class DynamicVoxelEncoder(nn.Module):
+ def __init__(
+ self, pc_range, voxel_size
+ ):
+ super(DynamicVoxelEncoder, self).__init__()
+ self.pc_range = torch.tensor(pc_range)
+ self.voxel_size = torch.tensor(voxel_size)
+ self.shape = torch.round((self.pc_range[3:] - self.pc_range[:3]) / self.voxel_size)
+ self.shape_np = self.shape.numpy().astype(np.int32)
+
+ @torch.no_grad()
+ def forward(self, points):
+ # points list[torch.Tensor]
+ coors = []
+ voxels = []
+ for res in points:
+ voxel, coor = voxelization(res, self.pc_range.to(res.device), self.voxel_size.to(res.device))
+ voxels.append(voxel)
+ coors.append(coor)
+
+ coors_batch = []
+ for i in range(len(voxels)):
+ coor_pad = F.pad(coors[i], (1, 0), mode='constant', value=i)
+ coors_batch.append(coor_pad)
+
+ coors_batch = torch.cat(coors_batch, dim=0)
+ voxels_batch = torch.cat(voxels, dim=0)
+ return voxels_batch, coors_batch, self.shape_np
diff --git a/det3d/models/readers/pillar_encoder.py b/det3d/models/readers/pillar_encoder.py
new file mode 100644
index 0000000..d155035
--- /dev/null
+++ b/det3d/models/readers/pillar_encoder.py
@@ -0,0 +1,209 @@
+"""
+PointPillars fork from SECOND.
+Code written by Alex Lang and Oscar Beijbom, 2018.
+Licensed under MIT License [see LICENSE].
+"""
+
+import torch
+from det3d.models.utils import get_paddings_indicator
+from torch import nn
+from torch.nn import functional as F
+from ..registry import BACKBONES, READERS
+from ..utils import build_norm_layer
+
+
+class PFNLayer(nn.Module):
+ def __init__(self, in_channels, out_channels, norm_cfg=None, last_layer=False):
+ """
+ Pillar Feature Net Layer.
+ The Pillar Feature Net could be composed of a series of these layers, but the PointPillars paper results only
+ used a single PFNLayer. This layer performs a similar role as second.pytorch.voxelnet.VFELayer.
+ :param in_channels: . Number of input channels.
+ :param out_channels: . Number of output channels.
+ :param last_layer: . If last_layer, there is no concatenation of features.
+ """
+
+ super().__init__()
+ self.name = "PFNLayer"
+ self.last_vfe = last_layer
+ if not self.last_vfe:
+ out_channels = out_channels // 2
+ self.units = out_channels
+
+ if norm_cfg is None:
+ norm_cfg = dict(type="BN1d", eps=1e-3, momentum=0.01)
+ self.norm_cfg = norm_cfg
+
+ self.linear = nn.Linear(in_channels, self.units, bias=False)
+ self.norm = build_norm_layer(self.norm_cfg, self.units)[1]
+
+ def forward(self, inputs):
+
+ x = self.linear(inputs)
+ torch.backends.cudnn.enabled = False
+ x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2, 1).contiguous()
+ torch.backends.cudnn.enabled = True
+ x = F.relu(x)
+
+ x_max = torch.max(x, dim=1, keepdim=True)[0]
+
+ if self.last_vfe:
+ return x_max
+ else:
+ x_repeat = x_max.repeat(1, inputs.shape[1], 1)
+ x_concatenated = torch.cat([x, x_repeat], dim=2)
+ return x_concatenated
+
+
+@READERS.register_module
+class PillarFeatureNet(nn.Module):
+ def __init__(
+ self,
+ num_input_features=4,
+ num_filters=(64,),
+ with_distance=False,
+ voxel_size=(0.2, 0.2, 4),
+ pc_range=(0, -40, -3, 70.4, 40, 1),
+ norm_cfg=None,
+ ):
+ """
+ Pillar Feature Net.
+ The network prepares the pillar features and performs forward pass through PFNLayers. This net performs a
+ similar role to SECOND's second.pytorch.voxelnet.VoxelFeatureExtractor.
+ :param num_input_features: . Number of input features, either x, y, z or x, y, z, r.
+ :param num_filters: (: N). Number of features in each of the N PFNLayers.
+ :param with_distance: . Whether to include Euclidean distance to points.
+ :param voxel_size: (: 3). Size of voxels, only utilize x and y size.
+ :param pc_range: (: 6). Point cloud range, only utilize x and y min.
+ """
+
+ super().__init__()
+ self.name = "PillarFeatureNet"
+ assert len(num_filters) > 0
+
+ self.num_input = num_input_features
+ num_input_features += 5
+ if with_distance:
+ num_input_features += 1
+ self._with_distance = with_distance
+
+ # Create PillarFeatureNet layers
+ num_filters = [num_input_features] + list(num_filters)
+ pfn_layers = []
+ for i in range(len(num_filters) - 1):
+ in_filters = num_filters[i]
+ out_filters = num_filters[i + 1]
+ if i < len(num_filters) - 2:
+ last_layer = False
+ else:
+ last_layer = True
+ pfn_layers.append(
+ PFNLayer(
+ in_filters, out_filters, norm_cfg=norm_cfg, last_layer=last_layer
+ )
+ )
+ self.pfn_layers = nn.ModuleList(pfn_layers)
+
+ # Need pillar (voxel) size and x/y offset in order to calculate pillar offset
+ self.vx = voxel_size[0]
+ self.vy = voxel_size[1]
+ self.x_offset = self.vx / 2 + pc_range[0]
+ self.y_offset = self.vy / 2 + pc_range[1]
+
+ def forward(self, features, num_voxels, coors):
+ device = features.device
+
+ dtype = features.dtype
+
+ # Find distance of x, y, and z from cluster center
+ # features = features[:, :, :self.num_input]
+ points_mean = features[:, :, :3].sum(dim=1, keepdim=True) / num_voxels.type_as(
+ features
+ ).view(-1, 1, 1)
+ f_cluster = features[:, :, :3] - points_mean
+
+ # Find distance of x, y, and z from pillar center
+ # f_center = features[:, :, :2]
+ f_center = torch.zeros_like(features[:, :, :2])
+ f_center[:, :, 0] = features[:, :, 0] - (
+ coors[:, 3].to(dtype).unsqueeze(1) * self.vx + self.x_offset
+ )
+ f_center[:, :, 1] = features[:, :, 1] - (
+ coors[:, 2].to(dtype).unsqueeze(1) * self.vy + self.y_offset
+ )
+
+ # Combine together feature decorations
+ features_ls = [features, f_cluster, f_center]
+ if self._with_distance:
+ points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
+ features_ls.append(points_dist)
+ features = torch.cat(features_ls, dim=-1)
+
+ # The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
+ # empty pillars remain set to zeros.
+ voxel_count = features.shape[1]
+ mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
+ mask = torch.unsqueeze(mask, -1).type_as(features)
+ features *= mask
+
+ # Forward pass through PFNLayers
+ for pfn in self.pfn_layers:
+ features = pfn(features)
+
+ return features.squeeze()
+
+
+@BACKBONES.register_module
+class PointPillarsScatter(nn.Module):
+ def __init__(
+ self, num_input_features=64, norm_cfg=None, name="PointPillarsScatter", **kwargs
+ ):
+ """
+ Point Pillar's Scatter.
+ Converts learned features from dense tensor to sparse pseudo image. This replaces SECOND's
+ second.pytorch.voxelnet.SparseMiddleExtractor.
+ :param output_shape: ([int]: 4). Required output shape of features.
+ :param num_input_features: . Number of input features.
+ """
+
+ super().__init__()
+ self.name = "PointPillarsScatter"
+ self.nchannels = num_input_features
+
+ def forward(self, voxel_features, coords, batch_size, input_shape):
+
+ self.nx = input_shape[0]
+ self.ny = input_shape[1]
+
+ # batch_canvas will be the final output.
+ batch_canvas = []
+ for batch_itt in range(batch_size):
+ # Create the canvas for this sample
+ canvas = torch.zeros(
+ self.nchannels,
+ self.nx * self.ny,
+ dtype=voxel_features.dtype,
+ device=voxel_features.device,
+ )
+
+ # Only include non-empty pillars
+ batch_mask = coords[:, 0] == batch_itt
+
+ this_coords = coords[batch_mask, :]
+ indices = this_coords[:, 2] * self.nx + this_coords[:, 3]
+ indices = indices.type(torch.long)
+ voxels = voxel_features[batch_mask, :]
+ voxels = voxels.t()
+
+ # Now scatter the blob back to the canvas.
+ canvas[:, indices] = voxels
+
+ # Append to a list for later stacking.
+ batch_canvas.append(canvas)
+
+ # Stack to 3-dim tensor (batch-size, nchannels, nrows*ncols)
+ batch_canvas = torch.stack(batch_canvas, 0)
+
+ # Undo the column stacking to final 4-dim tensor
+ batch_canvas = batch_canvas.view(batch_size, self.nchannels, self.ny, self.nx)
+ return batch_canvas
diff --git a/det3d/models/readers/voxel_encoder.py b/det3d/models/readers/voxel_encoder.py
new file mode 100644
index 0000000..b889314
--- /dev/null
+++ b/det3d/models/readers/voxel_encoder.py
@@ -0,0 +1,24 @@
+from torch import nn
+from torch.nn import functional as F
+
+from ..registry import READERS
+
+
+
+@READERS.register_module
+class VoxelFeatureExtractorV3(nn.Module):
+ def __init__(
+ self, num_input_features=4, norm_cfg=None, name="VoxelFeatureExtractorV3"
+ ):
+ super(VoxelFeatureExtractorV3, self).__init__()
+ self.name = name
+ self.num_input_features = num_input_features
+
+ def forward(self, features, num_voxels, coors=None):
+ assert self.num_input_features == features.shape[-1]
+
+ points_mean = features[:, :, : self.num_input_features].sum(
+ dim=1, keepdim=False
+ ) / num_voxels.type_as(features).view(-1, 1)
+
+ return points_mean.contiguous()
diff --git a/det3d/models/registry.py b/det3d/models/registry.py
new file mode 100644
index 0000000..de7c71e
--- /dev/null
+++ b/det3d/models/registry.py
@@ -0,0 +1,10 @@
+from det3d.utils import Registry
+
+READERS = Registry("reader")
+BACKBONES = Registry("backbone")
+NECKS = Registry("neck")
+HEADS = Registry("head")
+LOSSES = Registry("loss")
+DETECTORS = Registry("detector")
+SECOND_STAGE = Registry("second_stage")
+ROI_HEAD = Registry("roi_head")
\ No newline at end of file
diff --git a/det3d/models/roi_heads/__init__.py b/det3d/models/roi_heads/__init__.py
new file mode 100644
index 0000000..6389403
--- /dev/null
+++ b/det3d/models/roi_heads/__init__.py
@@ -0,0 +1,7 @@
+from .roi_head_template import RoIHeadTemplate
+from .roi_head import RoIHead
+
+__all__ = [
+ 'RoIHeadTemplate',
+ 'RoIHead',
+]
diff --git a/det3d/models/roi_heads/roi_head.py b/det3d/models/roi_heads/roi_head.py
new file mode 100644
index 0000000..9ed6bc7
--- /dev/null
+++ b/det3d/models/roi_heads/roi_head.py
@@ -0,0 +1,106 @@
+# ------------------------------------------------------------------------------
+# Portions of this code are from
+# OpenPCDet (https://github.com/open-mmlab/OpenPCDet)
+# Licensed under the Apache License.
+# ------------------------------------------------------------------------------
+
+from torch import batch_norm
+import torch.nn as nn
+
+from .roi_head_template import RoIHeadTemplate
+
+from det3d.core import box_torch_ops
+
+from ..registry import ROI_HEAD
+
+@ROI_HEAD.register_module
+class RoIHead(RoIHeadTemplate):
+ def __init__(self, input_channels, model_cfg, num_class=1, code_size=7, test_cfg=None):
+ super().__init__(num_class=num_class, model_cfg=model_cfg)
+ self.model_cfg = model_cfg
+ self.test_cfg = test_cfg
+ self.code_size = code_size
+
+ pre_channel = input_channels
+
+ shared_fc_list = []
+ for k in range(0, self.model_cfg.SHARED_FC.__len__()):
+ shared_fc_list.extend([
+ nn.Conv1d(pre_channel, self.model_cfg.SHARED_FC[k], kernel_size=1, bias=False),
+ nn.BatchNorm1d(self.model_cfg.SHARED_FC[k]),
+ nn.ReLU()
+ ])
+ pre_channel = self.model_cfg.SHARED_FC[k]
+
+ if k != self.model_cfg.SHARED_FC.__len__() - 1 and self.model_cfg.DP_RATIO > 0:
+ shared_fc_list.append(nn.Dropout(self.model_cfg.DP_RATIO))
+
+ self.shared_fc_layer = nn.Sequential(*shared_fc_list)
+
+ self.cls_layers = self.make_fc_layers(
+ input_channels=pre_channel, output_channels=self.num_class, fc_list=self.model_cfg.CLS_FC
+ )
+ self.reg_layers = self.make_fc_layers(
+ input_channels=pre_channel,
+ output_channels=code_size,
+ fc_list=self.model_cfg.REG_FC
+ )
+ self.init_weights(weight_init='xavier')
+
+ def init_weights(self, weight_init='xavier'):
+ if weight_init == 'kaiming':
+ init_func = nn.init.kaiming_normal_
+ elif weight_init == 'xavier':
+ init_func = nn.init.xavier_normal_
+ elif weight_init == 'normal':
+ init_func = nn.init.normal_
+ else:
+ raise NotImplementedError
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
+ if weight_init == 'normal':
+ init_func(m.weight, mean=0, std=0.001)
+ else:
+ init_func(m.weight)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ nn.init.normal_(self.reg_layers[-1].weight, mean=0, std=0.001)
+
+ def forward(self, batch_dict, training=True):
+ """
+ :param input_data: input dict
+ :return:
+ """
+ batch_dict['batch_size'] = len(batch_dict['rois'])
+ if training:
+ targets_dict = self.assign_targets(batch_dict)
+ batch_dict['rois'] = targets_dict['rois']
+ batch_dict['roi_labels'] = targets_dict['roi_labels']
+ batch_dict['roi_features'] = targets_dict['roi_features']
+
+ # RoI aware pooling
+ pooled_features = batch_dict['roi_features'].reshape(-1, 1,
+ batch_dict['roi_features'].shape[-1]).contiguous() # (BxN, 1, C)
+
+ batch_size_rcnn = pooled_features.shape[0]
+ pooled_features = pooled_features.permute(0, 2, 1).contiguous() # (BxN, C, 1)
+
+ shared_features = self.shared_fc_layer(pooled_features.view(batch_size_rcnn, -1, 1))
+ rcnn_cls = self.cls_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, 1 or 2)
+ rcnn_reg = self.reg_layers(shared_features).transpose(1, 2).contiguous().squeeze(dim=1) # (B, C)
+
+ if not training:
+ batch_cls_preds, batch_box_preds = self.generate_predicted_boxes(
+ batch_size=batch_dict['batch_size'], rois=batch_dict['rois'], cls_preds=rcnn_cls, box_preds=rcnn_reg
+ )
+ batch_dict['batch_cls_preds'] = batch_cls_preds
+ batch_dict['batch_box_preds'] = batch_box_preds
+ batch_dict['cls_preds_normalized'] = False
+ else:
+ targets_dict['rcnn_cls'] = rcnn_cls
+ targets_dict['rcnn_reg'] = rcnn_reg
+
+ self.forward_ret_dict = targets_dict
+
+ return batch_dict
\ No newline at end of file
diff --git a/det3d/models/roi_heads/roi_head_template.py b/det3d/models/roi_heads/roi_head_template.py
new file mode 100644
index 0000000..9d7011f
--- /dev/null
+++ b/det3d/models/roi_heads/roi_head_template.py
@@ -0,0 +1,215 @@
+# ------------------------------------------------------------------------------
+# Portions of this code are from
+# OpenPCDet (https://github.com/open-mmlab/OpenPCDet)
+# Licensed under the Apache License.
+# ------------------------------------------------------------------------------
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from det3d.core.bbox import box_torch_ops
+from .target_assigner.proposal_target_layer import ProposalTargetLayer
+
+def limit_period(val, offset=0.5, period=np.pi):
+ return val - torch.floor(val / period + offset) * period
+
+
+class RoIHeadTemplate(nn.Module):
+ def __init__(self, num_class, model_cfg):
+ super().__init__()
+ self.model_cfg = model_cfg
+ self.num_class = num_class
+ self.proposal_target_layer = ProposalTargetLayer(roi_sampler_cfg=self.model_cfg.TARGET_CONFIG)
+
+ self.forward_ret_dict = None
+
+ def make_fc_layers(self, input_channels, output_channels, fc_list, separate_head = 1):
+ fc_layers = []
+ pre_channel = input_channels
+ for k in range(0, fc_list.__len__()):
+ fc_layers.extend([
+ nn.Conv1d(pre_channel, fc_list[k]*separate_head, groups=separate_head, kernel_size=1, bias=False),
+ nn.BatchNorm1d(fc_list[k]*separate_head),
+ nn.ReLU()
+ ])
+ pre_channel = fc_list[k]*separate_head
+ if self.model_cfg.DP_RATIO >= 0 and k == 0:
+ fc_layers.append(nn.Dropout(self.model_cfg.DP_RATIO))
+ fc_layers.append(nn.Conv1d(pre_channel, output_channels, kernel_size=1, bias=True))
+ fc_layers = nn.Sequential(*fc_layers)
+ return fc_layers
+
+ def assign_targets(self, batch_dict):
+ batch_size = batch_dict['batch_size']
+ with torch.no_grad():
+ targets_dict = self.proposal_target_layer.forward(batch_dict)
+
+ rois = targets_dict['rois'] # (B, N, 7 + C)
+ gt_of_rois = targets_dict['gt_of_rois'] # (B, N, 7 + C + 1)
+ targets_dict['gt_of_rois_src'] = gt_of_rois.clone().detach()
+
+ roi_ry = limit_period(rois[:, :, 6], offset=0.5, period=np.pi*2)
+
+ gt_of_rois[:, :, :6] = gt_of_rois[:, :, :6] - rois[:, :, :6]
+ gt_of_rois[:, :, 6] = gt_of_rois[:, :, 6] - roi_ry
+
+ gt_of_rois = box_torch_ops.rotate_points_along_z(
+ points=gt_of_rois.view(-1, 1, gt_of_rois.shape[-1]), angle=-roi_ry.view(-1)
+ ).view(batch_size, -1, gt_of_rois.shape[-1])
+
+ if rois.shape[-1] == 9:
+ # rotate velocity
+ gt_of_rois[:, :, 7:-1] = gt_of_rois[:, :, 7:-1] - rois[:, :, 7:]
+
+ """
+ roi_vel = gt_of_rois[:, :, 7:-1]
+ roi_vel = torch.cat([roi_vel, torch.zeros([roi_vel.shape[0], roi_vel.shape[1], 1]).to(roi_vel)], dim=-1)
+
+ gt_of_rois[:, :, 7:-1] = box_torch_ops.rotate_points_along_z(
+ points=roi_vel.view(-1, 1, 3), angle=-roi_ry.view(-1)
+ ).view(batch_size, -1, 3)[..., :2]
+ """
+
+ # flip orientation if rois have opposite orientation
+ heading_label = gt_of_rois[:, :, 6] % (2 * np.pi) # 0 ~ 2pi
+ opposite_flag = (heading_label > np.pi * 0.5) & (heading_label < np.pi * 1.5)
+ heading_label[opposite_flag] = (heading_label[opposite_flag] + np.pi) % (2 * np.pi) # (0 ~ pi/2, 3pi/2 ~ 2pi)
+ flag = heading_label > np.pi
+ heading_label[flag] = heading_label[flag] - np.pi * 2 # (-pi/2, pi/2)
+ heading_label = torch.clamp(heading_label, min=-np.pi / 2, max=np.pi / 2)
+
+ gt_of_rois[:, :, 6] = heading_label
+
+
+ targets_dict['gt_of_rois'] = gt_of_rois
+ return targets_dict
+
+ def get_box_reg_layer_loss(self, forward_ret_dict):
+ loss_cfgs = self.model_cfg.LOSS_CONFIG
+ code_size = forward_ret_dict['rcnn_reg'].shape[-1]
+ reg_valid_mask = forward_ret_dict['reg_valid_mask'].view(-1)
+ gt_boxes3d_ct = forward_ret_dict['gt_of_rois'][..., 0:code_size]
+ rcnn_reg = forward_ret_dict['rcnn_reg'] # (rcnn_batch_size, C)
+ rcnn_batch_size = gt_boxes3d_ct.view(-1, code_size).shape[0]
+
+ fg_mask = (reg_valid_mask > 0)
+ fg_sum = fg_mask.long().sum().item()
+
+ tb_dict = {}
+
+ if loss_cfgs.REG_LOSS == 'L1':
+ reg_targets = gt_boxes3d_ct.view(rcnn_batch_size, -1)
+ rcnn_loss_reg = F.l1_loss(
+ rcnn_reg.view(rcnn_batch_size, -1),
+ reg_targets,
+ reduction='none'
+ ) # [B, M, 7]
+
+ rcnn_loss_reg = rcnn_loss_reg * rcnn_loss_reg.new_tensor(\
+ loss_cfgs.LOSS_WEIGHTS['code_weights'])
+
+ rcnn_loss_reg = (rcnn_loss_reg.view(rcnn_batch_size, -1) * fg_mask.unsqueeze(dim=-1).float()).sum() / max(fg_sum, 1)
+ rcnn_loss_reg = rcnn_loss_reg * loss_cfgs.LOSS_WEIGHTS['rcnn_reg_weight']
+ tb_dict['rcnn_loss_reg'] = rcnn_loss_reg.detach()
+ else:
+ raise NotImplementedError
+
+ return rcnn_loss_reg, tb_dict
+
+ def get_box_cls_layer_loss(self, forward_ret_dict):
+ loss_cfgs = self.model_cfg.LOSS_CONFIG
+ rcnn_cls = forward_ret_dict['rcnn_cls']
+ rcnn_cls_labels = forward_ret_dict['rcnn_cls_labels'].view(-1)
+ if loss_cfgs.CLS_LOSS == 'BinaryCrossEntropy':
+ rcnn_cls_flat = rcnn_cls.view(-1)
+ batch_loss_cls = F.binary_cross_entropy(torch.sigmoid(rcnn_cls_flat), rcnn_cls_labels.float(), reduction='none')
+ cls_valid_mask = (rcnn_cls_labels >= 0).float()
+ rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
+ elif loss_cfgs.CLS_LOSS == 'CrossEntropy':
+ batch_loss_cls = F.cross_entropy(rcnn_cls, rcnn_cls_labels, reduction='none', ignore_index=-1)
+ cls_valid_mask = (rcnn_cls_labels >= 0).float()
+ rcnn_loss_cls = (batch_loss_cls * cls_valid_mask).sum() / torch.clamp(cls_valid_mask.sum(), min=1.0)
+ else:
+ raise NotImplementedError
+
+ rcnn_loss_cls = rcnn_loss_cls * loss_cfgs.LOSS_WEIGHTS['rcnn_cls_weight']
+ tb_dict = {'rcnn_loss_cls': rcnn_loss_cls.detach()}
+ return rcnn_loss_cls, tb_dict
+
+ def get_box_nms_layer_loss(self, forward_ret_dict):
+ loss_cfgs = self.model_cfg.LOSS_CONFIG
+ rcnn_nms = forward_ret_dict['rcnn_nms']
+ rcnn_nms_labels = forward_ret_dict['roi_box_mask'].view(-1)
+ rcnn_cls_labels = forward_ret_dict['rcnn_cls_labels'].view(-1)
+ rcnn_nms_flat = rcnn_nms.view(-1)
+ batch_loss_nms = F.binary_cross_entropy(torch.sigmoid(rcnn_nms_flat), rcnn_nms_labels.float(), reduction='none')
+ nms_valid_mask = (rcnn_cls_labels >= 0).float()
+ rcnn_loss_nms = (batch_loss_nms * nms_valid_mask).sum() / torch.clamp(nms_valid_mask.sum(), min=1.0)
+
+ rcnn_loss_nms = rcnn_loss_nms * loss_cfgs.LOSS_WEIGHTS['rcnn_nms_weight']
+ tb_dict = {'rcnn_loss_nms': rcnn_loss_nms.detach()}
+ return rcnn_loss_nms, tb_dict
+
+ def get_loss_nms(self, tb_dict=None):
+ tb_dict = {} if tb_dict is None else tb_dict
+ rcnn_loss = 0
+ rcnn_loss_cls, cls_tb_dict = self.get_box_cls_layer_loss(self.forward_ret_dict)
+ rcnn_loss += rcnn_loss_cls
+ tb_dict.update(cls_tb_dict)
+
+ rcnn_loss_reg, reg_tb_dict = self.get_box_reg_layer_loss(self.forward_ret_dict)
+ rcnn_loss += rcnn_loss_reg
+ tb_dict.update(reg_tb_dict)
+
+ rcnn_loss_nms, nms_tb_dict = self.get_box_nms_layer_loss(self.forward_ret_dict)
+ rcnn_loss += rcnn_loss_nms
+ tb_dict.update(nms_tb_dict)
+ tb_dict['rcnn_loss'] = rcnn_loss.item()
+ return rcnn_loss, tb_dict
+
+ def get_loss(self, tb_dict=None):
+ tb_dict = {} if tb_dict is None else tb_dict
+ rcnn_loss = 0
+ rcnn_loss_cls, cls_tb_dict = self.get_box_cls_layer_loss(self.forward_ret_dict)
+ rcnn_loss += rcnn_loss_cls
+ tb_dict.update(cls_tb_dict)
+
+ rcnn_loss_reg, reg_tb_dict = self.get_box_reg_layer_loss(self.forward_ret_dict)
+ rcnn_loss += rcnn_loss_reg
+ tb_dict.update(reg_tb_dict)
+ tb_dict['rcnn_loss'] = rcnn_loss.item()
+ return rcnn_loss, tb_dict
+
+ def generate_predicted_boxes(self, batch_size, rois, cls_preds, box_preds):
+ """
+ Args:
+ batch_size:
+ rois: (B, N, 7)
+ cls_preds: (BN, num_class)
+ box_preds: (BN, code_size)
+
+ Returns:
+
+ """
+ code_size = box_preds.shape[-1]
+ # batch_cls_preds: (B, N, num_class or 1)
+ batch_cls_preds = cls_preds.view(batch_size, -1, cls_preds.shape[-1])
+ batch_box_preds = box_preds.view(batch_size, -1, code_size)
+
+ roi_ry = rois[:, :, 6].view(-1)
+ roi_xyz = rois[:, :, 0:3].view(-1, 3)
+
+ local_rois = rois.clone().detach()
+ local_rois[:, :, 0:3] = 0
+
+ batch_box_preds = (batch_box_preds + local_rois).view(-1, code_size)
+ batch_box_preds = box_torch_ops.rotate_points_along_z(
+ batch_box_preds.unsqueeze(dim=1), roi_ry
+ ).squeeze(dim=1)
+
+ batch_box_preds[:, 0:3] += roi_xyz
+ # batch_box_preds[:, 3:] = rois[:, :, 3:].view(-1, code_size-3)
+ batch_box_preds = batch_box_preds.view(batch_size, -1, code_size)
+
+ return batch_cls_preds, batch_box_preds
diff --git a/det3d/models/roi_heads/target_assigner/proposal_target_layer.py b/det3d/models/roi_heads/target_assigner/proposal_target_layer.py
new file mode 100644
index 0000000..c04839a
--- /dev/null
+++ b/det3d/models/roi_heads/target_assigner/proposal_target_layer.py
@@ -0,0 +1,271 @@
+# ------------------------------------------------------------------------------
+# Portions of this code are from
+# OpenPCDet (https://github.com/open-mmlab/OpenPCDet)
+# Licensed under the Apache License.
+# ------------------------------------------------------------------------------
+
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ....ops.iou3d_nms.iou3d_nms_utils import boxes_iou3d_gpu
+
+
+class ProposalTargetLayer(nn.Module):
+ def __init__(self, roi_sampler_cfg):
+ super().__init__()
+ self.roi_sampler_cfg = roi_sampler_cfg
+
+ def forward(self, batch_dict):
+ """
+ Args:
+ batch_dict:
+ batch_size:
+ rois: (B, num_rois, 7 + C)
+ roi_scores: (B, num_rois)
+ gt_boxes: (B, N, 7 + C + 1)
+ roi_labels: (B, num_rois)
+ Returns:
+ batch_dict:
+ rois: (B, M, 7 + C)
+ gt_of_rois: (B, M, 7 + C)
+ gt_iou_of_rois: (B, M)
+ roi_scores: (B, M)
+ roi_labels: (B, M)
+ reg_valid_mask: (B, M)
+ rcnn_cls_labels: (B, M)
+ roi_points_loc: (B, M, K, 3)
+ """
+ if 'roi_box_mask' in batch_dict:
+ batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels, \
+ batch_roi_features, batch_roi_points_loc, batch_roi_box_mask = self.sample_rois_for_rcnn(
+ batch_dict=batch_dict
+ )
+ else:
+ batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels, \
+ batch_roi_features, batch_roi_points_loc = self.sample_rois_for_rcnn(
+ batch_dict=batch_dict
+ )
+ # regression valid mask
+ reg_valid_mask = (batch_roi_ious > self.roi_sampler_cfg.REG_FG_THRESH).long()
+
+ # classification label
+ if self.roi_sampler_cfg.CLS_SCORE_TYPE == 'cls':
+ batch_cls_labels = (batch_roi_ious > self.roi_sampler_cfg.CLS_FG_THRESH).long()
+ ignore_mask = (batch_roi_ious > self.roi_sampler_cfg.CLS_BG_THRESH) & \
+ (batch_roi_ious < self.roi_sampler_cfg.CLS_FG_THRESH)
+ batch_cls_labels[ignore_mask > 0] = -1
+ elif self.roi_sampler_cfg.CLS_SCORE_TYPE == 'roi_iou':
+ # padding_mask = (torch.isclose(batch_rois.sum(dim=-1), batch_rois.new_zeros(1)))
+
+ iou_bg_thresh = self.roi_sampler_cfg.CLS_BG_THRESH
+ iou_fg_thresh = self.roi_sampler_cfg.CLS_FG_THRESH
+ fg_mask = batch_roi_ious > iou_fg_thresh
+ bg_mask = batch_roi_ious < iou_bg_thresh
+ interval_mask = (fg_mask == 0) & (bg_mask == 0)
+
+ batch_cls_labels = (fg_mask > 0).float()
+ batch_cls_labels[interval_mask] = \
+ (batch_roi_ious[interval_mask] - iou_bg_thresh) / (iou_fg_thresh - iou_bg_thresh)
+ # batch_cls_labels[padding_mask > 0] = -1
+ else:
+ raise NotImplementedError
+
+ targets_dict = {'rois': batch_rois, 'gt_of_rois': batch_gt_of_rois, 'gt_iou_of_rois': batch_roi_ious,
+ 'roi_scores': batch_roi_scores, 'roi_labels': batch_roi_labels,
+ 'roi_features': batch_roi_features, 'reg_valid_mask': reg_valid_mask,
+ 'rcnn_cls_labels': batch_cls_labels, 'roi_points_loc': batch_roi_points_loc}
+
+ if 'roi_box_mask' in batch_dict:
+ targets_dict['roi_box_mask'] = batch_roi_box_mask
+
+ return targets_dict
+
+ def sample_rois_for_rcnn(self, batch_dict):
+ """
+ Args:
+ batch_dict:
+ batch_size:
+ rois: (B, num_rois, 7 + C)
+ roi_scores: (B, num_rois)
+ gt_boxes: (B, N, 7 + C + 1)
+ roi_labels: (B, num_rois)
+ Returns:
+
+ """
+ batch_size = batch_dict['batch_size']
+ rois = batch_dict['rois']
+ roi_scores = batch_dict['roi_scores']
+ roi_labels = batch_dict['roi_labels']
+ gt_boxes = batch_dict['gt_boxes_and_cls']
+ roi_features = batch_dict['roi_features']
+ roi_points_loc = batch_dict['roi_points_loc']
+
+
+ code_size = rois.shape[-1]
+ batch_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size)
+ batch_gt_of_rois = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, code_size + 1)
+ batch_roi_ious = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE)
+ batch_roi_scores = rois.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE)
+ batch_roi_labels = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long)
+ batch_roi_features = roi_features.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE,
+ roi_features.shape[-1])
+ batch_roi_points_loc = roi_points_loc.new_zeros(batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE, roi_points_loc.shape[-2],
+ roi_points_loc.shape[-1])
+ if 'roi_box_mask' in batch_dict:
+ roi_box_mask = batch_dict['roi_box_mask']
+ batch_roi_box_mask = rois.new_zeros((batch_size, self.roi_sampler_cfg.ROI_PER_IMAGE), dtype=torch.long)
+
+ for index in range(batch_size):
+ cur_roi, cur_gt, cur_roi_labels, cur_roi_scores, cur_roi_features, cur_roi_points_loc = \
+ rois[index], gt_boxes[index], roi_labels[index], roi_scores[index], \
+ roi_features[index], roi_points_loc[index]
+
+ if 'roi_box_mask' in batch_dict:
+ cur_roi_box_mask = roi_box_mask[index]
+
+ k = cur_gt.__len__() - 1
+ while k > 0 and cur_gt[k].sum() == 0:
+ k -= 1
+ cur_gt = cur_gt[:k + 1]
+ cur_gt = cur_gt.new_zeros((1, cur_gt.shape[1])) if len(cur_gt) == 0 else cur_gt
+
+ if self.roi_sampler_cfg.get('SAMPLE_ROI_BY_EACH_CLASS', False):
+ max_overlaps, gt_assignment = self.get_max_iou_with_same_class(
+ rois=cur_roi[:, :7], roi_labels=cur_roi_labels,
+ gt_boxes=cur_gt[:, 0:7], gt_labels=cur_gt[:, -1].long()
+ )
+ else:
+ iou3d = boxes_iou3d_gpu(cur_roi, cur_gt[:, 0:7]) # (M, N)
+ max_overlaps, gt_assignment = torch.max(iou3d, dim=1)
+
+ sampled_inds = self.subsample_rois(max_overlaps=max_overlaps)
+ # sampled_inds = np.arange(max_overlaps.shape[0]) # bypass subsample
+
+ batch_rois[index] = cur_roi[sampled_inds]
+ batch_roi_labels[index] = cur_roi_labels[sampled_inds]
+ batch_roi_ious[index] = max_overlaps[sampled_inds]
+ batch_roi_scores[index] = cur_roi_scores[sampled_inds]
+ batch_gt_of_rois[index] = cur_gt[gt_assignment[sampled_inds]]
+ batch_roi_features[index] = cur_roi_features[sampled_inds]
+ batch_roi_points_loc[index] = cur_roi_points_loc[sampled_inds]
+ if 'roi_box_mask' in batch_dict:
+ batch_roi_box_mask[index] = cur_roi_box_mask[sampled_inds]
+
+ if 'roi_box_mask' in batch_dict:
+ return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels, batch_roi_features, batch_roi_points_loc, batch_roi_box_mask
+ else:
+ return batch_rois, batch_gt_of_rois, batch_roi_ious, batch_roi_scores, batch_roi_labels, batch_roi_features, batch_roi_points_loc
+
+ def subsample_rois(self, max_overlaps):
+ # sample fg, easy_bg, hard_bg
+ fg_rois_per_image = int(np.round(self.roi_sampler_cfg.FG_RATIO * self.roi_sampler_cfg.ROI_PER_IMAGE))
+ fg_thresh = min(self.roi_sampler_cfg.REG_FG_THRESH, self.roi_sampler_cfg.CLS_FG_THRESH)
+
+ fg_inds = ((max_overlaps >= fg_thresh)).nonzero().view(-1)
+ easy_bg_inds = ((max_overlaps < self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1)
+ hard_bg_inds = ((max_overlaps < self.roi_sampler_cfg.REG_FG_THRESH) &
+ (max_overlaps >= self.roi_sampler_cfg.CLS_BG_THRESH_LO)).nonzero().view(-1)
+
+ fg_num_rois = fg_inds.numel()
+ bg_num_rois = hard_bg_inds.numel() + easy_bg_inds.numel()
+
+ if fg_num_rois > 0 and bg_num_rois > 0:
+ # sampling fg
+ fg_rois_per_this_image = min(fg_rois_per_image, fg_num_rois)
+
+ rand_num = torch.from_numpy(np.random.permutation(fg_num_rois)).type_as(max_overlaps).long()
+ fg_inds = fg_inds[rand_num[:fg_rois_per_this_image]]
+
+ # sampling bg
+ bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE - fg_rois_per_this_image
+ bg_inds = self.sample_bg_inds(
+ hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO
+ )
+
+ elif fg_num_rois > 0 and bg_num_rois == 0:
+ # sampling fg
+ rand_num = np.floor(np.random.rand(self.roi_sampler_cfg.ROI_PER_IMAGE) * fg_num_rois)
+ rand_num = torch.from_numpy(rand_num).type_as(max_overlaps).long()
+ fg_inds = fg_inds[rand_num]
+ bg_inds = []
+
+ elif bg_num_rois > 0 and fg_num_rois == 0:
+ # sampling bg
+ bg_rois_per_this_image = self.roi_sampler_cfg.ROI_PER_IMAGE
+ bg_inds = self.sample_bg_inds(
+ hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, self.roi_sampler_cfg.HARD_BG_RATIO
+ )
+ else:
+ print('maxoverlaps:(min=%f, max=%f)' % (max_overlaps.min().item(), max_overlaps.max().item()))
+ print('ERROR: FG=%d, BG=%d' % (fg_num_rois, bg_num_rois))
+ raise NotImplementedError
+
+ sampled_inds = torch.cat((fg_inds, bg_inds), dim=0)
+ return sampled_inds
+
+ @staticmethod
+ def sample_bg_inds(hard_bg_inds, easy_bg_inds, bg_rois_per_this_image, hard_bg_ratio):
+ if hard_bg_inds.numel() > 0 and easy_bg_inds.numel() > 0:
+ hard_bg_rois_num = min(int(bg_rois_per_this_image * hard_bg_ratio), len(hard_bg_inds))
+ easy_bg_rois_num = bg_rois_per_this_image - hard_bg_rois_num
+
+ # sampling hard bg
+ rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long()
+ hard_bg_inds = hard_bg_inds[rand_idx]
+
+ # sampling easy bg
+ rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long()
+ easy_bg_inds = easy_bg_inds[rand_idx]
+
+ bg_inds = torch.cat([hard_bg_inds, easy_bg_inds], dim=0)
+ elif hard_bg_inds.numel() > 0 and easy_bg_inds.numel() == 0:
+ hard_bg_rois_num = bg_rois_per_this_image
+ # sampling hard bg
+ rand_idx = torch.randint(low=0, high=hard_bg_inds.numel(), size=(hard_bg_rois_num,)).long()
+ bg_inds = hard_bg_inds[rand_idx]
+ elif hard_bg_inds.numel() == 0 and easy_bg_inds.numel() > 0:
+ easy_bg_rois_num = bg_rois_per_this_image
+ # sampling easy bg
+ rand_idx = torch.randint(low=0, high=easy_bg_inds.numel(), size=(easy_bg_rois_num,)).long()
+ bg_inds = easy_bg_inds[rand_idx]
+ else:
+ raise NotImplementedError
+
+ return bg_inds
+
+ @staticmethod
+ def get_max_iou_with_same_class(rois, roi_labels, gt_boxes, gt_labels):
+ """
+ Args:
+ rois: (N, 7)
+ roi_labels: (N)
+ gt_boxes: (N, )
+ gt_labels:
+
+ Returns:
+
+ """
+ """
+ :param rois: (N, 7)
+ :param roi_labels: (N)
+ :param gt_boxes: (N, 8)
+ :return:
+ """
+ max_overlaps = rois.new_zeros(rois.shape[0])
+ gt_assignment = roi_labels.new_zeros(roi_labels.shape[0])
+
+ for k in range(gt_labels.min().item(), gt_labels.max().item() + 1):
+ roi_mask = (roi_labels == k)
+ gt_mask = (gt_labels == k)
+ if roi_mask.sum() > 0 and gt_mask.sum() > 0:
+ cur_roi = rois[roi_mask]
+ cur_gt = gt_boxes[gt_mask]
+ original_gt_assignment = gt_mask.nonzero().view(-1)
+
+ iou3d = boxes_iou3d_gpu(cur_roi, cur_gt) # (M, N)
+ cur_max_overlaps, cur_gt_assignment = torch.max(iou3d, dim=1)
+ max_overlaps[roi_mask] = cur_max_overlaps
+ gt_assignment[roi_mask] = original_gt_assignment[cur_gt_assignment]
+
+ return max_overlaps, gt_assignment
diff --git a/det3d/models/second_stage/__init__.py b/det3d/models/second_stage/__init__.py
new file mode 100644
index 0000000..d5db279
--- /dev/null
+++ b/det3d/models/second_stage/__init__.py
@@ -0,0 +1 @@
+from .bird_eye_view import BEVFeatureExtractor
diff --git a/det3d/models/second_stage/bird_eye_view.py b/det3d/models/second_stage/bird_eye_view.py
new file mode 100644
index 0000000..3cbff6d
--- /dev/null
+++ b/det3d/models/second_stage/bird_eye_view.py
@@ -0,0 +1,41 @@
+import torch
+from torch import nn
+
+from ..registry import SECOND_STAGE
+from det3d.core.utils.center_utils import (
+ bilinear_interpolate_torch,
+)
+
+@SECOND_STAGE.register_module
+class BEVFeatureExtractor(nn.Module):
+ def __init__(self, pc_start,
+ voxel_size, out_stride):
+ super().__init__()
+ self.pc_start = pc_start
+ self.voxel_size = voxel_size
+ self.out_stride = out_stride
+
+ def absl_to_relative(self, absolute):
+ a1 = (absolute[..., 0] - self.pc_start[0]) / self.voxel_size[0] / self.out_stride
+ a2 = (absolute[..., 1] - self.pc_start[1]) / self.voxel_size[1] / self.out_stride
+
+ return a1, a2
+
+ def forward(self, example, batch_centers, num_point):
+ batch_size = len(example['bev_feature'])
+ ret_maps = []
+
+ for batch_idx in range(batch_size):
+ xs, ys = self.absl_to_relative(batch_centers[batch_idx])
+
+ # N x C
+ feature_map = bilinear_interpolate_torch(example['bev_feature'][batch_idx],
+ xs, ys)
+
+ if num_point > 1:
+ section_size = len(feature_map) // num_point
+ feature_map = torch.cat([feature_map[i*section_size: (i+1)*section_size] for i in range(num_point)], dim=1)
+
+ ret_maps.append(feature_map)
+
+ return ret_maps
\ No newline at end of file
diff --git a/det3d/models/utils/__init__.py b/det3d/models/utils/__init__.py
new file mode 100644
index 0000000..c960b10
--- /dev/null
+++ b/det3d/models/utils/__init__.py
@@ -0,0 +1,48 @@
+from .conv_module import ConvModule, build_conv_layer
+from .conv_ws import ConvWS2d, conv_ws_2d
+from .misc import (
+ Empty,
+ GroupNorm,
+ Sequential,
+ change_default_args,
+ get_kw_to_default_map,
+ get_paddings_indicator,
+ get_pos_to_kw_map,
+ get_printer,
+ register_hook,
+)
+from .norm import build_norm_layer
+from .scale import Scale
+from .weight_init import (
+ bias_init_with_prob,
+ kaiming_init,
+ normal_init,
+ uniform_init,
+ xavier_init,
+)
+from .transformer import Transformer, Deform_Transformer
+
+__all__ = [
+ "conv_ws_2d",
+ "ConvWS2d",
+ "build_conv_layer",
+ "ConvModule",
+ "build_norm_layer",
+ "xavier_init",
+ "normal_init",
+ "uniform_init",
+ "kaiming_init",
+ "bias_init_with_prob",
+ "Scale",
+ "Sequential",
+ "GroupNorm",
+ "Empty",
+ "get_pos_to_kw_map",
+ "get_kw_to_default_map",
+ "change_default_args",
+ "get_printer",
+ "register_hook",
+ "get_paddings_indicator",
+ "Transformer",
+ "Deform_Transformer"
+]
diff --git a/det3d/models/utils/conv_module.py b/det3d/models/utils/conv_module.py
new file mode 100644
index 0000000..ae659a5
--- /dev/null
+++ b/det3d/models/utils/conv_module.py
@@ -0,0 +1,165 @@
+import warnings
+
+import torch.nn as nn
+from det3d.torchie.cnn import constant_init, kaiming_init
+
+from .conv_ws import ConvWS2d
+from .norm import build_norm_layer
+
+conv_cfg = {
+ "Conv": nn.Conv2d,
+ "ConvWS": ConvWS2d,
+ # TODO: octave conv
+}
+
+
+def build_conv_layer(cfg, *args, **kwargs):
+ """ Build convolution layer
+ Args:
+ cfg (None or dict): cfg should contain:
+ type (str): identify conv layer type.
+ layer args: args needed to instantiate a conv layer.
+ Returns:
+ layer (nn.Module): created conv layer
+ """
+ if cfg is None:
+ cfg_ = dict(type="Conv")
+ else:
+ assert isinstance(cfg, dict) and "type" in cfg
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop("type")
+ if layer_type not in conv_cfg:
+ raise KeyError("Unrecognized norm type {}".format(layer_type))
+ else:
+ conv_layer = conv_cfg[layer_type]
+
+ layer = conv_layer(*args, **kwargs, **cfg_)
+
+ return layer
+
+
+class ConvModule(nn.Module):
+ """A conv block that contains conv/norm/activation layers.
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ conv_cfg (dict): Config dict for convolution layer.
+ norm_cfg (dict): Config dict for normalization layer.
+ activation (str or None): Activation type, "ReLU" by default.
+ inplace (bool): Whether to use inplace mode for activation.
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
+ sequence of "conv", "norm" and "act". Examples are
+ ("conv", "norm", "act") and ("act", "conv", "norm").
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias="auto",
+ conv_cfg=None,
+ norm_cfg=None,
+ activation="relu",
+ inplace=True,
+ order=("conv", "norm", "act"),
+ ):
+ super(ConvModule, self).__init__()
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.activation = activation
+ self.inplace = inplace
+ self.order = order
+ assert isinstance(self.order, tuple) and len(self.order) == 3
+ assert set(order) == set(["conv", "norm", "act"])
+
+ self.with_norm = norm_cfg is not None
+ self.with_activatation = activation is not None
+ # if the conv layer is before a norm layer, bias is unnecessary.
+ if bias == "auto":
+ bias = False if self.with_norm else True
+ self.with_bias = bias
+
+ if self.with_norm and self.with_bias:
+ warnings.warn("ConvModule has norm and bias at the same time")
+
+ # build convolution layer
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ # export the attributes of self.conv to a higher level for convenience
+ self.in_channels = self.conv.in_channels
+ self.out_channels = self.conv.out_channels
+ self.kernel_size = self.conv.kernel_size
+ self.stride = self.conv.stride
+ self.padding = self.conv.padding
+ self.dilation = self.conv.dilation
+ self.transposed = self.conv.transposed
+ self.output_padding = self.conv.output_padding
+ self.groups = self.conv.groups
+
+ # build normalization layers
+ if self.with_norm:
+ # norm layer is after conv layer
+ if order.index("norm") > order.index("conv"):
+ norm_channels = out_channels
+ else:
+ norm_channels = in_channels
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
+ self.add_module(self.norm_name, norm)
+
+ # build activation layer
+ if self.with_activatation:
+ # TODO: introduce `act_cfg` and supports more activation layers
+ if self.activation not in ["relu"]:
+ raise ValueError(
+ "{} is currently not supported.".format(self.activation)
+ )
+ if self.activation == "relu":
+ self.activate = nn.ReLU(inplace=inplace)
+
+ # Use msra init by default
+ self.init_weights()
+
+ @property
+ def norm(self):
+ return getattr(self, self.norm_name)
+
+ def init_weights(self):
+ nonlinearity = "relu" if self.activation is None else self.activation
+ kaiming_init(self.conv, nonlinearity=nonlinearity)
+ if self.with_norm:
+ constant_init(self.norm, 1, bias=0)
+
+ def forward(self, x, activate=True, norm=True):
+ for layer in self.order:
+ if layer == "conv":
+ x = self.conv(x)
+ elif layer == "norm" and norm and self.with_norm:
+ x = self.norm(x)
+ elif layer == "act" and activate and self.with_activatation:
+ x = self.activate(x)
+ return x
diff --git a/det3d/models/utils/conv_ws.py b/det3d/models/utils/conv_ws.py
new file mode 100644
index 0000000..d7abd92
--- /dev/null
+++ b/det3d/models/utils/conv_ws.py
@@ -0,0 +1,51 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def conv_ws_2d(
+ input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, eps=1e-5
+):
+ c_in = weight.size(0)
+ weight_flat = weight.view(c_in, -1)
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
+ weight = (weight - mean) / (std + eps)
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
+
+
+class ConvWS2d(nn.Conv2d):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ bias=True,
+ eps=1e-5,
+ ):
+ super(ConvWS2d, self).__init__(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias,
+ )
+ self.eps = eps
+
+ def forward(self, x):
+ return conv_ws_2d(
+ x,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.eps,
+ )
diff --git a/det3d/models/utils/finetune_utils.py b/det3d/models/utils/finetune_utils.py
new file mode 100644
index 0000000..e06cad8
--- /dev/null
+++ b/det3d/models/utils/finetune_utils.py
@@ -0,0 +1,111 @@
+import torch
+import torch.distributed as dist
+from torch import nn
+from torch.autograd.function import Function
+from torch.nn import functional as F
+import logging
+
+class FrozenBatchNorm2d(nn.Module):
+ """
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
+ It contains non-trainable buffers called
+ "weight" and "bias", "running_mean", "running_var",
+ initialized to perform identity transformation.
+ The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
+ which are computed from the original four parameters of BN.
+ The affine transform `x * weight + bias` will perform the equivalent
+ computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
+ When loading a backbone model from Caffe2, "running_mean" and "running_var"
+ will be left unchanged as identity transformation.
+ Other pre-trained backbone models may contain all 4 parameters.
+ The forward is implemented by `F.batch_norm(..., training=False)`.
+ """
+
+ _version = 3
+
+ def __init__(self, num_features, eps=1e-5):
+ super().__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.register_buffer("weight", torch.ones(num_features))
+ self.register_buffer("bias", torch.zeros(num_features))
+ self.register_buffer("running_mean", torch.zeros(num_features))
+ self.register_buffer("running_var", torch.ones(num_features) - eps)
+
+ def forward(self, x):
+ if x.requires_grad:
+ # When gradients are needed, F.batch_norm will use extra memory
+ # because its backward op computes gradients for weight/bias as well.
+ scale = self.weight * (self.running_var + self.eps).rsqrt()
+ bias = self.bias - self.running_mean * scale
+ scale = scale.reshape(1, -1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1)
+ return x * scale + bias
+ else:
+ # When gradients are not needed, F.batch_norm is a single fused op
+ # and provide more optimization opportunities.
+ return F.batch_norm(
+ x,
+ self.running_mean,
+ self.running_var,
+ self.weight,
+ self.bias,
+ training=False,
+ eps=self.eps,
+ )
+
+ def _load_from_state_dict(
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ ):
+ version = local_metadata.get("version", None)
+
+ if version is None or version < 2:
+ # No running_mean/var in early versions
+ # This will silent the warnings
+ if prefix + "running_mean" not in state_dict:
+ state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
+ if prefix + "running_var" not in state_dict:
+ state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
+
+ if version is not None and version < 3:
+ logger = logging.getLogger(__name__)
+ logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
+ # In version < 3, running_var are used without +eps.
+ state_dict[prefix + "running_var"] -= self.eps
+
+ super()._load_from_state_dict(
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
+ )
+
+ def __repr__(self):
+ return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
+
+ @classmethod
+ def convert_frozen_batchnorm(cls, module):
+ """
+ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
+ Args:
+ module (torch.nn.Module):
+ Returns:
+ If module is BatchNorm/SyncBatchNorm, returns a new module.
+ Otherwise, in-place convert module and return it.
+ Similar to convert_sync_batchnorm in
+ https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
+ """
+ bn_module = nn.modules.batchnorm
+ bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
+ res = module
+ if isinstance(module, bn_module):
+ res = cls(module.num_features)
+ if module.affine:
+ res.weight.data = module.weight.data.clone().detach()
+ res.bias.data = module.bias.data.clone().detach()
+ res.running_mean.data = module.running_mean.data
+ res.running_var.data = module.running_var.data
+ res.eps = module.eps
+ else:
+ for name, child in module.named_children():
+ new_child = cls.convert_frozen_batchnorm(child)
+ if new_child is not child:
+ res.add_module(name, new_child)
+ return res
\ No newline at end of file
diff --git a/det3d/models/utils/misc.py b/det3d/models/utils/misc.py
new file mode 100644
index 0000000..6ba6e75
--- /dev/null
+++ b/det3d/models/utils/misc.py
@@ -0,0 +1,202 @@
+import functools
+import inspect
+import sys
+from collections import OrderedDict
+
+import numba
+import numpy as np
+import torch
+
+# from lib.models.backbone.utils import Registry
+#
+# BACKBONES = Registry()
+# RPN_HEADS = Registry()
+# ROI_BOX_FEATURE_EXTRACTORS = Registry()
+# ROI_BOX_PREDICTOR = Registry()
+# ROI_KEYPOINT_FEATURE_EXTRACTORS = Registry()
+# ROI_KEYPOINT_PREDICTOR = Registry()
+# ROI_MASK_FEATURE_EXTRACTORS = Registry()
+# ROI_MASK_PREDICTOR = Registry()
+
+
+class Sequential(torch.nn.Module):
+ r"""A sequential container.
+ Modules will be added to it in the order they are passed in the constructor.
+ Alternatively, an ordered dict of modules can also be passed in.
+
+ To make it easier to understand, given is a small example::
+
+ # Example of using Sequential
+ model = Sequential(
+ nn.Conv2d(1,20,5),
+ nn.ReLU(),
+ nn.Conv2d(20,64,5),
+ nn.ReLU()
+ )
+
+ # Example of using Sequential with OrderedDict
+ model = Sequential(OrderedDict([
+ ('conv1', nn.Conv2d(1,20,5)),
+ ('relu1', nn.ReLU()),
+ ('conv2', nn.Conv2d(20,64,5)),
+ ('relu2', nn.ReLU())
+ ]))
+
+ # Example of using Sequential with kwargs(python 3.6+)
+ model = Sequential(
+ conv1=nn.Conv2d(1,20,5),
+ relu1=nn.ReLU(),
+ conv2=nn.Conv2d(20,64,5),
+ relu2=nn.ReLU()
+ )
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(Sequential, self).__init__()
+ if len(args) == 1 and isinstance(args[0], OrderedDict):
+ for key, module in args[0].items():
+ self.add_module(key, module)
+ else:
+ for idx, module in enumerate(args):
+ self.add_module(str(idx), module)
+ for name, module in kwargs.items():
+ if sys.version_info < (3, 6):
+ raise ValueError("kwargs only supported in py36+")
+ if name in self._modules:
+ raise ValueError("name exists.")
+ self.add_module(name, module)
+
+ def __getitem__(self, idx):
+ if not (-len(self) <= idx < len(self)):
+ raise IndexError("index {} is out of range".format(idx))
+ if idx < 0:
+ idx += len(self)
+ it = iter(self._modules.values())
+ for i in range(idx):
+ next(it)
+ return next(it)
+
+ def __len__(self):
+ return len(self._modules)
+
+ def add(self, module, name=None):
+ if name is None:
+ name = str(len(self._modules))
+ if name in self._modules:
+ raise KeyError("name exists")
+ self.add_module(name, module)
+
+ def forward(self, input):
+ # i = 0
+ for module in self._modules.values():
+ # print(i)
+ input = module(input)
+ # i += 1
+ return input
+
+
+class GroupNorm(torch.nn.GroupNorm):
+ def __init__(self, num_channels, num_groups, eps=1e-5, affine=True):
+ super().__init__(
+ num_groups=num_groups, num_channels=num_channels, eps=eps, affine=affine
+ )
+
+
+class Empty(torch.nn.Module):
+ def __init__(self, *args, **kwargs):
+ super(Empty, self).__init__()
+
+ def forward(self, *args, **kwargs):
+ if len(args) == 1:
+ return args[0]
+ elif len(args) == 0:
+ return None
+ return args
+
+
+def get_pos_to_kw_map(func):
+ pos_to_kw = {}
+ fsig = inspect.signature(func)
+ pos = 0
+ for name, info in fsig.parameters.items():
+ if info.kind is info.POSITIONAL_OR_KEYWORD:
+ pos_to_kw[pos] = name
+ pos += 1
+ return pos_to_kw
+
+
+def get_kw_to_default_map(func):
+ kw_to_default = {}
+ fsig = inspect.signature(func)
+ for name, info in fsig.parameters.items():
+ if info.kind is info.POSITIONAL_OR_KEYWORD:
+ if info.default is not info.empty:
+ kw_to_default[name] = info.default
+ return kw_to_default
+
+
+def change_default_args(**kwargs):
+ def layer_wrapper(layer_class):
+ class DefaultArgLayer(layer_class):
+ def __init__(self, *args, **kw):
+ pos_to_kw = get_pos_to_kw_map(layer_class.__init__)
+ kw_to_pos = {kw: pos for pos, kw in pos_to_kw.items()}
+ for key, val in kwargs.items():
+ if key not in kw and kw_to_pos[key] > len(args):
+ kw[key] = val
+ super().__init__(*args, **kw)
+
+ return DefaultArgLayer
+
+ return layer_wrapper
+
+
+def get_printer(msg):
+ """This function returns a printer function, that prints information about a tensor's
+ gradient. Used by register_hook in the backward pass.
+ """
+
+ def printer(tensor):
+ if tensor.nelement() == 1:
+ print(f"{msg} {tensor}")
+ else:
+ print(
+ f"{msg} shape: {tensor.shape}"
+ f" max: {tensor.max()} min: {tensor.min()}"
+ f" mean: {tensor.mean()}"
+ )
+
+ return printer
+
+
+def register_hook(tensor, msg):
+ """Utility function to call retain_grad and Pytorch's register_hook
+ in a single line
+ """
+ tensor.retain_grad()
+ tensor.register_hook(get_printer(msg))
+
+
+def get_paddings_indicator(actual_num, max_num, axis=0):
+ """Create boolean mask by actually number of a padded tensor.
+
+ Args:
+ actual_num ([type]): [description]
+ max_num ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+
+ actual_num = torch.unsqueeze(actual_num, axis + 1)
+ # tiled_actual_num: [N, M, 1]
+ max_num_shape = [1] * len(actual_num.shape)
+ max_num_shape[axis + 1] = -1
+ max_num = torch.arange(max_num, dtype=torch.int, device=actual_num.device).view(
+ max_num_shape
+ )
+ # tiled_actual_num: [[3,3,3,3,3], [4,4,4,4,4], [2,2,2,2,2]]
+ # tiled_max_num: [[0,1,2,3,4], [0,1,2,3,4], [0,1,2,3,4]]
+ paddings_indicator = actual_num.int() > max_num
+ # paddings_indicator shape: [batch_size, max_num]
+ return paddings_indicator
diff --git a/det3d/models/utils/norm.py b/det3d/models/utils/norm.py
new file mode 100644
index 0000000..106cd09
--- /dev/null
+++ b/det3d/models/utils/norm.py
@@ -0,0 +1,114 @@
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from det3d.utils.dist import dist_common as comm
+from torch.autograd.function import Function
+from torch.nn import BatchNorm2d
+
+
+class AllReduce(Function):
+ @staticmethod
+ def forward(ctx, input):
+ input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
+ # Use allgather instead of allreduce since I don't trust in-place operations ..
+ dist.all_gather(input_list, input, async_op=False)
+ inputs = torch.stack(input_list, dim=0)
+ return torch.sum(inputs, dim=0)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ dist.all_reduce(grad_output, async_op=False)
+ return grad_output
+
+
+class NaiveSyncBatchNorm(BatchNorm2d):
+ """
+ `torch.nn.SyncBatchNorm` has known unknown bugs.
+ It produces significantly worse AP (and sometimes goes NaN)
+ when the batch size on each worker is quite different
+ (e.g., when scale augmentation is used, or when it is applied to mask head).
+ Use this implementation before `nn.SyncBatchNorm` is fixed.
+ It is slower than `nn.SyncBatchNorm`.
+ """
+
+ def forward(self, input):
+ if comm.get_world_size() == 1 or not self.training:
+ return super().forward(input)
+
+ assert input.shape[0] > 0, "SyncBatchNorm does not support empty input"
+ C = input.shape[1]
+ mean = torch.mean(input, dim=[0, 2, 3])
+ meansqr = torch.mean(input * input, dim=[0, 2, 3])
+
+ vec = torch.cat([mean, meansqr], dim=0)
+ vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
+
+ mean, meansqr = torch.split(vec, C)
+ var = meansqr - mean * mean
+ self.running_mean += self.momentum * (mean.detach() - self.running_mean)
+ self.running_var += self.momentum * (var.detach() - self.running_var)
+
+ invstd = torch.rsqrt(var + self.eps)
+ scale = self.weight * invstd
+ bias = self.bias - mean * scale
+ scale = scale.reshape(1, -1, 1, 1)
+ bias = bias.reshape(1, -1, 1, 1)
+ return input * scale + bias
+
+
+norm_cfg = {
+ # format: layer_type: (abbreviation, module)
+ "BN": ("bn", nn.BatchNorm2d),
+ "BN1d": ("bn1d", nn.BatchNorm1d),
+ "GN": ("gn", nn.GroupNorm),
+}
+
+
+def build_norm_layer(cfg, num_features, postfix=""):
+ """ Build normalization layer
+ Args:
+ cfg (dict): cfg should contain:
+ type (str): identify norm layer type.
+ layer args: args needed to instantiate a norm layer.
+ requires_grad (bool): [optional] whether stop gradient updates
+ num_features (int): number of channels from input.
+ postfix (int, str): appended into norm abbreviation to
+ create named layer.
+ Returns:
+ name (str): abbreviation + postfix
+ layer (nn.Module): created norm layer
+ """
+ assert isinstance(cfg, dict) and "type" in cfg
+ cfg_ = cfg.copy()
+
+ layer_type = cfg_.pop("type")
+ if layer_type not in norm_cfg:
+ raise KeyError("Unrecognized norm type {}".format(layer_type))
+ else:
+ abbr, norm_layer = norm_cfg[layer_type]
+ if norm_layer is None:
+ raise NotImplementedError
+
+ assert isinstance(postfix, (int, str))
+ name = abbr + str(postfix)
+
+ requires_grad = cfg_.pop("requires_grad", True)
+ cfg_.setdefault("eps", 1e-5)
+ if layer_type != "GN":
+ layer = norm_layer(num_features, **cfg_)
+ # if layer_type == 'SyncBN':
+ # layer._specify_ddp_gpu_num(1)
+ else:
+ assert "num_groups" in cfg_
+ layer = norm_layer(num_channels=num_features, **cfg_)
+
+ for param in layer.parameters():
+ param.requires_grad = requires_grad
+
+ layer.apply(bn_weight_init)
+
+ return name, layer
+
+def bn_weight_init(m):
+ if m.weight is not None:
+ torch.nn.init.uniform_(m.weight)
diff --git a/det3d/models/utils/scale.py b/det3d/models/utils/scale.py
new file mode 100644
index 0000000..01501b5
--- /dev/null
+++ b/det3d/models/utils/scale.py
@@ -0,0 +1,11 @@
+import torch
+import torch.nn as nn
+
+
+class Scale(nn.Module):
+ def __init__(self, scale=1.0):
+ super(Scale, self).__init__()
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
+
+ def forward(self, x):
+ return x * self.scale
diff --git a/det3d/models/utils/transformer.py b/det3d/models/utils/transformer.py
new file mode 100644
index 0000000..4ba4a18
--- /dev/null
+++ b/det3d/models/utils/transformer.py
@@ -0,0 +1,406 @@
+import math
+
+import torch
+
+from torch import nn
+from torch.nn import functional as F
+from torch import batch_norm, einsum
+
+from einops import rearrange, repeat
+from det3d.models.ops.modules import MSDeformAttn
+
+
+class MLP(nn.Module):
+ """Very simple multi-layer perceptron (also called FFN)"""
+
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
+ super().__init__()
+ self.num_layers = num_layers
+ h = [hidden_dim] * (num_layers - 1)
+ self.layers = nn.ModuleList(
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
+ )
+
+ def forward(self, x):
+ for i, layer in enumerate(self.layers):
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+ return x
+
+
+class GELU(nn.Module):
+ def forward(self, x):
+ return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
+
+# transformer layer
+class PreNorm(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, **kwargs):
+ return self.fn(self.norm(x), **kwargs)
+
+
+class PreNorm_CA(nn.Module):
+ def __init__(self, dim, fn):
+ super().__init__()
+ self.norm = nn.LayerNorm(dim)
+ self.fn = fn
+
+ def forward(self, x, y, **kwargs):
+ return self.fn(self.norm(x), self.norm(y), **kwargs)
+
+
+class FeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim, dropout=0.0):
+ super().__init__()
+ self.net = nn.Sequential(
+ nn.Linear(dim, hidden_dim),
+ GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(hidden_dim, dim),
+ nn.Dropout(dropout),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+class Attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, out_attention=False):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ self.out_attention = out_attention
+
+ self.attend = nn.Softmax(dim=-1)
+ self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x):
+ b, n, _, h = *x.shape, self.heads
+ qkv = self.to_qkv(x).chunk(3, dim=-1)
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
+
+ dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+
+ attn = self.attend(dots)
+
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
+ out = rearrange(out, "b h n d -> b n (h d)")
+
+ if self.out_attention:
+ return self.to_out(out), attn
+ else:
+ return self.to_out(out)
+
+
+class Cross_attention(nn.Module):
+ def __init__(self, dim, heads=8, dim_head=64, dropout=0.0, out_attention=False):
+ super().__init__()
+ inner_dim = dim_head * heads
+ project_out = not (heads == 1 and dim_head == dim)
+
+ self.heads = heads
+ self.scale = dim_head**-0.5
+ self.out_attention = out_attention
+
+ self.attend = nn.Softmax(dim=-1)
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
+
+ self.to_out = (
+ nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
+ if project_out
+ else nn.Identity()
+ )
+
+ def forward(self, x, y):
+ b, n, m, _, h = *y.shape, self.heads
+ q = self.to_q(x)
+ kv = self.to_kv(y).chunk(2, dim=-1)
+ q = rearrange(q, "b n (h d) -> (b n) h 1 d", h=h)
+ k, v = map(lambda t: rearrange(t, "b n m (h d) -> (b n) h m d", h=h), kv)
+
+ dots = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
+
+ attn = self.attend(dots)
+
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
+ out = rearrange(out, "(b n) h 1 d -> b n (h d)", b=b)
+
+ if self.out_attention:
+ return self.to_out(out), rearrange(attn, "(b n) h i j -> b n h (i j)", b=b)
+ else:
+ return self.to_out(out)
+
+
+class DeformableTransformerCrossAttention(nn.Module):
+ def __init__(
+ self,
+ d_model=256,
+ d_head=64,
+ dropout=0.3,
+ n_levels=3,
+ n_heads=6,
+ n_points=9,
+ out_sample_loc=False,
+ ):
+ super().__init__()
+
+ # cross attention
+ self.cross_attn = MSDeformAttn(
+ d_model, d_head, n_levels, n_heads, n_points, out_sample_loc=out_sample_loc
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.out_sample_loc = out_sample_loc
+
+ @staticmethod
+ def with_pos_embed(tensor, pos):
+ return tensor if pos is None else tensor + pos
+
+ def forward(
+ self,
+ tgt,
+ src,
+ query_pos=None,
+ reference_points=None,
+ src_spatial_shapes=None,
+ level_start_index=None,
+ src_padding_mask=None,
+ ):
+ # cross attention
+ tgt2, sampling_locations = self.cross_attn(
+ self.with_pos_embed(tgt, query_pos),
+ reference_points,
+ src,
+ src_spatial_shapes,
+ level_start_index,
+ src_padding_mask,
+ )
+ tgt = self.dropout(tgt2)
+
+ if self.out_sample_loc:
+ return tgt, sampling_locations
+ else:
+ return tgt
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth=2,
+ heads=4,
+ dim_head=64,
+ mlp_dim=256,
+ dropout=0.0,
+ out_attention=False,
+ ):
+ super().__init__()
+ self.out_attention = out_attention
+ self.layers = nn.ModuleList([])
+ self.depth = depth
+
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(
+ dim,
+ Attention(
+ dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=dropout,
+ out_attention=self.out_attention,
+ ),
+ ),
+ PreNorm_CA(
+ dim,
+ Cross_attention(
+ dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=dropout,
+ out_attention=self.out_attention,
+ ),
+ ),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
+ ]
+ )
+ )
+
+ def forward(
+ self, x, pos_embedding=None, center_pos=None, y=None, neighbor_pos=None
+ ):
+ if self.out_attention:
+ out_cross_attention_list = []
+ out_self_attention_list = []
+ if center_pos is not None and pos_embedding is not None:
+ center_pos_embedding = pos_embedding(center_pos)
+ if neighbor_pos is not None and pos_embedding is not None:
+ neighbor_pos_embedding = pos_embedding(neighbor_pos)
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
+ if self.out_attention:
+ if pos_embedding is not None:
+ x_att, self_att = self_attn(x + center_pos_embedding)
+ x = x_att + x
+ x_att, cross_att = cross_attn(
+ x + center_pos_embedding, y + neighbor_pos_embedding
+ )
+ else:
+ x_att, self_att = self_attn(x)
+ x = x_att + x
+ x_att, cross_att = cross_attn(x, y)
+ out_cross_attention_list.append(cross_att)
+ else:
+ if pos_embedding is not None:
+ x_att = self_attn(x + center_pos_embedding)
+ x = x_att + x
+ x_att = cross_attn(
+ x + center_pos_embedding, y + neighbor_pos_embedding
+ )
+ else:
+ x_att = self_attn(x)
+ x = x_att + x
+ x_att = cross_attn(x, y)
+
+ x = x_att + x
+ x = ff(x) + x
+
+ out_dict = {"ct_feat": x}
+ if self.out_attention:
+ out_dict.update(
+ {"out_attention": torch.stack(out_cross_attention_list, dim=2)}
+ )
+ return out_dict
+
+
+class Deform_Transformer(nn.Module):
+ def __init__(
+ self,
+ dim,
+ levels=3,
+ depth=2,
+ heads=4,
+ dim_head=32,
+ mlp_dim=256,
+ dropout=0.0,
+ out_attention=False,
+ n_points=9,
+ ):
+ super().__init__()
+ self.out_attention = out_attention
+ self.layers = nn.ModuleList([])
+ self.depth = depth
+ self.levels = levels
+ self.n_points = n_points
+
+ for _ in range(depth):
+ self.layers.append(
+ nn.ModuleList(
+ [
+ PreNorm(
+ dim,
+ Attention(
+ dim,
+ heads=heads,
+ dim_head=dim_head,
+ dropout=dropout,
+ out_attention=self.out_attention,
+ ),
+ ),
+ PreNorm_CA(
+ dim,
+ DeformableTransformerCrossAttention(
+ dim,
+ dim_head,
+ n_levels=levels,
+ n_heads=heads,
+ dropout=dropout,
+ n_points=n_points,
+ out_sample_loc=self.out_attention,
+ ),
+ ),
+ PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)),
+ ]
+ )
+ )
+
+ def forward(
+ self, x, pos_embedding, src, src_spatial_shapes, level_start_index, center_pos
+ ):
+ if self.out_attention:
+ out_cross_attention_list = []
+ out_self_attention_list = []
+ if pos_embedding is not None:
+ center_pos_embedding = pos_embedding(center_pos)
+ reference_points = center_pos[:, :, None, :].repeat(1, 1, self.levels, 1)
+ for i, (self_attn, cross_attn, ff) in enumerate(self.layers):
+ if self.out_attention:
+ if center_pos_embedding is not None:
+ x_att, self_att = self_attn(x + center_pos_embedding)
+ x = x_att + x
+ x_att, cross_att = cross_attn(
+ x,
+ src,
+ query_pos=center_pos_embedding,
+ reference_points=reference_points,
+ src_spatial_shapes=src_spatial_shapes,
+ level_start_index=level_start_index,
+ )
+ else:
+ x_att, self_att = self_attn(x)
+ x = x_att + x
+ x_att, cross_att = cross_attn(
+ x,
+ src,
+ query_pos=None,
+ reference_points=reference_points,
+ src_spatial_shapes=src_spatial_shapes,
+ level_start_index=level_start_index,
+ )
+ out_cross_attention_list.append(cross_att)
+ else:
+ if center_pos_embedding is not None:
+ x_att = self_attn(x + center_pos_embedding)
+ x = x_att + x
+ x_att = cross_attn(
+ x,
+ src,
+ query_pos=center_pos_embedding,
+ reference_points=reference_points,
+ src_spatial_shapes=src_spatial_shapes,
+ level_start_index=level_start_index,
+ )
+ else:
+ x_att = self_attn(x)
+ x = x_att + x
+ x_att = cross_attn(
+ x,
+ src,
+ query_pos=None,
+ reference_points=reference_points,
+ src_spatial_shapes=src_spatial_shapes,
+ level_start_index=level_start_index,
+ )
+
+ x = x_att + x
+ x = ff(x) + x
+
+ out_dict = {"ct_feat": x}
+ if self.out_attention:
+ out_dict.update(
+ {"out_attention": torch.stack(out_cross_attention_list, dim=2)}
+ )
+ return out_dict
diff --git a/det3d/models/utils/weight_init.py b/det3d/models/utils/weight_init.py
new file mode 100644
index 0000000..c629cbb
--- /dev/null
+++ b/det3d/models/utils/weight_init.py
@@ -0,0 +1,42 @@
+import numpy as np
+import torch.nn as nn
+
+
+def xavier_init(module, gain=1, bias=0, distribution="normal"):
+ assert distribution in ["uniform", "normal"]
+ if distribution == "uniform":
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, "bias"):
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, "bias"):
+ nn.init.constant_(module.bias, bias)
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, "bias"):
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(
+ module, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
+):
+ assert distribution in ["uniform", "normal"]
+ if distribution == "uniform":
+ nn.init.kaiming_uniform_(module.weight, mode=mode, nonlinearity=nonlinearity)
+ else:
+ nn.init.kaiming_normal_(module.weight, mode=mode, nonlinearity=nonlinearity)
+ if hasattr(module, "bias"):
+ nn.init.constant_(module.bias, bias)
+
+
+def bias_init_with_prob(prior_prob):
+ """ initialize conv/fc bias value according to giving probablity"""
+ bias_init = float(-np.log((1 - prior_prob) / prior_prob))
+ return bias_init
diff --git a/det3d/ops/dcn/__init__.py b/det3d/ops/dcn/__init__.py
new file mode 100644
index 0000000..d3cbb67
--- /dev/null
+++ b/det3d/ops/dcn/__init__.py
@@ -0,0 +1,8 @@
+from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv,
+ ModulatedDeformConvPack, deform_conv,
+ modulated_deform_conv)
+
+__all__ = [
+ 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv',
+ 'ModulatedDeformConvPack', 'deform_conv', 'modulated_deform_conv',
+]
diff --git a/det3d/ops/dcn/deform_conv.py b/det3d/ops/dcn/deform_conv.py
new file mode 100644
index 0000000..4680cb0
--- /dev/null
+++ b/det3d/ops/dcn/deform_conv.py
@@ -0,0 +1,446 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+from torch.nn.modules.utils import _pair, _single
+
+# from mmdet.utils import print_log
+from . import deform_conv_cuda
+
+
+class DeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ weight,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ im2col_step=64):
+ if input is not None and input.dim() != 4:
+ raise ValueError(
+ 'Expected 4D tensor as input, got {}D tensor instead.'.format(
+ input.dim()))
+ ctx.stride = _pair(stride)
+ ctx.padding = _pair(padding)
+ ctx.dilation = _pair(dilation)
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.im2col_step = im2col_step
+
+ ctx.save_for_backward(input, offset, weight)
+
+ output = input.new_empty(
+ DeformConvFunction._output_size(input, weight, ctx.padding,
+ ctx.dilation, ctx.stride))
+
+ ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
+
+ if not input.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+ deform_conv_cuda.deform_conv_forward_cuda(
+ input, weight, offset, output, ctx.bufs_[0], ctx.bufs_[1],
+ weight.size(3), weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups,
+ cur_im2col_step)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ input, offset, weight = ctx.saved_tensors
+
+ grad_input = grad_offset = grad_weight = None
+
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ else:
+ cur_im2col_step = min(ctx.im2col_step, input.shape[0])
+ assert (input.shape[0] %
+ cur_im2col_step) == 0, 'im2col step must divide batchsize'
+
+ if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ deform_conv_cuda.deform_conv_backward_input_cuda(
+ input, offset, grad_output, grad_input,
+ grad_offset, weight, ctx.bufs_[0], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups,
+ cur_im2col_step)
+
+ if ctx.needs_input_grad[2]:
+ grad_weight = torch.zeros_like(weight)
+ deform_conv_cuda.deform_conv_backward_parameters_cuda(
+ input, offset, grad_output,
+ grad_weight, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
+ weight.size(2), ctx.stride[1], ctx.stride[0],
+ ctx.padding[1], ctx.padding[0], ctx.dilation[1],
+ ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
+ cur_im2col_step)
+
+ return (grad_input, grad_offset, grad_weight, None, None, None, None,
+ None)
+
+ @staticmethod
+ def _output_size(input, weight, padding, dilation, stride):
+ channels = weight.size(0)
+ output_size = (input.size(0), channels)
+ for d in range(input.dim() - 2):
+ in_size = input.size(d + 2)
+ pad = padding[d]
+ kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
+ stride_ = stride[d]
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ if not all(map(lambda s: s > 0, output_size)):
+ raise ValueError(
+ 'convolution input is too small (output would be {})'.format(
+ 'x'.join(map(str, output_size))))
+ return output_size
+
+
+class ModulatedDeformConvFunction(Function):
+
+ @staticmethod
+ def forward(ctx,
+ input,
+ offset,
+ mask,
+ weight,
+ bias=None,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1):
+ ctx.stride = stride
+ ctx.padding = padding
+ ctx.dilation = dilation
+ ctx.groups = groups
+ ctx.deformable_groups = deformable_groups
+ ctx.with_bias = bias is not None
+ if not ctx.with_bias:
+ bias = input.new_empty(1) # fake tensor
+ if not input.is_cuda:
+ raise NotImplementedError
+ if weight.requires_grad or mask.requires_grad or offset.requires_grad \
+ or input.requires_grad:
+ ctx.save_for_backward(input, offset, mask, weight, bias)
+ output = input.new_empty(
+ ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
+ ctx._bufs = [input.new_empty(0), input.new_empty(0)]
+ deform_conv_cuda.modulated_deform_conv_cuda_forward(
+ input, weight, bias, ctx._bufs[0], offset, mask, output,
+ ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ return output
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_output):
+ if not grad_output.is_cuda:
+ raise NotImplementedError
+ input, offset, mask, weight, bias = ctx.saved_tensors
+ grad_input = torch.zeros_like(input)
+ grad_offset = torch.zeros_like(offset)
+ grad_mask = torch.zeros_like(mask)
+ grad_weight = torch.zeros_like(weight)
+ grad_bias = torch.zeros_like(bias)
+ deform_conv_cuda.modulated_deform_conv_cuda_backward(
+ input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
+ grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
+ grad_output, weight.shape[2], weight.shape[3], ctx.stride,
+ ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
+ ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ if not ctx.with_bias:
+ grad_bias = None
+
+ return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias,
+ None, None, None, None, None)
+
+ @staticmethod
+ def _infer_shape(ctx, input, weight):
+ n = input.size(0)
+ channels_out = weight.size(0)
+ height, width = input.shape[2:4]
+ kernel_h, kernel_w = weight.shape[2:4]
+ height_out = (height + 2 * ctx.padding -
+ (ctx.dilation * (kernel_h - 1) + 1)) // ctx.stride + 1
+ width_out = (width + 2 * ctx.padding -
+ (ctx.dilation * (kernel_w - 1) + 1)) // ctx.stride + 1
+ return n, channels_out, height_out, width_out
+
+
+deform_conv = DeformConvFunction.apply
+modulated_deform_conv = ModulatedDeformConvFunction.apply
+
+
+class DeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False):
+ super(DeformConv, self).__init__()
+
+ assert not bias
+ assert in_channels % groups == 0, \
+ 'in_channels {} cannot be divisible by groups {}'.format(
+ in_channels, groups)
+ assert out_channels % groups == 0, \
+ 'out_channels {} cannot be divisible by groups {}'.format(
+ out_channels, groups)
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = _pair(stride)
+ self.padding = _pair(padding)
+ self.dilation = _pair(dilation)
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // self.groups,
+ *self.kernel_size))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+
+ def forward(self, x, offset):
+ # To fix an assert error in deform_conv_cuda.cpp:128
+ # input image is smaller than kernel
+ input_pad = (
+ x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ if input_pad:
+ pad_h = max(self.kernel_size[0] - x.size(2), 0)
+ pad_w = max(self.kernel_size[1] - x.size(3), 0)
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
+ offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant',
+ 0).contiguous()
+ out = deform_conv(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+ if input_pad:
+ out = out[:, :, :out.size(2) - pad_h, :out.size(3) -
+ pad_w].contiguous()
+ return out
+
+
+class DeformConvPack(DeformConv):
+ """A Deformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(DeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 2 * self.kernel_size[0] *
+ self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ offset = self.conv_offset(x)
+ return deform_conv(x, offset, self.weight, self.stride, self.padding,
+ self.dilation, self.groups, self.deformable_groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, DeformConvPack loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ 'DeformConvPack {} is upgraded to version 2.'.format(
+ prefix.rstrip('.')),
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
+
+
+class ModulatedDeformConv(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True):
+ super(ModulatedDeformConv, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deformable_groups = deformable_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels // groups,
+ *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+
+class ModulatedDeformConvPack(ModulatedDeformConv):
+ """A ModulatedDeformable Conv Encapsulation that acts as normal Conv layers.
+
+ Args:
+ in_channels (int): Same as nn.Conv2d.
+ out_channels (int): Same as nn.Conv2d.
+ kernel_size (int or tuple[int]): Same as nn.Conv2d.
+ stride (int or tuple[int]): Same as nn.Conv2d.
+ padding (int or tuple[int]): Same as nn.Conv2d.
+ dilation (int or tuple[int]): Same as nn.Conv2d.
+ groups (int): Same as nn.Conv2d.
+ bias (bool or str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias will be set as True if norm_cfg is None, otherwise
+ False.
+ """
+
+ _version = 2
+
+ def __init__(self, *args, **kwargs):
+ super(ModulatedDeformConvPack, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Conv2d(
+ self.in_channels,
+ self.deformable_groups * 3 * self.kernel_size[0] *
+ self.kernel_size[1],
+ kernel_size=self.kernel_size,
+ stride=_pair(self.stride),
+ padding=_pair(self.padding),
+ bias=True)
+ self.init_offset()
+
+ def init_offset(self):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x):
+ out = self.conv_offset(x)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+ offset = torch.cat((o1, o2), dim=1)
+ mask = torch.sigmoid(mask)
+ return modulated_deform_conv(x, offset, mask, self.weight, self.bias,
+ self.stride, self.padding, self.dilation,
+ self.groups, self.deformable_groups)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ version = local_metadata.get('version', None)
+
+ if version is None or version < 2:
+ # the key is different in early versions
+ # In version < 2, ModulatedDeformConvPack
+ # loads previous benchmark models.
+ if (prefix + 'conv_offset.weight' not in state_dict
+ and prefix[:-1] + '_offset.weight' in state_dict):
+ state_dict[prefix + 'conv_offset.weight'] = state_dict.pop(
+ prefix[:-1] + '_offset.weight')
+ if (prefix + 'conv_offset.bias' not in state_dict
+ and prefix[:-1] + '_offset.bias' in state_dict):
+ state_dict[prefix +
+ 'conv_offset.bias'] = state_dict.pop(prefix[:-1] +
+ '_offset.bias')
+
+ if version is not None and version > 1:
+ print_log(
+ 'ModulatedDeformConvPack {} is upgraded to version 2.'.format(
+ prefix.rstrip('.')),
+ logger='root')
+
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
diff --git a/det3d/ops/dcn/setup.py b/det3d/ops/dcn/setup.py
new file mode 100644
index 0000000..0f56393
--- /dev/null
+++ b/det3d/ops/dcn/setup.py
@@ -0,0 +1,20 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='masked_conv',
+ ext_modules=[
+ CUDAExtension('deform_conv_cuda', [
+ 'src/deform_conv_cuda.cpp',
+ 'src/deform_conv_cuda_kernel.cu',
+ ],
+ define_macros=[('WITH_CUDA', None)],
+ extra_compile_args={
+ 'cxx': [],
+ 'nvcc': [
+ '-D__CUDA_NO_HALF_OPERATORS__',
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
+ '-D__CUDA_NO_HALF2_OPERATORS__',
+ ]})],
+ cmdclass={'build_ext': BuildExtension})
+
diff --git a/det3d/ops/dcn/src/deform_conv_cuda.cpp b/det3d/ops/dcn/src/deform_conv_cuda.cpp
new file mode 100644
index 0000000..2321e02
--- /dev/null
+++ b/det3d/ops/dcn/src/deform_conv_cuda.cpp
@@ -0,0 +1,701 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c
+
+#include
+#include
+
+#include
+#include
+
+void deformable_im2col(const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor data_col);
+
+void deformable_col2im(const at::Tensor data_col, const at::Tensor data_offset,
+ const int channels, const int height, const int width,
+ const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im);
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const int channels, const int height,
+ const int width, const int ksize_h, const int ksize_w, const int pad_h,
+ const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor grad_offset);
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor data_col);
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset,
+ const at::Tensor data_mask, const int batch_size, const int channels,
+ const int height_im, const int width_im, const int height_col,
+ const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int deformable_group,
+ at::Tensor grad_im);
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im,
+ const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im,
+ const int width_im, const int height_col, const int width_col,
+ const int kernel_h, const int kenerl_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w, const int dilation_h,
+ const int dilation_w, const int deformable_group, at::Tensor grad_offset,
+ at::Tensor grad_mask);
+
+void shape_check(at::Tensor input, at::Tensor offset, at::Tensor *gradOutput,
+ at::Tensor weight, int kH, int kW, int dH, int dW, int padH,
+ int padW, int dilationH, int dilationW, int group,
+ int deformable_group) {
+ AT_CHECK(weight.ndimension() == 4,
+ "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, "
+ "but got: %s",
+ weight.ndimension());
+
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+
+ AT_CHECK(kW > 0 && kH > 0,
+ "kernel size should be greater than zero, but got kH: %d kW: %d", kH,
+ kW);
+
+ AT_CHECK((weight.size(2) == kH && weight.size(3) == kW),
+ "kernel size should be consistent with weight, ",
+ "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", kH,
+ kW, weight.size(2), weight.size(3));
+
+ AT_CHECK(dW > 0 && dH > 0,
+ "stride should be greater than zero, but got dH: %d dW: %d", dH, dW);
+
+ AT_CHECK(
+ dilationW > 0 && dilationH > 0,
+ "dilation should be greater than 0, but got dilationH: %d dilationW: %d",
+ dilationH, dilationW);
+
+ int ndim = input.ndimension();
+ int dimf = 0;
+ int dimh = 1;
+ int dimw = 2;
+
+ if (ndim == 4) {
+ dimf++;
+ dimh++;
+ dimw++;
+ }
+
+ AT_CHECK(ndim == 3 || ndim == 4, "3D or 4D input tensor expected but got: %s",
+ ndim);
+
+ long nInputPlane = weight.size(1) * group;
+ long inputHeight = input.size(dimh);
+ long inputWidth = input.size(dimw);
+ long nOutputPlane = weight.size(0);
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+
+ AT_CHECK(nInputPlane % deformable_group == 0,
+ "input channels must divide deformable group size");
+
+ if (outputWidth < 1 || outputHeight < 1)
+ AT_ERROR(
+ "Given input size: (%ld x %ld x %ld). "
+ "Calculated output size: (%ld x %ld x %ld). Output size is too small",
+ nInputPlane, inputHeight, inputWidth, nOutputPlane, outputHeight,
+ outputWidth);
+
+ AT_CHECK(input.size(1) == nInputPlane,
+ "invalid number of input planes, expected: %d, but got: %d",
+ nInputPlane, input.size(1));
+
+ AT_CHECK((inputHeight >= kH && inputWidth >= kW),
+ "input image is smaller than kernel");
+
+ AT_CHECK((offset.size(2) == outputHeight && offset.size(3) == outputWidth),
+ "invalid spatial size of offset, expected height: %d width: %d, but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, offset.size(2), offset.size(3));
+
+ AT_CHECK((offset.size(1) == deformable_group * 2 * kH * kW),
+ "invalid number of channels of offset");
+
+ if (gradOutput != NULL) {
+ AT_CHECK(gradOutput->size(dimf) == nOutputPlane,
+ "invalid number of gradOutput planes, expected: %d, but got: %d",
+ nOutputPlane, gradOutput->size(dimf));
+
+ AT_CHECK((gradOutput->size(dimh) == outputHeight &&
+ gradOutput->size(dimw) == outputWidth),
+ "invalid size of gradOutput, expected height: %d width: %d , but "
+ "got height: %d width: %d",
+ outputHeight, outputWidth, gradOutput->size(dimh),
+ gradOutput->size(dimw));
+ }
+}
+
+int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight,
+ at::Tensor offset, at::Tensor output,
+ at::Tensor columns, at::Tensor ones, int kW,
+ int kH, int dW, int dH, int padW, int padH,
+ int dilationW, int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ // todo: resize columns to include im2col: done
+ // todo: add im2col_step as input
+ // todo: add new output buffer and transpose it to output (or directly
+ // transpose output) todo: possibly change data indexing because of
+ // parallel_imgs
+
+ shape_check(input, offset, NULL, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input.unsqueeze_(0);
+ offset.unsqueeze_(0);
+ }
+
+ // todo: assert batchsize dividable by im2col_step
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ output = output.view({batchSize / im2col_step, im2col_step, nOutputPlane,
+ outputHeight, outputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < outputHeight * outputWidth) {
+ ones = at::ones({outputHeight, outputWidth}, input.options());
+ }
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ at::Tensor output_buffer =
+ at::zeros({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth},
+ output.options());
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), group, output_buffer.size(1) / group,
+ output_buffer.size(2), output_buffer.size(3)});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ output_buffer[elt][g] = output_buffer[elt][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output_buffer[elt][g]);
+ }
+ }
+
+ output_buffer = output_buffer.view(
+ {output_buffer.size(0), output_buffer.size(1) * output_buffer.size(2),
+ output_buffer.size(3), output_buffer.size(4)});
+
+ output_buffer = output_buffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step, outputHeight, outputWidth});
+ output_buffer.transpose_(1, 2);
+ output.copy_(output_buffer);
+ output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ output = output.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset,
+ at::Tensor gradOutput, at::Tensor gradInput,
+ at::Tensor gradOffset, at::Tensor weight,
+ at::Tensor columns, int kW, int kH, int dW,
+ int dH, int padW, int padH, int dilationW,
+ int dilationH, int group,
+ int deformable_group, int im2col_step) {
+ shape_check(input, offset, &gradOutput, weight, kH, kW, dH, dW, padH, padW,
+ dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+ weight = weight.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view({1, input.size(0), input.size(1), input.size(2)});
+ offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)});
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = weight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ AT_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset");
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ // change order of grad output
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ gradInput = gradInput.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ gradOffset = gradOffset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight,
+ outputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ // divide into groups
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), group, gradOutput.size(1) / group,
+ gradOutput.size(2), gradOutput.size(3), gradOutput.size(4)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g] = columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ gradOutput[elt][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradOutput = gradOutput.view(
+ {gradOutput.size(0), gradOutput.size(1) * gradOutput.size(2),
+ gradOutput.size(3), gradOutput.size(4), gradOutput.size(5)});
+
+ deformable_col2im_coord(columns, input[elt], offset[elt], nInputPlane,
+ inputHeight, inputWidth, kH, kW, padH, padW, dH, dW,
+ dilationH, dilationW, im2col_step, deformable_group,
+ gradOffset[elt]);
+
+ deformable_col2im(columns, offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, gradInput[elt]);
+ }
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ gradOffset = gradOffset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth});
+ offset = offset.view({offset.size(1), offset.size(2), offset.size(3)});
+ gradOffset =
+ gradOffset.view({offset.size(1), offset.size(2), offset.size(3)});
+ }
+
+ return 1;
+}
+
+int deform_conv_backward_parameters_cuda(
+ at::Tensor input, at::Tensor offset, at::Tensor gradOutput,
+ at::Tensor gradWeight, // at::Tensor gradBias,
+ at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH,
+ int padW, int padH, int dilationW, int dilationH, int group,
+ int deformable_group, float scale, int im2col_step) {
+ // todo: transpose and reshape outGrad
+ // todo: reshape columns
+ // todo: add im2col_step as input
+
+ shape_check(input, offset, &gradOutput, gradWeight, kH, kW, dH, dW, padH,
+ padW, dilationH, dilationW, group, deformable_group);
+ at::DeviceGuard guard(input.device());
+
+ input = input.contiguous();
+ offset = offset.contiguous();
+ gradOutput = gradOutput.contiguous();
+
+ int batch = 1;
+
+ if (input.ndimension() == 3) {
+ // Force batch
+ batch = 0;
+ input = input.view(
+ at::IntList({1, input.size(0), input.size(1), input.size(2)}));
+ gradOutput = gradOutput.view(
+ {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)});
+ }
+
+ long batchSize = input.size(0);
+ long nInputPlane = input.size(1);
+ long inputHeight = input.size(2);
+ long inputWidth = input.size(3);
+
+ long nOutputPlane = gradWeight.size(0);
+
+ long outputWidth =
+ (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1;
+ long outputHeight =
+ (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1;
+
+ AT_CHECK((offset.size(0) == batchSize), "invalid batch size of offset");
+
+ columns = at::zeros(
+ {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth},
+ input.options());
+
+ gradOutput = gradOutput.view({batchSize / im2col_step, im2col_step,
+ nOutputPlane, outputHeight, outputWidth});
+ gradOutput.transpose_(1, 2);
+
+ at::Tensor gradOutputBuffer = at::zeros_like(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane, im2col_step,
+ outputHeight, outputWidth});
+ gradOutputBuffer.copy_(gradOutput);
+ gradOutputBuffer =
+ gradOutputBuffer.view({batchSize / im2col_step, nOutputPlane,
+ im2col_step * outputHeight, outputWidth});
+
+ gradOutput.transpose_(1, 2);
+ gradOutput =
+ gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth});
+
+ input = input.view({batchSize / im2col_step, im2col_step, nInputPlane,
+ inputHeight, inputWidth});
+ offset =
+ offset.view({batchSize / im2col_step, im2col_step,
+ deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ for (int elt = 0; elt < batchSize / im2col_step; elt++) {
+ deformable_im2col(input[elt], offset[elt], nInputPlane, inputHeight,
+ inputWidth, kH, kW, padH, padW, dH, dW, dilationH,
+ dilationW, im2col_step, deformable_group, columns);
+
+ // divide into group
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0), group, gradOutputBuffer.size(1) / group,
+ gradOutputBuffer.size(2), gradOutputBuffer.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ gradWeight =
+ gradWeight.view({group, gradWeight.size(0) / group, gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ gradWeight[g] = gradWeight[g]
+ .flatten(1)
+ .addmm_(gradOutputBuffer[elt][g].flatten(1),
+ columns[g].transpose(1, 0), 1.0, scale)
+ .view_as(gradWeight[g]);
+ }
+ gradOutputBuffer = gradOutputBuffer.view(
+ {gradOutputBuffer.size(0),
+ gradOutputBuffer.size(1) * gradOutputBuffer.size(2),
+ gradOutputBuffer.size(3), gradOutputBuffer.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ gradWeight = gradWeight.view({gradWeight.size(0) * gradWeight.size(1),
+ gradWeight.size(2), gradWeight.size(3),
+ gradWeight.size(4)});
+ }
+
+ input = input.view({batchSize, nInputPlane, inputHeight, inputWidth});
+ offset = offset.view(
+ {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth});
+
+ if (batch == 0) {
+ gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth});
+ input = input.view({nInputPlane, inputHeight, inputWidth});
+ }
+
+ return 1;
+}
+
+void modulated_deform_conv_cuda_forward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns,
+ int kernel_h, int kernel_w, const int stride_h, const int stride_w,
+ const int pad_h, const int pad_w, const int dilation_h,
+ const int dilation_w, const int group, const int deformable_group,
+ const bool with_bias) {
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_out = weight.size(0);
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ // resize output
+ output = output.view({batch, channels_out, height_out, width_out}).zero_();
+ // resize temporary columns
+ columns =
+ at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out},
+ input.options());
+
+ output = output.view({output.size(0), group, output.size(1) / group,
+ output.size(2), output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ // divide into group
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+
+ for (int g = 0; g < group; g++) {
+ output[b][g] = output[b][g]
+ .flatten(1)
+ .addmm_(weight[g].flatten(1), columns[g])
+ .view_as(output[b][g]);
+ }
+
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ }
+
+ output = output.view({output.size(0), output.size(1) * output.size(2),
+ output.size(3), output.size(4)});
+
+ if (with_bias) {
+ output += bias.view({1, bias.size(0), 1, 1});
+ }
+}
+
+void modulated_deform_conv_cuda_backward(
+ at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones,
+ at::Tensor offset, at::Tensor mask, at::Tensor columns,
+ at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias,
+ at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output,
+ int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h,
+ int pad_w, int dilation_h, int dilation_w, int group, int deformable_group,
+ const bool with_bias) {
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ AT_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+
+ const int channels_kernel = weight.size(1);
+ const int kernel_h_ = weight.size(2);
+ const int kernel_w_ = weight.size(3);
+ if (kernel_h_ != kernel_h || kernel_w_ != kernel_w)
+ AT_ERROR("Input shape and kernel shape wont match: (%d x %d vs %d x %d).",
+ kernel_h_, kernel_w, kernel_h_, kernel_w_);
+ if (channels != channels_kernel * group)
+ AT_ERROR("Input shape and kernel channels wont match: (%d vs %d).",
+ channels, channels_kernel * group);
+
+ const int height_out =
+ (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
+ const int width_out =
+ (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
+
+ if (ones.ndimension() != 2 ||
+ ones.size(0) * ones.size(1) < height_out * width_out) {
+ // Resize plane and fill with ones...
+ ones = at::ones({height_out, width_out}, input.options());
+ }
+
+ grad_input = grad_input.view({batch, channels, height, width});
+ columns = at::zeros({channels * kernel_h * kernel_w, height_out * width_out},
+ input.options());
+
+ grad_output =
+ grad_output.view({grad_output.size(0), group, grad_output.size(1) / group,
+ grad_output.size(2), grad_output.size(3)});
+
+ for (int b = 0; b < batch; b++) {
+ // divide int group
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ weight = weight.view({group, weight.size(0) / group, weight.size(1),
+ weight.size(2), weight.size(3)});
+
+ for (int g = 0; g < group; g++) {
+ columns[g].addmm_(weight[g].flatten(1).transpose(0, 1),
+ grad_output[b][g].flatten(1), 0.0f, 1.0f);
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ weight = weight.view({weight.size(0) * weight.size(1), weight.size(2),
+ weight.size(3), weight.size(4)});
+
+ // gradient w.r.t. input coordinate data
+ modulated_deformable_col2im_coord_cuda(
+ columns, input[b], offset[b], mask[b], 1, channels, height, width,
+ height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h,
+ stride_w, dilation_h, dilation_w, deformable_group, grad_offset[b],
+ grad_mask[b]);
+ // gradient w.r.t. input data
+ modulated_deformable_col2im_cuda(
+ columns, offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, grad_input[b]);
+
+ // gradient w.r.t. weight, dWeight should accumulate across the batch and
+ // group
+ modulated_deformable_im2col_cuda(
+ input[b], offset[b], mask[b], 1, channels, height, width, height_out,
+ width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, deformable_group, columns);
+
+ columns = columns.view({group, columns.size(0) / group, columns.size(1)});
+ grad_weight = grad_weight.view({group, grad_weight.size(0) / group,
+ grad_weight.size(1), grad_weight.size(2),
+ grad_weight.size(3)});
+ if (with_bias)
+ grad_bias = grad_bias.view({group, grad_bias.size(0) / group});
+
+ for (int g = 0; g < group; g++) {
+ grad_weight[g] =
+ grad_weight[g]
+ .flatten(1)
+ .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1))
+ .view_as(grad_weight[g]);
+ if (with_bias) {
+ grad_bias[g] =
+ grad_bias[g]
+ .view({-1, 1})
+ .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1}))
+ .view(-1);
+ }
+ }
+
+ columns =
+ columns.view({columns.size(0) * columns.size(1), columns.size(2)});
+ grad_weight = grad_weight.view({grad_weight.size(0) * grad_weight.size(1),
+ grad_weight.size(2), grad_weight.size(3),
+ grad_weight.size(4)});
+ if (with_bias)
+ grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)});
+ }
+ grad_output = grad_output.view({grad_output.size(0) * grad_output.size(1),
+ grad_output.size(2), grad_output.size(3),
+ grad_output.size(4)});
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_conv_forward_cuda", &deform_conv_forward_cuda,
+ "deform forward (CUDA)");
+ m.def("deform_conv_backward_input_cuda", &deform_conv_backward_input_cuda,
+ "deform_conv_backward_input (CUDA)");
+ m.def("deform_conv_backward_parameters_cuda",
+ &deform_conv_backward_parameters_cuda,
+ "deform_conv_backward_parameters (CUDA)");
+ m.def("modulated_deform_conv_cuda_forward",
+ &modulated_deform_conv_cuda_forward,
+ "modulated deform conv forward (CUDA)");
+ m.def("modulated_deform_conv_cuda_backward",
+ &modulated_deform_conv_cuda_backward,
+ "modulated deform conv backward (CUDA)");
+}
diff --git a/det3d/ops/dcn/src/deform_conv_cuda_kernel.cu b/det3d/ops/dcn/src/deform_conv_cuda_kernel.cu
new file mode 100644
index 0000000..e7a26f2
--- /dev/null
+++ b/det3d/ops/dcn/src/deform_conv_cuda_kernel.cu
@@ -0,0 +1,867 @@
+/*!
+ ******************* BEGIN Caffe Copyright Notice and Disclaimer ****************
+ *
+ * COPYRIGHT
+ *
+ * All contributions by the University of California:
+ * Copyright (c) 2014-2017 The Regents of the University of California (Regents)
+ * All rights reserved.
+ *
+ * All other contributions:
+ * Copyright (c) 2014-2017, the respective contributors
+ * All rights reserved.
+ *
+ * Caffe uses a shared copyright model: each contributor holds copyright over
+ * their contributions to Caffe. The project versioning records all such
+ * contribution and copyright details. If a contributor wants to further mark
+ * their specific copyright on a particular contribution, they should indicate
+ * their copyright solely in the commit message of the change when it is
+ * committed.
+ *
+ * LICENSE
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ * CONTRIBUTION AGREEMENT
+ *
+ * By contributing to the BVLC/caffe repository through pull-request, comment,
+ * or otherwise, the contributor releases their content to the
+ * license and copyright terms herein.
+ *
+ ***************** END Caffe Copyright Notice and Disclaimer ********************
+ *
+ * Copyright (c) 2018 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file modulated_deformable_im2col.cuh
+ * \brief Function definitions of converting an image to
+ * column matrix based on kernel, padding, dilation, and offset.
+ * These functions are mainly used in deformable convolution operators.
+ * \ref: https://arxiv.org/abs/1703.06211
+ * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng
+ */
+
+// modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+const int kMaxGridNum = 65535;
+
+inline int GET_BLOCKS(const int N)
+{
+ return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
+}
+
+template
+__device__ scalar_t deformable_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void deformable_im2col_gpu_kernel(const int n, const scalar_t *data_im, const scalar_t *data_offset,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const scalar_t map_h = i * dilation_h + offset_h;
+ //const scalar_t map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = deformable_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val;
+ data_col_ptr += batch_size * height_col * width_col;
+ }
+ }
+ }
+}
+
+void deformable_im2col(
+ const at::Tensor data_im, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h, const int ksize_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w, const int parallel_imgs,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ // todo: check parallel_imgs is correctly passed in
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ scalar_t *data_col_ = data_col.data();
+
+ deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, height, width, ksize_h, ksize_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
+ channel_per_deformable_group, parallel_imgs, channels, deformable_group,
+ height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_im2col: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_gpu_kernel(
+ const int n, const scalar_t *data_col, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) *
+ 2 * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index];
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+void deformable_col2im(
+ const at::Tensor data_col, const at::Tensor data_offset, const int channels,
+ const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group,
+ at::Tensor grad_im)
+{
+
+ // todo: make sure parallel_imgs is passed in correctly
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs;
+ int channel_per_deformable_group = channels / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ scalar_t *grad_im_ = grad_im.data();
+
+ deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, channels, height, width, ksize_h,
+ ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in deformable_col2im: %s\n", cudaGetErrorString(err));
+ }
+}
+
+template
+__global__ void deformable_col2im_coord_gpu_kernel(const int n, const scalar_t *data_col,
+ const scalar_t *data_im, const scalar_t *data_offset,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col, scalar_t *grad_offset)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group *
+ batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) *
+ channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 *
+ kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ const scalar_t weight = get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos];
+ cnt += 1;
+ }
+
+ grad_offset[index] = val;
+ }
+}
+
+void deformable_col2im_coord(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset,
+ const int channels, const int height, const int width, const int ksize_h,
+ const int ksize_w, const int pad_h, const int pad_w, const int stride_h,
+ const int stride_w, const int dilation_h, const int dilation_w,
+ const int parallel_imgs, const int deformable_group, at::Tensor grad_offset)
+{
+
+ int height_col = (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1;
+ int width_col = (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1;
+ int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * deformable_group * parallel_imgs;
+ int channel_per_deformable_group = channels * ksize_h * ksize_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ scalar_t *grad_offset_ = grad_offset.data();
+
+ deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, channels, height, width,
+ ksize_h, ksize_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ parallel_imgs, 2 * ksize_h * ksize_w * deformable_group, deformable_group,
+ height_col, width_col, grad_offset_);
+ }));
+}
+
+template
+__device__ scalar_t dmcn_im2col_bilinear(const scalar_t *bottom_data, const int data_width,
+ const int height, const int width, scalar_t h, scalar_t w)
+{
+ int h_low = floor(h);
+ int w_low = floor(w);
+ int h_high = h_low + 1;
+ int w_high = w_low + 1;
+
+ scalar_t lh = h - h_low;
+ scalar_t lw = w - w_low;
+ scalar_t hh = 1 - lh, hw = 1 - lw;
+
+ scalar_t v1 = 0;
+ if (h_low >= 0 && w_low >= 0)
+ v1 = bottom_data[h_low * data_width + w_low];
+ scalar_t v2 = 0;
+ if (h_low >= 0 && w_high <= width - 1)
+ v2 = bottom_data[h_low * data_width + w_high];
+ scalar_t v3 = 0;
+ if (h_high <= height - 1 && w_low >= 0)
+ v3 = bottom_data[h_high * data_width + w_low];
+ scalar_t v4 = 0;
+ if (h_high <= height - 1 && w_high <= width - 1)
+ v4 = bottom_data[h_high * data_width + w_high];
+
+ scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+ scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+ return val;
+}
+
+template
+__device__ scalar_t dmcn_get_gradient_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int h, const int w, const int height, const int width)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+ if (h == argmax_h_low && w == argmax_w_low)
+ weight = (h + 1 - argmax_h) * (w + 1 - argmax_w);
+ if (h == argmax_h_low && w == argmax_w_high)
+ weight = (h + 1 - argmax_h) * (argmax_w + 1 - w);
+ if (h == argmax_h_high && w == argmax_w_low)
+ weight = (argmax_h + 1 - h) * (w + 1 - argmax_w);
+ if (h == argmax_h_high && w == argmax_w_high)
+ weight = (argmax_h + 1 - h) * (argmax_w + 1 - w);
+ return weight;
+}
+
+template
+__device__ scalar_t dmcn_get_coordinate_weight(scalar_t argmax_h, scalar_t argmax_w,
+ const int height, const int width, const scalar_t *im_data,
+ const int data_width, const int bp_dir)
+{
+ if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width)
+ {
+ //empty
+ return 0;
+ }
+
+ int argmax_h_low = floor(argmax_h);
+ int argmax_w_low = floor(argmax_w);
+ int argmax_h_high = argmax_h_low + 1;
+ int argmax_w_high = argmax_w_low + 1;
+
+ scalar_t weight = 0;
+
+ if (bp_dir == 0)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+ else if (bp_dir == 1)
+ {
+ if (argmax_h_low >= 0 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low];
+ if (argmax_h_low >= 0 && argmax_w_high <= width - 1)
+ weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high];
+ if (argmax_h_high <= height - 1 && argmax_w_low >= 0)
+ weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low];
+ if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1)
+ weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high];
+ }
+
+ return weight;
+}
+
+template
+__global__ void modulated_deformable_im2col_gpu_kernel(const int n,
+ const scalar_t *data_im, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int height, const int width, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int num_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *data_col)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ // index index of output matrix
+ const int w_col = index % width_col;
+ const int h_col = (index / width_col) % height_col;
+ const int b_col = (index / width_col / height_col) % batch_size;
+ const int c_im = (index / width_col / height_col) / batch_size;
+ const int c_col = c_im * kernel_h * kernel_w;
+
+ // compute deformable group index
+ const int deformable_group_index = c_im / channel_per_deformable_group;
+
+ const int h_in = h_col * stride_h - pad_h;
+ const int w_in = w_col * stride_w - pad_w;
+
+ scalar_t *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col;
+ //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in;
+ const scalar_t *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+
+ const scalar_t *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ for (int i = 0; i < kernel_h; ++i)
+ {
+ for (int j = 0; j < kernel_w; ++j)
+ {
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t val = static_cast(0);
+ const scalar_t h_im = h_in + i * dilation_h + offset_h;
+ const scalar_t w_im = w_in + j * dilation_w + offset_w;
+ //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) {
+ if (h_im > -1 && w_im > -1 && h_im < height && w_im < width)
+ {
+ //const float map_h = i * dilation_h + offset_h;
+ //const float map_w = j * dilation_w + offset_w;
+ //const int cur_height = height - h_in;
+ //const int cur_width = width - w_in;
+ //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w);
+ val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im);
+ }
+ *data_col_ptr = val * mask;
+ data_col_ptr += batch_size * height_col * width_col;
+ //data_col_ptr += height_col * width_col;
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_im)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ const int j = (index / width_col / height_col / batch_size) % kernel_w;
+ const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / channel_per_deformable_group;
+
+ int w_out = index % width_col;
+ int h_out = (index / width_col) % height_col;
+ int b = (index / width_col / height_col) % batch_size;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+ const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out;
+ const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out;
+ const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out;
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h;
+ const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w;
+
+ const scalar_t cur_top_grad = data_col[index] * mask;
+ const int cur_h = (int)cur_inv_h_data;
+ const int cur_w = (int)cur_inv_w_data;
+ for (int dy = -2; dy <= 2; dy++)
+ {
+ for (int dx = -2; dx <= 2; dx++)
+ {
+ if (cur_h + dy >= 0 && cur_h + dy < height &&
+ cur_w + dx >= 0 && cur_w + dx < width &&
+ abs(cur_inv_h_data - (cur_h + dy)) < 1 &&
+ abs(cur_inv_w_data - (cur_w + dx)) < 1)
+ {
+ int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx;
+ scalar_t weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width);
+ atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad);
+ }
+ }
+ }
+ }
+}
+
+template
+__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n,
+ const scalar_t *data_col, const scalar_t *data_im,
+ const scalar_t *data_offset, const scalar_t *data_mask,
+ const int channels, const int height, const int width,
+ const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w,
+ const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int channel_per_deformable_group,
+ const int batch_size, const int offset_channels, const int deformable_group,
+ const int height_col, const int width_col,
+ scalar_t *grad_offset, scalar_t *grad_mask)
+{
+ CUDA_KERNEL_LOOP(index, n)
+ {
+ scalar_t val = 0, mval = 0;
+ int w = index % width_col;
+ int h = (index / width_col) % height_col;
+ int c = (index / width_col / height_col) % offset_channels;
+ int b = (index / width_col / height_col) / offset_channels;
+ // compute the start and end of the output
+
+ const int deformable_group_index = c / (2 * kernel_h * kernel_w);
+ const int col_step = kernel_h * kernel_w;
+ int cnt = 0;
+ const scalar_t *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col;
+ const scalar_t *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width;
+ const scalar_t *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col;
+ const scalar_t *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col;
+
+ const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w;
+
+ for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step)
+ {
+ const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w;
+ const int bp_dir = offset_c % 2;
+
+ int j = (col_pos / width_col / height_col / batch_size) % kernel_w;
+ int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h;
+ int w_out = col_pos % width_col;
+ int h_out = (col_pos / width_col) % height_col;
+ int w_in = w_out * stride_w - pad_w;
+ int h_in = h_out * stride_h - pad_h;
+ const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out);
+ const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out);
+ const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out);
+ const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr];
+ const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr];
+ const scalar_t mask = data_mask_ptr[data_mask_hw_ptr];
+ scalar_t inv_h = h_in + i * dilation_h + offset_h;
+ scalar_t inv_w = w_in + j * dilation_w + offset_w;
+ if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width)
+ {
+ inv_h = inv_w = -2;
+ }
+ else
+ {
+ mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w);
+ }
+ const scalar_t weight = dmcn_get_coordinate_weight(
+ inv_h, inv_w,
+ height, width, data_im_ptr + cnt * height * width, width, bp_dir);
+ val += weight * data_col_ptr[col_pos] * mask;
+ cnt += 1;
+ }
+ // KERNEL_ASSIGN(grad_offset[index], offset_req, val);
+ grad_offset[index] = val;
+ if (offset_c % 2 == 0)
+ // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval);
+ grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval;
+ }
+}
+
+void modulated_deformable_im2col_cuda(
+ const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kenerl_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor data_col)
+{
+ // num_axes should be smaller than block size
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] {
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ const scalar_t *data_mask_ = data_mask.data();
+ scalar_t *data_col_ = data_col.data();
+
+ modulated_deformable_im2col_gpu_kernel<<>>(
+ num_kernels, data_im_, data_offset_, data_mask_, height_im, width_im, kernel_h, kenerl_w,
+ pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, channels, deformable_group, height_col, width_col, data_col_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_cuda(
+ const at::Tensor data_col, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group, at::Tensor grad_im)
+{
+
+ const int channel_per_deformable_group = channels / deformable_group;
+ const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ const scalar_t *data_mask_ = data_mask.data();
+ scalar_t *grad_im_ = grad_im.data();
+
+ modulated_deformable_col2im_gpu_kernel<<>>(
+ num_kernels, data_col_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, deformable_group, height_col, width_col, grad_im_);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void modulated_deformable_col2im_coord_cuda(
+ const at::Tensor data_col, const at::Tensor data_im, const at::Tensor data_offset, const at::Tensor data_mask,
+ const int batch_size, const int channels, const int height_im, const int width_im,
+ const int height_col, const int width_col, const int kernel_h, const int kernel_w,
+ const int pad_h, const int pad_w, const int stride_h, const int stride_w,
+ const int dilation_h, const int dilation_w,
+ const int deformable_group,
+ at::Tensor grad_offset, at::Tensor grad_mask)
+{
+ const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group;
+ const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] {
+ const scalar_t *data_col_ = data_col.data();
+ const scalar_t *data_im_ = data_im.data();
+ const scalar_t *data_offset_ = data_offset.data();
+ const scalar_t *data_mask_ = data_mask.data();
+ scalar_t *grad_offset_ = grad_offset.data();
+ scalar_t *grad_mask_ = grad_mask.data();
+
+ modulated_deformable_col2im_coord_gpu_kernel<<>>(
+ num_kernels, data_col_, data_im_, data_offset_, data_mask_, channels, height_im, width_im,
+ kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w,
+ dilation_h, dilation_w, channel_per_deformable_group,
+ batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col,
+ grad_offset_, grad_mask_);
+ }));
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/det3d/ops/dcn/src/deform_pool_cuda.cpp b/det3d/ops/dcn/src/deform_pool_cuda.cpp
new file mode 100644
index 0000000..9e0e3ff
--- /dev/null
+++ b/det3d/ops/dcn/src/deform_pool_cuda.cpp
@@ -0,0 +1,90 @@
+// modify from
+// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/modulated_dcn_cuda.c
+
+// based on
+// author: Charles Shang
+// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu
+
+#include
+#include
+
+#include
+#include
+
+void DeformablePSROIPoolForward(
+ const at::Tensor data, const at::Tensor bbox, const at::Tensor trans,
+ at::Tensor out, at::Tensor top_count, const int batch, const int channels,
+ const int height, const int width, const int num_bbox,
+ const int channels_trans, const int no_trans, const float spatial_scale,
+ const int output_dim, const int group_size, const int pooled_size,
+ const int part_size, const int sample_per_part, const float trans_std);
+
+void DeformablePSROIPoolBackwardAcc(
+ const at::Tensor out_grad, const at::Tensor data, const at::Tensor bbox,
+ const at::Tensor trans, const at::Tensor top_count, at::Tensor in_grad,
+ at::Tensor trans_grad, const int batch, const int channels,
+ const int height, const int width, const int num_bbox,
+ const int channels_trans, const int no_trans, const float spatial_scale,
+ const int output_dim, const int group_size, const int pooled_size,
+ const int part_size, const int sample_per_part, const float trans_std);
+
+void deform_psroi_pooling_cuda_forward(
+ at::Tensor input, at::Tensor bbox, at::Tensor trans, at::Tensor out,
+ at::Tensor top_count, const int no_trans, const float spatial_scale,
+ const int output_dim, const int group_size, const int pooled_size,
+ const int part_size, const int sample_per_part, const float trans_std) {
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+ const int channels_trans = no_trans ? 2 : trans.size(1);
+
+ const int num_bbox = bbox.size(0);
+ if (num_bbox != out.size(0))
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+ out.size(0), num_bbox);
+
+ DeformablePSROIPoolForward(
+ input, bbox, trans, out, top_count, batch, channels, height, width,
+ num_bbox, channels_trans, no_trans, spatial_scale, output_dim, group_size,
+ pooled_size, part_size, sample_per_part, trans_std);
+}
+
+void deform_psroi_pooling_cuda_backward(
+ at::Tensor out_grad, at::Tensor input, at::Tensor bbox, at::Tensor trans,
+ at::Tensor top_count, at::Tensor input_grad, at::Tensor trans_grad,
+ const int no_trans, const float spatial_scale, const int output_dim,
+ const int group_size, const int pooled_size, const int part_size,
+ const int sample_per_part, const float trans_std) {
+ AT_CHECK(out_grad.is_contiguous(), "out_grad tensor has to be contiguous");
+ AT_CHECK(input.is_contiguous(), "input tensor has to be contiguous");
+ at::DeviceGuard guard(input.device());
+
+ const int batch = input.size(0);
+ const int channels = input.size(1);
+ const int height = input.size(2);
+ const int width = input.size(3);
+ const int channels_trans = no_trans ? 2 : trans.size(1);
+
+ const int num_bbox = bbox.size(0);
+ if (num_bbox != out_grad.size(0))
+ AT_ERROR("Output shape and bbox number wont match: (%d vs %d).",
+ out_grad.size(0), num_bbox);
+
+ DeformablePSROIPoolBackwardAcc(
+ out_grad, input, bbox, trans, top_count, input_grad, trans_grad, batch,
+ channels, height, width, num_bbox, channels_trans, no_trans,
+ spatial_scale, output_dim, group_size, pooled_size, part_size,
+ sample_per_part, trans_std);
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("deform_psroi_pooling_cuda_forward", &deform_psroi_pooling_cuda_forward,
+ "deform psroi pooling forward(CUDA)");
+ m.def("deform_psroi_pooling_cuda_backward",
+ &deform_psroi_pooling_cuda_backward,
+ "deform psroi pooling backward(CUDA)");
+}
diff --git a/det3d/ops/dcn/src/deform_pool_cuda_kernel.cu b/det3d/ops/dcn/src/deform_pool_cuda_kernel.cu
new file mode 100644
index 0000000..05b00d4
--- /dev/null
+++ b/det3d/ops/dcn/src/deform_pool_cuda_kernel.cu
@@ -0,0 +1,364 @@
+/*!
+ * Copyright (c) 2017 Microsoft
+ * Licensed under The MIT License [see LICENSE for details]
+ * \file deformable_psroi_pooling.cu
+ * \brief
+ * \author Yi Li, Guodong Zhang, Jifeng Dai
+*/
+/***************** Adapted by Charles Shang *********************/
+// modify from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/cuda/deform_psroi_pooling_cuda.cu
+
+#include
+#include
+#include
+#include
+#include
+
+using namespace at;
+
+#define CUDA_KERNEL_LOOP(i, n) \
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
+ i < (n); \
+ i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N)
+{
+ return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
+}
+
+template
+__device__ scalar_t bilinear_interp(
+ const scalar_t *data,
+ const scalar_t x,
+ const scalar_t y,
+ const int width,
+ const int height)
+{
+ int x1 = floor(x);
+ int x2 = ceil(x);
+ int y1 = floor(y);
+ int y2 = ceil(y);
+ scalar_t dist_x = (scalar_t)(x - x1);
+ scalar_t dist_y = (scalar_t)(y - y1);
+ scalar_t value11 = data[y1 * width + x1];
+ scalar_t value12 = data[y2 * width + x1];
+ scalar_t value21 = data[y1 * width + x2];
+ scalar_t value22 = data[y2 * width + x2];
+ scalar_t value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22;
+ return value;
+}
+
+template
+__global__ void DeformablePSROIPoolForwardKernel(
+ const int count,
+ const scalar_t *bottom_data,
+ const scalar_t spatial_scale,
+ const int channels,
+ const int height, const int width,
+ const int pooled_height, const int pooled_width,
+ const scalar_t *bottom_rois, const scalar_t *bottom_trans,
+ const int no_trans,
+ const scalar_t trans_std,
+ const int sample_per_part,
+ const int output_dim,
+ const int group_size,
+ const int part_size,
+ const int num_classes,
+ const int channels_each_class,
+ scalar_t *top_data,
+ scalar_t *top_count)
+{
+ CUDA_KERNEL_LOOP(index, count)
+ {
+ // The output is in order (n, ctop, ph, pw)
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
+ int n = index / pooled_width / pooled_height / output_dim;
+
+ // [start, end) interval for spatial sampling
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+ scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+ scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+ scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+ scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+ // Force too small ROIs to be 1x1
+ scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+ scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+ // Compute w and h at bottom
+ scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
+ scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
+
+ scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
+ scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
+
+ int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
+ int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
+ int class_id = ctop / channels_each_class;
+ scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+ scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+
+ scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
+ wstart += trans_x * roi_width;
+ scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
+ hstart += trans_y * roi_height;
+
+ scalar_t sum = 0;
+ int count = 0;
+ int gw = floor((scalar_t)(pw)*group_size / pooled_width);
+ int gh = floor((scalar_t)(ph)*group_size / pooled_height);
+ gw = min(max(gw, 0), group_size - 1);
+ gh = min(max(gh, 0), group_size - 1);
+
+ const scalar_t *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width;
+ for (int ih = 0; ih < sample_per_part; ih++)
+ {
+ for (int iw = 0; iw < sample_per_part; iw++)
+ {
+ scalar_t w = wstart + iw * sub_bin_size_w;
+ scalar_t h = hstart + ih * sub_bin_size_h;
+ // bilinear interpolation
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+ {
+ continue;
+ }
+ w = min(max(w, 0.), width - 1.);
+ h = min(max(h, 0.), height - 1.);
+ int c = (ctop * group_size + gh) * group_size + gw;
+ scalar_t val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height);
+ sum += val;
+ count++;
+ }
+ }
+ top_data[index] = count == 0 ? (scalar_t)(0) : sum / count;
+ top_count[index] = count;
+ }
+}
+
+template
+__global__ void DeformablePSROIPoolBackwardAccKernel(
+ const int count,
+ const scalar_t *top_diff,
+ const scalar_t *top_count,
+ const int num_rois,
+ const scalar_t spatial_scale,
+ const int channels,
+ const int height, const int width,
+ const int pooled_height, const int pooled_width,
+ const int output_dim,
+ scalar_t *bottom_data_diff, scalar_t *bottom_trans_diff,
+ const scalar_t *bottom_data,
+ const scalar_t *bottom_rois,
+ const scalar_t *bottom_trans,
+ const int no_trans,
+ const scalar_t trans_std,
+ const int sample_per_part,
+ const int group_size,
+ const int part_size,
+ const int num_classes,
+ const int channels_each_class)
+{
+ CUDA_KERNEL_LOOP(index, count)
+ {
+ // The output is in order (n, ctop, ph, pw)
+ int pw = index % pooled_width;
+ int ph = (index / pooled_width) % pooled_height;
+ int ctop = (index / pooled_width / pooled_height) % output_dim;
+ int n = index / pooled_width / pooled_height / output_dim;
+
+ // [start, end) interval for spatial sampling
+ const scalar_t *offset_bottom_rois = bottom_rois + n * 5;
+ int roi_batch_ind = offset_bottom_rois[0];
+ scalar_t roi_start_w = (scalar_t)(round(offset_bottom_rois[1])) * spatial_scale - 0.5;
+ scalar_t roi_start_h = (scalar_t)(round(offset_bottom_rois[2])) * spatial_scale - 0.5;
+ scalar_t roi_end_w = (scalar_t)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5;
+ scalar_t roi_end_h = (scalar_t)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5;
+
+ // Force too small ROIs to be 1x1
+ scalar_t roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0
+ scalar_t roi_height = max(roi_end_h - roi_start_h, 0.1);
+
+ // Compute w and h at bottom
+ scalar_t bin_size_h = roi_height / (scalar_t)(pooled_height);
+ scalar_t bin_size_w = roi_width / (scalar_t)(pooled_width);
+
+ scalar_t sub_bin_size_h = bin_size_h / (scalar_t)(sample_per_part);
+ scalar_t sub_bin_size_w = bin_size_w / (scalar_t)(sample_per_part);
+
+ int part_h = floor((scalar_t)(ph) / pooled_height * part_size);
+ int part_w = floor((scalar_t)(pw) / pooled_width * part_size);
+ int class_id = ctop / channels_each_class;
+ scalar_t trans_x = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+ scalar_t trans_y = no_trans ? (scalar_t)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * (scalar_t)trans_std;
+
+ scalar_t wstart = (scalar_t)(pw)*bin_size_w + roi_start_w;
+ wstart += trans_x * roi_width;
+ scalar_t hstart = (scalar_t)(ph)*bin_size_h + roi_start_h;
+ hstart += trans_y * roi_height;
+
+ if (top_count[index] <= 0)
+ {
+ continue;
+ }
+ scalar_t diff_val = top_diff[index] / top_count[index];
+ const scalar_t *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width;
+ scalar_t *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width;
+ int gw = floor((scalar_t)(pw)*group_size / pooled_width);
+ int gh = floor((scalar_t)(ph)*group_size / pooled_height);
+ gw = min(max(gw, 0), group_size - 1);
+ gh = min(max(gh, 0), group_size - 1);
+
+ for (int ih = 0; ih < sample_per_part; ih++)
+ {
+ for (int iw = 0; iw < sample_per_part; iw++)
+ {
+ scalar_t w = wstart + iw * sub_bin_size_w;
+ scalar_t h = hstart + ih * sub_bin_size_h;
+ // bilinear interpolation
+ if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5)
+ {
+ continue;
+ }
+ w = min(max(w, 0.), width - 1.);
+ h = min(max(h, 0.), height - 1.);
+ int c = (ctop * group_size + gh) * group_size + gw;
+ // backward on feature
+ int x0 = floor(w);
+ int x1 = ceil(w);
+ int y0 = floor(h);
+ int y1 = ceil(h);
+ scalar_t dist_x = w - x0, dist_y = h - y0;
+ scalar_t q00 = (1 - dist_x) * (1 - dist_y);
+ scalar_t q01 = (1 - dist_x) * dist_y;
+ scalar_t q10 = dist_x * (1 - dist_y);
+ scalar_t q11 = dist_x * dist_y;
+ int bottom_index_base = c * height * width;
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val);
+ atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val);
+
+ if (no_trans)
+ {
+ continue;
+ }
+ scalar_t U00 = offset_bottom_data[bottom_index_base + y0 * width + x0];
+ scalar_t U01 = offset_bottom_data[bottom_index_base + y1 * width + x0];
+ scalar_t U10 = offset_bottom_data[bottom_index_base + y0 * width + x1];
+ scalar_t U11 = offset_bottom_data[bottom_index_base + y1 * width + x1];
+ scalar_t diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val;
+ diff_x *= roi_width;
+ scalar_t diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val;
+ diff_y *= roi_height;
+
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x);
+ atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y);
+ }
+ }
+ }
+}
+
+void DeformablePSROIPoolForward(const at::Tensor data,
+ const at::Tensor bbox,
+ const at::Tensor trans,
+ at::Tensor out,
+ at::Tensor top_count,
+ const int batch,
+ const int channels,
+ const int height,
+ const int width,
+ const int num_bbox,
+ const int channels_trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ const int pooled_height = pooled_size;
+ const int pooled_width = pooled_size;
+ const int count = num_bbox * output_dim * pooled_height * pooled_width;
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ data.scalar_type(), "deformable_psroi_pool_forward", ([&] {
+ const scalar_t *bottom_data = data.data();
+ const scalar_t *bottom_rois = bbox.data();
+ const scalar_t *bottom_trans = no_trans ? NULL : trans.data();
+ scalar_t *top_data = out.data();
+ scalar_t *top_count_data = top_count.data();
+
+ DeformablePSROIPoolForwardKernel<<>>(
+ count, bottom_data, (scalar_t)spatial_scale, channels, height, width, pooled_height, pooled_width,
+ bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part, output_dim,
+ group_size, part_size, num_classes, channels_each_class, top_data, top_count_data);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
+ }
+}
+
+void DeformablePSROIPoolBackwardAcc(const at::Tensor out_grad,
+ const at::Tensor data,
+ const at::Tensor bbox,
+ const at::Tensor trans,
+ const at::Tensor top_count,
+ at::Tensor in_grad,
+ at::Tensor trans_grad,
+ const int batch,
+ const int channels,
+ const int height,
+ const int width,
+ const int num_bbox,
+ const int channels_trans,
+ const int no_trans,
+ const float spatial_scale,
+ const int output_dim,
+ const int group_size,
+ const int pooled_size,
+ const int part_size,
+ const int sample_per_part,
+ const float trans_std)
+{
+ // LOG(INFO) << "DeformablePSROIPoolBackward";
+ const int num_rois = num_bbox;
+ const int pooled_height = pooled_size;
+ const int pooled_width = pooled_size;
+ const int count = num_bbox * output_dim * pooled_height * pooled_width;
+ const int num_classes = no_trans ? 1 : channels_trans / 2;
+ const int channels_each_class = no_trans ? output_dim : output_dim / num_classes;
+
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(
+ out_grad.scalar_type(), "deformable_psroi_pool_backward_acc", ([&] {
+ const scalar_t *top_diff = out_grad.data();
+ const scalar_t *bottom_data = data.data();
+ const scalar_t *bottom_rois = bbox.data();
+ const scalar_t *bottom_trans = no_trans ? NULL : trans.data();
+ scalar_t *bottom_data_diff = in_grad.data();
+ scalar_t *bottom_trans_diff = no_trans ? NULL : trans_grad.data();
+ const scalar_t *top_count_data = top_count.data();
+
+ DeformablePSROIPoolBackwardAccKernel<<>>(
+ count, top_diff, top_count_data, num_rois, (scalar_t)spatial_scale, channels, height, width,
+ pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff,
+ bottom_data, bottom_rois, bottom_trans, no_trans, (scalar_t)trans_std, sample_per_part,
+ group_size, part_size, num_classes, channels_each_class);
+ }));
+
+ cudaError_t err = cudaGetLastError();
+ if (err != cudaSuccess)
+ {
+ printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err));
+ }
+}
diff --git a/det3d/ops/iou3d_nms/__init__.py b/det3d/ops/iou3d_nms/__init__.py
new file mode 100644
index 0000000..c267f07
--- /dev/null
+++ b/det3d/ops/iou3d_nms/__init__.py
@@ -0,0 +1 @@
+from det3d.ops.iou3d_nms import iou3d_nms_cuda, iou3d_nms_utils
diff --git a/det3d/ops/iou3d_nms/iou3d_nms_utils.py b/det3d/ops/iou3d_nms/iou3d_nms_utils.py
new file mode 100644
index 0000000..4d71e33
--- /dev/null
+++ b/det3d/ops/iou3d_nms/iou3d_nms_utils.py
@@ -0,0 +1,107 @@
+"""
+3D IoU Calculation and Rotated NMS
+Written by Shaoshuai Shi
+All Rights Reserved 2019-2020.
+"""
+import torch
+
+from . import iou3d_nms_cuda
+import numpy as np
+
+
+
+def boxes_iou_bev(boxes_a, boxes_b):
+ """
+ Args:
+ boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
+ boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
+
+ Returns:
+ ans_iou: (N, M)
+ """
+ assert boxes_a.shape[1] == boxes_b.shape[1] == 7
+ ans_iou = torch.cuda.FloatTensor(torch.Size((boxes_a.shape[0], boxes_b.shape[0]))).zero_()
+
+ iou3d_nms_cuda.boxes_iou_bev_gpu(boxes_a.contiguous(), boxes_b.contiguous(), ans_iou)
+
+ return ans_iou
+
+def to_pcdet(boxes):
+ # transform back to pcdet's coordinate
+ boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]]
+ boxes[:, -1] = -boxes[:, -1] - np.pi/2
+ return boxes
+
+def boxes_iou3d_gpu(boxes_a, boxes_b):
+ """
+ Args:
+ boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
+ boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading]
+
+ Returns:
+ ans_iou: (N, M)
+ """
+ assert boxes_a.shape[1] == boxes_b.shape[1] == 7
+
+ # transform back to pcdet's coordinate
+ boxes_a = to_pcdet(boxes_a)
+ boxes_b = to_pcdet(boxes_b)
+
+ # height overlap
+ boxes_a_height_max = (boxes_a[:, 2] + boxes_a[:, 5] / 2).view(-1, 1)
+ boxes_a_height_min = (boxes_a[:, 2] - boxes_a[:, 5] / 2).view(-1, 1)
+ boxes_b_height_max = (boxes_b[:, 2] + boxes_b[:, 5] / 2).view(1, -1)
+ boxes_b_height_min = (boxes_b[:, 2] - boxes_b[:, 5] / 2).view(1, -1)
+
+ # bev overlap
+ overlaps_bev = torch.cuda.FloatTensor(torch.Size((boxes_a.shape[0], boxes_b.shape[0]))).zero_() # (N, M)
+ iou3d_nms_cuda.boxes_overlap_bev_gpu(boxes_a.contiguous(), boxes_b.contiguous(), overlaps_bev)
+
+ max_of_min = torch.max(boxes_a_height_min, boxes_b_height_min)
+ min_of_max = torch.min(boxes_a_height_max, boxes_b_height_max)
+ overlaps_h = torch.clamp(min_of_max - max_of_min, min=0)
+
+ # 3d iou
+ overlaps_3d = overlaps_bev * overlaps_h
+
+ vol_a = (boxes_a[:, 3] * boxes_a[:, 4] * boxes_a[:, 5]).view(-1, 1)
+ vol_b = (boxes_b[:, 3] * boxes_b[:, 4] * boxes_b[:, 5]).view(1, -1)
+
+ iou3d = overlaps_3d / torch.clamp(vol_a + vol_b - overlaps_3d, min=1e-6)
+
+ return iou3d
+
+
+def nms_gpu(boxes, scores, thresh, pre_maxsize=None, **kwargs):
+ """
+ :param boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
+ :param scores: (N)
+ :param thresh:
+ :return:
+ """
+ assert boxes.shape[1] == 7
+ order = scores.sort(0, descending=True)[1]
+ if pre_maxsize is not None:
+ order = order[:pre_maxsize]
+
+ boxes = boxes[order].contiguous()
+ keep = torch.LongTensor(boxes.size(0))
+ num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh)
+ return order[keep[:num_out].cuda()].contiguous(), None
+
+
+def nms_normal_gpu(boxes, scores, thresh, **kwargs):
+ """
+ :param boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
+ :param scores: (N)
+ :param thresh:
+ :return:
+ """
+ assert boxes.shape[1] == 7
+ order = scores.sort(0, descending=True)[1]
+
+ boxes = boxes[order].contiguous()
+
+ keep = torch.LongTensor(boxes.size(0))
+ num_out = iou3d_nms_cuda.nms_normal_gpu(boxes, keep, thresh)
+ return order[keep[:num_out].cuda()].contiguous(), None
\ No newline at end of file
diff --git a/det3d/ops/iou3d_nms/setup.py b/det3d/ops/iou3d_nms/setup.py
new file mode 100644
index 0000000..74b89a8
--- /dev/null
+++ b/det3d/ops/iou3d_nms/setup.py
@@ -0,0 +1,16 @@
+from setuptools import setup
+from torch.utils.cpp_extension import BuildExtension, CUDAExtension
+
+setup(
+ name='iou3d_nms',
+ ext_modules=[
+ CUDAExtension('iou3d_nms_cuda', [
+ 'src/iou3d_cpu.cpp',
+ 'src/iou3d_nms_api.cpp',
+ 'src/iou3d_nms.cpp',
+ 'src/iou3d_nms_kernel.cu',
+ ],
+ extra_compile_args={'cxx': ['-g', '-I /usr/local/cuda/include'],
+ 'nvcc': ['-O2']})
+ ],
+ cmdclass={'build_ext': BuildExtension})
diff --git a/det3d/ops/iou3d_nms/src/iou3d_cpu.cpp b/det3d/ops/iou3d_nms/src/iou3d_cpu.cpp
new file mode 100644
index 0000000..d528ad9
--- /dev/null
+++ b/det3d/ops/iou3d_nms/src/iou3d_cpu.cpp
@@ -0,0 +1,252 @@
+/*
+3D Rotated IoU Calculation (CPU)
+Written by Shaoshuai Shi
+All Rights Reserved 2020.
+*/
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include "iou3d_cpu.h"
+
+#define CHECK_CUDA(x) do { \
+ if (!x.type().is_cuda()) { \
+ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
+ exit(-1); \
+ } \
+} while (0)
+#define CHECK_CONTIGUOUS(x) do { \
+ if (!x.is_contiguous()) { \
+ fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
+ exit(-1); \
+ } \
+} while (0)
+#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
+
+inline float min(float a, float b){
+ return a > b ? b : a;
+}
+
+inline float max(float a, float b){
+ return a > b ? a : b;
+}
+
+const float EPS = 1e-8;
+struct Point {
+ float x, y;
+ __device__ Point() {}
+ __device__ Point(double _x, double _y){
+ x = _x, y = _y;
+ }
+
+ __device__ void set(float _x, float _y){
+ x = _x; y = _y;
+ }
+
+ __device__ Point operator +(const Point &b)const{
+ return Point(x + b.x, y + b.y);
+ }
+
+ __device__ Point operator -(const Point &b)const{
+ return Point(x - b.x, y - b.y);
+ }
+};
+
+inline float cross(const Point &a, const Point &b){
+ return a.x * b.y - a.y * b.x;
+}
+
+inline float cross(const Point &p1, const Point &p2, const Point &p0){
+ return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
+}
+
+inline int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, const Point &q2){
+ int ret = min(p1.x,p2.x) <= max(q1.x,q2.x) &&
+ min(q1.x,q2.x) <= max(p1.x,p2.x) &&
+ min(p1.y,p2.y) <= max(q1.y,q2.y) &&
+ min(q1.y,q2.y) <= max(p1.y,p2.y);
+ return ret;
+}
+
+inline int check_in_box2d(const float *box, const Point &p){
+ //params: (7) [x, y, z, dx, dy, dz, heading]
+ const float MARGIN = 1e-2;
+
+ float center_x = box[0], center_y = box[1];
+ float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]); // rotate the point in the opposite direction of box
+ float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
+ float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
+
+ return (fabs(rot_x) < box[3] / 2 + MARGIN && fabs(rot_y) < box[4] / 2 + MARGIN);
+}
+
+inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, Point &ans){
+ // fast exclusion
+ if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
+
+ // check cross standing
+ float s1 = cross(q0, p1, p0);
+ float s2 = cross(p1, q1, p0);
+ float s3 = cross(p0, q1, q0);
+ float s4 = cross(q1, p1, q0);
+
+ if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
+
+ // calculate intersection of two lines
+ float s5 = cross(q1, p1, p0);
+ if(fabs(s5 - s1) > EPS){
+ ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
+ ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
+
+ }
+ else{
+ float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
+ float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
+ float D = a0 * b1 - a1 * b0;
+
+ ans.x = (b0 * c1 - b1 * c0) / D;
+ ans.y = (a1 * c0 - a0 * c1) / D;
+ }
+
+ return 1;
+}
+
+inline void rotate_around_center(const Point ¢er, const float angle_cos, const float angle_sin, Point &p){
+ float new_x = (p.x - center.x) * angle_cos + (p.y - center.y) * (-angle_sin) + center.x;
+ float new_y = (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
+ p.set(new_x, new_y);
+}
+
+inline int point_cmp(const Point &a, const Point &b, const Point ¢er){
+ return atan2(a.y - center.y, a.x - center.x) > atan2(b.y - center.y, b.x - center.x);
+}
+
+inline float box_overlap(const float *box_a, const float *box_b){
+ // params: box_a (7) [x, y, z, dx, dy, dz, heading]
+ // params: box_b (7) [x, y, z, dx, dy, dz, heading]
+
+// float a_x1 = box_a[0], a_y1 = box_a[1], a_x2 = box_a[2], a_y2 = box_a[3], a_angle = box_a[4];
+// float b_x1 = box_b[0], b_y1 = box_b[1], b_x2 = box_b[2], b_y2 = box_b[3], b_angle = box_b[4];
+ float a_angle = box_a[6], b_angle = box_b[6];
+ float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2, a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
+ float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
+ float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
+ float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
+ float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
+
+ Point center_a(box_a[0], box_a[1]);
+ Point center_b(box_b[0], box_b[1]);
+
+ Point box_a_corners[5];
+ box_a_corners[0].set(a_x1, a_y1);
+ box_a_corners[1].set(a_x2, a_y1);
+ box_a_corners[2].set(a_x2, a_y2);
+ box_a_corners[3].set(a_x1, a_y2);
+
+ Point box_b_corners[5];
+ box_b_corners[0].set(b_x1, b_y1);
+ box_b_corners[1].set(b_x2, b_y1);
+ box_b_corners[2].set(b_x2, b_y2);
+ box_b_corners[3].set(b_x1, b_y2);
+
+ // get oriented corners
+ float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
+ float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
+
+ for (int k = 0; k < 4; k++){
+ rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
+ rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
+ }
+
+ box_a_corners[4] = box_a_corners[0];
+ box_b_corners[4] = box_b_corners[0];
+
+ // get intersection of lines
+ Point cross_points[16];
+ Point poly_center;
+ int cnt = 0, flag = 0;
+
+ poly_center.set(0, 0);
+ for (int i = 0; i < 4; i++){
+ for (int j = 0; j < 4; j++){
+ flag = intersection(box_a_corners[i + 1], box_a_corners[i], box_b_corners[j + 1], box_b_corners[j], cross_points[cnt]);
+ if (flag){
+ poly_center = poly_center + cross_points[cnt];
+ cnt++;
+ }
+ }
+ }
+
+ // check corners
+ for (int k = 0; k < 4; k++){
+ if (check_in_box2d(box_a, box_b_corners[k])){
+ poly_center = poly_center + box_b_corners[k];
+ cross_points[cnt] = box_b_corners[k];
+ cnt++;
+ }
+ if (check_in_box2d(box_b, box_a_corners[k])){
+ poly_center = poly_center + box_a_corners[k];
+ cross_points[cnt] = box_a_corners[k];
+ cnt++;
+ }
+ }
+
+ poly_center.x /= cnt;
+ poly_center.y /= cnt;
+
+ // sort the points of polygon
+ Point temp;
+ for (int j = 0; j < cnt - 1; j++){
+ for (int i = 0; i < cnt - j - 1; i++){
+ if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)){
+ temp = cross_points[i];
+ cross_points[i] = cross_points[i + 1];
+ cross_points[i + 1] = temp;
+ }
+ }
+ }
+
+ // get the overlap areas
+ float area = 0;
+ for (int k = 0; k < cnt - 1; k++){
+ area += cross(cross_points[k] - cross_points[0], cross_points[k + 1] - cross_points[0]);
+ }
+
+ return fabs(area) / 2.0;
+}
+
+inline float iou_bev(const float *box_a, const float *box_b){
+ // params: box_a (7) [x, y, z, dx, dy, dz, heading]
+ // params: box_b (7) [x, y, z, dx, dy, dz, heading]
+ float sa = box_a[3] * box_a[4];
+ float sb = box_b[3] * box_b[4];
+ float s_overlap = box_overlap(box_a, box_b);
+ return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
+}
+
+
+int boxes_iou_bev_cpu(at::Tensor boxes_a_tensor, at::Tensor boxes_b_tensor, at::Tensor ans_iou_tensor){
+ // params boxes_a_tensor: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params boxes_b_tensor: (M, 7) [x, y, z, dx, dy, dz, heading]
+ // params ans_iou_tensor: (N, M)
+
+ CHECK_CONTIGUOUS(boxes_a_tensor);
+ CHECK_CONTIGUOUS(boxes_b_tensor);
+
+ int num_boxes_a = boxes_a_tensor.size(0);
+ int num_boxes_b = boxes_b_tensor.size(0);
+ const float *boxes_a = boxes_a_tensor.data();
+ const float *boxes_b = boxes_b_tensor.data();
+ float *ans_iou = ans_iou_tensor.data();
+
+ for (int i = 0; i < num_boxes_a; i++){
+ for (int j = 0; j < num_boxes_b; j++){
+ ans_iou[i * num_boxes_b + j] = iou_bev(boxes_a + i * 7, boxes_b + j * 7);
+ }
+ }
+ return 1;
+}
diff --git a/det3d/ops/iou3d_nms/src/iou3d_cpu.h b/det3d/ops/iou3d_nms/src/iou3d_cpu.h
new file mode 100644
index 0000000..8835ee7
--- /dev/null
+++ b/det3d/ops/iou3d_nms/src/iou3d_cpu.h
@@ -0,0 +1,11 @@
+#ifndef IOU3D_CPU_H
+#define IOU3D_CPU_H
+
+#include
+#include
+#include
+#include
+
+int boxes_iou_bev_cpu(at::Tensor boxes_a_tensor, at::Tensor boxes_b_tensor, at::Tensor ans_iou_tensor);
+
+#endif
diff --git a/det3d/ops/iou3d_nms/src/iou3d_nms.cpp b/det3d/ops/iou3d_nms/src/iou3d_nms.cpp
new file mode 100644
index 0000000..d41da8a
--- /dev/null
+++ b/det3d/ops/iou3d_nms/src/iou3d_nms.cpp
@@ -0,0 +1,188 @@
+/*
+3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
+Written by Shaoshuai Shi
+All Rights Reserved 2019-2020.
+*/
+
+#include
+#include
+#include
+#include
+#include
+#include "iou3d_nms.h"
+
+#define CHECK_CUDA(x) do { \
+ if (!x.type().is_cuda()) { \
+ fprintf(stderr, "%s must be CUDA tensor at %s:%d\n", #x, __FILE__, __LINE__); \
+ exit(-1); \
+ } \
+} while (0)
+#define CHECK_CONTIGUOUS(x) do { \
+ if (!x.is_contiguous()) { \
+ fprintf(stderr, "%s must be contiguous tensor at %s:%d\n", #x, __FILE__, __LINE__); \
+ exit(-1); \
+ } \
+} while (0)
+#define CHECK_INPUT(x) CHECK_CUDA(x);CHECK_CONTIGUOUS(x)
+
+#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0))
+
+#define CHECK_ERROR(ans) { gpuAssert((ans), __FILE__, __LINE__); }
+inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true)
+{
+ if (code != cudaSuccess)
+ {
+ fprintf(stderr,"GPUassert: %s %s %d\n", cudaGetErrorString(code), file, line);
+ if (abort) exit(code);
+ }
+}
+
+const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
+
+
+void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap);
+void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou);
+void nmsLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh);
+void nmsNormalLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh);
+
+
+int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap){
+ // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
+ // params ans_overlap: (N, M)
+
+ CHECK_INPUT(boxes_a);
+ CHECK_INPUT(boxes_b);
+ CHECK_INPUT(ans_overlap);
+
+ int num_a = boxes_a.size(0);
+ int num_b = boxes_b.size(0);
+
+ const float * boxes_a_data = boxes_a.data();
+ const float * boxes_b_data = boxes_b.data();
+ float * ans_overlap_data = ans_overlap.data();
+
+ boxesoverlapLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_overlap_data);
+
+ return 1;
+}
+
+int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou){
+ // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
+ // params ans_overlap: (N, M)
+ CHECK_INPUT(boxes_a);
+ CHECK_INPUT(boxes_b);
+ CHECK_INPUT(ans_iou);
+
+ int num_a = boxes_a.size(0);
+ int num_b = boxes_b.size(0);
+
+ const float * boxes_a_data = boxes_a.data();
+ const float * boxes_b_data = boxes_b.data();
+ float * ans_iou_data = ans_iou.data();
+
+ boxesioubevLauncher(num_a, boxes_a_data, num_b, boxes_b_data, ans_iou_data);
+
+ return 1;
+}
+
+int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){
+ // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params keep: (N)
+ CHECK_INPUT(boxes);
+ CHECK_CONTIGUOUS(keep);
+
+ int boxes_num = boxes.size(0);
+ const float * boxes_data = boxes.data();
+ long * keep_data = keep.data();
+
+ const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
+
+ unsigned long long *mask_data = NULL;
+ CHECK_ERROR(cudaMalloc((void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long)));
+ nmsLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
+
+ // unsigned long long mask_cpu[boxes_num * col_blocks];
+ // unsigned long long *mask_cpu = new unsigned long long [boxes_num * col_blocks];
+ std::vector mask_cpu(boxes_num * col_blocks);
+
+// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
+ CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long),
+ cudaMemcpyDeviceToHost));
+
+ cudaFree(mask_data);
+
+ unsigned long long remv_cpu[col_blocks];
+ memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
+
+ int num_to_keep = 0;
+
+ for (int i = 0; i < boxes_num; i++){
+ int nblock = i / THREADS_PER_BLOCK_NMS;
+ int inblock = i % THREADS_PER_BLOCK_NMS;
+
+ if (!(remv_cpu[nblock] & (1ULL << inblock))){
+ keep_data[num_to_keep++] = i;
+ unsigned long long *p = &mask_cpu[0] + i * col_blocks;
+ for (int j = nblock; j < col_blocks; j++){
+ remv_cpu[j] |= p[j];
+ }
+ }
+ }
+ if ( cudaSuccess != cudaGetLastError() ) printf( "Error!\n" );
+
+ return num_to_keep;
+}
+
+
+int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh){
+ // params boxes: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params keep: (N)
+
+ CHECK_INPUT(boxes);
+ CHECK_CONTIGUOUS(keep);
+
+ int boxes_num = boxes.size(0);
+ const float * boxes_data = boxes.data();
+ long * keep_data = keep.data();
+
+ const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
+
+ unsigned long long *mask_data = NULL;
+ CHECK_ERROR(cudaMalloc((void**)&mask_data, boxes_num * col_blocks * sizeof(unsigned long long)));
+ nmsNormalLauncher(boxes_data, mask_data, boxes_num, nms_overlap_thresh);
+
+ // unsigned long long mask_cpu[boxes_num * col_blocks];
+ // unsigned long long *mask_cpu = new unsigned long long [boxes_num * col_blocks];
+ std::vector mask_cpu(boxes_num * col_blocks);
+
+// printf("boxes_num=%d, col_blocks=%d\n", boxes_num, col_blocks);
+ CHECK_ERROR(cudaMemcpy(&mask_cpu[0], mask_data, boxes_num * col_blocks * sizeof(unsigned long long),
+ cudaMemcpyDeviceToHost));
+
+ cudaFree(mask_data);
+
+ unsigned long long remv_cpu[col_blocks];
+ memset(remv_cpu, 0, col_blocks * sizeof(unsigned long long));
+
+ int num_to_keep = 0;
+
+ for (int i = 0; i < boxes_num; i++){
+ int nblock = i / THREADS_PER_BLOCK_NMS;
+ int inblock = i % THREADS_PER_BLOCK_NMS;
+
+ if (!(remv_cpu[nblock] & (1ULL << inblock))){
+ keep_data[num_to_keep++] = i;
+ unsigned long long *p = &mask_cpu[0] + i * col_blocks;
+ for (int j = nblock; j < col_blocks; j++){
+ remv_cpu[j] |= p[j];
+ }
+ }
+ }
+ if ( cudaSuccess != cudaGetLastError() ) printf( "Error!\n" );
+
+ return num_to_keep;
+}
+
+
diff --git a/det3d/ops/iou3d_nms/src/iou3d_nms.h b/det3d/ops/iou3d_nms/src/iou3d_nms.h
new file mode 100644
index 0000000..aa7ae0e
--- /dev/null
+++ b/det3d/ops/iou3d_nms/src/iou3d_nms.h
@@ -0,0 +1,14 @@
+#ifndef IOU3D_NMS_H
+#define IOU3D_NMS_H
+
+#include
+#include
+#include
+#include
+
+int boxes_overlap_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_overlap);
+int boxes_iou_bev_gpu(at::Tensor boxes_a, at::Tensor boxes_b, at::Tensor ans_iou);
+int nms_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh);
+int nms_normal_gpu(at::Tensor boxes, at::Tensor keep, float nms_overlap_thresh);
+
+#endif
diff --git a/det3d/ops/iou3d_nms/src/iou3d_nms_api.cpp b/det3d/ops/iou3d_nms/src/iou3d_nms_api.cpp
new file mode 100644
index 0000000..5a2d3a3
--- /dev/null
+++ b/det3d/ops/iou3d_nms/src/iou3d_nms_api.cpp
@@ -0,0 +1,17 @@
+#include
+#include
+#include
+#include
+#include
+
+#include "iou3d_cpu.h"
+#include "iou3d_nms.h"
+
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+ m.def("boxes_overlap_bev_gpu", &boxes_overlap_bev_gpu, "oriented boxes overlap");
+ m.def("boxes_iou_bev_gpu", &boxes_iou_bev_gpu, "oriented boxes iou");
+ m.def("nms_gpu", &nms_gpu, "oriented nms gpu");
+ m.def("nms_normal_gpu", &nms_normal_gpu, "nms gpu");
+ m.def("boxes_iou_bev_cpu", &boxes_iou_bev_cpu, "oriented boxes iou");
+}
diff --git a/det3d/ops/iou3d_nms/src/iou3d_nms_kernel.cu b/det3d/ops/iou3d_nms/src/iou3d_nms_kernel.cu
new file mode 100644
index 0000000..e5e305c
--- /dev/null
+++ b/det3d/ops/iou3d_nms/src/iou3d_nms_kernel.cu
@@ -0,0 +1,414 @@
+/*
+3D IoU Calculation and Rotated NMS(modified from 2D NMS written by others)
+Written by Shaoshuai Shi
+All Rights Reserved 2019-2020.
+*/
+
+
+#include
+#define THREADS_PER_BLOCK 16
+#define DIVUP(m, n) ((m) / (n) + ((m) % (n) > 0))
+
+// #define DEBUG
+const int THREADS_PER_BLOCK_NMS = sizeof(unsigned long long) * 8;
+const float EPS = 1e-8;
+struct Point {
+ float x, y;
+ __device__ Point() {}
+ __device__ Point(double _x, double _y){
+ x = _x, y = _y;
+ }
+
+ __device__ void set(float _x, float _y){
+ x = _x; y = _y;
+ }
+
+ __device__ Point operator +(const Point &b)const{
+ return Point(x + b.x, y + b.y);
+ }
+
+ __device__ Point operator -(const Point &b)const{
+ return Point(x - b.x, y - b.y);
+ }
+};
+
+__device__ inline float cross(const Point &a, const Point &b){
+ return a.x * b.y - a.y * b.x;
+}
+
+__device__ inline float cross(const Point &p1, const Point &p2, const Point &p0){
+ return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y);
+}
+
+__device__ int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, const Point &q2){
+ int ret = min(p1.x,p2.x) <= max(q1.x,q2.x) &&
+ min(q1.x,q2.x) <= max(p1.x,p2.x) &&
+ min(p1.y,p2.y) <= max(q1.y,q2.y) &&
+ min(q1.y,q2.y) <= max(p1.y,p2.y);
+ return ret;
+}
+
+__device__ inline int check_in_box2d(const float *box, const Point &p){
+ //params: (7) [x, y, z, dx, dy, dz, heading]
+ const float MARGIN = 1e-2;
+
+ float center_x = box[0], center_y = box[1];
+ float angle_cos = cos(-box[6]), angle_sin = sin(-box[6]); // rotate the point in the opposite direction of box
+ float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin);
+ float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos;
+
+ return (fabs(rot_x) < box[3] / 2 + MARGIN && fabs(rot_y) < box[4] / 2 + MARGIN);
+}
+
+__device__ inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, Point &ans){
+ // fast exclusion
+ if (check_rect_cross(p0, p1, q0, q1) == 0) return 0;
+
+ // check cross standing
+ float s1 = cross(q0, p1, p0);
+ float s2 = cross(p1, q1, p0);
+ float s3 = cross(p0, q1, q0);
+ float s4 = cross(q1, p1, q0);
+
+ if (!(s1 * s2 > 0 && s3 * s4 > 0)) return 0;
+
+ // calculate intersection of two lines
+ float s5 = cross(q1, p1, p0);
+ if(fabs(s5 - s1) > EPS){
+ ans.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1);
+ ans.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1);
+
+ }
+ else{
+ float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y;
+ float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y;
+ float D = a0 * b1 - a1 * b0;
+
+ ans.x = (b0 * c1 - b1 * c0) / D;
+ ans.y = (a1 * c0 - a0 * c1) / D;
+ }
+
+ return 1;
+}
+
+__device__ inline void rotate_around_center(const Point ¢er, const float angle_cos, const float angle_sin, Point &p){
+ float new_x = (p.x - center.x) * angle_cos + (p.y - center.y) * (-angle_sin) + center.x;
+ float new_y = (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y;
+ p.set(new_x, new_y);
+}
+
+__device__ inline int point_cmp(const Point &a, const Point &b, const Point ¢er){
+ return atan2(a.y - center.y, a.x - center.x) > atan2(b.y - center.y, b.x - center.x);
+}
+
+__device__ inline float box_overlap(const float *box_a, const float *box_b){
+ // params box_a: [x, y, z, dx, dy, dz, heading]
+ // params box_b: [x, y, z, dx, dy, dz, heading]
+
+ float a_angle = box_a[6], b_angle = box_b[6];
+ float a_dx_half = box_a[3] / 2, b_dx_half = box_b[3] / 2, a_dy_half = box_a[4] / 2, b_dy_half = box_b[4] / 2;
+ float a_x1 = box_a[0] - a_dx_half, a_y1 = box_a[1] - a_dy_half;
+ float a_x2 = box_a[0] + a_dx_half, a_y2 = box_a[1] + a_dy_half;
+ float b_x1 = box_b[0] - b_dx_half, b_y1 = box_b[1] - b_dy_half;
+ float b_x2 = box_b[0] + b_dx_half, b_y2 = box_b[1] + b_dy_half;
+
+ Point center_a(box_a[0], box_a[1]);
+ Point center_b(box_b[0], box_b[1]);
+
+#ifdef DEBUG
+ printf("a: (%.3f, %.3f, %.3f, %.3f, %.3f), b: (%.3f, %.3f, %.3f, %.3f, %.3f)\n", a_x1, a_y1, a_x2, a_y2, a_angle,
+ b_x1, b_y1, b_x2, b_y2, b_angle);
+ printf("center a: (%.3f, %.3f), b: (%.3f, %.3f)\n", center_a.x, center_a.y, center_b.x, center_b.y);
+#endif
+
+ Point box_a_corners[5];
+ box_a_corners[0].set(a_x1, a_y1);
+ box_a_corners[1].set(a_x2, a_y1);
+ box_a_corners[2].set(a_x2, a_y2);
+ box_a_corners[3].set(a_x1, a_y2);
+
+ Point box_b_corners[5];
+ box_b_corners[0].set(b_x1, b_y1);
+ box_b_corners[1].set(b_x2, b_y1);
+ box_b_corners[2].set(b_x2, b_y2);
+ box_b_corners[3].set(b_x1, b_y2);
+
+ // get oriented corners
+ float a_angle_cos = cos(a_angle), a_angle_sin = sin(a_angle);
+ float b_angle_cos = cos(b_angle), b_angle_sin = sin(b_angle);
+
+ for (int k = 0; k < 4; k++){
+#ifdef DEBUG
+ printf("before corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
+#endif
+ rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]);
+ rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]);
+#ifdef DEBUG
+ printf("corner %d: a(%.3f, %.3f), b(%.3f, %.3f) \n", k, box_a_corners[k].x, box_a_corners[k].y, box_b_corners[k].x, box_b_corners[k].y);
+#endif
+ }
+
+ box_a_corners[4] = box_a_corners[0];
+ box_b_corners[4] = box_b_corners[0];
+
+ // get intersection of lines
+ Point cross_points[16];
+ Point poly_center;
+ int cnt = 0, flag = 0;
+
+ poly_center.set(0, 0);
+ for (int i = 0; i < 4; i++){
+ for (int j = 0; j < 4; j++){
+ flag = intersection(box_a_corners[i + 1], box_a_corners[i], box_b_corners[j + 1], box_b_corners[j], cross_points[cnt]);
+ if (flag){
+ poly_center = poly_center + cross_points[cnt];
+ cnt++;
+#ifdef DEBUG
+ printf("Cross points (%.3f, %.3f): a(%.3f, %.3f)->(%.3f, %.3f), b(%.3f, %.3f)->(%.3f, %.3f) \n",
+ cross_points[cnt - 1].x, cross_points[cnt - 1].y,
+ box_a_corners[i].x, box_a_corners[i].y, box_a_corners[i + 1].x, box_a_corners[i + 1].y,
+ box_b_corners[i].x, box_b_corners[i].y, box_b_corners[i + 1].x, box_b_corners[i + 1].y);
+#endif
+ }
+ }
+ }
+
+ // check corners
+ for (int k = 0; k < 4; k++){
+ if (check_in_box2d(box_a, box_b_corners[k])){
+ poly_center = poly_center + box_b_corners[k];
+ cross_points[cnt] = box_b_corners[k];
+ cnt++;
+#ifdef DEBUG
+ printf("b corners in a: corner_b(%.3f, %.3f)", cross_points[cnt - 1].x, cross_points[cnt - 1].y);
+#endif
+ }
+ if (check_in_box2d(box_b, box_a_corners[k])){
+ poly_center = poly_center + box_a_corners[k];
+ cross_points[cnt] = box_a_corners[k];
+ cnt++;
+#ifdef DEBUG
+ printf("a corners in b: corner_a(%.3f, %.3f)", cross_points[cnt - 1].x, cross_points[cnt - 1].y);
+#endif
+ }
+ }
+
+ poly_center.x /= cnt;
+ poly_center.y /= cnt;
+
+ // sort the points of polygon
+ Point temp;
+ for (int j = 0; j < cnt - 1; j++){
+ for (int i = 0; i < cnt - j - 1; i++){
+ if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)){
+ temp = cross_points[i];
+ cross_points[i] = cross_points[i + 1];
+ cross_points[i + 1] = temp;
+ }
+ }
+ }
+
+#ifdef DEBUG
+ printf("cnt=%d\n", cnt);
+ for (int i = 0; i < cnt; i++){
+ printf("All cross point %d: (%.3f, %.3f)\n", i, cross_points[i].x, cross_points[i].y);
+ }
+#endif
+
+ // get the overlap areas
+ float area = 0;
+ for (int k = 0; k < cnt - 1; k++){
+ area += cross(cross_points[k] - cross_points[0], cross_points[k + 1] - cross_points[0]);
+ }
+
+ return fabs(area) / 2.0;
+}
+
+__device__ inline float iou_bev(const float *box_a, const float *box_b){
+ // params box_a: [x, y, z, dx, dy, dz, heading]
+ // params box_b: [x, y, z, dx, dy, dz, heading]
+ float sa = box_a[3] * box_a[4];
+ float sb = box_b[3] * box_b[4];
+ float s_overlap = box_overlap(box_a, box_b);
+ return s_overlap / fmaxf(sa + sb - s_overlap, EPS);
+}
+
+__global__ void boxes_overlap_kernel(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap){
+ // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
+ const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
+ const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
+
+ if (a_idx >= num_a || b_idx >= num_b){
+ return;
+ }
+ const float * cur_box_a = boxes_a + a_idx * 7;
+ const float * cur_box_b = boxes_b + b_idx * 7;
+ float s_overlap = box_overlap(cur_box_a, cur_box_b);
+ ans_overlap[a_idx * num_b + b_idx] = s_overlap;
+}
+
+__global__ void boxes_iou_bev_kernel(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou){
+ // params boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading]
+ // params boxes_b: (M, 7) [x, y, z, dx, dy, dz, heading]
+ const int a_idx = blockIdx.y * THREADS_PER_BLOCK + threadIdx.y;
+ const int b_idx = blockIdx.x * THREADS_PER_BLOCK + threadIdx.x;
+
+ if (a_idx >= num_a || b_idx >= num_b){
+ return;
+ }
+
+ const float * cur_box_a = boxes_a + a_idx * 7;
+ const float * cur_box_b = boxes_b + b_idx * 7;
+ float cur_iou_bev = iou_bev(cur_box_a, cur_box_b);
+ ans_iou[a_idx * num_b + b_idx] = cur_iou_bev;
+}
+
+__global__ void nms_kernel(const int boxes_num, const float nms_overlap_thresh,
+ const float *boxes, unsigned long long *mask){
+ //params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
+ //params: mask (N, N/THREADS_PER_BLOCK_NMS)
+
+ const int row_start = blockIdx.y;
+ const int col_start = blockIdx.x;
+
+ // if (row_start > col_start) return;
+
+ const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
+ const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
+
+ __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
+
+ if (threadIdx.x < col_size) {
+ block_boxes[threadIdx.x * 7 + 0] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
+ block_boxes[threadIdx.x * 7 + 1] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
+ block_boxes[threadIdx.x * 7 + 2] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
+ block_boxes[threadIdx.x * 7 + 3] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
+ block_boxes[threadIdx.x * 7 + 4] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
+ block_boxes[threadIdx.x * 7 + 5] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
+ block_boxes[threadIdx.x * 7 + 6] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
+ }
+ __syncthreads();
+
+ if (threadIdx.x < row_size) {
+ const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
+ const float *cur_box = boxes + cur_box_idx * 7;
+
+ int i = 0;
+ unsigned long long t = 0;
+ int start = 0;
+ if (row_start == col_start) {
+ start = threadIdx.x + 1;
+ }
+ for (i = start; i < col_size; i++) {
+ if (iou_bev(cur_box, block_boxes + i * 7) > nms_overlap_thresh){
+ t |= 1ULL << i;
+ }
+ }
+ const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
+ mask[cur_box_idx * col_blocks + col_start] = t;
+ }
+}
+
+
+__device__ inline float iou_normal(float const * const a, float const * const b) {
+ //params: a: [x, y, z, dx, dy, dz, heading]
+ //params: b: [x, y, z, dx, dy, dz, heading]
+
+ float left = fmaxf(a[0] - a[3] / 2, b[0] - b[3] / 2), right = fminf(a[0] + a[3] / 2, b[0] + b[3] / 2);
+ float top = fmaxf(a[1] - a[4] / 2, b[1] - b[4] / 2), bottom = fminf(a[1] + a[4] / 2, b[1] + b[4] / 2);
+ float width = fmaxf(right - left, 0.f), height = fmaxf(bottom - top, 0.f);
+ float interS = width * height;
+ float Sa = a[3] * a[4];
+ float Sb = b[3] * b[4];
+ return interS / fmaxf(Sa + Sb - interS, EPS);
+}
+
+
+__global__ void nms_normal_kernel(const int boxes_num, const float nms_overlap_thresh,
+ const float *boxes, unsigned long long *mask){
+ //params: boxes (N, 7) [x, y, z, dx, dy, dz, heading]
+ //params: mask (N, N/THREADS_PER_BLOCK_NMS)
+
+ const int row_start = blockIdx.y;
+ const int col_start = blockIdx.x;
+
+ // if (row_start > col_start) return;
+
+ const int row_size = fminf(boxes_num - row_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
+ const int col_size = fminf(boxes_num - col_start * THREADS_PER_BLOCK_NMS, THREADS_PER_BLOCK_NMS);
+
+ __shared__ float block_boxes[THREADS_PER_BLOCK_NMS * 7];
+
+ if (threadIdx.x < col_size) {
+ block_boxes[threadIdx.x * 7 + 0] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 0];
+ block_boxes[threadIdx.x * 7 + 1] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 1];
+ block_boxes[threadIdx.x * 7 + 2] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 2];
+ block_boxes[threadIdx.x * 7 + 3] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 3];
+ block_boxes[threadIdx.x * 7 + 4] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 4];
+ block_boxes[threadIdx.x * 7 + 5] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 5];
+ block_boxes[threadIdx.x * 7 + 6] = boxes[(THREADS_PER_BLOCK_NMS * col_start + threadIdx.x) * 7 + 6];
+ }
+ __syncthreads();
+
+ if (threadIdx.x < row_size) {
+ const int cur_box_idx = THREADS_PER_BLOCK_NMS * row_start + threadIdx.x;
+ const float *cur_box = boxes + cur_box_idx * 7;
+
+ int i = 0;
+ unsigned long long t = 0;
+ int start = 0;
+ if (row_start == col_start) {
+ start = threadIdx.x + 1;
+ }
+ for (i = start; i < col_size; i++) {
+ if (iou_normal(cur_box, block_boxes + i * 7) > nms_overlap_thresh){
+ t |= 1ULL << i;
+ }
+ }
+ const int col_blocks = DIVUP(boxes_num, THREADS_PER_BLOCK_NMS);
+ mask[cur_box_idx * col_blocks + col_start] = t;
+ }
+}
+
+
+
+
+
+void boxesoverlapLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_overlap){
+
+ dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
+
+ boxes_overlap_kernel<<>>(num_a, boxes_a, num_b, boxes_b, ans_overlap);
+#ifdef DEBUG
+ cudaDeviceSynchronize(); // for using printf in kernel function
+#endif
+}
+
+void boxesioubevLauncher(const int num_a, const float *boxes_a, const int num_b, const float *boxes_b, float *ans_iou){
+
+ dim3 blocks(DIVUP(num_b, THREADS_PER_BLOCK), DIVUP(num_a, THREADS_PER_BLOCK)); // blockIdx.x(col), blockIdx.y(row)
+ dim3 threads(THREADS_PER_BLOCK, THREADS_PER_BLOCK);
+
+ boxes_iou_bev_kernel<<>>(num_a, boxes_a, num_b, boxes_b, ans_iou);
+#ifdef DEBUG
+ cudaDeviceSynchronize(); // for using printf in kernel function
+#endif
+}
+
+
+void nmsLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh){
+ dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
+ DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
+ dim3 threads(THREADS_PER_BLOCK_NMS);
+ nms_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, mask);
+}
+
+
+void nmsNormalLauncher(const float *boxes, unsigned long long * mask, int boxes_num, float nms_overlap_thresh){
+ dim3 blocks(DIVUP(boxes_num, THREADS_PER_BLOCK_NMS),
+ DIVUP(boxes_num, THREADS_PER_BLOCK_NMS));
+ dim3 threads(THREADS_PER_BLOCK_NMS);
+ nms_normal_kernel<<>>(boxes_num, nms_overlap_thresh, boxes, mask);
+}
diff --git a/det3d/ops/point_cloud/__init__.py b/det3d/ops/point_cloud/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/det3d/ops/point_cloud/bev_ops.py b/det3d/ops/point_cloud/bev_ops.py
new file mode 100644
index 0000000..c56d7fa
--- /dev/null
+++ b/det3d/ops/point_cloud/bev_ops.py
@@ -0,0 +1,117 @@
+import math
+
+import numba
+import numpy as np
+
+
+@numba.jit(nopython=True)
+def _points_to_bevmap_reverse_kernel(
+ points,
+ voxel_size,
+ coors_range,
+ coor_to_voxelidx,
+ # coors_2d,
+ bev_map,
+ height_lowers,
+ # density_norm_num=16,
+ with_reflectivity=False,
+ max_voxels=40000,
+):
+ # put all computations to one loop.
+ # we shouldn't create large array in main jit code, otherwise
+ # reduce performance
+ N = points.shape[0]
+ ndim = points.shape[1] - 1
+ # ndim = 3
+ ndim_minus_1 = ndim - 1
+ grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size
+ # np.round(grid_size)
+ # grid_size = np.round(grid_size).astype(np.int64)(np.int32)
+ grid_size = np.round(grid_size, 0, grid_size).astype(np.int32)
+ height_slice_size = voxel_size[-1]
+ coor = np.zeros(shape=(3,), dtype=np.int32) # DHW
+ voxel_num = 0
+ failed = False
+ for i in range(N):
+ failed = False
+ for j in range(ndim):
+ c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j])
+ if c < 0 or c >= grid_size[j]:
+ failed = True
+ break
+ coor[ndim_minus_1 - j] = c
+ if failed:
+ continue
+ voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]]
+ if voxelidx == -1:
+ voxelidx = voxel_num
+ if voxel_num >= max_voxels:
+ break
+ voxel_num += 1
+ coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx
+ # coors_2d[voxelidx] = coor[1:]
+ bev_map[-1, coor[1], coor[2]] += 1
+ height_norm = bev_map[coor[0], coor[1], coor[2]]
+ incomimg_height_norm = (
+ points[i, 2] - height_lowers[coor[0]]
+ ) / height_slice_size
+ if incomimg_height_norm > height_norm:
+ bev_map[coor[0], coor[1], coor[2]] = incomimg_height_norm
+ if with_reflectivity:
+ bev_map[-2, coor[1], coor[2]] = points[i, 3]
+ # return voxel_num
+
+
+def points_to_bev(
+ points,
+ voxel_size,
+ coors_range,
+ with_reflectivity=False,
+ density_norm_num=16,
+ max_voxels=40000,
+):
+ """convert kitti points(N, 4) to a bev map. return [C, H, W] map.
+ this function based on algorithm in points_to_voxel.
+ takes 5ms in a reduced pointcloud with voxel_size=[0.1, 0.1, 0.8]
+
+ Args:
+ points: [N, ndim] float tensor. points[:, :3] contain xyz points and
+ points[:, 3] contain reflectivity.
+ voxel_size: [3] list/tuple or array, float. xyz, indicate voxel size
+ coors_range: [6] list/tuple or array, float. indicate voxel range.
+ format: xyzxyz, minmax
+ with_reflectivity: bool. if True, will add a intensity map to bev map.
+ Returns:
+ bev_map: [num_height_maps + 1(2), H, W] float tensor.
+ `WARNING`: bev_map[-1] is num_points map, NOT density map,
+ because calculate density map need more time in cpu rather than gpu.
+ if with_reflectivity is True, bev_map[-2] is intensity map.
+ """
+ if not isinstance(voxel_size, np.ndarray):
+ voxel_size = np.array(voxel_size, dtype=points.dtype)
+ if not isinstance(coors_range, np.ndarray):
+ coors_range = np.array(coors_range, dtype=points.dtype)
+ voxelmap_shape = (coors_range[3:] - coors_range[:3]) / voxel_size
+ voxelmap_shape = tuple(np.round(voxelmap_shape).astype(np.int32).tolist())
+ voxelmap_shape = voxelmap_shape[::-1] # DHW format
+ coor_to_voxelidx = -np.ones(shape=voxelmap_shape, dtype=np.int32)
+ # coors_2d = np.zeros(shape=(max_voxels, 2), dtype=np.int32)
+ bev_map_shape = list(voxelmap_shape)
+ bev_map_shape[0] += 1
+ height_lowers = np.linspace(
+ coors_range[2], coors_range[5], voxelmap_shape[0], endpoint=False
+ )
+ if with_reflectivity:
+ bev_map_shape[0] += 1
+ bev_map = np.zeros(shape=bev_map_shape, dtype=points.dtype)
+ _points_to_bevmap_reverse_kernel(
+ points,
+ voxel_size,
+ coors_range,
+ coor_to_voxelidx,
+ bev_map,
+ height_lowers,
+ with_reflectivity,
+ max_voxels,
+ )
+ return bev_map
diff --git a/det3d/ops/point_cloud/point_cloud_ops.py b/det3d/ops/point_cloud/point_cloud_ops.py
new file mode 100644
index 0000000..3583508
--- /dev/null
+++ b/det3d/ops/point_cloud/point_cloud_ops.py
@@ -0,0 +1,202 @@
+import time
+
+import numba
+import numpy as np
+
+
+@numba.jit(nopython=True)
+def _points_to_voxel_reverse_kernel(
+ points,
+ voxel_size,
+ coors_range,
+ num_points_per_voxel,
+ coor_to_voxelidx,
+ voxels,
+ coors,
+ max_points=35,
+ max_voxels=20000,
+):
+ # put all computations to one loop.
+ # we shouldn't create large array in main jit code, otherwise
+ # reduce performance
+ N = points.shape[0]
+ # ndim = points.shape[1] - 1
+ ndim = 3
+ ndim_minus_1 = ndim - 1
+ grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size
+ # np.round(grid_size)
+ # grid_size = np.round(grid_size).astype(np.int64)(np.int32)
+ grid_size = np.round(grid_size, 0, grid_size).astype(np.int32)
+ coor = np.zeros(shape=(3,), dtype=np.int32)
+ voxel_num = 0
+ failed = False
+ for i in range(N):
+ failed = False
+ for j in range(ndim):
+ c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j])
+ if c < 0 or c >= grid_size[j]:
+ failed = True
+ break
+ coor[ndim_minus_1 - j] = c
+ if failed:
+ continue
+ voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]]
+ if voxelidx == -1:
+ voxelidx = voxel_num
+ if voxel_num >= max_voxels:
+ continue
+ voxel_num += 1
+ coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx
+ coors[voxelidx] = coor
+ num = num_points_per_voxel[voxelidx]
+ if num < max_points:
+ voxels[voxelidx, num] = points[i]
+ num_points_per_voxel[voxelidx] += 1
+ return voxel_num
+
+
+@numba.jit(nopython=True)
+def _points_to_voxel_kernel(
+ points,
+ voxel_size,
+ coors_range,
+ num_points_per_voxel,
+ coor_to_voxelidx,
+ voxels,
+ coors,
+ max_points=35,
+ max_voxels=20000,
+):
+ # need mutex if write in cuda, but numba.cuda don't support mutex.
+ # in addition, pytorch don't support cuda in dataloader(tensorflow support this).
+ # put all computations to one loop.
+ # we shouldn't create large array in main jit code, otherwise
+ # decrease performance
+ N = points.shape[0]
+ # ndim = points.shape[1] - 1
+ ndim = 3
+ grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size
+ # grid_size = np.round(grid_size).astype(np.int64)(np.int32)
+ grid_size = np.round(grid_size, 0, grid_size).astype(np.int32)
+
+ lower_bound = coors_range[:3]
+ upper_bound = coors_range[3:]
+ coor = np.zeros(shape=(3,), dtype=np.int32)
+ voxel_num = 0
+ failed = False
+ for i in range(N):
+ failed = False
+ for j in range(ndim):
+ c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j])
+ if c < 0 or c >= grid_size[j]:
+ failed = True
+ break
+ coor[j] = c
+ if failed:
+ continue
+ voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]]
+ if voxelidx == -1:
+ voxelidx = voxel_num
+ if voxel_num >= max_voxels:
+ continue
+ voxel_num += 1
+ coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx
+ coors[voxelidx] = coor
+ num = num_points_per_voxel[voxelidx]
+ if num < max_points:
+ voxels[voxelidx, num] = points[i]
+ num_points_per_voxel[voxelidx] += 1
+ return voxel_num
+
+
+def points_to_voxel(
+ points, voxel_size, coors_range, max_points=35, reverse_index=True, max_voxels=20000
+):
+ """convert kitti points(N, >=3) to voxels. This version calculate
+ everything in one loop. now it takes only 4.2ms(complete point cloud)
+ with jit and 3.2ghz cpu.(don't calculate other features)
+ Note: this function in ubuntu seems faster than windows 10.
+
+ Args:
+ points: [N, ndim] float tensor. points[:, :3] contain xyz points and
+ points[:, 3:] contain other information such as reflectivity.
+ voxel_size: [3] list/tuple or array, float. xyz, indicate voxel size
+ coors_range: [6] list/tuple or array, float. indicate voxel range.
+ format: xyzxyz, minmax
+ max_points: int. indicate maximum points contained in a voxel.
+ reverse_index: boolean. indicate whether return reversed coordinates.
+ if points has xyz format and reverse_index is True, output
+ coordinates will be zyx format, but points in features always
+ xyz format.
+ max_voxels: int. indicate maximum voxels this function create.
+ for second, 20000 is a good choice. you should shuffle points
+ before call this function because max_voxels may drop some points.
+
+ Returns:
+ voxels: [M, max_points, ndim] float tensor. only contain points.
+ coordinates: [M, 3] int32 tensor.
+ num_points_per_voxel: [M] int32 tensor.
+ """
+ if not isinstance(voxel_size, np.ndarray):
+ voxel_size = np.array(voxel_size, dtype=points.dtype)
+ if not isinstance(coors_range, np.ndarray):
+ coors_range = np.array(coors_range, dtype=points.dtype)
+ voxelmap_shape = (coors_range[3:] - coors_range[:3]) / voxel_size
+ voxelmap_shape = tuple(np.round(voxelmap_shape).astype(np.int32).tolist())
+ if reverse_index:
+ voxelmap_shape = voxelmap_shape[::-1]
+ # don't create large array in jit(nopython=True) code.
+ num_points_per_voxel = np.zeros(shape=(max_voxels,), dtype=np.int32)
+ coor_to_voxelidx = -np.ones(shape=voxelmap_shape, dtype=np.int32)
+ voxels = np.zeros(
+ shape=(max_voxels, max_points, points.shape[-1]), dtype=points.dtype
+ )
+ coors = np.zeros(shape=(max_voxels, 3), dtype=np.int32)
+ if reverse_index:
+ voxel_num = _points_to_voxel_reverse_kernel(
+ points,
+ voxel_size,
+ coors_range,
+ num_points_per_voxel,
+ coor_to_voxelidx,
+ voxels,
+ coors,
+ max_points,
+ max_voxels,
+ )
+
+ else:
+ voxel_num = _points_to_voxel_kernel(
+ points,
+ voxel_size,
+ coors_range,
+ num_points_per_voxel,
+ coor_to_voxelidx,
+ voxels,
+ coors,
+ max_points,
+ max_voxels,
+ )
+
+ coors = coors[:voxel_num]
+ voxels = voxels[:voxel_num]
+ num_points_per_voxel = num_points_per_voxel[:voxel_num]
+ return voxels, coors, num_points_per_voxel
+
+
+@numba.jit(nopython=True)
+def bound_points_jit(points, upper_bound, lower_bound):
+ # to use nopython=True, np.bool is not supported. so you need
+ # convert result to np.bool after this function.
+ N = points.shape[0]
+ ndim = points.shape[1]
+ keep_indices = np.zeros((N,), dtype=np.int32)
+ success = 0
+ for i in range(N):
+ success = 1
+ for j in range(ndim):
+ if points[i, j] < lower_bound[j] or points[i, j] >= upper_bound[j]:
+ success = 0
+ break
+ keep_indices[i] = success
+ return keep_indices
diff --git a/det3d/solver/__init__.py b/det3d/solver/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/det3d/solver/background.py b/det3d/solver/background.py
new file mode 100644
index 0000000..2c7d082
--- /dev/null
+++ b/det3d/solver/background.py
@@ -0,0 +1,28 @@
+import threading, queue
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, max_prefetch=1):
+ threading.Thread.__init__(self)
+ self.queue = queue.Queue(max_prefetch)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ # Python 3 compatibility
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
diff --git a/det3d/solver/fastai_optim.py b/det3d/solver/fastai_optim.py
new file mode 100644
index 0000000..a543447
--- /dev/null
+++ b/det3d/solver/fastai_optim.py
@@ -0,0 +1,309 @@
+from collections import Iterable, defaultdict
+from copy import deepcopy
+from itertools import chain
+
+import torch
+from torch import nn
+from torch._utils import _unflatten_dense_tensors
+from torch.autograd import Variable
+from torch.nn.utils import parameters_to_vector
+try:
+ from apex.parallel.optimized_sync_batchnorm import SyncBatchNorm
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.modules.batchnorm._BatchNorm, SyncBatchNorm)
+except:
+ print('no apex')
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,nn.modules.batchnorm._BatchNorm)
+
+def split_bn_bias(layer_groups):
+ "Split the layers in `layer_groups` into batchnorm (`bn_types`) and non-batchnorm groups."
+ split_groups = []
+ for l in layer_groups:
+ l1, l2 = [], []
+ for c in l.children():
+ if isinstance(c, bn_types):
+ l2.append(c)
+ else:
+ l1.append(c)
+ split_groups += [nn.Sequential(*l1), nn.Sequential(*l2)]
+ return split_groups
+
+
+def get_master(layer_groups, flat_master: bool = False):
+ "Return two lists, one for the model parameters in FP16 and one for the master parameters in FP32."
+ split_groups = split_bn_bias(layer_groups)
+ model_params = [
+ [param for param in lg.parameters() if param.requires_grad]
+ for lg in split_groups
+ ]
+ if flat_master:
+ master_params = []
+ for lg in model_params:
+ if len(lg) != 0:
+ mp = parameters_to_vector([param.data.float() for param in lg])
+ mp = torch.nn.Parameter(mp, requires_grad=True)
+ if mp.grad is None:
+ mp.grad = mp.new(*mp.size())
+ master_params.append([mp])
+ else:
+ master_params.append([])
+ return model_params, master_params
+ else:
+ master_params = [
+ [param.clone().float().detach() for param in lg] for lg in model_params
+ ]
+ for mp in master_params:
+ for param in mp:
+ param.requires_grad = True
+ return model_params, master_params
+
+
+def model_g2master_g(model_params, master_params, flat_master: bool = False) -> None:
+ "Copy the `model_params` gradients to `master_params` for the optimizer step."
+ if flat_master:
+ for model_group, master_group in zip(model_params, master_params):
+ if len(master_group) != 0:
+ master_group[0].grad.data.copy_(
+ parameters_to_vector([p.grad.data.float() for p in model_group])
+ )
+ else:
+ for model_group, master_group in zip(model_params, master_params):
+ for model, master in zip(model_group, master_group):
+ if model.grad is not None:
+ if master.grad is None:
+ master.grad = master.data.new(*master.data.size())
+ master.grad.data.copy_(model.grad.data)
+ else:
+ master.grad = None
+
+
+def master2model(model_params, master_params, flat_master: bool = False) -> None:
+ "Copy `master_params` to `model_params`."
+ if flat_master:
+ for model_group, master_group in zip(model_params, master_params):
+ if len(model_group) != 0:
+ for model, master in zip(
+ model_group,
+ _unflatten_dense_tensors(master_group[0].data, model_group),
+ ):
+ model.data.copy_(master)
+ else:
+ for model_group, master_group in zip(model_params, master_params):
+ for model, master in zip(model_group, master_group):
+ model.data.copy_(master.data)
+
+
+def listify(p=None, q=None):
+ "Make `p` listy and the same length as `q`."
+ if p is None:
+ p = []
+ elif isinstance(p, str):
+ p = [p]
+ elif not isinstance(p, Iterable):
+ p = [p]
+ n = q if type(q) == int else len(p) if q is None else len(q)
+ if len(p) == 1:
+ p = p * n
+ assert len(p) == n, f"List len mismatch ({len(p)} vs {n})"
+ return list(p)
+
+
+def trainable_params(m: nn.Module):
+ "Return list of trainable params in `m`."
+ res = filter(lambda p: p.requires_grad, m.parameters())
+ return res
+
+
+def is_tuple(x) -> bool:
+ return isinstance(x, tuple)
+
+
+# copy from fastai.
+class OptimWrapper:
+ "Basic wrapper around `opt` to simplify hyper-parameters changes."
+
+ def __init__(self, opt, wd, true_wd: bool = False, bn_wd: bool = True):
+ self.opt, self.true_wd, self.bn_wd = opt, true_wd, bn_wd
+ self.opt_keys = list(self.opt.param_groups[0].keys())
+ self.opt_keys.remove("params")
+ self.read_defaults()
+ self.wd = wd
+
+ @classmethod
+ def create(cls, opt_func, lr, layer_groups, **kwargs):
+ "Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`."
+ split_groups = split_bn_bias(layer_groups)
+ opt = opt_func([{"params": trainable_params(l), "lr": 0} for l in split_groups])
+ opt = cls(opt, **kwargs)
+ opt.lr, opt.opt_func = listify(lr, layer_groups), opt_func
+ return opt
+
+ def new(self, layer_groups):
+ "Create a new `OptimWrapper` from `self` with another `layer_groups` but the same hyper-parameters."
+ opt_func = getattr(self, "opt_func", self.opt.__class__)
+ split_groups = split_bn_bias(layer_groups)
+ opt = opt_func([{"params": trainable_params(l), "lr": 0} for l in split_groups])
+ return self.create(
+ opt_func,
+ self.lr,
+ layer_groups,
+ wd=self.wd,
+ true_wd=self.true_wd,
+ bn_wd=self.bn_wd,
+ )
+
+ def __repr__(self) -> str:
+ return f"OptimWrapper over {repr(self.opt)}.\nTrue weight decay: {self.true_wd}"
+
+ # Pytorch optimizer methods
+ def step(self) -> None:
+ "Set weight decay and step optimizer."
+ # weight decay outside of optimizer step (AdamW)
+ if self.true_wd:
+ for lr, wd, pg1, pg2 in zip(
+ self._lr,
+ self._wd,
+ self.opt.param_groups[::2],
+ self.opt.param_groups[1::2],
+ ):
+ for p in pg1["params"]:
+ p.data.mul_(1 - wd * lr)
+ if self.bn_wd:
+ for p in pg2["params"]:
+ p.data.mul_(1 - wd * lr)
+ self.set_val("weight_decay", listify(0, self._wd))
+ self.opt.step()
+
+ def zero_grad(self) -> None:
+ "Clear optimizer gradients."
+ self.opt.zero_grad()
+
+ # Passthrough to the inner opt.
+ def __getattr__(self, k: str):
+ return getattr(self.opt, k, None)
+
+ def clear(self):
+ "Reset the state of the inner optimizer."
+ sd = self.state_dict()
+ sd["state"] = {}
+ self.load_state_dict(sd)
+
+ # Hyperparameters as properties
+ @property
+ def lr(self) -> float:
+ return self._lr[-1]
+
+ @lr.setter
+ def lr(self, val: float) -> None:
+ self._lr = self.set_val("lr", listify(val, self._lr))
+
+ @property
+ def mom(self) -> float:
+ return self._mom[-1]
+
+ @mom.setter
+ def mom(self, val: float) -> None:
+ if "momentum" in self.opt_keys:
+ self.set_val("momentum", listify(val, self._mom))
+ elif "betas" in self.opt_keys:
+ self.set_val("betas", (listify(val, self._mom), self._beta))
+ self._mom = listify(val, self._mom)
+
+ @property
+ def beta(self) -> float:
+ return None if self._beta is None else self._beta[-1]
+
+ @beta.setter
+ def beta(self, val: float) -> None:
+ "Set beta (or alpha as makes sense for given optimizer)."
+ if val is None:
+ return
+ if "betas" in self.opt_keys:
+ self.set_val("betas", (self._mom, listify(val, self._beta)))
+ elif "alpha" in self.opt_keys:
+ self.set_val("alpha", listify(val, self._beta))
+ self._beta = listify(val, self._beta)
+
+ @property
+ def wd(self) -> float:
+ return self._wd[-1]
+
+ @wd.setter
+ def wd(self, val: float) -> None:
+ "Set weight decay."
+ if not self.true_wd:
+ self.set_val("weight_decay", listify(val, self._wd), bn_groups=self.bn_wd)
+ self._wd = listify(val, self._wd)
+
+ # Helper functions
+ def read_defaults(self) -> None:
+ "Read the values inside the optimizer for the hyper-parameters."
+ self._beta = None
+ if "lr" in self.opt_keys:
+ self._lr = self.read_val("lr")
+ if "momentum" in self.opt_keys:
+ self._mom = self.read_val("momentum")
+ if "alpha" in self.opt_keys:
+ self._beta = self.read_val("alpha")
+ if "betas" in self.opt_keys:
+ self._mom, self._beta = self.read_val("betas")
+ if "weight_decay" in self.opt_keys:
+ self._wd = self.read_val("weight_decay")
+
+ def set_val(self, key: str, val, bn_groups: bool = True):
+ "Set `val` inside the optimizer dictionary at `key`."
+ if is_tuple(val):
+ val = [(v1, v2) for v1, v2 in zip(*val)]
+ for v, pg1, pg2 in zip(
+ val, self.opt.param_groups[::2], self.opt.param_groups[1::2]
+ ):
+ pg1[key] = v
+ if bn_groups:
+ pg2[key] = v
+ return val
+
+ def read_val(self, key: str):
+ "Read a hyperparameter `key` in the optimizer dictionary."
+ val = [pg[key] for pg in self.opt.param_groups[::2]]
+ if is_tuple(val[0]):
+ val = [o[0] for o in val], [o[1] for o in val]
+ return val
+
+
+class FastAIMixedOptim(OptimWrapper):
+ @classmethod
+ def create(
+ cls,
+ opt_func,
+ lr,
+ layer_groups,
+ model,
+ flat_master=False,
+ loss_scale=512.0,
+ **kwargs,
+ ):
+ "Create an `optim.Optimizer` from `opt_func` with `lr`. Set lr on `layer_groups`."
+ opt = OptimWrapper.create(opt_func, lr, layer_groups, **kwargs)
+ opt.model_params, opt.master_params = get_master(layer_groups, flat_master)
+ opt.flat_master = flat_master
+ opt.loss_scale = loss_scale
+ opt.model = model
+ # Changes the optimizer so that the optimization step is done in FP32.
+ # opt = self.learn.opt
+ mom, wd, beta = opt.mom, opt.wd, opt.beta
+ lrs = [lr for lr in opt._lr for _ in range(2)]
+ opt_params = [
+ {"params": mp, "lr": lr} for mp, lr in zip(opt.master_params, lrs)
+ ]
+ opt.opt = opt_func(opt_params)
+ opt.mom, opt.wd, opt.beta = mom, wd, beta
+ return opt
+
+ def step(self):
+ model_g2master_g(self.model_params, self.master_params, self.flat_master)
+ for group in self.master_params:
+ for param in group:
+ param.grad.div_(self.loss_scale)
+ super(FastAIMixedOptim, self).step()
+ self.model.zero_grad()
+ # Update the params from master to model.
+ master2model(self.model_params, self.master_params, self.flat_master)
diff --git a/det3d/solver/learning_schedules.py b/det3d/solver/learning_schedules.py
new file mode 100644
index 0000000..cd1bf55
--- /dev/null
+++ b/det3d/solver/learning_schedules.py
@@ -0,0 +1,192 @@
+"""PyTorch edition of TensorFlow learning schedule in tensorflow object
+detection API.
+"""
+import numpy as np
+from torch.optim.optimizer import Optimizer
+
+
+class _LRSchedulerStep(object):
+ def __init__(self, optimizer, last_step=-1):
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
+ self.optimizer = optimizer
+ if last_step == -1:
+ for group in optimizer.param_groups:
+ group.setdefault("initial_lr", group["lr"])
+ else:
+ for i, group in enumerate(optimizer.param_groups):
+ if "initial_lr" not in group:
+ raise KeyError(
+ "param 'initial_lr' is not specified "
+ "in param_groups[{}] when resuming an optimizer".format(i)
+ )
+ self.base_lrs = list(
+ map(lambda group: group["initial_lr"], optimizer.param_groups)
+ )
+ self.step(last_step + 1)
+ self.last_step = last_step
+
+ """
+ def get_lr(self):
+ raise NotImplementedError
+ """
+
+ def get_lr(self):
+ ret = [self._get_lr_per_group(base_lr) for base_lr in self.base_lrs]
+ return ret
+
+ def _get_lr_per_group(self, base_lr):
+ raise NotImplementedError
+
+ def step(self, step=None):
+ if step is None:
+ step = self.last_step + 1
+ self.last_step = step
+ for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
+ param_group["lr"] = lr
+
+
+class Constant(_LRSchedulerStep):
+ def __init__(self, optimizer, last_step=-1):
+ super().__init__(optimizer, last_step)
+
+ def _get_lr_per_group(self, base_lr):
+ return base_lr
+
+
+class ManualStepping(_LRSchedulerStep):
+ """Pytorch edition of manual_stepping in tensorflow.
+ DON'T SUPPORT PARAM GROUPS.
+ """
+
+ def __init__(self, optimizer, boundaries, rates, last_step=-1):
+ self._boundaries = boundaries
+ self._num_boundaries = len(boundaries)
+ self._learning_rates = rates
+
+ if any([b < 0 for b in boundaries]) or any(
+ [not isinstance(b, int) for b in boundaries]
+ ):
+ raise ValueError("boundaries must be a list of positive integers")
+ if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
+ raise ValueError("Entries in boundaries must be strictly increasing.")
+ if any([not isinstance(r, float) for r in rates]):
+ raise ValueError("Learning rates must be floats")
+ if len(rates) != len(boundaries) + 1:
+ raise ValueError(
+ "Number of provided learning rates must exceed "
+ "number of boundary points by exactly 1."
+ )
+ super().__init__(optimizer, last_step)
+
+ def _get_lr_per_group(self, base_lr):
+ step = self.last_step
+ ret = None
+ for i, bound in enumerate(self._boundaries):
+ if step > bound:
+ ret = self._learning_rates[i + 1]
+ if ret is not None:
+ return ret
+ return self._learning_rates[0]
+
+
+class ExponentialDecayWithBurnin(_LRSchedulerStep):
+ """Pytorch edition of manual_stepping in tensorflow.
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ learning_rate_decay_steps,
+ learning_rate_decay_factor,
+ burnin_learning_rate,
+ burnin_steps,
+ last_step=-1,
+ ):
+ self._decay_steps = learning_rate_decay_steps
+ self._decay_factor = learning_rate_decay_factor
+ self._burnin_learning_rate = burnin_learning_rate
+ self._burnin_steps = burnin_steps
+
+ super().__init__(optimizer, last_step)
+
+ def _get_lr_per_group(self, base_lr):
+ if self._burnin_learning_rate == 0:
+ burnin_learning_rate = base_lr
+ step = self.last_step
+ post_burnin_learning_rate = base_lr * self._decay_factor ^ (
+ step // self._decay_steps
+ )
+ if step < self._burnin_steps:
+ return burnin_learning_rate
+ else:
+ return post_burnin_learning_rate
+
+
+class ExponentialDecay(_LRSchedulerStep):
+ def __init__(
+ self,
+ optimizer,
+ learning_rate_decay_steps,
+ learning_rate_decay_factor,
+ staircase=True,
+ last_step=-1,
+ ):
+ self._decay_steps = learning_rate_decay_steps
+ self._decay_factor = learning_rate_decay_factor
+ self._staircase = staircase
+
+ super().__init__(optimizer, last_step)
+
+ def _get_lr_per_group(self, base_lr):
+ step = self.last_step
+ if self._staircase:
+ post_burnin_learning_rate = base_lr * pow(
+ self._decay_factor, (step // self._decay_steps)
+ )
+ else:
+ post_burnin_learning_rate = base_lr * pow(
+ self._decay_factor, (step / self._decay_steps)
+ )
+
+ return post_burnin_learning_rate
+
+
+class CosineDecayWithWarmup(_LRSchedulerStep):
+ def __init__(
+ self, optimizer, total_steps, warmup_learning_rate, warmup_steps, last_step=-1
+ ):
+ if total_steps < warmup_steps:
+ raise ValueError("total_steps must be larger or equal to " "warmup_steps.")
+ self._total_steps = total_steps
+ self._warmup_learning_rate = warmup_learning_rate
+ self._warmup_steps = warmup_steps
+
+ super().__init__(optimizer, last_step)
+
+ def _get_lr_per_group(self, base_lr):
+ if base_lr < self._warmup_learning_rate:
+ raise ValueError(
+ "learning_rate_base must be larger " "or equal to warmup_learning_rate."
+ )
+
+ step = self.last_step
+ learning_rate = (
+ 0.5
+ * base_lr
+ * (
+ 1
+ + np.cos(
+ np.pi
+ * (float(step) - self._warmup_steps)
+ / float(self._total_steps - self._warmup_steps)
+ )
+ )
+ )
+ if self._warmup_steps > 0:
+ slope = (base_lr - self._warmup_learning_rate) / self._warmup_steps
+ pre_cosine_learning_rate = slope * float(step) + self._warmup_learning_rate
+ if step < self._warmup_steps:
+ return pre_cosine_learning_rate
+ else:
+ return learning_rate
diff --git a/det3d/solver/learning_schedules_fastai.py b/det3d/solver/learning_schedules_fastai.py
new file mode 100644
index 0000000..79fff8e
--- /dev/null
+++ b/det3d/solver/learning_schedules_fastai.py
@@ -0,0 +1,168 @@
+import math
+from functools import partial
+
+import numpy as np
+
+
+class LRSchedulerStep(object):
+ def __init__(self, fai_optimizer, total_step, lr_phases, mom_phases):
+ self.optimizer = fai_optimizer
+ self.total_step = total_step
+ self.lr_phases = []
+
+ for i, (start, lambda_func) in enumerate(lr_phases):
+ if len(self.lr_phases) != 0:
+ assert self.lr_phases[-1][0] < int(start * total_step)
+ if isinstance(lambda_func, str):
+ lambda_func = eval(lambda_func)
+ if i < len(lr_phases) - 1:
+ self.lr_phases.append(
+ (
+ int(start * total_step),
+ int(lr_phases[i + 1][0] * total_step),
+ lambda_func,
+ )
+ )
+ else:
+ self.lr_phases.append(
+ (int(start * total_step), total_step, lambda_func)
+ )
+ assert self.lr_phases[0][0] == 0
+ self.mom_phases = []
+ for i, (start, lambda_func) in enumerate(mom_phases):
+ if len(self.mom_phases) != 0:
+ assert self.mom_phases[-1][0] < start
+ if isinstance(lambda_func, str):
+ lambda_func = eval(lambda_func)
+ if i < len(mom_phases) - 1:
+ self.mom_phases.append(
+ (
+ int(start * total_step),
+ int(mom_phases[i + 1][0] * total_step),
+ lambda_func,
+ )
+ )
+ else:
+ self.mom_phases.append(
+ (int(start * total_step), total_step, lambda_func)
+ )
+ # assert self.mom_phases[0][0] == 0
+ if len(mom_phases) > 0:
+ assert self.mom_phases[0][0] == 0
+
+ def step(self, step):
+ lrs, moms = [], []
+
+ for start, end, func in self.lr_phases:
+ if step >= start:
+ # self.optimizer.lr = func((step - start) / (end - start))
+ lrs.append(func((step - start) / (end - start)))
+ if len(lrs) > 0:
+ self.optimizer.lr = lrs[-1]
+ for start, end, func in self.mom_phases:
+ if step >= start:
+ moms.append(func((step - start) / (end - start)))
+ self.optimizer.mom = func((step - start) / (end - start))
+ if len(moms) > 0:
+ self.optimizer.mom = moms[-1]
+
+
+def annealing_cos(start, end, pct):
+ # print(pct, start, end)
+ "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."
+ cos_out = np.cos(np.pi * pct) + 1
+ return end + (start - end) / 2 * cos_out
+
+
+class OneCycle(LRSchedulerStep):
+ def __init__(self, fai_optimizer, total_step, lr_max, moms, div_factor, pct_start):
+ self.lr_max = lr_max
+ self.moms = moms
+ self.div_factor = div_factor
+ self.pct_start = pct_start
+ a1 = int(total_step * self.pct_start)
+ a2 = total_step - a1
+ low_lr = self.lr_max / self.div_factor
+ lr_phases = (
+ (0, partial(annealing_cos, low_lr, self.lr_max)),
+ (self.pct_start, partial(annealing_cos, self.lr_max, low_lr / 1e4)),
+ )
+ mom_phases = (
+ (0, partial(annealing_cos, *self.moms)),
+ (self.pct_start, partial(annealing_cos, *self.moms[::-1])),
+ )
+ fai_optimizer.lr, fai_optimizer.mom = low_lr, self.moms[0]
+ super().__init__(fai_optimizer, total_step, lr_phases, mom_phases)
+
+
+class ExponentialDecay(LRSchedulerStep):
+ def __init__(
+ self,
+ fai_optimizer,
+ total_step,
+ initial_learning_rate,
+ decay_length,
+ decay_factor,
+ staircase=True,
+ ):
+ """
+ Args:
+ decay_length: must in (0, 1)
+ """
+ assert decay_length > 0
+ assert decay_length < 1
+ self._decay_steps_unified = decay_length
+ self._decay_factor = decay_factor
+ self._staircase = staircase
+ step = 0
+ stage = 1
+ lr_phases = []
+ if staircase:
+ while step <= total_step:
+ func = lambda p, _d=initial_learning_rate * stage: _d
+ lr_phases.append((step / total_step, func))
+ stage *= decay_factor
+ step += int(decay_length * total_step)
+ else:
+ func = lambda p: pow(decay_factor, (p / decay_length))
+ lr_phases.append((0, func))
+ super().__init__(fai_optimizer, total_step, lr_phases, [])
+
+
+class ManualStepping(LRSchedulerStep):
+ def __init__(self, fai_optimizer, total_step, boundaries, rates):
+ assert all([b > 0 and b < 1 for b in boundaries])
+ assert len(boundaries) + 1 == len(rates)
+ boundaries.insert(0, 0.0)
+ lr_phases = []
+ for start, rate in zip(boundaries, rates):
+ func = lambda p, _d=rate: _d
+ lr_phases.append((start, func))
+ super().__init__(fai_optimizer, total_step, lr_phases, [])
+
+
+class FakeOptim:
+ def __init__(self):
+ self.lr = 0
+ self.mom = 0
+
+
+if __name__ == "__main__":
+ import matplotlib.pyplot as plt
+
+ opt = FakeOptim() # 3e-3, wd=0.4, div_factor=10
+ # schd = OneCycle(opt, 100, 3e-3, (0.95, 0.85), 10.0, 0.1)
+ schd = ExponentialDecay(opt, 100, 3e-4, 0.1, 0.8, staircase=True)
+ schd = ManualStepping(opt, 100, [0.8, 0.9], [0.001, 0.0001, 0.00005])
+
+ lrs = []
+ moms = []
+ for i in range(100):
+ schd.step(i)
+ lrs.append(opt.lr)
+ moms.append(opt.mom)
+ plt.plot(lrs)
+ # plt.plot(moms)
+ # plt.show()
+ # plt.plot(moms)
+ plt.show()
diff --git a/det3d/solver/optim.py b/det3d/solver/optim.py
new file mode 100644
index 0000000..224ada0
--- /dev/null
+++ b/det3d/solver/optim.py
@@ -0,0 +1,116 @@
+from collections import Iterable, defaultdict
+from copy import deepcopy
+from itertools import chain
+
+import torch
+from torch.autograd import Variable
+
+required = object()
+
+
+def param_fp32_copy(params):
+ param_copy = [
+ param.clone().type(torch.cuda.FloatTensor).detach() for param in params
+ ]
+ for param in param_copy:
+ param.requires_grad = True
+ return param_copy
+
+
+def set_grad(params, params_with_grad, scale=1.0):
+ for param, param_w_grad in zip(params, params_with_grad):
+ if param.grad is None:
+ param.grad = torch.nn.Parameter(
+ param.data.new().resize_(*param.data.size())
+ )
+ grad = param_w_grad.grad.data
+ if scale is not None:
+ grad /= scale
+ if torch.isnan(grad).any() or torch.isinf(grad).any():
+ return True # invalid grad
+ param.grad.data.copy_(grad)
+ return False
+
+
+class MixedPrecisionWrapper(object):
+ """mixed precision optimizer wrapper.
+ Arguments:
+ optimizer (torch.optim.Optimizer): an instance of
+ :class:`torch.optim.Optimizer`
+ scale: (float): a scalar for grad scale.
+ auto_scale: (bool): whether enable auto scale.
+ The algorihm of auto scale is discribled in
+ http://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
+ """
+
+ def __init__(
+ self,
+ optimizer,
+ scale=None,
+ auto_scale=True,
+ inc_factor=2.0,
+ dec_factor=0.5,
+ num_iters_be_stable=500,
+ ):
+ if not isinstance(optimizer, torch.optim.Optimizer):
+ raise ValueError("must provide a torch.optim.Optimizer")
+ self.optimizer = optimizer
+ if hasattr(self.optimizer, "name"):
+ self.name = self.optimizer.name # for ckpt system
+ param_groups_copy = []
+ for i, group in enumerate(optimizer.param_groups):
+ group_copy = {n: v for n, v in group.items() if n != "params"}
+ group_copy["params"] = param_fp32_copy(group["params"])
+ param_groups_copy.append(group_copy)
+
+ # switch param_groups, may be dangerous
+ self.param_groups = optimizer.param_groups
+ optimizer.param_groups = param_groups_copy
+ self.grad_scale = scale
+ self.auto_scale = auto_scale
+ self.inc_factor = inc_factor
+ self.dec_factor = dec_factor
+ self.stable_iter_count = 0
+ self.num_iters_be_stable = num_iters_be_stable
+
+ def __getstate__(self):
+ return self.optimizer.__getstate__()
+
+ def __setstate__(self, state):
+ return self.optimizer.__setstate__(state)
+
+ def __repr__(self):
+ return self.optimizer.__repr__()
+
+ def state_dict(self):
+ return self.optimizer.state_dict()
+
+ def load_state_dict(self, state_dict):
+ return self.optimizer.load_state_dict(state_dict)
+
+ def zero_grad(self):
+ return self.optimizer.zero_grad()
+
+ def step(self, closure=None):
+ for g, g_copy in zip(self.param_groups, self.optimizer.param_groups):
+ invalid = set_grad(g_copy["params"], g["params"], self.grad_scale)
+ if invalid:
+ if self.grad_scale is None or self.auto_scale is False:
+ raise ValueError("nan/inf detected but auto_scale disabled.")
+ self.grad_scale *= self.dec_factor
+ print("scale decay to {}".format(self.grad_scale))
+ return
+ if self.auto_scale is True:
+ self.stable_iter_count += 1
+ if self.stable_iter_count > self.num_iters_be_stable:
+ if self.grad_scale is not None:
+ self.grad_scale *= self.inc_factor
+ self.stable_iter_count = 0
+
+ if closure is None:
+ self.optimizer.step()
+ else:
+ self.optimizer.step(closure)
+ for g, g_copy in zip(self.param_groups, self.optimizer.param_groups):
+ for p_copy, p in zip(g_copy["params"], g["params"]):
+ p.data.copy_(p_copy.data)
diff --git a/det3d/torchie/__init__.py b/det3d/torchie/__init__.py
new file mode 100644
index 0000000..e5df6bc
--- /dev/null
+++ b/det3d/torchie/__init__.py
@@ -0,0 +1,6 @@
+# from .apis import *
+from .cnn import *
+from .fileio import *
+from .parallel import *
+from .trainer import *
+from .utils import *
diff --git a/det3d/torchie/apis/__init__.py b/det3d/torchie/apis/__init__.py
new file mode 100644
index 0000000..952d978
--- /dev/null
+++ b/det3d/torchie/apis/__init__.py
@@ -0,0 +1,14 @@
+from .env import get_root_logger, init_dist, set_random_seed
+from .train import batch_processor, batch_processor_ensemble, build_optimizer, train_detector
+
+# from .inference import init_detector, inference_detector, show_result
+
+__all__ = [
+ "init_dist",
+ "get_root_logger",
+ "set_random_seed",
+ "train_detector",
+ "build_optimizer",
+ "batch_processor",
+ # 'init_detector', 'inference_detector', 'show_result'
+]
diff --git a/det3d/torchie/apis/env.py b/det3d/torchie/apis/env.py
new file mode 100644
index 0000000..75dc44e
--- /dev/null
+++ b/det3d/torchie/apis/env.py
@@ -0,0 +1,67 @@
+import logging
+import os
+import random
+import subprocess
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from det3d.torchie.trainer import get_dist_info
+
+
+def init_dist(launcher, backend="nccl", **kwargs):
+ if mp.get_start_method(allow_none=True) is None:
+ mp.set_start_method("spawn")
+ if launcher == "pytorch":
+ _init_dist_pytorch(backend, **kwargs)
+ elif launcher == "mpi":
+ _init_dist_mpi(backend, **kwargs)
+ elif launcher == "slurm":
+ _init_dist_slurm(backend, **kwargs)
+ else:
+ raise ValueError("Invalid launcher type: {}".format(launcher))
+
+
+def _init_dist_pytorch(backend, **kwargs):
+ torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
+ dist.init_process_group(backend=backend, **kwargs)
+
+
+def _init_dist_mpi(backend, **kwargs):
+ raise NotImplementedError
+
+
+def _init_dist_slurm(backend, port=29500, **kwargs):
+ proc_id = int(os.environ["SLURM_PROCID"])
+ ntasks = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(proc_id % num_gpus)
+ addr = subprocess.getoutput(
+ "scontrol show hostname {} | head -n1".format(node_list)
+ )
+ os.environ["MASTER_PORT"] = str(port)
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(ntasks)
+ os.environ["RANK"] = str(proc_id)
+ dist.init_process_group(backend=backend)
+
+
+def set_random_seed(seed):
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_root_logger(log_level=logging.INFO):
+ logger = logging.getLogger()
+ if not logger.hasHandlers():
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(message)s", level=log_level
+ )
+ rank, _ = get_dist_info()
+ if rank != 0:
+ logger.setLevel("ERROR")
+ return logger
diff --git a/det3d/torchie/apis/train.py b/det3d/torchie/apis/train.py
new file mode 100644
index 0000000..4580c5a
--- /dev/null
+++ b/det3d/torchie/apis/train.py
@@ -0,0 +1,335 @@
+from __future__ import division
+
+import re
+from collections import OrderedDict, defaultdict
+from functools import partial
+
+# try:
+# import apex
+# except:
+# print("No APEX!")
+
+import numpy as np
+import torch
+from det3d.builder import _create_learning_rate_scheduler
+
+# from det3d.datasets.kitti.eval_hooks import KittiDistEvalmAPHook, KittiEvalmAPHookV2
+from det3d.core import DistOptimizerHook
+from det3d.datasets import DATASETS, build_dataloader
+from det3d.solver.fastai_optim import OptimWrapper
+from det3d.torchie.trainer import DistSamplerSeedHook, DisableDBSamplerHook, Trainer, obj_from_dict
+from det3d.utils.print_utils import metric_to_str
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel
+
+from .env import get_root_logger
+
+
+def example_to_device(example, device=None, non_blocking=False) -> dict:
+ assert device is not None
+
+ example_torch = {}
+ float_names = ["voxels", "bev_map"]
+ for k, v in example.items():
+ if k in ["anchors", "anchors_mask", "reg_targets", "reg_weights", "labels", "hm",
+ "anno_box", "ind", "mask", 'cat', 'corners', 'points']:
+ example_torch[k] = [res.to(device, non_blocking=non_blocking) for res in v]
+ elif k in [
+ "voxels",
+ "bev_map",
+ "coordinates",
+ "num_points",
+ # "points",
+ "num_voxels",
+ "cyv_voxels",
+ "cyv_num_voxels",
+ "cyv_coordinates",
+ "cyv_num_points",
+ "gt_boxes_and_cls",
+ "gt_boxes_mask",
+ "gt_offset",
+ 'times'
+ ]:
+ example_torch[k] = v.to(device, non_blocking=non_blocking)
+ elif k == "calib":
+ calib = {}
+ for k1, v1 in v.items():
+ # calib[k1] = torch.tensor(v1, dtype=dtype, device=device)
+ calib[k1] = torch.tensor(v1).to(device, non_blocking=non_blocking)
+ example_torch[k] = calib
+ else:
+ example_torch[k] = v
+
+ return example_torch
+
+
+def parse_losses(losses):
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError("{} is not a tensor or list of tensors".format(loss_name))
+
+ loss = sum(_value for _key, _value in log_vars.items() if "loss" in _key)
+
+ log_vars["loss"] = loss
+ for name in log_vars:
+ log_vars[name] = log_vars[name].item()
+
+ return loss, log_vars
+
+
+def parse_second_losses(losses):
+
+ log_vars = OrderedDict()
+ loss = sum(losses["loss"])
+ for loss_name, loss_value in losses.items():
+ if loss_name == "loc_loss_elem":
+ log_vars[loss_name] = [[i.item() for i in j] for j in loss_value]
+ else:
+ log_vars[loss_name] = [i.item() for i in loss_value]
+
+ return loss, log_vars
+
+
+def batch_processor(model, data, train_mode, **kwargs):
+
+ if "local_rank" in kwargs:
+ device = torch.device(kwargs["local_rank"])
+ else:
+ device = None
+
+ # data = example_convert_to_torch(data, device=device)
+ example = example_to_device(data, device, non_blocking=False)
+
+ del data
+
+ if train_mode:
+ losses = model(example, return_loss=True)
+ loss, log_vars = parse_second_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(example["anchors"][0])
+ )
+ return outputs
+ else:
+ return model(example, return_loss=False)
+
+def batch_processor_ensemble(model1, model2, data, train_mode, **kwargs):
+ assert 0, 'deprecated'
+ if "local_rank" in kwargs:
+ device = torch.device(kwargs["local_rank"])
+ else:
+ device = None
+
+ assert train_mode is False
+
+ example = example_to_device(data, device, non_blocking=False)
+ del data
+
+ preds_dicts1 = model1.pred_hm(example)
+ preds_dicts2 = model2.pred_hm(example)
+
+ num_task = len(preds_dicts1)
+
+ merge_list = []
+
+ # take the average
+ for task_id in range(num_task):
+ preds_dict1 = preds_dicts1[task_id]
+ preds_dict2 = preds_dicts2[task_id]
+
+ for key in preds_dict1.keys():
+ preds_dict1[key] = (preds_dict1[key] + preds_dict2[key]) / 2
+
+ merge_list.append(preds_dict1)
+
+ # now get the final prediciton
+ return model1.pred_result(example, merge_list)
+
+
+def flatten_model(m):
+ return sum(map(flatten_model, m.children()), []) if len(list(m.children())) else [m]
+
+
+def get_layer_groups(m):
+ return [nn.Sequential(*flatten_model(m))]
+
+
+def build_one_cycle_optimizer(model, optimizer_config):
+ if optimizer_config.fixed_wd:
+ optimizer_func = partial(
+ torch.optim.Adam, betas=(0.9, 0.99), amsgrad=optimizer_config.amsgrad
+ )
+ else:
+ optimizer_func = partial(torch.optim.Adam, amsgrad=optimizer_cfg.amsgrad)
+
+ optimizer = OptimWrapper.create(
+ optimizer_func,
+ 3e-3, # TODO: CHECKING LR HERE !!!
+ get_layer_groups(model),
+ wd=optimizer_config.wd,
+ true_wd=optimizer_config.fixed_wd,
+ bn_wd=True,
+ )
+
+ return optimizer
+
+
+def build_optimizer(model, optimizer_cfg):
+ """Build optimizer from configs.
+ Args:
+ model (:obj:`nn.Module`): The model with parameters to be optimized.
+ optimizer_cfg (dict): The config dict of the optimizer.
+ Positional fields are:
+ - type: class name of the optimizer.
+ - lr: base learning rate.
+ Optional fields are:
+ - any arguments of the corresponding optimizer type, e.g.,
+ weight_decay, momentum, etc.
+ - paramwise_options: a dict with 3 accepted fileds
+ (bias_lr_mult, bias_decay_mult, norm_decay_mult).
+ `bias_lr_mult` and `bias_decay_mult` will be multiplied to
+ the lr and weight decay respectively for all bias parameters
+ (except for the normalization layers), and
+ `norm_decay_mult` will be multiplied to the weight decay
+ for all weight and bias parameters of normalization layers.
+ Returns:
+ torch.optim.Optimizer: The initialized optimizer.
+ """
+ if hasattr(model, "module"):
+ model = model.module
+
+ optimizer_cfg = optimizer_cfg.copy()
+ paramwise_options = optimizer_cfg.pop("paramwise_options", None)
+ # if no paramwise option is specified, just use the global setting
+ if paramwise_options is None:
+ return obj_from_dict(
+ optimizer_cfg, torch.optim, dict(params=model.parameters())
+ )
+ else:
+ assert isinstance(paramwise_options, dict)
+ # get base lr and weight decay
+ base_lr = optimizer_cfg["lr"]
+ base_wd = optimizer_cfg.get("weight_decay", None)
+ # weight_decay must be explicitly specified if mult is specified
+ if (
+ "bias_decay_mult" in paramwise_options
+ or "norm_decay_mult" in paramwise_options
+ ):
+ assert base_wd is not None
+ # get param-wise options
+ bias_lr_mult = paramwise_options.get("bias_lr_mult", 1.0)
+ bias_decay_mult = paramwise_options.get("bias_decay_mult", 1.0)
+ norm_decay_mult = paramwise_options.get("norm_decay_mult", 1.0)
+ # set param-wise lr and weight decay
+ params = []
+ for name, param in model.named_parameters():
+ param_group = {"params": [param]}
+ if not param.requires_grad:
+ # FP16 training needs to copy gradient/weight between master
+ # weight copy and model weight, it is convenient to keep all
+ # parameters here to align with model.parameters()
+ params.append(param_group)
+ continue
+
+ # for norm layers, overwrite the weight decay of weight and bias
+ # TODO: obtain the norm layer prefixes dynamically
+ if re.search(r"(bn|gn)(\d+)?.(weight|bias)", name):
+ if base_wd is not None:
+ param_group["weight_decay"] = base_wd * norm_decay_mult
+ # for other layers, overwrite both lr and weight decay of bias
+ elif name.endswith(".bias"):
+ param_group["lr"] = base_lr * bias_lr_mult
+ if base_wd is not None:
+ param_group["weight_decay"] = base_wd * bias_decay_mult
+ # otherwise use the global settings
+
+ params.append(param_group)
+
+ optimizer_cls = getattr(torch.optim, optimizer_cfg.pop("type"))
+ return optimizer_cls(params, **optimizer_cfg)
+
+
+def train_detector(model, dataset, cfg, distributed=False, validate=False, logger=None):
+ if logger is None:
+ logger = get_root_logger(cfg.log_level)
+
+ # start training
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ data_loaders = [
+ build_dataloader(
+ ds, cfg.data.samples_per_gpu, cfg.data.workers_per_gpu, dist=distributed
+ )
+ for ds in dataset
+ ]
+
+ total_steps = cfg.total_epochs * len(data_loaders[0])
+ # print(f"total_steps: {total_steps}")
+ if distributed:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ if cfg.lr_config.type == "one_cycle":
+ # build trainer
+ optimizer = build_one_cycle_optimizer(model, cfg.optimizer)
+ lr_scheduler = _create_learning_rate_scheduler(
+ optimizer, cfg.lr_config, total_steps
+ )
+ cfg.lr_config = None
+ else:
+ optimizer = build_optimizer(model, cfg.optimizer)
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=cfg.drop_step, gamma=.1)
+ # lr_scheduler = None
+ cfg.lr_config = None
+
+ # put model on gpus
+ if distributed:
+ model = DistributedDataParallel(
+ model.cuda(cfg.local_rank),
+ device_ids=[cfg.local_rank],
+ output_device=cfg.local_rank,
+ # broadcast_buffers=False,
+ find_unused_parameters=True,
+ )
+ else:
+ model = model.cuda()
+
+ logger.info(f"model structure: {model}")
+
+ trainer = Trainer(
+ model, batch_processor, optimizer, lr_scheduler, cfg.work_dir, cfg.log_level
+ )
+
+ if distributed:
+ optimizer_config = DistOptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ trainer.register_training_hooks(
+ cfg.lr_config, optimizer_config, cfg.checkpoint_config, cfg.log_config
+ )
+
+ if distributed:
+ trainer.register_hook(DistSamplerSeedHook())
+
+ if "disable_dbsampler_after_epoch" in cfg:
+ trainer.register_hook(DisableDBSamplerHook(cfg.disable_dbsampler_after_epoch))
+
+ # # register eval hooks
+ # if validate:
+ # val_dataset_cfg = cfg.data.val
+ # eval_cfg = cfg.get('evaluation', {})
+ # dataset_type = DATASETS.get(val_dataset_cfg.type)
+ # trainer.register_hook(
+ # KittiEvalmAPHookV2(val_dataset_cfg, **eval_cfg))
+
+ if cfg.resume_from:
+ trainer.resume(cfg.resume_from)
+ elif cfg.load_from:
+ trainer.load_checkpoint(cfg.load_from)
+
+ trainer.run(data_loaders, cfg.workflow, cfg.total_epochs, local_rank=cfg.local_rank)
diff --git a/det3d/torchie/cnn/__init__.py b/det3d/torchie/cnn/__init__.py
new file mode 100644
index 0000000..ef83ac5
--- /dev/null
+++ b/det3d/torchie/cnn/__init__.py
@@ -0,0 +1,25 @@
+from .alexnet import AlexNet
+from .resnet import ResNet, make_res_layer
+from .vgg import VGG, make_vgg_layer
+from .weight_init import (
+ caffe2_xavier_init,
+ constant_init,
+ kaiming_init,
+ normal_init,
+ uniform_init,
+ xavier_init,
+)
+
+__all__ = [
+ "AlexNet",
+ "VGG",
+ "make_vgg_layer",
+ "ResNet",
+ "make_res_layer",
+ "constant_init",
+ "xavier_init",
+ "normal_init",
+ "uniform_init",
+ "kaiming_init",
+ "caffe2_xavier_init",
+]
diff --git a/det3d/torchie/cnn/alexnet.py b/det3d/torchie/cnn/alexnet.py
new file mode 100644
index 0000000..28db641
--- /dev/null
+++ b/det3d/torchie/cnn/alexnet.py
@@ -0,0 +1,61 @@
+import logging
+
+import torch.nn as nn
+
+from ..trainer import load_checkpoint
+
+
+class AlexNet(nn.Module):
+ """AlexNet backbone.
+
+ Args:
+ num_classes (int): number of classes for classification.
+ """
+
+ def __init__(self, num_classes=-1):
+ super(AlexNet, self).__init__()
+ self.num_classes = num_classes
+ self.features = nn.Sequential(
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ )
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Dropout(),
+ nn.Linear(256 * 6 * 6, 4096),
+ nn.ReLU(inplace=True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(inplace=True),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ # use default initializer
+ pass
+ else:
+ raise TypeError("pretrained must be a str or None")
+
+ def forward(self, x):
+
+ x = self.features(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), 256 * 6 * 6)
+ x = self.classifier(x)
+
+ return x
diff --git a/det3d/torchie/cnn/resnet.py b/det3d/torchie/cnn/resnet.py
new file mode 100644
index 0000000..2af2aec
--- /dev/null
+++ b/det3d/torchie/cnn/resnet.py
@@ -0,0 +1,323 @@
+import logging
+
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+
+from ..trainer import load_checkpoint
+from .weight_init import constant_init, kaiming_init
+
+
+def conv3x3(in_planes, out_planes, stride=1, dilation=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(
+ in_planes,
+ out_planes,
+ kernel_size=3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False,
+ )
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style="pytorch",
+ with_cp=False,
+ ):
+ super(BasicBlock, self).__init__()
+ self.conv1 = conv3x3(inplanes, planes, stride, dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = conv3x3(planes, planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ assert not with_cp
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style="pytorch",
+ with_cp=False,
+ ):
+ """Bottleneck block.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer,
+ if it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__()
+ assert style in ["pytorch", "caffe"]
+ if style == "pytorch":
+ conv1_stride = 1
+ conv2_stride = stride
+ else:
+ conv1_stride = stride
+ conv2_stride = 1
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=conv1_stride, bias=False
+ )
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False,
+ )
+
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.conv3 = nn.Conv2d(
+ planes, planes * self.expansion, kernel_size=1, bias=False
+ )
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ def forward(self, x):
+ def _inner_forward(x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+def make_res_layer(
+ block,
+ inplanes,
+ planes,
+ blocks,
+ stride=1,
+ dilation=1,
+ style="pytorch",
+ with_cp=False,
+):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes, planes, stride, dilation, downsample, style=style, with_cp=with_cp
+ )
+ )
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(inplanes, planes, 1, dilation, style=style, with_cp=with_cp)
+ )
+
+ return nn.Sequential(*layers)
+
+
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ num_stages (int): Resnet stages, normally 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ }
+
+ def __init__(
+ self,
+ depth,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style="pytorch",
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ with_cp=False,
+ ):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError("invalid depth {} for resnet".format(depth))
+ assert num_stages >= 1 and num_stages <= 4
+ block, stage_blocks = self.arch_settings[depth]
+ stage_blocks = stage_blocks[:num_stages]
+ assert len(strides) == len(dilations) == num_stages
+ assert max(out_indices) < num_stages
+
+ self.out_indices = out_indices
+ self.style = style
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+ self.with_cp = with_cp
+
+ self.inplanes = 64
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ planes = 64 * 2 ** i
+ res_layer = make_res_layer(
+ block,
+ self.inplanes,
+ planes,
+ num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ with_cp=with_cp,
+ )
+ self.inplanes = planes * block.expansion
+ layer_name = "layer{}".format(i + 1)
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self.feat_dim = block.expansion * 64 * 2 ** (len(stage_blocks) - 1)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError("pretrained must be a str or None")
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(ResNet, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ if mode and self.frozen_stages >= 0:
+ for param in self.conv1.parameters():
+ param.requires_grad = False
+ for param in self.bn1.parameters():
+ param.requires_grad = False
+ self.bn1.eval()
+ self.bn1.weight.requires_grad = False
+ self.bn1.bias.requires_grad = False
+ for i in range(1, self.frozen_stages + 1):
+ mod = getattr(self, "layer{}".format(i))
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/det3d/torchie/cnn/vgg.py b/det3d/torchie/cnn/vgg.py
new file mode 100644
index 0000000..21e389c
--- /dev/null
+++ b/det3d/torchie/cnn/vgg.py
@@ -0,0 +1,171 @@
+import logging
+
+import torch.nn as nn
+
+from ..trainer import load_checkpoint
+from .weight_init import constant_init, kaiming_init, normal_init
+
+
+def conv3x3(in_planes, out_planes, dilation=1):
+ "3x3 convolution with padding"
+ return nn.Conv2d(
+ in_planes, out_planes, kernel_size=3, padding=dilation, dilation=dilation
+ )
+
+
+def make_vgg_layer(
+ inplanes, planes, num_blocks, dilation=1, with_bn=False, ceil_mode=False
+):
+ layers = []
+ for _ in range(num_blocks):
+ layers.append(conv3x3(inplanes, planes, dilation))
+ if with_bn:
+ layers.append(nn.BatchNorm2d(planes))
+ layers.append(nn.ReLU(inplace=True))
+ inplanes = planes
+ layers.append(nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=ceil_mode))
+
+ return layers
+
+
+class VGG(nn.Module):
+ """VGG backbone.
+
+ Args:
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ with_bn (bool): Use BatchNorm or not.
+ num_classes (int): number of classes for classification.
+ num_stages (int): VGG stages, normally 5.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
+ running stats (mean and var).
+ bn_frozen (bool): Whether to freeze weight and bias of BN layers.
+ """
+
+ arch_settings = {
+ 11: (1, 1, 2, 2, 2),
+ 13: (2, 2, 2, 2, 2),
+ 16: (2, 2, 3, 3, 3),
+ 19: (2, 2, 4, 4, 4),
+ }
+
+ def __init__(
+ self,
+ depth,
+ with_bn=False,
+ num_classes=-1,
+ num_stages=5,
+ dilations=(1, 1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3, 4),
+ frozen_stages=-1,
+ bn_eval=True,
+ bn_frozen=False,
+ ceil_mode=False,
+ with_last_pool=True,
+ ):
+ super(VGG, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError("invalid depth {} for vgg".format(depth))
+ assert num_stages >= 1 and num_stages <= 5
+ stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ assert len(dilations) == num_stages
+ assert max(out_indices) <= num_stages
+
+ self.num_classes = num_classes
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.bn_eval = bn_eval
+ self.bn_frozen = bn_frozen
+
+ self.inplanes = 3
+ start_idx = 0
+ vgg_layers = []
+ self.range_sub_modules = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ num_modules = num_blocks * (2 + with_bn) + 1
+ end_idx = start_idx + num_modules
+ dilation = dilations[i]
+ planes = 64 * 2 ** i if i < 4 else 512
+ vgg_layer = make_vgg_layer(
+ self.inplanes,
+ planes,
+ num_blocks,
+ dilation=dilation,
+ with_bn=with_bn,
+ ceil_mode=ceil_mode,
+ )
+ vgg_layers.extend(vgg_layer)
+ self.inplanes = planes
+ self.range_sub_modules.append([start_idx, end_idx])
+ start_idx = end_idx
+ if not with_last_pool:
+ vgg_layers.pop(-1)
+ self.range_sub_modules[-1][1] -= 1
+ self.module_name = "features"
+ self.add_module(self.module_name, nn.Sequential(*vgg_layers))
+
+ if self.num_classes > 0:
+ self.classifier = nn.Sequential(
+ nn.Linear(512 * 7 * 7, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, 4096),
+ nn.ReLU(True),
+ nn.Dropout(),
+ nn.Linear(4096, num_classes),
+ )
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError("pretrained must be a str or None")
+
+ def forward(self, x):
+ outs = []
+ vgg_layers = getattr(self, self.module_name)
+ for i, num_blocks in enumerate(self.stage_blocks):
+ for j in range(*self.range_sub_modules[i]):
+ vgg_layer = vgg_layers[j]
+ x = vgg_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ if self.num_classes > 0:
+ x = x.view(x.size(0), -1)
+ x = self.classifier(x)
+ outs.append(x)
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def train(self, mode=True):
+ super(VGG, self).train(mode)
+ if self.bn_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+ if self.bn_frozen:
+ for params in m.parameters():
+ params.requires_grad = False
+ vgg_layers = getattr(self, self.module_name)
+ if mode and self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ for j in range(*self.range_sub_modules[i]):
+ mod = vgg_layers[j]
+ mod.eval()
+ for param in mod.parameters():
+ param.requires_grad = False
diff --git a/det3d/torchie/cnn/weight_init.py b/det3d/torchie/cnn/weight_init.py
new file mode 100644
index 0000000..c876d79
--- /dev/null
+++ b/det3d/torchie/cnn/weight_init.py
@@ -0,0 +1,53 @@
+import torch.nn as nn
+
+
+def constant_init(module, val, bias=0):
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, "bias") and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def xavier_init(module, gain=1, bias=0, distribution="normal"):
+ assert distribution in ["uniform", "normal"]
+ if distribution == "uniform":
+ nn.init.xavier_uniform_(module.weight, gain=gain)
+ else:
+ nn.init.xavier_normal_(module.weight, gain=gain)
+ if hasattr(module, "bias") and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def normal_init(module, mean=0, std=1, bias=0):
+ nn.init.normal_(module.weight, mean, std)
+ if hasattr(module, "bias") and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def uniform_init(module, a=0, b=1, bias=0):
+ nn.init.uniform_(module.weight, a, b)
+ if hasattr(module, "bias") and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def kaiming_init(
+ module, a=0, mode="fan_out", nonlinearity="relu", bias=0, distribution="normal"
+):
+ assert distribution in ["uniform", "normal"]
+ if distribution == "uniform":
+ nn.init.kaiming_uniform_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity
+ )
+ else:
+ nn.init.kaiming_normal_(
+ module.weight, a=a, mode=mode, nonlinearity=nonlinearity
+ )
+ if hasattr(module, "bias") and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+
+def caffe2_xavier_init(module, bias=0):
+ # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
+ # Acknowledgment to FAIR's internal code
+ kaiming_init(
+ module, a=1, mode="fan_in", nonlinearity="leaky_relu", distribution="uniform"
+ )
diff --git a/det3d/torchie/fileio/__init__.py b/det3d/torchie/fileio/__init__.py
new file mode 100644
index 0000000..a93b684
--- /dev/null
+++ b/det3d/torchie/fileio/__init__.py
@@ -0,0 +1,15 @@
+from .io import load, dump, register_handler
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+from .parse import list_from_file, dict_from_file
+
+__all__ = [
+ "load",
+ "dump",
+ "register_handler",
+ "BaseFileHandler",
+ "JsonHandler",
+ "PickleHandler",
+ "YamlHandler",
+ "list_from_file",
+ "dict_from_file",
+]
diff --git a/det3d/torchie/fileio/handlers/__init__.py b/det3d/torchie/fileio/handlers/__init__.py
new file mode 100644
index 0000000..b808563
--- /dev/null
+++ b/det3d/torchie/fileio/handlers/__init__.py
@@ -0,0 +1,6 @@
+from .base import BaseFileHandler
+from .json_handler import JsonHandler
+from .pickle_handler import PickleHandler
+from .yaml_handler import YamlHandler
+
+__all__ = ["BaseFileHandler", "JsonHandler", "PickleHandler", "YamlHandler"]
diff --git a/det3d/torchie/fileio/handlers/base.py b/det3d/torchie/fileio/handlers/base.py
new file mode 100644
index 0000000..413502d
--- /dev/null
+++ b/det3d/torchie/fileio/handlers/base.py
@@ -0,0 +1,26 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BaseFileHandler(object):
+
+ __metaclass__ = ABCMeta # python 2 compatibility
+
+ @abstractmethod
+ def load_from_fileobj(self, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ pass
+
+ @abstractmethod
+ def dump_to_str(self, obj, **kwargs):
+ pass
+
+ def load_from_path(self, filepath, mode="r", **kwargs):
+ with open(filepath, mode) as f:
+ return self.load_from_fileobj(f, **kwargs)
+
+ def dump_to_path(self, obj, filepath, mode="w", **kwargs):
+ with open(filepath, mode) as f:
+ self.dump_to_fileobj(obj, f, **kwargs)
diff --git a/det3d/torchie/fileio/handlers/json_handler.py b/det3d/torchie/fileio/handlers/json_handler.py
new file mode 100644
index 0000000..567244b
--- /dev/null
+++ b/det3d/torchie/fileio/handlers/json_handler.py
@@ -0,0 +1,14 @@
+import json
+
+from .base import BaseFileHandler
+
+
+class JsonHandler(BaseFileHandler):
+ def load_from_fileobj(self, file):
+ return json.load(file)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ json.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ return json.dumps(obj, **kwargs)
diff --git a/det3d/torchie/fileio/handlers/pickle_handler.py b/det3d/torchie/fileio/handlers/pickle_handler.py
new file mode 100644
index 0000000..001da2d
--- /dev/null
+++ b/det3d/torchie/fileio/handlers/pickle_handler.py
@@ -0,0 +1,22 @@
+from six.moves import cPickle as pickle
+
+from .base import BaseFileHandler
+
+
+class PickleHandler(BaseFileHandler):
+ def load_from_fileobj(self, file, **kwargs):
+ return pickle.load(file, **kwargs)
+
+ def load_from_path(self, filepath, **kwargs):
+ return super(PickleHandler, self).load_from_path(filepath, mode="rb", **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault("protocol", 2)
+ return pickle.dumps(obj, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault("protocol", 2)
+ pickle.dump(obj, file, **kwargs)
+
+ def dump_to_path(self, obj, filepath, **kwargs):
+ super(PickleHandler, self).dump_to_path(obj, filepath, mode="wb", **kwargs)
diff --git a/det3d/torchie/fileio/handlers/yaml_handler.py b/det3d/torchie/fileio/handlers/yaml_handler.py
new file mode 100644
index 0000000..013ae01
--- /dev/null
+++ b/det3d/torchie/fileio/handlers/yaml_handler.py
@@ -0,0 +1,22 @@
+import yaml
+
+try:
+ from yaml import CLoader as Loader, CDumper as Dumper
+except ImportError:
+ from yaml import Loader, Dumper
+
+from .base import BaseFileHandler # isort:skip
+
+
+class YamlHandler(BaseFileHandler):
+ def load_from_fileobj(self, file, **kwargs):
+ kwargs.setdefault("Loader", Loader)
+ return yaml.load(file, **kwargs)
+
+ def dump_to_fileobj(self, obj, file, **kwargs):
+ kwargs.setdefault("Dumper", Dumper)
+ yaml.dump(obj, file, **kwargs)
+
+ def dump_to_str(self, obj, **kwargs):
+ kwargs.setdefault("Dumper", Dumper)
+ return yaml.dump(obj, **kwargs)
diff --git a/det3d/torchie/fileio/io.py b/det3d/torchie/fileio/io.py
new file mode 100644
index 0000000..349d104
--- /dev/null
+++ b/det3d/torchie/fileio/io.py
@@ -0,0 +1,110 @@
+from pathlib import Path
+
+from ..utils import is_list_of, is_str
+from .handlers import BaseFileHandler, JsonHandler, PickleHandler, YamlHandler
+
+file_handlers = {
+ "json": JsonHandler(),
+ "yaml": YamlHandler(),
+ "yml": YamlHandler(),
+ "pickle": PickleHandler(),
+ "pkl": PickleHandler(),
+}
+
+
+def load(file, file_format=None, **kwargs):
+ """Load data from json/yaml/pickle files.
+
+ This method provides a unified api for loading data from serialized files.
+
+ Args:
+ file (str or :obj:`Path` or file-like object): Filename or a file-like
+ object.
+ file_format (str, optional): If not specified, the file format will be
+ inferred from the file extension, otherwise use the specified one.
+ Currently supported formats include "json", "yaml/yml" and
+ "pickle/pkl".
+
+ Returns:
+ The content from the file.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None and is_str(file):
+ file_format = file.split(".")[-1]
+ if file_format not in file_handlers:
+ raise TypeError("Unsupported format: {}".format(file_format))
+
+ handler = file_handlers[file_format]
+ if is_str(file):
+ obj = handler.load_from_path(file, **kwargs)
+ elif hasattr(file, "read"):
+ obj = handler.load_from_fileobj(file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filepath str or a file-object')
+ return obj
+
+
+def dump(obj, file=None, file_format=None, **kwargs):
+ """Dump data to json/yaml/pickle strings or files.
+
+ This method provides a unified api for dumping data as strings or to files,
+ and also supports custom arguments for each file format.
+
+ Args:
+ obj (any): The python object to be dumped.
+ file (str or :obj:`Path` or file-like object, optional): If not
+ specified, then the object is dump to a str, otherwise to a file
+ specified by the filename or file-like object.
+ file_format (str, optional): Same as :func:`load`.
+
+ Returns:
+ bool: True for success, False otherwise.
+ """
+ if isinstance(file, Path):
+ file = str(file)
+ if file_format is None:
+ if is_str(file):
+ file_format = file.split(".")[-1]
+ elif file is None:
+ raise ValueError("file_format must be specified since file is None")
+ if file_format not in file_handlers:
+ raise TypeError("Unsupported format: {}".format(file_format))
+
+ handler = file_handlers[file_format]
+ if file is None:
+ return handler.dump_to_str(obj, **kwargs)
+ elif is_str(file):
+ handler.dump_to_path(obj, file, **kwargs)
+ elif hasattr(file, "write"):
+ handler.dump_to_fileobj(obj, file, **kwargs)
+ else:
+ raise TypeError('"file" must be a filename str or a file-object')
+
+
+def _register_handler(handler, file_formats):
+ """Register a handler for some file extensions.
+
+ Args:
+ handler (:obj:`BaseFileHandler`): Handler to be registered.
+ file_formats (str or list[str]): File formats to be handled by this
+ handler.
+ """
+ if not isinstance(handler, BaseFileHandler):
+ raise TypeError(
+ "handler must be a child of BaseFileHandler, not {}".format(type(handler))
+ )
+ if isinstance(file_formats, str):
+ file_formats = [file_formats]
+ if not is_list_of(file_formats, str):
+ raise TypeError("file_formats must be a str or a list of str")
+ for ext in file_formats:
+ file_handlers[ext] = handler
+
+
+def register_handler(file_formats, **kwargs):
+ def wrap(cls):
+ _register_handler(cls(**kwargs), file_formats)
+ return cls
+
+ return wrap
diff --git a/det3d/torchie/fileio/parse.py b/det3d/torchie/fileio/parse.py
new file mode 100644
index 0000000..9fd1e8a
--- /dev/null
+++ b/det3d/torchie/fileio/parse.py
@@ -0,0 +1,50 @@
+def list_from_file(filename, prefix="", offset=0, max_num=0):
+ """Load a text file and parse the content as a list of strings.
+
+ Args:
+ filename (str): Filename.
+ prefix (str): The prefix to be inserted to the begining of each item.
+ offset (int): The offset of lines.
+ max_num (int): The maximum number of lines to be read,
+ zeros and negatives mean no limitation.
+
+ Returns:
+ list[str]: A list of strings.
+ """
+ cnt = 0
+ item_list = []
+ with open(filename, "r") as f:
+ for _ in range(offset):
+ f.readline()
+ for line in f:
+ if max_num > 0 and cnt >= max_num:
+ break
+ item_list.append(prefix + line.rstrip("\n"))
+ cnt += 1
+ return item_list
+
+
+def dict_from_file(filename, key_type=str):
+ """Load a text file and parse the content as a dict.
+
+ Each line of the text file will be two or more columns splited by
+ whitespaces or tabs. The first column will be parsed as dict keys, and
+ the following columns will be parsed as dict values.
+
+ Args:
+ filename(str): Filename.
+ key_type(type): Type of the dict's keys. str is user by default and
+ type conversion will be performed if specified.
+
+ Returns:
+ dict: The parsed contents.
+ """
+ mapping = {}
+ with open(filename, "r") as f:
+ for line in f:
+ items = line.rstrip("\n").split()
+ assert len(items) >= 2
+ key = key_type(items[0])
+ val = items[1:] if len(items) > 2 else items[1]
+ mapping[key] = val
+ return mapping
diff --git a/det3d/torchie/parallel/__init__.py b/det3d/torchie/parallel/__init__.py
new file mode 100644
index 0000000..fdc3dea
--- /dev/null
+++ b/det3d/torchie/parallel/__init__.py
@@ -0,0 +1,15 @@
+from .collate import collate, collate_kitti
+from .data_container import DataContainer
+from .data_parallel import MegDataParallel
+from .distributed import MegDistributedDataParallel
+from .scatter_gather import scatter, scatter_kwargs
+
+__all__ = [
+ "collate",
+ "collate_kitti",
+ "DataContainer",
+ "MegDataParallel",
+ "MegDistributedDataParallel",
+ "scatter",
+ "scatter_kwargs",
+]
diff --git a/det3d/torchie/parallel/_functions.py b/det3d/torchie/parallel/_functions.py
new file mode 100644
index 0000000..cd5fdf4
--- /dev/null
+++ b/det3d/torchie/parallel/_functions.py
@@ -0,0 +1,74 @@
+import torch
+from torch.nn.parallel._functions import _get_stream
+
+
+def scatter(input, devices, streams=None):
+ """Scatters tensor across multiple GPUs.
+ """
+ if streams is None:
+ streams = [None] * len(devices)
+
+ if isinstance(input, list):
+ chunk_size = (len(input) - 1) // len(devices) + 1
+ outputs = [
+ scatter(input[i], [devices[i // chunk_size]], [streams[i // chunk_size]])
+ for i in range(len(input))
+ ]
+ return outputs
+ elif isinstance(input, torch.Tensor):
+ output = input.contiguous()
+ # TODO: copy to a pinned buffer first (if copying from CPU)
+ stream = streams[0] if output.numel() > 0 else None
+ with torch.cuda.device(devices[0]), torch.cuda.stream(stream):
+ output = output.cuda(devices[0], non_blocking=True)
+ return output
+ else:
+ raise Exception("Unknown type {}.".format(type(input)))
+
+
+def synchronize_stream(output, devices, streams):
+ if isinstance(output, list):
+ chunk_size = len(output) // len(devices)
+ for i in range(len(devices)):
+ for j in range(chunk_size):
+ synchronize_stream(
+ output[i * chunk_size + j], [devices[i]], [streams[i]]
+ )
+ elif isinstance(output, torch.Tensor):
+ if output.numel() != 0:
+ with torch.cuda.device(devices[0]):
+ main_stream = torch.cuda.current_stream()
+ main_stream.wait_stream(streams[0])
+ output.record_stream(main_stream)
+ else:
+ raise Exception("Unknown type {}.".format(type(output)))
+
+
+def get_input_device(input):
+ if isinstance(input, list):
+ for item in input:
+ input_device = get_input_device(item)
+ if input_device != -1:
+ return input_device
+ return -1
+ elif isinstance(input, torch.Tensor):
+ return input.get_device() if input.is_cuda else -1
+ else:
+ raise Exception("Unknown type {}.".format(type(input)))
+
+
+class Scatter(object):
+ @staticmethod
+ def forward(target_gpus, input):
+ input_device = get_input_device(input)
+ streams = None
+ if input_device == -1:
+ # Perform CPU to GPU copies in a background stream
+ streams = [_get_stream(device) for device in target_gpus]
+
+ outputs = scatter(input, target_gpus, streams)
+ # Synchronize with the copy stream
+ if streams is not None:
+ synchronize_stream(outputs, target_gpus, streams)
+
+ return tuple(outputs)
diff --git a/det3d/torchie/parallel/collate.py b/det3d/torchie/parallel/collate.py
new file mode 100644
index 0000000..d2af2f8
--- /dev/null
+++ b/det3d/torchie/parallel/collate.py
@@ -0,0 +1,165 @@
+import collections
+from collections import defaultdict
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.utils.data.dataloader import default_collate
+
+from .data_container import DataContainer
+
+
+def collate(batch, samples_per_gpu=1):
+ """Puts each data field into a tensor/DataContainer with outer dimension
+ batch size.
+
+ Extend default_collate to add support for
+ :type:`~torchie.parallel.DataContainer`. There are 3 cases.
+
+ 1. cpu_only = True, e.g., meta data
+ 2. cpu_only = False, stack = True, e.g., images tensors
+ 3. cpu_only = False, stack = False, e.g., gt bboxes
+ """
+
+ if not isinstance(batch, collections.Sequence):
+ raise TypeError("{} is not supported.".format(batch.dtype))
+
+ if isinstance(batch[0], DataContainer):
+ assert len(batch) % samples_per_gpu == 0
+ stacked = []
+ if batch[0].cpu_only:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i : i + samples_per_gpu]]
+ )
+ return DataContainer(
+ stacked, batch[0].stack, batch[0].padding_value, cpu_only=True
+ )
+ elif batch[0].stack:
+ for i in range(0, len(batch), samples_per_gpu):
+ assert isinstance(batch[i].data, torch.Tensor)
+
+ if batch[i].pad_dims is not None:
+ ndim = batch[i].dim()
+ assert ndim > batch[i].pad_dims
+ max_shape = [0 for _ in range(batch[i].pad_dims)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = batch[i].size(-dim)
+ for sample in batch[i : i + samples_per_gpu]:
+ for dim in range(0, ndim - batch[i].pad_dims):
+ assert batch[i].size(dim) == sample.size(dim)
+ for dim in range(1, batch[i].pad_dims + 1):
+ max_shape[dim - 1] = max(
+ max_shape[dim - 1], sample.size(-dim)
+ )
+ padded_samples = []
+ for sample in batch[i : i + samples_per_gpu]:
+ pad = [0 for _ in range(batch[i].pad_dims * 2)]
+ for dim in range(1, batch[i].pad_dims + 1):
+ pad[2 * dim - 1] = max_shape[dim - 1] - sample.size(-dim)
+ padded_samples.append(
+ F.pad(sample.data, pad, value=sample.padding_value)
+ )
+ stacked.append(default_collate(padded_samples))
+ elif batch[i].pad_dims is None:
+ stacked.append(
+ default_collate(
+ [sample.data for sample in batch[i : i + samples_per_gpu]]
+ )
+ )
+ else:
+ raise ValueError("pad_dims should be either None or integers (1-3)")
+
+ else:
+ for i in range(0, len(batch), samples_per_gpu):
+ stacked.append(
+ [sample.data for sample in batch[i : i + samples_per_gpu]]
+ )
+ return DataContainer(stacked, batch[0].stack, batch[0].padding_value)
+ elif isinstance(batch[0], collections.Sequence):
+ transposed = zip(*batch)
+ return [collate(samples, samples_per_gpu) for samples in transposed]
+ elif isinstance(batch[0], collections.Mapping):
+ return {
+ key: collate([d[key] for d in batch], samples_per_gpu) for key in batch[0]
+ }
+ else:
+ return default_collate(batch)
+
+
+
+def collate_kitti(batch_list, samples_per_gpu=1):
+ example_merged = collections.defaultdict(list)
+ for example in batch_list:
+ if type(example) is list:
+ for subexample in example:
+ for k, v in subexample.items():
+ example_merged[k].append(v)
+ else:
+ for k, v in example.items():
+ example_merged[k].append(v)
+ batch_size = len(example_merged['metadata'])
+ ret = {}
+ # voxel_nums_list = example_merged["num_voxels"]
+ # example_merged.pop("num_voxels")
+ for key, elems in example_merged.items():
+ if key in ["voxels", "num_points", "num_gt", "voxel_labels", "num_voxels",
+ "cyv_voxels", "cyv_num_points", "cyv_num_voxels"]:
+ ret[key] = torch.tensor(np.concatenate(elems, axis=0))
+ elif key in [
+ "gt_boxes",
+ ]:
+ task_max_gts = []
+ for task_id in range(len(elems[0])):
+ max_gt = 0
+ for k in range(batch_size):
+ max_gt = max(max_gt, len(elems[k][task_id]))
+ task_max_gts.append(max_gt)
+ res = []
+ for idx, max_gt in enumerate(task_max_gts):
+ batch_task_gt_boxes3d = np.zeros((batch_size, max_gt, 7))
+ for i in range(batch_size):
+ batch_task_gt_boxes3d[i, : len(elems[i][idx]), :] = elems[i][idx]
+ res.append(batch_task_gt_boxes3d)
+ ret[key] = res
+ elif key == "metadata":
+ ret[key] = elems
+ elif key == "calib":
+ ret[key] = {}
+ for elem in elems:
+ for k1, v1 in elem.items():
+ if k1 not in ret[key]:
+ ret[key][k1] = [v1]
+ else:
+ ret[key][k1].append(v1)
+ for k1, v1 in ret[key].items():
+ ret[key][k1] = torch.tensor(np.stack(v1, axis=0))
+ elif key in ["points"]:
+ ret[key] = [torch.tensor(elem) for elem in elems]
+ elif key in ["multi_points"]:
+ ret['points'] = [torch.tensor(frame) for elem in elems for frame in elem]
+ elif key in ["coordinates","cyv_coordinates"]:
+ coors = []
+ for i, coor in enumerate(elems):
+ coor_pad = np.pad(
+ coor, ((0, 0), (1, 0)), mode="constant", constant_values=i
+ )
+ coors.append(coor_pad)
+ ret[key] = torch.tensor(np.concatenate(coors, axis=0))
+ elif key in ["anchors", "anchors_mask", "reg_targets", "reg_weights", "labels", "hm", "anno_box",
+ "ind", "mask", "cat","corners"]:
+
+ ret[key] = defaultdict(list)
+ res = []
+ for elem in elems:
+ for idx, ele in enumerate(elem):
+ ret[key][str(idx)].append(torch.tensor(ele))
+ for kk, vv in ret[key].items():
+ res.append(torch.stack(vv))
+ ret[key] = res
+ elif key in ['gt_boxes_and_cls','gt_boxes_mask','gt_offset','gt_grid_offset','times']:
+ ret[key] = torch.tensor(np.stack(elems, axis=0))
+ else:
+ ret[key] = np.stack(elems, axis=0)
+
+ return ret
diff --git a/det3d/torchie/parallel/data_container.py b/det3d/torchie/parallel/data_container.py
new file mode 100644
index 0000000..46632eb
--- /dev/null
+++ b/det3d/torchie/parallel/data_container.py
@@ -0,0 +1,81 @@
+import functools
+
+import torch
+
+
+def assert_tensor_type(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ if not isinstance(args[0].data, torch.Tensor):
+ raise AttributeError(
+ "{} has no attribute {} for type {}".format(
+ args[0].__class__.__name__, func.__name__, args[0].datatype
+ )
+ )
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+class DataContainer(object):
+ """A container for any type of objects.
+
+ Typically tensors will be stacked in the collate function and sliced along
+ some dimension in the scatter function. This behavior has some limitations.
+ 1. All tensors have to be the same size.
+ 2. Types are limited (numpy array or Tensor).
+
+ We design `DataContainer` and `MMDataParallel` to overcome these
+ limitations. The behavior can be either of the following.
+
+ - copy to GPU, pad all tensors to the same size and stack them
+ - copy to GPU without stacking
+ - leave the objects as is and pass it to the model
+ - pad_dims specifies the number of last few dimensions to do padding
+ """
+
+ def __init__(self, data, stack=False, padding_value=0, cpu_only=False, pad_dims=2):
+ self._data = data
+ self._cpu_only = cpu_only
+ self._stack = stack
+ self._padding_value = padding_value
+ assert pad_dims in [None, 1, 2, 3]
+ self._pad_dims = pad_dims
+
+ def __repr__(self):
+ return "{}({})".format(self.__class__.__name__, repr(self.data))
+
+ @property
+ def data(self):
+ return self._data
+
+ @property
+ def datatype(self):
+ if isinstance(self.data, torch.Tensor):
+ return self.data.type()
+ else:
+ return type(self.data)
+
+ @property
+ def cpu_only(self):
+ return self._cpu_only
+
+ @property
+ def stack(self):
+ return self._stack
+
+ @property
+ def padding_value(self):
+ return self._padding_value
+
+ @property
+ def pad_dims(self):
+ return self._pad_dims
+
+ @assert_tensor_type
+ def size(self, *args, **kwargs):
+ return self.data.size(*args, **kwargs)
+
+ @assert_tensor_type
+ def dim(self):
+ return self.data.dim()
diff --git a/det3d/torchie/parallel/data_parallel.py b/det3d/torchie/parallel/data_parallel.py
new file mode 100644
index 0000000..fe2869f
--- /dev/null
+++ b/det3d/torchie/parallel/data_parallel.py
@@ -0,0 +1,8 @@
+from torch.nn.parallel import DataParallel
+
+from .scatter_gather import scatter_kwargs
+
+
+class MegDataParallel(DataParallel):
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
diff --git a/det3d/torchie/parallel/distributed.py b/det3d/torchie/parallel/distributed.py
new file mode 100644
index 0000000..51b22d9
--- /dev/null
+++ b/det3d/torchie/parallel/distributed.py
@@ -0,0 +1,45 @@
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch._utils import _flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors
+
+from .scatter_gather import scatter_kwargs
+
+
+class MegDistributedDataParallel(nn.Module):
+ def __init__(self, module, dim=0, broadcast_buffers=True, bucket_cap_mb=25):
+ super(MegDistributedDataParallel, self).__init__()
+ self.module = module
+ self.dim = dim
+ self.broadcast_buffers = broadcast_buffers
+
+ self.broadcast_bucket_size = bucket_cap_mb * 1024 * 1024
+ self._sync_params()
+
+ def _dist_broadcast_coalesced(self, tensors, buffer_size):
+ for tensors in _take_tensors(tensors, buffer_size):
+ flat_tensors = _flatten_dense_tensors(tensors)
+ dist.broadcast(flat_tensors, 0)
+ for tensor, synced in zip(
+ tensors, _unflatten_dense_tensors(flat_tensors, tensors)
+ ):
+ tensor.copy_(synced)
+
+ def _sync_params(self):
+ module_states = list(self.module.state_dict().values())
+ if len(module_states) > 0:
+ self._dist_broadcast_coalesced(module_states, self.broadcast_bucket_size)
+ if self.broadcast_buffers:
+ if torch.__version__ < "1.0":
+ buffers = [b.data for b in self.module._all_buffers()]
+ else:
+ buffers = [b.data for b in self.module.buffers()]
+ if len(buffers) > 0:
+ self._dist_broadcast_coalesced(buffers, self.broadcast_bucket_size)
+
+ def scatter(self, inputs, kwargs, device_ids):
+ return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
+
+ def forward(self, *inputs, **kwargs):
+ inputs, kwargs = self.scatter(inputs, kwargs, [torch.cuda.current_device()])
+ return self.module(*inputs[0], **kwargs[0])
diff --git a/det3d/torchie/parallel/scatter_gather.py b/det3d/torchie/parallel/scatter_gather.py
new file mode 100644
index 0000000..1ea64d3
--- /dev/null
+++ b/det3d/torchie/parallel/scatter_gather.py
@@ -0,0 +1,54 @@
+import torch
+from torch.nn.parallel._functions import Scatter as OrigScatter
+
+from ._functions import Scatter
+from .data_container import DataContainer
+
+
+def scatter(inputs, target_gpus, dim=0):
+ """Scatter inputs to target gpus.
+
+ The only difference from original :func:`scatter` is to add support for
+ :type:`~mmcv.parallel.DataContainer`.
+ """
+
+ def scatter_map(obj):
+ if isinstance(obj, torch.Tensor):
+ return OrigScatter.apply(target_gpus, None, dim, obj)
+ if isinstance(obj, DataContainer):
+ if obj.cpu_only:
+ return obj.data
+ else:
+ return Scatter.forward(target_gpus, obj.data)
+ if isinstance(obj, tuple) and len(obj) > 0:
+ return list(zip(*map(scatter_map, obj)))
+ if isinstance(obj, list) and len(obj) > 0:
+ out = list(map(list, zip(*map(scatter_map, obj))))
+ return out
+ if isinstance(obj, dict) and len(obj) > 0:
+ out = list(map(type(obj), zip(*map(scatter_map, obj.items()))))
+ return out
+ return [obj for targets in target_gpus]
+
+ # After scatter_map is called, a scatter_map cell will exist. This cell
+ # has a reference to the actual function scatter_map, which has references
+ # to a closure that has a reference to the scatter_map cell (because the
+ # fn is recursive). To avoid this reference cycle, we set the function to
+ # None, clearing the cell
+ try:
+ return scatter_map(inputs)
+ finally:
+ scatter_map = None
+
+
+def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
+ """Scatter with support for kwargs dictionary"""
+ inputs = scatter(inputs, target_gpus, dim) if inputs else []
+ kwargs = scatter(kwargs, target_gpus, dim) if kwargs else []
+ if len(inputs) < len(kwargs):
+ inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
+ elif len(kwargs) < len(inputs):
+ kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
+ inputs = tuple(inputs)
+ kwargs = tuple(kwargs)
+ return inputs, kwargs
diff --git a/det3d/torchie/trainer/__init__.py b/det3d/torchie/trainer/__init__.py
new file mode 100644
index 0000000..25a4c23
--- /dev/null
+++ b/det3d/torchie/trainer/__init__.py
@@ -0,0 +1,60 @@
+from .checkpoint import (
+ load_checkpoint,
+ load_state_dict,
+ save_checkpoint,
+ weights_to_cpu,
+)
+from .hooks import (
+ CheckpointHook,
+ ClosureHook,
+ DisableDBSamplerHook,
+ DistSamplerSeedHook,
+ Hook,
+ IterTimerHook,
+ LoggerHook,
+ LrUpdaterHook,
+ OptimizerHook,
+ PaviLoggerHook,
+ TensorboardLoggerHook,
+ TextLoggerHook,
+)
+from .log_buffer import LogBuffer
+from .parallel_test import parallel_test
+from .priority import Priority, get_priority
+from .trainer import Trainer
+from .utils import (
+ get_dist_info,
+ get_host_info,
+ get_time_str,
+ master_only,
+ obj_from_dict,
+)
+
+__all__ = [
+ "Trainer",
+ "LogBuffer",
+ "Hook",
+ "CheckpointHook",
+ "ClosureHook",
+ "LrUpdaterHook",
+ "OptimizerHook",
+ "IterTimerHook",
+ "DisableDBSamplerHook",
+ "DistSamplerSeedHook",
+ "LoggerHook",
+ "TextLoggerHook",
+ "PaviLoggerHook",
+ "TensorboardLoggerHook",
+ "load_state_dict",
+ "load_checkpoint",
+ "weights_to_cpu",
+ "save_checkpoint",
+ "parallel_test",
+ "Priority",
+ "get_priority",
+ "get_host_info",
+ "get_dist_info",
+ "master_only",
+ "get_time_str",
+ "obj_from_dict",
+]
diff --git a/det3d/torchie/trainer/checkpoint.py b/det3d/torchie/trainer/checkpoint.py
new file mode 100644
index 0000000..728d730
--- /dev/null
+++ b/det3d/torchie/trainer/checkpoint.py
@@ -0,0 +1,216 @@
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+
+import torch
+import torchvision
+from det3d import torchie
+from terminaltables import AsciiTable
+from torch.utils import model_zoo
+
+from .utils import get_dist_info
+
+open_mmlab_model_urls = {
+ "vgg16_caffe": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth", # noqa: E501
+ "resnet50_caffe": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth", # noqa: E501
+ "resnet101_caffe": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth", # noqa: E501
+ "resnext50_32x4d": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth", # noqa: E501
+ "resnext101_32x4d": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth", # noqa: E501
+ "resnext101_64x4d": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth", # noqa: E501
+ "contrib/resnet50_gn": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth", # noqa: E501
+ "detectron/resnet50_gn": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth", # noqa: E501
+ "detectron/resnet101_gn": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth", # noqa: E501
+ "jhu/resnet50_gn_ws": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth", # noqa: E501
+ "jhu/resnet101_gn_ws": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth", # noqa: E501
+ "jhu/resnext50_32x4d_gn_ws": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth", # noqa: E501
+ "jhu/resnext101_32x4d_gn_ws": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth", # noqa: E501
+ "jhu/resnext50_32x4d_gn": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth", # noqa: E501
+ "jhu/resnext101_32x4d_gn": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth", # noqa: E501
+ "msra/hrnetv2_w18": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth", # noqa: E501
+ "msra/hrnetv2_w32": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth", # noqa: E501
+ "msra/hrnetv2_w40": "https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth", # noqa: E501
+ "bninception_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth", # noqa: E501
+ "kin400/i3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth", # noqa: E501
+ "kin400/nl3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth", # noqa: E501
+} # yapf: disable
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict into a module
+ """
+ unexpected_keys = []
+ shape_mismatch_pairs = []
+
+ own_state = module.state_dict()
+ for name, param in state_dict.items():
+ # a hacky fixed to load a new voxelnet
+ if name not in own_state:
+ unexpected_keys.append(name)
+ continue
+ if isinstance(param, torch.nn.Parameter):
+ # backwards compatibility for serialized parameters
+ param = param.data
+ if param.size() != own_state[name].size():
+ shape_mismatch_pairs.append([name, own_state[name].size(), param.size()])
+ continue
+ own_state[name].copy_(param)
+
+ all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
+
+ err_msg = []
+ if unexpected_keys:
+ err_msg.append(
+ "unexpected key in source state_dict: {}\n".format(
+ ", ".join(unexpected_keys)
+ )
+ )
+ if missing_keys:
+ err_msg.append(
+ "missing keys in source state_dict: {}\n".format(", ".join(missing_keys))
+ )
+ if shape_mismatch_pairs:
+ mismatch_info = "these keys have mismatched shape:\n"
+ header = ["key", "expected shape", "loaded shape"]
+ table_data = [header] + shape_mismatch_pairs
+ table = AsciiTable(table_data)
+ err_msg.append(mismatch_info + table.table)
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
+ err_msg = "\n".join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def load_url_dist(url):
+ """ In distributed setting, this function only download checkpoint at
+ local rank 0 """
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get("LOCAL_RANK", rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(url)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(url)
+ return checkpoint
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module("torchvision.models.{}".format(name))
+ if hasattr(_zoo, "model_urls"):
+ _urls = getattr(_zoo, "model_urls")
+ model_urls.update(_urls)
+ return model_urls
+
+
+def load_checkpoint(model, filename, map_location=None, strict=False, logger=None):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ # load checkpoint from modelzoo or file or url
+ if filename.startswith("modelzoo://"):
+ warnings.warn(
+ 'The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead'
+ )
+ model_urls = get_torchvision_models()
+ model_name = filename[11:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith("torchvision://"):
+ model_urls = get_torchvision_models()
+ model_name = filename[14:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith("open-mmlab://"):
+ model_name = filename[13:]
+ checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
+ elif filename.startswith(("http://", "https://")):
+ checkpoint = load_url_dist(filename)
+ else:
+ if not osp.isfile(filename):
+ raise IOError("{} is not a checkpoint file".format(filename))
+ checkpoint = torch.load(filename, map_location=map_location)
+ # get state_dict from checkpoint
+ if isinstance(checkpoint, OrderedDict):
+ state_dict = checkpoint
+ elif isinstance(checkpoint, dict) and "state_dict" in checkpoint:
+ state_dict = checkpoint["state_dict"]
+ else:
+ raise RuntimeError("No state_dict found in checkpoint file {}".format(filename))
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith("module."):
+ state_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()}
+ # load state_dict
+ if hasattr(model, "module"):
+ load_state_dict(model.module, state_dict, strict, logger)
+ else:
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError("meta must be a dict or None, but got {}".format(type(meta)))
+
+ torchie.mkdir_or_exist(osp.dirname(filename))
+ if hasattr(model, "module"):
+ model = model.module
+
+ checkpoint = {"meta": meta, "state_dict": weights_to_cpu(model.state_dict())}
+ if optimizer is not None:
+ checkpoint["optimizer"] = optimizer.state_dict()
+
+ torch.save(checkpoint, filename)
diff --git a/det3d/torchie/trainer/hooks/__init__.py b/det3d/torchie/trainer/hooks/__init__.py
new file mode 100644
index 0000000..f3b27ff
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/__init__.py
@@ -0,0 +1,26 @@
+from .checkpoint import CheckpointHook
+from .closure import ClosureHook
+from .hook import Hook
+from .iter_timer import IterTimerHook
+from .logger import LoggerHook, PaviLoggerHook, TensorboardLoggerHook, TextLoggerHook
+from .lr_updater import LrUpdaterHook
+from .memory import EmptyCacheHook
+from .optimizer import OptimizerHook
+from .sampler_seed import DistSamplerSeedHook
+from .disable_dbsampler import DisableDBSamplerHook
+
+__all__ = [
+ "Hook",
+ "CheckpointHook",
+ "ClosureHook",
+ "LrUpdaterHook",
+ "OptimizerHook",
+ "IterTimerHook",
+ "DisableDBSamplerHook",
+ "DistSamplerSeedHook",
+ "EmptyCacheHook",
+ "LoggerHook",
+ "TextLoggerHook",
+ "PaviLoggerHook",
+ "TensorboardLoggerHook",
+]
diff --git a/det3d/torchie/trainer/hooks/checkpoint.py b/det3d/torchie/trainer/hooks/checkpoint.py
new file mode 100644
index 0000000..10052dd
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/checkpoint.py
@@ -0,0 +1,31 @@
+from ..utils import master_only
+from .hook import Hook
+
+
+class CheckpointHook(Hook):
+ def __init__(self, interval=1, save_optimizer=True, out_dir=None, **kwargs):
+ self.interval = interval
+ self.save_optimizer = save_optimizer
+ self.out_dir = out_dir
+ self.args = kwargs
+
+ @master_only
+ def after_train_epoch(self, trainer):
+ if not self.every_n_epochs(trainer, self.interval):
+ return
+
+ if not self.out_dir:
+ self.out_dir = trainer.work_dir
+
+ trainer.save_checkpoint(
+ self.out_dir, save_optimizer=self.save_optimizer, **self.args
+ )
+
+ @master_only
+ def before_run(self, trainer):
+ if not self.out_dir:
+ self.out_dir = trainer.work_dir
+
+ trainer.save_checkpoint(
+ self.out_dir, filename_tmpl="pre_epoch_{}.pth", save_optimizer=self.save_optimizer, **self.args
+ )
\ No newline at end of file
diff --git a/det3d/torchie/trainer/hooks/closure.py b/det3d/torchie/trainer/hooks/closure.py
new file mode 100644
index 0000000..8af5421
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/closure.py
@@ -0,0 +1,8 @@
+from .hook import Hook
+
+
+class ClosureHook(Hook):
+ def __init__(self, fn_name, fn):
+ assert hasattr(self, fn_name)
+ assert callable(fn)
+ setattr(self, fn_name, fn)
diff --git a/det3d/torchie/trainer/hooks/disable_dbsampler.py b/det3d/torchie/trainer/hooks/disable_dbsampler.py
new file mode 100644
index 0000000..421c6a8
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/disable_dbsampler.py
@@ -0,0 +1,12 @@
+from .hook import Hook
+
+
+class DisableDBSamplerHook(Hook):
+ def __init__(self, disable_dbsampler_after_epoch):
+ self.disable_dbsampler_after_epoch = disable_dbsampler_after_epoch
+
+ def before_epoch(self, trainer):
+ if trainer.epoch >= self.disable_dbsampler_after_epoch:
+ for pipeline in trainer.data_loader.dataset.pipeline.transforms:
+ if "db_sampler" in dir(pipeline):
+ pipeline.db_sampler = None
diff --git a/det3d/torchie/trainer/hooks/hook.py b/det3d/torchie/trainer/hooks/hook.py
new file mode 100644
index 0000000..d4b2950
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/hook.py
@@ -0,0 +1,63 @@
+class Hook(object):
+ def before_run(self, trainer):
+ pass
+
+ def after_run(self, trainer):
+ pass
+
+ def before_epoch(self, trainer):
+ pass
+
+ def after_epoch(self, trainer):
+ pass
+
+ def before_iter(self, trainer):
+ pass
+
+ def after_iter(self, trainer):
+ pass
+
+ def after_data_to_device(self, trainer):
+ pass
+
+ def after_forward(self, trainer):
+ pass
+
+ def after_parse_loss(self, trainer):
+ pass
+
+ def before_train_epoch(self, trainer):
+ self.before_epoch(trainer)
+
+ def before_val_epoch(self, trainer):
+ self.before_epoch(trainer)
+
+ def after_train_epoch(self, trainer):
+ self.after_epoch(trainer)
+
+ def after_val_epoch(self, trainer):
+ self.after_epoch(trainer)
+
+ def before_train_iter(self, trainer):
+ self.before_iter(trainer)
+
+ def before_val_iter(self, trainer):
+ self.before_iter(trainer)
+
+ def after_train_iter(self, trainer):
+ self.after_iter(trainer)
+
+ def after_val_iter(self, trainer):
+ self.after_iter(trainer)
+
+ def every_n_epochs(self, trainer, n):
+ return (trainer.epoch + 1) % n == 0 if n > 0 else False
+
+ def every_n_iters(self, trainer, n):
+ return (trainer.iter + 1) % n == 0 if n > 0 else False
+
+ def every_n_inner_iters(self, trainer, n):
+ return (trainer.inner_iter + 1) % n == 0 if n > 0 else False
+
+ def end_of_epoch(self, trainer):
+ return trainer.inner_iter + 1 == len(trainer.data_loader)
diff --git a/det3d/torchie/trainer/hooks/iter_timer.py b/det3d/torchie/trainer/hooks/iter_timer.py
new file mode 100644
index 0000000..0b95160
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/iter_timer.py
@@ -0,0 +1,24 @@
+import time
+
+from .hook import Hook
+
+
+class IterTimerHook(Hook):
+ def before_epoch(self, runner):
+ self.t = time.time()
+
+ def before_iter(self, runner):
+ runner.log_buffer.update({"data_time": time.time() - self.t})
+
+ def after_iter(self, runner):
+ runner.log_buffer.update({"time": time.time() - self.t})
+ self.t = time.time()
+
+ def after_data_to_device(self, runner):
+ runner.log_buffer.update({"transfer_time": time.time() - self.t})
+
+ def after_forward(self, runner):
+ runner.log_buffer.update({"forward_time": time.time() - self.t})
+
+ def after_parse_loss(self, runner):
+ runner.log_buffer.update({"loss_parse_time": time.time() - self.t})
diff --git a/det3d/torchie/trainer/hooks/logger/__init__.py b/det3d/torchie/trainer/hooks/logger/__init__.py
new file mode 100644
index 0000000..6da5ee1
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/logger/__init__.py
@@ -0,0 +1,6 @@
+from .base import LoggerHook
+from .pavi import PaviLoggerHook
+from .tensorboard import TensorboardLoggerHook
+from .text import TextLoggerHook
+
+__all__ = ["LoggerHook", "TextLoggerHook", "PaviLoggerHook", "TensorboardLoggerHook"]
diff --git a/det3d/torchie/trainer/hooks/logger/base.py b/det3d/torchie/trainer/hooks/logger/base.py
new file mode 100644
index 0000000..08f5bfe
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/logger/base.py
@@ -0,0 +1,57 @@
+from abc import ABCMeta, abstractmethod
+
+from ..hook import Hook
+
+
+class LoggerHook(Hook):
+ """Base class for logger hooks
+
+ Args:
+ interval (int)
+ ignore_last (bool)
+ reset_flag (bool)
+ """
+
+ __metaclass__ = ABCMeta
+
+ def __init__(self, interval=10, ignore_last=True, reset_flag=False):
+ self.interval = interval
+ self.ignore_last = ignore_last
+ self.reset_flag = reset_flag
+
+ @abstractmethod
+ def log(self, trainer):
+ pass
+
+ def before_run(self, trainer):
+ for hook in trainer.hooks[::-1]:
+ if isinstance(hook, LoggerHook):
+ hook.reset_flag = True
+ break
+
+ def before_epoch(self, trainer):
+ trainer.log_buffer.clear()
+
+ def after_train_iter(self, trainer):
+ if self.every_n_inner_iters(trainer, self.interval):
+ trainer.log_buffer.average(self.interval)
+ elif self.end_of_epoch(trainer) and not self.ignore_last:
+ # not precise but more stable
+ trainer.log_buffer.average(self.interval)
+
+ if trainer.log_buffer.ready:
+ self.log(trainer)
+ if self.reset_flag:
+ trainer.log_buffer.clear_output()
+
+ def after_train_epoch(self, trainer):
+ if trainer.log_buffer.ready:
+ self.log(trainer)
+ if self.reset_flag:
+ trainer.log_buffer.clear_output()
+
+ def after_val_epoch(self, trainer):
+ trainer.log_buffer.average()
+ self.log(trainer)
+ if self.reset_flag:
+ trainer.log_buffer.clear_output()
diff --git a/det3d/torchie/trainer/hooks/logger/pavi.py b/det3d/torchie/trainer/hooks/logger/pavi.py
new file mode 100644
index 0000000..9d91c27
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/logger/pavi.py
@@ -0,0 +1,177 @@
+from __future__ import print_function
+
+import logging
+import os
+import os.path as osp
+import time
+from datetime import datetime
+from threading import Thread
+
+import requests
+from six.moves.queue import Empty, Queue
+
+from ...utils import get_host_info, master_only
+from .base import LoggerHook
+
+
+class PaviClient(object):
+ def __init__(self, url, username=None, password=None, instance_id=None):
+ self.url = url
+ self.username = self._get_env_var(username, "PAVI_USERNAME")
+ self.password = self._get_env_var(password, "PAVI_PASSWORD")
+ self.instance_id = instance_id
+ self.log_queue = None
+ self.logger = None
+
+ def _get_env_var(self, var, env_var):
+ if var is not None:
+ return str(var)
+
+ var = os.getenv(env_var)
+ if not var:
+ raise ValueError(
+ '"{}" is neither specified nor defined as env variables'.format(env_var)
+ )
+ return var
+
+ def _print_log(self, msg, level=logging.INFO, *args, **kwargs):
+ if self.logger is not None:
+ self.logger.log(level, msg, *args, **kwargs)
+ else:
+ print(msg, *args, **kwargs)
+
+ def connect(self, model_name, work_dir=None, info=dict(), timeout=5, logger=None):
+ if logger is not None:
+ self.logger = logger
+ self._print_log("connecting pavi service {}...".format(self.url))
+ post_data = dict(
+ time=str(datetime.now()),
+ username=self.username,
+ password=self.password,
+ instance_id=self.instance_id,
+ model=model_name,
+ work_dir=osp.abspath(work_dir) if work_dir else "",
+ session_file=info.get("session_file", ""),
+ session_text=info.get("session_text", ""),
+ model_text=info.get("model_text", ""),
+ device=get_host_info(),
+ )
+ try:
+ response = requests.post(self.url, json=post_data, timeout=timeout)
+ except Exception as ex:
+ self._print_log(
+ "fail to connect to pavi service: {}".format(ex), level=logging.ERROR
+ )
+ else:
+ if response.status_code == 200:
+ self.instance_id = response.text
+ self._print_log(
+ "pavi service connected, instance_id: {}".format(self.instance_id)
+ )
+ self.log_queue = Queue()
+ self.log_thread = Thread(target=self.post_worker_fn)
+ self.log_thread.daemon = True
+ self.log_thread.start()
+ return True
+ else:
+ self._print_log(
+ "fail to connect to pavi service, status code: "
+ "{}, err message: {}".format(response.status_code, response.reason),
+ level=logging.ERROR,
+ )
+ return False
+
+ def post_worker_fn(self, max_retry=3, queue_timeout=1, req_timeout=3):
+ while True:
+ try:
+ log = self.log_queue.get(timeout=queue_timeout)
+ except Empty:
+ time.sleep(1)
+ except Exception as ex:
+ self._print_log(
+ "fail to get logs from queue: {}".format(ex), level=logging.ERROR
+ )
+ else:
+ retry = 0
+ while retry < max_retry:
+ try:
+ response = requests.post(
+ self.url, json=log, timeout=req_timeout
+ )
+ except Exception as ex:
+ retry += 1
+ self._print_log(
+ "error when posting logs to pavi: {}".format(ex),
+ level=logging.ERROR,
+ )
+ else:
+ status_code = response.status_code
+ if status_code == 200:
+ break
+ else:
+ self._print_log(
+ "unexpected status code: {}, err msg: {}".format(
+ status_code, response.reason
+ ),
+ level=logging.ERROR,
+ )
+ retry += 1
+ if retry == max_retry:
+ self._print_log(
+ "fail to send logs of iteration {}".format(log["iter_num"]),
+ level=logging.ERROR,
+ )
+
+ def log(self, phase, iter, outputs):
+ if self.log_queue is not None:
+ logs = {
+ "time": str(datetime.now()),
+ "instance_id": self.instance_id,
+ "flow_id": phase,
+ "iter_num": iter,
+ "outputs": outputs,
+ "msg": "",
+ }
+ self.log_queue.put(logs)
+
+
+class PaviLoggerHook(LoggerHook):
+ def __init__(
+ self,
+ url,
+ username=None,
+ password=None,
+ instance_id=None,
+ config_file=None,
+ interval=10,
+ ignore_last=True,
+ reset_flag=True,
+ ):
+ self.pavi = PaviClient(url, username, password, instance_id)
+ self.config_file = config_file
+ super(PaviLoggerHook, self).__init__(interval, ignore_last, reset_flag)
+
+ def before_run(self, runner):
+ super(PaviLoggerHook, self).before_run(runner)
+ self.connect(runner)
+
+ @master_only
+ def connect(self, runner, timeout=5):
+ cfg_info = dict()
+ if self.config_file is not None:
+ with open(self.config_file, "r") as f:
+ config_text = f.read()
+ cfg_info.update(session_file=self.config_file, session_text=config_text)
+ return self.pavi.connect(
+ runner.model_name, runner.work_dir, cfg_info, timeout, runner.logger
+ )
+
+ @master_only
+ def log(self, runner):
+ log_outs = runner.log_buffer.output.copy()
+ log_outs.pop("time", None)
+ log_outs.pop("data_time", None)
+ for k, v in log_outs.items():
+ if isinstance(v, str):
+ log_outs.pop(k)
+ self.pavi.log(runner.mode, runner.iter + 1, log_outs)
diff --git a/det3d/torchie/trainer/hooks/logger/tensorboard.py b/det3d/torchie/trainer/hooks/logger/tensorboard.py
new file mode 100644
index 0000000..2ad176f
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/logger/tensorboard.py
@@ -0,0 +1,53 @@
+import os.path as osp
+
+import torch
+
+from ...utils import master_only
+from .base import LoggerHook
+
+
+class TensorboardLoggerHook(LoggerHook):
+ def __init__(self, log_dir=None, interval=10, ignore_last=True, reset_flag=True):
+ super(TensorboardLoggerHook, self).__init__(interval, ignore_last, reset_flag)
+ self.log_dir = log_dir
+
+ @master_only
+ def before_run(self, trainer):
+ if torch.__version__ >= "1.1":
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install future tensorboard" to install '
+ "the dependencies to use torch.utils.tensorboard "
+ "(applicable to PyTorch 1.1 or higher)"
+ )
+ else:
+ try:
+ from tensorboardX import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ "Please install tensorboardX to use " "TensorboardLoggerHook."
+ )
+
+ if self.log_dir is None:
+ self.log_dir = osp.join(trainer.work_dir, "tf_logs")
+ self.writer = SummaryWriter(self.log_dir)
+
+ @master_only
+ def log(self, trainer):
+ for var in trainer.log_buffer.output:
+ if var in ["time", "data_time"]:
+ continue
+ tag = "{}/{}".format(var, trainer.mode)
+ record = trainer.log_buffer.output[var]
+ if isinstance(record, str):
+ self.writer.add_text(tag, record, trainer.iter)
+ else:
+ self.writer.add_scalar(
+ tag, trainer.log_buffer.output[var], trainer.iter
+ )
+
+ @master_only
+ def after_run(self, trainer):
+ self.writer.close()
diff --git a/det3d/torchie/trainer/hooks/logger/text.py b/det3d/torchie/trainer/hooks/logger/text.py
new file mode 100644
index 0000000..32f3857
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/logger/text.py
@@ -0,0 +1,150 @@
+import datetime
+import os.path as osp
+from collections import OrderedDict
+
+import torch
+import torch.distributed as dist
+from det3d import torchie
+
+from .base import LoggerHook
+
+
+class TextLoggerHook(LoggerHook):
+ def __init__(self, interval=10, ignore_last=True, reset_flag=False):
+ super(TextLoggerHook, self).__init__(interval, ignore_last, reset_flag)
+ self.time_sec_tot = 0
+
+ def before_run(self, trainer):
+ super(TextLoggerHook, self).before_run(trainer)
+ self.start_iter = trainer.iter
+ self.json_log_path = osp.join(
+ trainer.work_dir, "{}.log.json".format(trainer.timestamp)
+ )
+
+ def _get_max_memory(self, trainer):
+ mem = torch.cuda.max_memory_allocated()
+ mem_mb = torch.tensor(
+ [mem / (1024 * 1024)], dtype=torch.int, device=torch.device("cuda")
+ )
+ if trainer.world_size > 1:
+ dist.reduce(mem_mb, 0, op=dist.ReduceOp.MAX)
+ return mem_mb.item()
+
+ def _convert_to_precision4(self, val):
+ if isinstance(val, float):
+ val = "{:.4f}".format(val)
+ elif isinstance(val, list):
+ val = [self._convert_to_precision4(v) for v in val]
+
+ return val
+
+ def _log_info(self, log_dict, trainer):
+ if trainer.mode == "train":
+ log_str = "Epoch [{}/{}][{}/{}]\tlr: {:.5f}, ".format(
+ log_dict["epoch"],
+ trainer._max_epochs,
+ log_dict["iter"],
+ len(trainer.data_loader),
+ log_dict["lr"],
+ )
+ if "time" in log_dict.keys():
+ self.time_sec_tot += log_dict["time"] * self.interval
+ time_sec_avg = self.time_sec_tot / (trainer.iter - self.start_iter + 1)
+ eta_sec = time_sec_avg * (trainer.max_iters - trainer.iter - 1)
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
+ log_str += "eta: {}, ".format(eta_str)
+ log_str += "time: {:.3f}, data_time: {:.3f}, transfer_time: {:.3f}, forward_time: {:.3f}, loss_parse_time: {:.3f} ".format(
+ log_dict["time"],
+ log_dict["data_time"],
+ log_dict["transfer_time"] - log_dict["data_time"],
+ log_dict["forward_time"] - log_dict["transfer_time"],
+ log_dict["loss_parse_time"] - log_dict["forward_time"],
+ )
+ log_str += "memory: {}, ".format(log_dict["memory"])
+ else:
+ log_str = "Epoch({}) [{}][{}]\t".format(
+ log_dict["mode"], log_dict["epoch"] - 1, log_dict["iter"]
+ )
+
+ trainer.logger.info(log_str)
+
+ if trainer.world_size > 1:
+ class_names = trainer.model.module.bbox_head.class_names
+ else:
+ class_names = trainer.model.bbox_head.class_names
+
+ for idx, task_class_names in enumerate(class_names):
+ log_items = [f"task : {task_class_names}"]
+ log_str = ""
+ for name, val in log_dict.items():
+ # TODO:
+ if name in [
+ "mode",
+ "Epoch",
+ "iter",
+ "lr",
+ "time",
+ "data_time",
+ "memory",
+ "epoch",
+ "transfer_time",
+ "forward_time",
+ "loss_parse_time",
+ ]:
+ continue
+
+ if isinstance(val, float):
+ val = "{:.4f}".format(val)
+
+ if isinstance(val, list):
+ log_items.append(
+ "{}: {}".format(name, self._convert_to_precision4(val[idx]))
+ )
+ else:
+ log_items.append("{}: {}".format(name, val))
+
+ log_str += ", ".join(log_items)
+ if idx == (len(class_names) - 1):
+ log_str += "\n"
+ trainer.logger.info(log_str)
+
+ def _dump_log(self, log_dict, trainer):
+ json_log = OrderedDict()
+ for k, v in log_dict.items():
+ json_log[k] = self._round_float(v)
+
+ if trainer.rank == 0:
+ with open(self.json_log_path, "a+") as f:
+ torchie.dump(json_log, f, file_format="json")
+ f.write("\n")
+
+ def _round_float(self, items):
+ if isinstance(items, list):
+ return [self._round_float(item) for item in items]
+ elif isinstance(items, float):
+ return round(items, 5)
+ else:
+ return items
+
+ def log(self, trainer):
+ log_dict = OrderedDict()
+ # Training mode if the output contains the key time
+ mode = "train" if "time" in trainer.log_buffer.output else "val"
+ log_dict["mode"] = mode
+ log_dict["epoch"] = trainer.epoch + 1
+ log_dict["iter"] = trainer.inner_iter + 1
+ # Only record lr of the first param group
+ log_dict["lr"] = trainer.current_lr()[0]
+ if mode == "train":
+ log_dict["time"] = trainer.log_buffer.output["time"]
+ log_dict["data_time"] = trainer.log_buffer.output["data_time"]
+ # statistic memory
+ if torch.cuda.is_available():
+ log_dict["memory"] = self._get_max_memory(trainer)
+ for name, val in trainer.log_buffer.output.items():
+ if name in ["time", "data_time"]:
+ continue
+ log_dict[name] = val
+
+ self._log_info(log_dict, trainer)
+ self._dump_log(log_dict, trainer)
diff --git a/det3d/torchie/trainer/hooks/lr_updater.py b/det3d/torchie/trainer/hooks/lr_updater.py
new file mode 100644
index 0000000..c450660
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/lr_updater.py
@@ -0,0 +1,175 @@
+from __future__ import division
+
+from math import cos, pi
+
+from det3d.solver import learning_schedules_fastai as lsf
+
+from .hook import Hook
+
+
+class LrUpdaterHook(Hook):
+ def __init__(
+ self, by_epoch=True, warmup=None, warmup_iters=0, warmup_ratio=0.1, **kwargs
+ ):
+ if warmup is not None:
+ if warmup not in ["constant", "linear", "exp"]:
+ raise ValueError(
+ '"{}" is not a supported type for warming up, valid types'
+ ' are "constant" and "linear"'.format(warmup)
+ )
+
+ if warmup is not None:
+ assert warmup_iters > 0, '"warmup_iters" must be a positive integer'
+ assert 0 < warmup_ratio <= 1.0, '"warmup_ratio" must be in range (0,1]'
+
+ self.by_epoch = by_epoch
+ self.warmup = warmup
+ self.warmup_ratio = warmup_ratio
+ self.warmup_iters = warmup_iters
+
+ self.base_lr = [] # initial lr for all param groups
+ self.regular_lr = [] # expected lr if no warming up is performed
+
+ def _set_lr(self, trainer, lr_groups):
+ for param_group, lr in zip(trainer.optimizer.param_groups, lr_groups):
+ param_group["lr"] = lr
+
+ def get_lr(self, runner, base_lr):
+ raise NotImplementedError
+
+ def get_regular_lr(self, trainer):
+ return [self.get_lr(trainer, _base_lr) for _base_lr in self.base_lr]
+
+ def get_warmup_lr(self, cur_iters):
+ if self.warmup == "constant":
+ warmup_lr = [_lr * self.warmup_ratio for _lr in self.regular_lr]
+ elif self.warmup == "linear":
+ k = (1 - cur_iters / self.warmup_iters) * (1 - self.warmup_ratio)
+ warmup_lr = [_lr * (1 - k) for _lr in self.regular_lr]
+ elif self.warmup == "exp":
+ k = self.warmup_ratio ** (1 - cur_iters / self.warmup_iters)
+ warmup_lr = [_lr * k for _lr in self.regular_lr]
+
+ return warmup_lr
+
+ def before_run(self, trainer):
+ for group in trainer.optimizer.param_groups:
+ group.setdefault("initial_lr", group["lr"])
+ self.base_lr = [group["initial_lr"] for group in trainer.optimizer.param_groups]
+
+ def before_train_epoch(self, trainer):
+ if not self.by_epoch:
+ return
+ self.regular_lr = self.get_regular_lr(trainer)
+ self._set_lr(trainer, self.regular_lr)
+
+ def before_train_iter(self, trainer):
+ cur_iter = trainer.iter
+ if not self.by_epoch:
+ self.regular_lr = self.get_regular_lr(trainer)
+ if self.warmup is None or cur_iter >= self.warmup_iters:
+ self._set_lr(trainer, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(trainer, warmup_lr)
+ elif self.by_epoch:
+ if self.warmup is None or cur_iter > self.warmup_iters:
+ return
+ elif cur_iter == self.warmup_iters:
+ self._set_lr(trainer, self.regular_lr)
+ else:
+ warmup_lr = self.get_warmup_lr(cur_iter)
+ self._set_lr(trainer, warmup_lr)
+
+
+class FixedLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, **kwargs):
+ super(FixedLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, trainer, base_lr):
+ return base_lr
+
+
+class StepLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, step, gamma=0.1, **kwargs):
+ assert isinstance(step, (list, int))
+ if isinstance(step, list):
+ for s in step:
+ assert isinstance(s, int) and s > 0
+ elif isinstance(step, int):
+ assert step > 0
+ else:
+ raise TypeError('"step" must be a list or integer')
+ self.step = step
+ self.gamma = gamma
+ super(StepLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = runner.epoch if self.by_epoch else trainer.iter
+
+ if isinstance(self.step, int):
+ return base_lr * (self.gamma ** (progress // self.step))
+
+ exp = len(self.step)
+ for i, s in enumerate(self.step):
+ if progress < s:
+ exp = i
+ break
+
+ return base_lr * self.gamma ** exp
+
+
+class ExpLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, gamma, **kwargs):
+ self.gamma = gamma
+ super(ExpLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, runner, base_lr):
+ progress = trainer.epoch if self.by_epoch else trainer.iter
+ return base_lr * self.gamma ** progress
+
+
+class PolyLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, power=1.0, min_lr=0.0, **kwargs):
+ self.power = power
+ self.min_lr = min_lr
+ super(PolyLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, trainer, base_lr):
+ if self.by_epoch:
+ progress = trainer.epoch
+ max_progress = trainer.max_epochs
+ else:
+ progress = trainer.iter
+ max_progress = trainer.max_iters
+ coeff = (1 - progress / max_progress) ** self.power
+ return (base_lr - self.min_lr) * coeff + self.min_lr
+
+
+class InvLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, gamma, power=1.0, **kwargs):
+ self.gamma = gamma
+ self.power = power
+ super(InvLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, trainer, base_lr):
+ progress = trainer.epoch if self.by_epoch else trainer.iter
+ return base_lr * (1 + self.gamma * progress) ** (-self.power)
+
+
+class CosineLrUpdaterHook(LrUpdaterHook):
+ def __init__(self, target_lr=0, **kwargs):
+ self.target_lr = target_lr
+ super(CosineLrUpdaterHook, self).__init__(**kwargs)
+
+ def get_lr(self, trainer, base_lr):
+ if self.by_epoch:
+ progress = trainer.epoch
+ max_progress = trainer.max_epochs
+ else:
+ progress = trainer.iter
+ max_progress = trainer.max_iters
+
+ return self.target_lr + 0.5 * (base_lr - self.target_lr) * (
+ 1 + cos(pi * (progress / max_progress))
+ )
diff --git a/det3d/torchie/trainer/hooks/memory.py b/det3d/torchie/trainer/hooks/memory.py
new file mode 100644
index 0000000..990f8ce
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/memory.py
@@ -0,0 +1,22 @@
+import torch
+
+from .hook import Hook
+
+
+class EmptyCacheHook(Hook):
+ def __init__(self, before_epoch=False, after_epoch=True, after_iter=False):
+ self._before_epoch = before_epoch
+ self._after_epoch = after_epoch
+ self._after_iter = after_iter
+
+ def after_iter(self, trainer):
+ if self._after_iter:
+ torch.cuda.empty_cache()
+
+ def before_epoch(self, trainer):
+ if self._before_epoch:
+ torch.cuda.empty_cache()
+
+ def after_epoch(self, trainer):
+ if self._after_epoch:
+ torch.cuda.empty_cache()
diff --git a/det3d/torchie/trainer/hooks/optimizer.py b/det3d/torchie/trainer/hooks/optimizer.py
new file mode 100644
index 0000000..9a50e62
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/optimizer.py
@@ -0,0 +1,21 @@
+from torch.nn.utils import clip_grad
+
+from .hook import Hook
+
+
+class OptimizerHook(Hook):
+ def __init__(self, grad_clip=None):
+ self.grad_clip = grad_clip
+
+ def clip_grads(self, params):
+ clip_grad.clip_grad_norm_(
+ filter(lambda p: p.requires_grad, params), **self.grad_clip
+ )
+
+ def after_train_iter(self, trainer):
+ trainer.optimizer.zero_grad()
+ # print(trainer.outputs["loss"])
+ trainer.outputs["loss"].backward()
+ if self.grad_clip is not None:
+ self.clip_grads(trainer.model.parameters())
+ trainer.optimizer.step()
diff --git a/det3d/torchie/trainer/hooks/sampler_seed.py b/det3d/torchie/trainer/hooks/sampler_seed.py
new file mode 100644
index 0000000..f3f0df3
--- /dev/null
+++ b/det3d/torchie/trainer/hooks/sampler_seed.py
@@ -0,0 +1,6 @@
+from .hook import Hook
+
+
+class DistSamplerSeedHook(Hook):
+ def before_epoch(self, trainer):
+ trainer.data_loader.sampler.set_epoch(trainer.epoch)
diff --git a/det3d/torchie/trainer/log_buffer.py b/det3d/torchie/trainer/log_buffer.py
new file mode 100644
index 0000000..694dc3d
--- /dev/null
+++ b/det3d/torchie/trainer/log_buffer.py
@@ -0,0 +1,42 @@
+from collections import OrderedDict
+
+import numpy as np
+
+
+class LogBuffer(object):
+ def __init__(self):
+ self.val_history = OrderedDict()
+ self.n_history = OrderedDict()
+ self.output = OrderedDict()
+ self.ready = False
+
+ def clear(self):
+ self.val_history.clear()
+ self.n_history.clear()
+ self.clear_output()
+
+ def clear_output(self):
+ self.output.clear()
+ self.ready = False
+
+ def update(self, vars, count=1):
+ assert isinstance(vars, dict)
+ for key, var in vars.items():
+ if key not in self.val_history:
+ self.val_history[key] = []
+ self.n_history[key] = []
+ self.val_history[key].append(var)
+ self.n_history[key].append(count)
+
+ def average(self, n=0):
+ """Average latest n values or all values"""
+ assert n >= 0
+ for key in self.val_history:
+ values = np.array(self.val_history[key][-n:])
+ nums = np.array(self.n_history[key][-n:])
+ if values.shape == nums.shape:
+ avg = np.sum(values * nums) / np.sum(nums)
+ else:
+ avg = np.mean(values, axis=0).tolist()
+ self.output[key] = avg
+ self.ready = True
diff --git a/det3d/torchie/trainer/parallel_test.py b/det3d/torchie/trainer/parallel_test.py
new file mode 100644
index 0000000..9947c27
--- /dev/null
+++ b/det3d/torchie/trainer/parallel_test.py
@@ -0,0 +1,77 @@
+import multiprocessing
+
+import torch
+from det3d import torchie
+
+from .checkpoint import load_checkpoint
+
+
+def worker_func(
+ model_cls,
+ model_kwargs,
+ checkpoint,
+ dataset,
+ data_func,
+ gpu_id,
+ idx_queue,
+ result_queue,
+):
+ model = model_cls(**model_kwargs)
+ load_checkpoint(model, checkpoint, map_location="cpu")
+ torch.cuda.set_device(gpu_id)
+ model.cuda()
+ model.eval()
+ with torch.no_grad():
+ while True:
+ idx = idx_queue.get()
+ data = dataset[idx]
+ result = model(**data_func(data, gpu_id))
+ result_queue.put((idx, result))
+
+
+def parallel_test(
+ model_cls, model_kwargs, checkpoint, dataset, data_func, gpus, workers_per_gpu=1
+):
+ """Parallel testing on multiple GPUs.
+
+ Args:
+ model_cls (type): Model class type.
+ model_kwargs (dict): Arguments to init the model.
+ checkpoint (str): Checkpoint filepath.
+ dataset (:obj:`Dataset`): The dataset to be tested.
+ data_func (callable): The function that generates model inputs.
+ gpus (list[int]): GPU ids to be used.
+ workers_per_gpu (int): Number of processes on each GPU. It is possible
+ to run multiple workers on each GPU.
+
+ Returns:
+ list: Test results.
+ """
+ ctx = multiprocessing.get_context("spawn")
+ idx_queue = ctx.Queue()
+ result_queue = ctx.Queue()
+ num_workers = len(gpus) * workers_per_gpu
+ workers = [
+ ctx.Process(
+ target=worker_func,
+ args=(
+ model_cls,
+ model_kwargs,
+ checkpoint,
+ dataset,
+ data_func,
+ gpus[i % len(gpus)],
+ idx_queue,
+ result_queue,
+ ),
+ )
+ for i in range(num_workers)
+ ]
+ for w in workers:
+ w.daemon = True
+ w.start()
+
+ for i in range(len(dataset)):
+ idx_queue.put(i)
+
+ results = [None for _ in range(len(dataset))]
diff --git a/det3d/torchie/trainer/priority.py b/det3d/torchie/trainer/priority.py
new file mode 100644
index 0000000..8daf7fb
--- /dev/null
+++ b/det3d/torchie/trainer/priority.py
@@ -0,0 +1,53 @@
+from enum import Enum
+
+
+class Priority(Enum):
+ """Hook priority levels.
+
+ +------------+------------+
+ | Level | Value |
+ +============+============+
+ | HIGHEST | 0 |
+ +------------+------------+
+ | VERY_HIGH | 10 |
+ +------------+------------+
+ | HIGH | 30 |
+ +------------+------------+
+ | NORMAL | 50 |
+ +------------+------------+
+ | LOW | 70 |
+ +------------+------------+
+ | VERY_LOW | 90 |
+ +------------+------------+
+ | LOWEST | 100 |
+ +------------+------------+
+ """
+
+ HIGHEST = 0
+ VERY_HIGH = 10
+ HIGH = 30
+ NORMAL = 50
+ LOW = 70
+ VERY_LOW = 90
+ LOWEST = 100
+
+
+def get_priority(priority):
+ """Get priority value.
+
+ Args:
+ priority (int or str or :obj:`Priority`): Priority.
+
+ Returns:
+ int: The priority value.
+ """
+ if isinstance(priority, int):
+ if priority < 0 or priority > 100:
+ raise ValueError("priority must be between 0 and 100")
+ return priority
+ elif isinstance(priority, Priority):
+ return priority.value
+ elif isinstance(priority, str):
+ return Priority[priority.upper()].value
+ else:
+ raise TypeError("priority must be an integer or Priority enum value")
diff --git a/det3d/torchie/trainer/trainer.py b/det3d/torchie/trainer/trainer.py
new file mode 100644
index 0000000..4f4918c
--- /dev/null
+++ b/det3d/torchie/trainer/trainer.py
@@ -0,0 +1,598 @@
+import logging
+import os.path as osp
+import queue
+import sys
+import threading
+import time
+from collections import OrderedDict
+
+import torch
+from det3d import torchie
+
+from . import hooks
+from .checkpoint import load_checkpoint, save_checkpoint
+from .hooks import (
+ CheckpointHook,
+ Hook,
+ IterTimerHook,
+ LrUpdaterHook,
+ OptimizerHook,
+ lr_updater,
+)
+from .log_buffer import LogBuffer
+from .priority import get_priority
+from .utils import (
+ all_gather,
+ get_dist_info,
+ get_host_info,
+ get_time_str,
+ obj_from_dict,
+ synchronize,
+)
+
+
+def example_to_device(example, device, non_blocking=False) -> dict:
+ example_torch = {}
+ float_names = ["voxels", "bev_map"]
+ for k, v in example.items():
+ if k in ["anchors", "anchors_mask", "reg_targets", "reg_weights", "labels", "hm",
+ "anno_box", "ind", "mask", 'cat','corners','points']:
+ example_torch[k] = [res.to(device, non_blocking=non_blocking) for res in v]
+ elif k in [
+ "voxels",
+ "bev_map",
+ "coordinates",
+ "num_points",
+ # "points",
+ "num_voxels",
+ "cyv_voxels",
+ "cyv_num_voxels",
+ "cyv_coordinates",
+ "cyv_num_points",
+ "gt_boxes_and_cls",
+ "gt_boxes_mask",
+ "gt_offset",
+ "times"
+ ]:
+ example_torch[k] = v.to(device, non_blocking=non_blocking)
+ elif k == "calib":
+ calib = {}
+ for k1, v1 in v.items():
+ calib[k1] = v1.to(device, non_blocking=non_blocking)
+ example_torch[k] = calib
+ else:
+ example_torch[k] = v
+
+ return example_torch
+
+
+def parse_second_losses(losses):
+
+ log_vars = OrderedDict()
+ loss = sum(losses["loss"])
+ for loss_name, loss_value in losses.items():
+ if loss_name == "loc_loss_elem":
+ log_vars[loss_name] = [[i.item() for i in j] for j in loss_value]
+ else:
+ log_vars[loss_name] = [i.item() for i in loss_value]
+
+ return loss, log_vars
+
+
+class BackgroundGenerator(threading.Thread):
+ def __init__(self, generator, max_prefetch=1):
+ threading.Thread.__init__(self)
+ self.queue = queue.Queue(max_prefetch)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def next(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ # Python 3 compatibility
+ def __next__(self):
+ return self.next()
+
+ def __iter__(self):
+ return self
+
+
+class Prefetcher(object):
+ def __init__(self, dataloader):
+ self.loader = iter(dataloader)
+ self.stream = torch.cuda.Stream()
+ self.preload()
+
+ def preload(self):
+ try:
+ self.next_input = next(self.loader)
+ except StopIteration:
+ self.next_input = None
+ return
+ with torch.cuda.stream(self.stream):
+ self.next_input = example_to_device(
+ self.next_input, torch.cuda.current_device(), non_blocking=False
+ )
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ input = self.next_input
+ self.preload()
+ return input
+
+
+class Trainer(object):
+ """ A training helper for PyTorch
+
+ Args:
+ model:
+ batch_processor:
+ optimizer:
+ workdir:
+ log_level:
+ logger:
+ """
+
+ def __init__(
+ self,
+ model,
+ batch_processor,
+ optimizer=None,
+ lr_scheduler=None,
+ work_dir=None,
+ log_level=logging.INFO,
+ logger=None,
+ **kwargs,
+ ):
+ assert callable(batch_processor)
+ self.model = model
+ self.optimizer = optimizer
+ self.lr_scheduler = lr_scheduler
+
+ self.batch_processor = batch_processor
+
+ # Create work_dir
+ if torchie.is_str(work_dir):
+ self.work_dir = osp.abspath(work_dir)
+ torchie.mkdir_or_exist(self.work_dir)
+ elif work_dir is None:
+ self.work_dir = None
+ else:
+ raise TypeError("'work_dir' must be a str or None")
+
+ # Get model name from the model class
+ if hasattr(self.model, "module"):
+ self._model_name = self.model.module.__class__.__name__
+ else:
+ self._model_name = self.model.__class__.__name__
+
+ self._rank, self._world_size = get_dist_info()
+ self.timestamp = get_time_str()
+ if logger is None:
+ self.logger = self.init_logger(work_dir, log_level)
+ else:
+ self.logger = logger
+ self.log_buffer = LogBuffer()
+
+ self.mode = None
+ self._hooks = []
+ self._epoch = 0
+ self._iter = 0
+ self._inner_iter = 0
+ self._max_epochs = 0
+ self._max_iters = 0
+
+ @property
+ def model_name(self):
+ """str: Name of the model, usually the module class name."""
+ return self._model_name
+
+ @property
+ def rank(self):
+ """int: Rank of current process. (distributed training)"""
+ return self._rank
+
+ @property
+ def world_size(self):
+ """int: Number of processes participating in the job.
+ (distributed training)"""
+ return self._world_size
+
+ @property
+ def hooks(self):
+ """list[:obj:`Hook`]: A list of registered hooks."""
+ return self._hooks
+
+ @property
+ def epoch(self):
+ """int: Current epoch."""
+ return self._epoch
+
+ @property
+ def iter(self):
+ """int: Current iteration."""
+ return self._iter
+
+ @property
+ def inner_iter(self):
+ """int: Iteration in an epoch."""
+ return self._inner_iter
+
+ @property
+ def max_epochs(self):
+ """int: Maximum training epochs."""
+ return self._max_epochs
+
+ @property
+ def max_iters(self):
+ """int: Maximum training iterations."""
+ return self._max_iters
+
+ def init_optimizer(self, optimizer):
+ """Init the optimizer
+
+ Args:
+ optimizer (dict or :obj:`~torch.optim.Optimizer`)
+
+ Returns:
+ :obj:`~torch.optim.Optimizer`
+
+ Examples:
+ >>> optimizer = dict(type='SGD', lr=0.01, momentum=0.9)
+ >>> type(runner.init_optimizer(optimizer))
+
+ """
+ if isinstance(optimizer, dict):
+ optimizer = obj_from_dict(
+ optimizer, torch.optim, dict(params=self.model.parameters())
+ )
+ elif not isinstance(optimizer, torch.optim.Optimizer):
+ raise TypeError(
+ "optimizer must be either an Optimizer object or a dict, "
+ "but got {}".format(type(optimizer))
+ )
+ return optimizer
+
+ def _add_file_handler(self, logger, filename=None, mode="w", level=logging.INFO):
+ # TODO: move this method out of runner
+ file_handler = logging.FileHandler(filename, mode)
+ file_handler.setFormatter(
+ logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
+ )
+ file_handler.setLevel(level)
+ logger.addHandler(file_handler)
+ return logger
+
+ def init_logger(self, log_dir=None, level=logging.INFO):
+ """Init the logger.
+
+ Args:
+
+ Returns:
+ :obj:`~logging.Logger`: Python logger.
+ """
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - % (message)s", level=level
+ )
+ logger = logging.getLogger(__name__)
+ if log_dir and self.rank == 0:
+ filename = "{}.log".format(self.timestamp)
+ log_file = osp.join(log_dir, filename)
+ self._add_file_handler(logger, log_file, level=level)
+ return logger
+
+ def current_lr(self):
+ if self.optimizer is None:
+ raise RuntimeError("lr is not applicable because optimizer does not exist.")
+ return [group["lr"] for group in self.optimizer.param_groups]
+
+ def register_hook(self, hook, priority="NORMAL"):
+ """Register a hook into the hook list.
+
+ Args:
+ hook (:obj:`Hook`)
+ priority (int or str or :obj:`Priority`)
+ """
+ assert isinstance(hook, Hook)
+ if hasattr(hook, "priority"):
+ raise ValueError('"priority" is a reserved attribute for hooks')
+ priority = get_priority(priority)
+ hook.priority = priority
+ # Insert the hook to a sorted list
+ inserted = False
+ for i in range(len(self._hooks) - 1, -1, -1):
+ if priority >= self._hooks[i].priority:
+ self._hooks.insert(i + 1, hook)
+ inserted = True
+ break
+ if not inserted:
+ self._hooks.insert(0, hook)
+
+ def build_hook(self, args, hook_type=None):
+ if isinstance(args, Hook):
+ return args
+ elif isinstance(args, dict):
+ assert issubclass(hook_type, Hook)
+ return hook_type(**args)
+ else:
+ raise TypeError(
+ "'args' must be either a Hook object"
+ " or dict, not {}".format(type(args))
+ )
+
+ def call_hook(self, fn_name):
+ for hook in self._hooks:
+ getattr(hook, fn_name)(self)
+
+ def load_checkpoint(self, filename, map_location="cpu", strict=False):
+ self.logger.info("load checkpoint from %s", filename)
+ return load_checkpoint(self.model, filename, map_location, strict, self.logger)
+
+ def save_checkpoint(
+ self, out_dir, filename_tmpl="epoch_{}.pth", save_optimizer=True, meta=None
+ ):
+ if meta is None:
+ meta = dict(epoch=self.epoch + 1, iter=self.iter)
+ else:
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+
+ filename = filename_tmpl.format(self.epoch + 1)
+ filepath = osp.join(out_dir, filename)
+ linkpath = osp.join(out_dir, "latest.pth")
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # Use relative symlink
+ torchie.symlink(filename, linkpath)
+
+ def batch_processor_inline(self, model, data, train_mode, **kwargs):
+
+ if "local_rank" in kwargs:
+ device = torch.device(kwargs["local_rank"])
+ else:
+ device = None
+
+ # data = example_convert_to_torch(data, device=device)
+ example = example_to_device(
+ data, torch.cuda.current_device(), non_blocking=False
+ )
+
+ self.call_hook("after_data_to_device")
+
+ if train_mode:
+ losses = model(example, return_loss=True)
+ self.call_hook("after_forward")
+ loss, log_vars = parse_second_losses(losses)
+ del losses
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=-1 # TODO: FIX THIS
+ )
+ self.call_hook("after_parse_loss")
+
+ return outputs
+ else:
+ return model(example, return_loss=False)
+
+ def train(self, data_loader, epoch, **kwargs):
+
+ self.model.train()
+ self.mode = "train"
+ self.data_loader = data_loader
+ self.length = len(data_loader)
+ self._max_iters = self._max_epochs * self.length
+ self.call_hook("before_train_epoch")
+
+ base_step = epoch * self.length
+
+ # prefetcher = Prefetcher(data_loader)
+ # for data_batch in BackgroundGenerator(data_loader, max_prefetch=3):
+ for i, data_batch in enumerate(data_loader):
+ global_step = base_step + i
+ if self.lr_scheduler is not None:
+ #print(global_step)
+ self.lr_scheduler.step(global_step)
+
+ self._inner_iter = i
+
+ self.call_hook("before_train_iter")
+
+ # outputs = self.batch_processor(self.model,
+ # data_batch,
+ # train_mode=True,
+ # **kwargs)
+ outputs = self.batch_processor_inline(
+ self.model, data_batch, train_mode=True, **kwargs
+ )
+
+ if not isinstance(outputs, dict):
+ raise TypeError("batch_processor() must return a dict")
+ if "log_vars" in outputs:
+ self.log_buffer.update(outputs["log_vars"], outputs["num_samples"])
+ self.outputs = outputs
+ self.call_hook("after_train_iter")
+ self._iter += 1
+
+ self.call_hook("after_train_epoch")
+ self._epoch += 1
+
+ def val(self, data_loader, **kwargs):
+ self.model.eval()
+ self.mode = "val"
+ self.data_loader = data_loader
+ self.call_hook("before_val_epoch")
+
+ self.logger.info(f"work dir: {self.work_dir}")
+
+ if self.rank == 0:
+ prog_bar = torchie.ProgressBar(len(data_loader.dataset))
+
+ detections = {}
+ cpu_device = torch.device("cpu")
+
+ for i, data_batch in enumerate(data_loader):
+ self._inner_iter = i
+ self.call_hook("before_val_iter")
+ with torch.no_grad():
+ outputs = self.batch_processor(
+ self.model, data_batch, train_mode=False, **kwargs
+ )
+ for output in outputs:
+ token = output["metadata"]["token"]
+ for k, v in output.items():
+ if k not in [
+ "metadata",
+ ]:
+ output[k] = v.to(cpu_device)
+ detections.update(
+ {token: output,}
+ )
+ if self.rank == 0:
+ for _ in range(self.world_size):
+ prog_bar.update()
+
+ synchronize()
+
+ all_predictions = all_gather(detections)
+
+ if self.rank != 0:
+ return
+
+ predictions = {}
+ for p in all_predictions:
+ predictions.update(p)
+
+ # torch.save(predictions, "final_predictions_debug.pkl")
+ # TODO fix evaluation module
+ result_dict, _ = self.data_loader.dataset.evaluation(
+ predictions, output_dir=self.work_dir
+ )
+
+ self.logger.info("\n")
+ for k, v in result_dict["results"].items():
+ self.logger.info(f"Evaluation {k}: {v}")
+
+ self.call_hook("after_val_epoch")
+
+ def resume(self, checkpoint, resume_optimizer=True, map_location="default"):
+ if map_location == "default":
+ checkpoint = self.load_checkpoint(
+ checkpoint , map_location='cuda:{}'.format(torch.cuda.current_device()) # TODO: FIX THIS!!
+ )
+ else:
+ checkpoint = self.load_checkpoint(checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint["meta"]["epoch"]
+ self._iter = checkpoint["meta"]["iter"]
+ if "optimizer" in checkpoint and resume_optimizer:
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
+
+ self.logger.info("resumed epoch %d, iter %d", self.epoch, self.iter)
+
+ def run(self, data_loaders, workflow, max_epochs, **kwargs):
+ """ Start running.
+
+ Args:
+ data_loaders (list[:obj:`DataLoader`])
+ workflow (list[tuple]): A list of (phase, epochs) to specify the
+ running order and epochs.
+ max_epochs (int)
+ """
+ assert isinstance(data_loaders, list)
+ assert torchie.is_list_of(workflow, tuple)
+ assert len(data_loaders) == len(workflow)
+
+ self._max_epochs = max_epochs
+ work_dir = self.work_dir if self.work_dir is not None else "NONE"
+ self.logger.info(
+ "Start running, host: %s, work_dir: %s", get_host_info(), work_dir
+ )
+ self.logger.info("workflow: %s, max: %d epochs", workflow, max_epochs)
+ self.call_hook("before_run")
+
+ while self.epoch < max_epochs:
+ for i, flow in enumerate(workflow):
+ mode, epochs = flow
+ if isinstance(mode, str):
+ if not hasattr(self, mode):
+ raise ValueError(
+ "Trainer has no method named '{}' to run an epoch".format(
+ mode
+ )
+ )
+ epoch_runner = getattr(self, mode)
+ elif callable(mode):
+ epoch_runner = mode
+ else:
+ raise TypeError(
+ "mode in workflow must be a str or "
+ "callable function not '{}'".format(type(mode))
+ )
+
+ for _ in range(epochs):
+ if mode == "train" and self.epoch >= max_epochs:
+ return
+ elif mode == "val":
+ epoch_runner(data_loaders[i], **kwargs)
+ else:
+ epoch_runner(data_loaders[i], self.epoch, **kwargs)
+
+ # time.sleep(1)
+ self.call_hook("after_run")
+
+ def register_lr_hooks(self, lr_config):
+ if isinstance(lr_config, LrUpdaterHook):
+ self.register_hook(lr_config)
+ elif isinstance(lr_config, dict):
+ assert "policy" in lr_config
+ hook_name = lr_config["policy"].title() + "LrUpdaterHook"
+ if not hasattr(lr_updater, hook_name):
+ raise ValueError('"{}" does not exist'.format(hook_name))
+ hook_cls = getattr(lr_updater, hook_name)
+ self.register_hook(hook_cls(**lr_config))
+ else:
+ raise TypeError(
+ "'lr_config' must be eigher a LrUpdaterHook object"
+ " or dict, not '{}'".format(type(lr_config))
+ )
+
+ def register_logger_hooks(self, log_config):
+ log_interval = log_config["interval"]
+ for info in log_config["hooks"]:
+ logger_hook = obj_from_dict(
+ info, hooks, default_args=dict(interval=log_interval)
+ )
+ self.register_hook(logger_hook, priority="VERY_LOW")
+
+ def register_training_hooks(
+ self, lr_config, optimizer_config=None, checkpoint_config=None, log_config=None
+ ):
+ """Register default hooks for training.
+
+ Default hooks include:
+ - LrUpdaterHook
+ - OptimizerStepperHook
+ - CheckpointSaverHook
+ - IterTimerHook
+ - LoggerHook(s)
+ """
+ if optimizer_config is None:
+ optimizer_config = {}
+ if checkpoint_config is None:
+ checkpoint_config = {}
+ if lr_config is not None:
+ assert self.lr_scheduler is None
+ self.register_lr_hooks(lr_config)
+ self.register_hook(self.build_hook(optimizer_config, OptimizerHook))
+ self.register_hook(self.build_hook(checkpoint_config, CheckpointHook))
+ self.register_hook(IterTimerHook())
+ if log_config is not None:
+ self.register_logger_hooks(log_config)
diff --git a/det3d/torchie/trainer/utils.py b/det3d/torchie/trainer/utils.py
new file mode 100644
index 0000000..e095890
--- /dev/null
+++ b/det3d/torchie/trainer/utils.py
@@ -0,0 +1,183 @@
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import functools
+import pickle
+import sys
+import time
+from getpass import getuser
+from socket import gethostname
+
+import torch
+import torch.distributed as dist
+from det3d import torchie
+
+
+def get_host_info():
+ return "{}@{}".format(getuser(), gethostname())
+
+
+def get_dist_info():
+ if torch.__version__ < "1.0":
+ initialized = dist._initialized
+ else:
+ initialized = dist.is_initialized()
+ if initialized:
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def master_only(func):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ rank, _ = get_dist_info()
+ if rank == 0:
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def get_time_str():
+ return time.strftime("%Y%m%d_%H%M%S", time.localtime())
+
+
+def obj_from_dict(info, parent=None, default_args=None):
+ """Initialize an object from dict.
+
+ The dict must contain the key "type", which indicates the object type
+
+ Args:
+ info (dict): Object types and arguments
+ parent (:class:`modules`):
+ default_args (dict, optional):
+ """
+ assert isinstance(info, dict) and "type" in info
+ assert isinstance(default_args, dict) or default_args is None
+ args = info.copy()
+ obj_type = args.pop("type")
+ if torchie.is_str(obj_type):
+ if parent is not None:
+ obj_type = getattr(parent, obj_type)
+ else:
+ obj_type = sys.modules[obj_type]
+ elif not isinstance(obj_type, type):
+ raise TypeError(
+ "type must be a str or valid type, but got {}".format(type(obj_type))
+ )
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+ return obj_type(**args)
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.IntTensor([tensor.numel()]).to("cuda")
+ size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/det3d/torchie/utils/__init__.py b/det3d/torchie/utils/__init__.py
new file mode 100644
index 0000000..3375e4a
--- /dev/null
+++ b/det3d/torchie/utils/__init__.py
@@ -0,0 +1,62 @@
+from .config import Config, ConfigDict
+from .misc import (
+ check_prerequisites,
+ concat_list,
+ is_list_of,
+ is_seq_of,
+ is_str,
+ is_tuple_of,
+ iter_cast,
+ list_cast,
+ requires_executable,
+ requires_package,
+ slice_list,
+ tuple_cast,
+)
+from .path import (
+ FileNotFoundError,
+ check_file_exist,
+ fopen,
+ is_filepath,
+ mkdir_or_exist,
+ scandir,
+ symlink,
+)
+from .progressbar import (
+ ProgressBar,
+ track_iter_progress,
+ track_parallel_progress,
+ track_progress,
+)
+from .timer import Timer, TimerError, check_time
+
+__all__ = [
+ "ConfigDict",
+ "Config",
+ "is_str",
+ "iter_cast",
+ "list_cast",
+ "tuple_cast",
+ "is_seq_of",
+ "is_list_of",
+ "is_tuple_of",
+ "slice_list",
+ "concat_list",
+ "check_prerequisites",
+ "requires_package",
+ "requires_executable",
+ "is_filepath",
+ "fopen",
+ "check_file_exist",
+ "mkdir_or_exist",
+ "symlink",
+ "scandir",
+ "FileNotFoundError",
+ "ProgressBar",
+ "track_progress",
+ "track_iter_progress",
+ "track_parallel_progress",
+ "Timer",
+ "TimerError",
+ "check_time",
+]
diff --git a/det3d/torchie/utils/config.py b/det3d/torchie/utils/config.py
new file mode 100644
index 0000000..97c6ef2
--- /dev/null
+++ b/det3d/torchie/utils/config.py
@@ -0,0 +1,162 @@
+import os.path as osp
+import sys
+from argparse import ArgumentParser
+from importlib import import_module
+
+from addict import Dict
+
+from .misc import collections_abc
+from .path import check_file_exist
+
+
+class ConfigDict(Dict):
+ def __missing__(self, name):
+ raise KeyError(name)
+
+ def __getattr__(self, name):
+ try:
+ value = super(ConfigDict, self).__getattr__(name)
+ except KeyError:
+ ex = AttributeError(
+ "'{}' object has no attribute '{}'".format(
+ self.__class__.__name__, name
+ )
+ )
+ except Exception as e:
+ ex = e
+ else:
+ return value
+ raise ex
+
+
+def add_args(parser, cfg, prefix=""):
+ for k, v in cfg.items():
+ if isinstance(v, str):
+ parser.add_argument("--" + prefix + k)
+ elif isinstance(v, int):
+ parser.add_argument("--" + prefix + k, type=int)
+ elif isinstance(v, float):
+ parser.add_argument("--" + prefix + k, type=float)
+ elif isinstance(v, bool):
+ parser.add_argument("--" + prefix + k, action="store_true")
+ elif isinstance(v, dict):
+ add_args(parser, v, k + ".")
+ elif isinstance(v, collections_abc.Iterable):
+ parser.add_argument("--" + prefix + k, type=type(v[0]), nargs="+")
+ else:
+ print("connot parse key {} of type {}".format(prefix + k, type(v)))
+ return parser
+
+
+class Config(object):
+ """A facility for config and config files.
+
+ It supports common file formats as configs: python/json/yaml. The interface
+ is the same as a dict object and also allows access config values as
+ attributes.
+
+ Example:
+ >>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
+ >>> cfg.a
+ 1
+ >>> cfg.b
+ {'b1': [0, 1]}
+ >>> cfg.b.b1
+ [0, 1]
+ >>> cfg = Config.fromfile('tests/data/config/a.py')
+ >>> cfg.filename
+ "/home/kchen/projects/torchie/tests/data/config/a.py"
+ >>> cfg.item4
+ 'test'
+ >>> cfg
+ "Config [path: /home/kchen/projects/torchie/tests/data/config/a.py]: "
+ "{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
+
+ """
+
+ @staticmethod
+ def fromfile(filename):
+ filename = osp.abspath(osp.expanduser(filename))
+ check_file_exist(filename)
+ if filename.endswith(".py"):
+ module_name = osp.basename(filename)[:-3]
+ if "." in module_name:
+ raise ValueError("Dots are not allowed in config file path.")
+ config_dir = osp.dirname(filename)
+ sys.path.insert(0, config_dir)
+ mod = import_module(module_name)
+ sys.path.pop(0)
+ cfg_dict = {
+ name: value
+ for name, value in mod.__dict__.items()
+ if not name.startswith("__")
+ }
+ elif filename.endswith((".yml", ".yaml", ".json")):
+ import torchie
+
+ cfg_dict = torchie.load(filename)
+ else:
+ raise IOError("Only py/yml/yaml/json type are supported now!")
+ return Config(cfg_dict, filename=filename)
+
+ @staticmethod
+ def auto_argparser(description=None):
+ """Generate argparser from config file automatically (experimental)
+ """
+ partial_parser = ArgumentParser(description=description)
+ partial_parser.add_argument("config", help="config file path")
+ cfg_file = partial_parser.parse_known_args()[0].config
+ cfg = Config.fromfile(cfg_file)
+ parser = ArgumentParser(description=description)
+ parser.add_argument("config", help="config file path")
+ add_args(parser, cfg)
+ return parser, cfg
+
+ def __init__(self, cfg_dict=None, filename=None):
+ if cfg_dict is None:
+ cfg_dict = dict()
+ elif not isinstance(cfg_dict, dict):
+ raise TypeError(
+ "cfg_dict must be a dict, but got {}".format(type(cfg_dict))
+ )
+
+ super(Config, self).__setattr__("_cfg_dict", ConfigDict(cfg_dict))
+ super(Config, self).__setattr__("_filename", filename)
+ if filename:
+ with open(filename, "r") as f:
+ super(Config, self).__setattr__("_text", f.read())
+ else:
+ super(Config, self).__setattr__("_text", "")
+
+ @property
+ def filename(self):
+ return self._filename
+
+ @property
+ def text(self):
+ return self._text
+
+ def __repr__(self):
+ return "Config (path: {}): {}".format(self.filename, self._cfg_dict.__repr__())
+
+ def __len__(self):
+ return len(self._cfg_dict)
+
+ def __getattr__(self, name):
+ return getattr(self._cfg_dict, name)
+
+ def __getitem__(self, name):
+ return self._cfg_dict.__getitem__(name)
+
+ def __setattr__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setattr__(name, value)
+
+ def __setitem__(self, name, value):
+ if isinstance(value, dict):
+ value = ConfigDict(value)
+ self._cfg_dict.__setitem__(name, value)
+
+ def __iter__(self):
+ return iter(self._cfg_dict)
diff --git a/det3d/torchie/utils/misc.py b/det3d/torchie/utils/misc.py
new file mode 100644
index 0000000..5526daf
--- /dev/null
+++ b/det3d/torchie/utils/misc.py
@@ -0,0 +1,221 @@
+import collections
+import functools
+import itertools
+import subprocess
+from importlib import import_module
+
+import six
+
+# ABCs from collections will be deprecated in python 3.8+,
+# while collections.abc is not available in python 2.7
+try:
+ import collections.abc as collections_abc
+except ImportError:
+ import collections as collections_abc
+
+
+def is_str(x):
+ """Whether the input is an string instance."""
+ return isinstance(x, six.string_types)
+
+
+def iter_cast(inputs, dst_type, return_type=None):
+ """Cast elements of an iterable object into some type.
+
+ Args:
+ inputs (Iterable): The input object.
+ dst_type (type): Destination type.
+ return_type (type, optional): If specified, the output object will be
+ converted to this type, otherwise an iterator.
+
+ Returns:
+ iterator or specified type: The converted object.
+ """
+ if not isinstance(inputs, collections_abc.Iterable):
+ raise TypeError("inputs must be an iterable object")
+ if not isinstance(dst_type, type):
+ raise TypeError('"dst_type" must be a valid type')
+
+ out_iterable = six.moves.map(dst_type, inputs)
+
+ if return_type is None:
+ return out_iterable
+ else:
+ return return_type(out_iterable)
+
+
+def list_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a list of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=list)
+
+
+def tuple_cast(inputs, dst_type):
+ """Cast elements of an iterable object into a tuple of some type.
+
+ A partial method of :func:`iter_cast`.
+ """
+ return iter_cast(inputs, dst_type, return_type=tuple)
+
+
+def is_seq_of(seq, expected_type, seq_type=None):
+ """Check whether it is a sequence of some type.
+
+ Args:
+ seq (Sequence): The sequence to be checked.
+ expected_type (type): Expected type of sequence items.
+ seq_type (type, optional): Expected sequence type.
+
+ Returns:
+ bool: Whether the sequence is valid.
+ """
+ if seq_type is None:
+ exp_seq_type = collections_abc.Sequence
+ else:
+ assert isinstance(seq_type, type)
+ exp_seq_type = seq_type
+ if not isinstance(seq, exp_seq_type):
+ return False
+ for item in seq:
+ if not isinstance(item, expected_type):
+ return False
+ return True
+
+
+def is_list_of(seq, expected_type):
+ """Check whether it is a list of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=list)
+
+
+def is_tuple_of(seq, expected_type):
+ """Check whether it is a tuple of some type.
+
+ A partial method of :func:`is_seq_of`.
+ """
+ return is_seq_of(seq, expected_type, seq_type=tuple)
+
+
+def slice_list(in_list, lens):
+ """Slice a list into several sub lists by a list of given length.
+
+ Args:
+ in_list (list): The list to be sliced.
+ lens(int or list): The expected length of each out list.
+
+ Returns:
+ list: A list of sliced list.
+ """
+ if not isinstance(lens, list):
+ raise TypeError('"indices" must be a list of integers')
+ elif sum(lens) != len(in_list):
+ raise ValueError(
+ "sum of lens and list length does not match: {} != {}".format(
+ sum(lens), len(in_list)
+ )
+ )
+ out_list = []
+ idx = 0
+ for i in range(len(lens)):
+ out_list.append(in_list[idx : idx + lens[i]])
+ idx += lens[i]
+ return out_list
+
+
+def concat_list(in_list):
+ """Concatenate a list of list into a single list.
+
+ Args:
+ in_list (list): The list of list to be merged.
+
+ Returns:
+ list: The concatenated flat list.
+ """
+ return list(itertools.chain(*in_list))
+
+
+def check_prerequisites(
+ prerequisites,
+ checker,
+ msg_tmpl='Prerequisites "{}" are required in method "{}" but not '
+ "found, please install them first.",
+):
+ """A decorator factory to check if prerequisites are satisfied.
+
+ Args:
+ prerequisites (str of list[str]): Prerequisites to be checked.
+ checker (callable): The checker method that returns True if a
+ prerequisite is meet, False otherwise.
+ msg_tmpl (str): The message template with two variables.
+
+ Returns:
+ decorator: A specific decorator.
+ """
+
+ def wrap(func):
+ @functools.wraps(func)
+ def wrapped_func(*args, **kwargs):
+ requirements = (
+ [prerequisites] if isinstance(prerequisites, str) else prerequisites
+ )
+ missing = []
+ for item in requirements:
+ if not checker(item):
+ missing.append(item)
+ if missing:
+ print(msg_tmpl.format(", ".join(missing), func.__name__))
+ raise RuntimeError("Prerequisites not meet.")
+ else:
+ return func(*args, **kwargs)
+
+ return wrapped_func
+
+ return wrap
+
+
+def _check_py_package(package):
+ try:
+ import_module(package)
+ except ImportError:
+ return False
+ else:
+ return True
+
+
+def _check_executable(cmd):
+ if subprocess.call("which {}".format(cmd), shell=True) != 0:
+ return False
+ else:
+ return True
+
+
+def requires_package(prerequisites):
+ """A decorator to check if some python packages are installed.
+
+ Example:
+ >>> @requires_package('numpy')
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ array([0.])
+ >>> @requires_package(['numpy', 'non_package'])
+ >>> func(arg1, args):
+ >>> return numpy.zeros(1)
+ ImportError
+ """
+ return check_prerequisites(prerequisites, checker=_check_py_package)
+
+
+def requires_executable(prerequisites):
+ """A decorator to check if some executable files are installed.
+
+ Example:
+ >>> @requires_executable('ffmpeg')
+ >>> func(arg1, args):
+ >>> print(1)
+ 1
+ """
+ return check_prerequisites(prerequisites, checker=_check_executable)
diff --git a/det3d/torchie/utils/path.py b/det3d/torchie/utils/path.py
new file mode 100644
index 0000000..6722f48
--- /dev/null
+++ b/det3d/torchie/utils/path.py
@@ -0,0 +1,79 @@
+import os
+import os.path as osp
+import sys
+from pathlib import Path
+
+import six
+
+from .misc import is_str
+
+if sys.version_info <= (3, 3):
+ FileNotFoundError = IOError
+else:
+ FileNotFoundError = FileNotFoundError
+
+
+def is_filepath(x):
+ if is_str(x) or isinstance(x, Path):
+ return True
+ else:
+ return False
+
+
+def fopen(filepath, *args, **kwargs):
+ if is_str(filepath):
+ return open(filepath, *args, **kwargs)
+ elif isinstance(filepath, Path):
+ return filepath.open(*args, **kwargs)
+
+
+def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
+ if not osp.isfile(filename):
+ raise FileNotFoundError(msg_tmpl.format(filename))
+
+
+def mkdir_or_exist(dir_name, mode=0o777):
+ if dir_name == "":
+ return
+ dir_name = osp.expanduser(dir_name)
+ if six.PY3:
+ os.makedirs(dir_name, mode=mode, exist_ok=True)
+ else:
+ if not osp.isdir(dir_name):
+ os.makedirs(dir_name, mode=mode)
+
+
+def symlink(src, dst, overwrite=True, **kwargs):
+ if os.path.lexists(dst) and overwrite:
+ os.remove(dst)
+ os.symlink(src, dst, **kwargs)
+
+
+def _scandir_py35(dir_path, suffix=None):
+ for entry in os.scandir(dir_path):
+ if not entry.is_file():
+ continue
+ filename = entry.name
+ if suffix is None:
+ yield filename
+ elif filename.endswith(suffix):
+ yield filename
+
+
+def _scandir_py(dir_path, suffix=None):
+ for filename in os.listdir(dir_path):
+ if not osp.isfile(osp.join(dir_path, filename)):
+ continue
+ if suffix is None:
+ yield filename
+ elif filename.endswith(suffix):
+ yield filename
+
+
+def scandir(dir_path, suffix=None):
+ if suffix is not None and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+ if sys.version_info >= (3, 5):
+ return _scandir_py35(dir_path, suffix)
+ else:
+ return _scandir_py(dir_path, suffix)
diff --git a/det3d/torchie/utils/progressbar.py b/det3d/torchie/utils/progressbar.py
new file mode 100644
index 0000000..a572449
--- /dev/null
+++ b/det3d/torchie/utils/progressbar.py
@@ -0,0 +1,216 @@
+import sys
+from multiprocessing import Pool
+
+from .misc import collections_abc
+from .timer import Timer
+
+
+class ProgressBar(object):
+ """A progress bar which can print the progress"""
+
+ def __init__(self, task_num=0, bar_width=50, start=True):
+ self.task_num = task_num
+ max_bar_width = self._get_max_bar_width()
+ self.bar_width = bar_width if bar_width <= max_bar_width else max_bar_width
+ self.completed = 0
+ if start:
+ self.start()
+
+ def _get_max_bar_width(self):
+ if sys.version_info > (3, 3):
+ from shutil import get_terminal_size
+ else:
+ from backports.shutil_get_terminal_size import get_terminal_size
+ terminal_width, _ = get_terminal_size()
+ max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
+ if max_bar_width < 10:
+ print(
+ "terminal width is too small ({}), please consider "
+ "widen the terminal for better progressbar "
+ "visualization".format(terminal_width)
+ )
+ max_bar_width = 10
+ return max_bar_width
+
+ def start(self):
+ if self.task_num > 0:
+ sys.stdout.write(
+ "[{}] 0/{}, elapsed: 0s, ETA:".format(
+ " " * self.bar_width, self.task_num
+ )
+ )
+ else:
+ sys.stdout.write("completed: 0, elapsed: 0s")
+ sys.stdout.flush()
+ self.timer = Timer()
+
+ def update(self):
+ self.completed += 1
+ elapsed = self.timer.since_start()
+ fps = self.completed / elapsed
+ if self.task_num > 0:
+ percentage = self.completed / float(self.task_num)
+ eta = int(elapsed * (1 - percentage) / percentage + 0.5)
+ mark_width = int(self.bar_width * percentage)
+ bar_chars = ">" * mark_width + " " * (self.bar_width - mark_width)
+ sys.stdout.write(
+ "\r[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s".format(
+ bar_chars,
+ self.completed,
+ self.task_num,
+ fps,
+ int(elapsed + 0.5),
+ eta,
+ )
+ )
+ else:
+ sys.stdout.write(
+ "completed: {}, elapsed: {}s, {:.1f} tasks/s".format(
+ self.completed, int(elapsed + 0.5), fps
+ )
+ )
+ sys.stdout.flush()
+
+
+def track_progress(func, tasks, bar_width=50, **kwargs):
+ """Track the progress of tasks execution with a progress bar.
+
+ Tasks are done with a simple for-loop.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], collections_abc.Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, collections_abc.Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError('"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width)
+ results = []
+ for task in tasks:
+ results.append(func(task, **kwargs))
+ prog_bar.update()
+ sys.stdout.write("\n")
+ return results
+
+
+def init_pool(process_num, initializer=None, initargs=None):
+ if initializer is None:
+ return Pool(process_num)
+ elif initargs is None:
+ return Pool(process_num, initializer)
+ else:
+ if not isinstance(initargs, tuple):
+ raise TypeError('"initargs" must be a tuple')
+ return Pool(process_num, initializer, initargs)
+
+
+def track_parallel_progress(
+ func,
+ tasks,
+ nproc,
+ initializer=None,
+ initargs=None,
+ bar_width=50,
+ chunksize=1,
+ skip_first=False,
+ keep_order=True,
+):
+ """Track the progress of parallel task execution with a progress bar.
+
+ The built-in :mod:`multiprocessing` module is used for process pools and
+ tasks are done with :func:`Pool.map` or :func:`Pool.imap_unordered`.
+
+ Args:
+ func (callable): The function to be applied to each task.
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ nproc (int): Process (worker) number.
+ initializer (None or callable): Refer to :class:`multiprocessing.Pool`
+ for details.
+ initargs (None or tuple): Refer to :class:`multiprocessing.Pool` for
+ details.
+ chunksize (int): Refer to :class:`multiprocessing.Pool` for details.
+ bar_width (int): Width of progress bar.
+ skip_first (bool): Whether to skip the first sample for each worker
+ when estimating fps, since the initialization step may takes
+ longer.
+ keep_order (bool): If True, :func:`Pool.imap` is used, otherwise
+ :func:`Pool.imap_unordered` is used.
+
+ Returns:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], collections_abc.Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, collections_abc.Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError('"tasks" must be an iterable object or a (iterator, int) tuple')
+ pool = init_pool(nproc, initializer, initargs)
+ start = not skip_first
+ task_num -= nproc * chunksize * int(skip_first)
+ prog_bar = ProgressBar(task_num, bar_width, start)
+ results = []
+ if keep_order:
+ gen = pool.imap(func, tasks, chunksize)
+ else:
+ gen = pool.imap_unordered(func, tasks, chunksize)
+ for result in gen:
+ results.append(result)
+ if skip_first:
+ if len(results) < nproc * chunksize:
+ continue
+ elif len(results) == nproc * chunksize:
+ prog_bar.start()
+ continue
+ prog_bar.update()
+ sys.stdout.write("\n")
+ pool.close()
+ pool.join()
+ return results
+
+
+def track_iter_progress(tasks, bar_width=50, **kwargs):
+ """Track the progress of tasks iteration or enumeration with a progress bar.
+
+ Tasks are yielded with a simple for-loop.
+
+ Args:
+ tasks (list or tuple[Iterable, int]): A list of tasks or
+ (tasks, total num).
+ bar_width (int): Width of progress bar.
+
+ Yields:
+ list: The task results.
+ """
+ if isinstance(tasks, tuple):
+ assert len(tasks) == 2
+ assert isinstance(tasks[0], collections_abc.Iterable)
+ assert isinstance(tasks[1], int)
+ task_num = tasks[1]
+ tasks = tasks[0]
+ elif isinstance(tasks, collections_abc.Iterable):
+ task_num = len(tasks)
+ else:
+ raise TypeError('"tasks" must be an iterable object or a (iterator, int) tuple')
+ prog_bar = ProgressBar(task_num, bar_width)
+ for task in tasks:
+ yield task
+ prog_bar.update()
+ sys.stdout.write("\n")
diff --git a/det3d/torchie/utils/timer.py b/det3d/torchie/utils/timer.py
new file mode 100644
index 0000000..f562937
--- /dev/null
+++ b/det3d/torchie/utils/timer.py
@@ -0,0 +1,116 @@
+from time import time
+
+
+class TimerError(Exception):
+ def __init__(self, message):
+ self.message = message
+ super(TimerError, self).__init__(message)
+
+
+class Timer(object):
+ """A flexible Timer class.
+
+ :Example:
+
+ >>> import time
+ >>> import mmcv
+ >>> with mmcv.Timer():
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ 1.000
+ >>> with mmcv.Timer(print_tmpl='it takes {:.1f} seconds'):
+ >>> # simulate a code block that will run for 1s
+ >>> time.sleep(1)
+ it takes 1.0 seconds
+ >>> timer = mmcv.Timer()
+ >>> time.sleep(0.5)
+ >>> print(timer.since_start())
+ 0.500
+ >>> time.sleep(0.5)
+ >>> print(timer.since_last_check())
+ 0.500
+ >>> print(timer.since_start())
+ 1.000
+ """
+
+ def __init__(self, start=True, print_tmpl=None):
+ self._is_running = False
+ self.print_tmpl = print_tmpl if print_tmpl else "{:.3f}"
+ if start:
+ self.start()
+
+ @property
+ def is_running(self):
+ """bool: indicate whether the timer is running"""
+ return self._is_running
+
+ def __enter__(self):
+ self.start()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ print(self.print_tmpl.format(self.since_last_check()))
+ self._is_running = False
+
+ def start(self):
+ """Start the timer."""
+ if not self._is_running:
+ self._t_start = time()
+ self._is_running = True
+ self._t_last = time()
+
+ def since_start(self):
+ """Total time since the timer is started.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError("timer is not running")
+ self._t_last = time()
+ return self._t_last - self._t_start
+
+ def since_last_check(self):
+ """Time since the last checking.
+
+ Either :func:`since_start` or :func:`since_last_check` is a checking
+ operation.
+
+ Returns (float): Time in seconds.
+ """
+ if not self._is_running:
+ raise TimerError("timer is not running")
+ dur = time() - self._t_last
+ self._t_last = time()
+ return dur
+
+
+_g_timers = {} # global timers
+
+
+def check_time(timer_id):
+ """Add check points in a single line.
+
+ This method is suitable for running a task on a list of items. A timer will
+ be registered when the method is called for the first time.
+
+ :Example:
+
+ >>> import time
+ >>> import mmcv
+ >>> for i in range(1, 6):
+ >>> # simulate a code block
+ >>> time.sleep(i)
+ >>> mmcv.check_time('task1')
+ 2.000
+ 3.000
+ 4.000
+ 5.000
+
+ Args:
+ timer_id (str): Timer identifier.
+ """
+ if timer_id not in _g_timers:
+ _g_timers[timer_id] = Timer()
+ return 0
+ else:
+ return _g_timers[timer_id].since_last_check()
diff --git a/det3d/utils/__init__.py b/det3d/utils/__init__.py
new file mode 100644
index 0000000..499c4f9
--- /dev/null
+++ b/det3d/utils/__init__.py
@@ -0,0 +1,4 @@
+from .flops_counter import get_model_complexity_info
+from .registry import Registry, build_from_cfg
+
+__all__ = ["Registry", "build_from_cfg", "get_model_complexity_info"]
diff --git a/det3d/utils/buildtools/__init__.py b/det3d/utils/buildtools/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/det3d/utils/buildtools/command.py b/det3d/utils/buildtools/command.py
new file mode 100644
index 0000000..9d5ed56
--- /dev/null
+++ b/det3d/utils/buildtools/command.py
@@ -0,0 +1,292 @@
+import multiprocessing
+import os
+import re
+import subprocess
+from concurrent.futures import ProcessPoolExecutor
+from enum import Enum
+from functools import partial
+from pathlib import Path
+
+import fire
+from det3d.utils.find import find_cuda, find_cuda_device_arch
+
+
+class Gpp:
+ def __init__(
+ self,
+ sources,
+ target,
+ std="c++11",
+ includes: list = None,
+ defines: dict = None,
+ cflags: str = None,
+ compiler="g++",
+ link=False,
+ libraries: dict = None,
+ lflags: str = None,
+ extra_cflags: str = None,
+ extra_lflags: str = None,
+ build_directory: str = None,
+ ):
+ if not isinstance(sources, (list, tuple)):
+ sources = [sources]
+ if build_directory is not None:
+ build_directory = Path(build_directory)
+ new_sources = []
+ for p in sources:
+ if not Path(p).is_absolute():
+ new_sources.append(str(build_directory / p))
+ else:
+ new_sources.append(p)
+ sources = new_sources
+ target = Path(target)
+ if not target.is_absolute():
+ target = str(build_directory / target)
+ self.sources = [str(p) for p in sources]
+ self.target = str(target)
+ self.std = std
+ self.includes = includes or []
+ self.cflags = cflags or "-fPIC -O3"
+ self.defines = defines or {}
+ self.compiler = compiler
+ self.link = link
+ self.libraries = libraries or {}
+ self.lflags = lflags or ""
+ self.extra_cflags = extra_cflags or ""
+ self.extra_lflags = extra_lflags or ""
+
+ def shell(self, target: str = None, compiler: str = None):
+ defines = [f"-D {n}={v}" for n, v in self.defines.items()]
+ includes = [f"-I{inc}" for inc in self.includes]
+ libraries = [
+ f"-L{n} {' '.join(['-l' + l for l in v])}"
+ for n, v in self.libraries.items()
+ ]
+ compiler = compiler or self.compiler
+ string = f"{compiler} -std={self.std} "
+ if self.link:
+ string += " -shared "
+ else:
+ string += " -c "
+ target = target or self.target
+ string += (
+ f"-o {target} {' '.join(self.sources)} "
+ f"{' '.join(defines)} "
+ f"{' '.join(includes)} "
+ f"{self.cflags} {self.extra_cflags}"
+ f"{' '.join(libraries)} "
+ f"{self.lflags} {self.extra_lflags}"
+ )
+ return re.sub(r" +", r" ", string)
+
+
+class Link:
+ def __init__(self, outs, target, compiler="ld", build_directory: str = None):
+ if not isinstance(outs, (list, tuple)):
+ outs = [outs]
+ if build_directory is not None:
+ build_directory = Path(build_directory)
+ new_outs = []
+ for p in outs:
+ if not Path(p).is_absolute():
+ new_outs.append(str(build_directory / p))
+ else:
+ new_outs.append(p)
+ outs = new_outs
+ target = Path(target)
+ if target.is_absolute():
+ target = str(build_directory / target)
+ self.outs = [str(p) for p in outs]
+ self.target = str(target)
+ self.compiler = compiler
+
+ def shell(self, target: str = None):
+ string = f"{self.compiler} -r "
+ if target is None:
+ target = self.target
+ string += f"-o {target} {' '.join(self.outs)} "
+ return string
+
+
+class Nvcc(Gpp):
+ def __init__(
+ self,
+ sources,
+ target,
+ arch=None,
+ std="c++11",
+ includes: list = None,
+ defines: dict = None,
+ cflags: str = None,
+ extra_cflags: str = None,
+ extra_lflags: str = None,
+ build_directory: str = None,
+ ):
+ if arch is None:
+ arch = find_cuda_device_arch()
+ if arch is None:
+ raise ValueError("you must specify arch if use cuda.")
+
+ cflags = (
+ cflags or f"-x cu -Xcompiler -fPIC -arch={arch} --expt-relaxed-constexpr"
+ )
+ try:
+ cuda_home = find_cuda()
+ except:
+ cuda_home = None
+ if cuda_home is not None:
+ cuda_include = Path(cuda_home) / "include"
+ includes = includes or []
+ includes += [str(cuda_include)]
+ super().__init__(
+ sources,
+ target,
+ std,
+ includes,
+ defines,
+ cflags,
+ compiler="nvcc",
+ extra_cflags=extra_cflags,
+ extra_lflags=extra_lflags,
+ build_directory=build_directory,
+ )
+
+
+class CUDALink(Gpp):
+ def __init__(
+ self,
+ sources,
+ target,
+ std="c++11",
+ includes: list = None,
+ defines: dict = None,
+ cflags: str = None,
+ libraries: dict = None,
+ lflags: str = None,
+ extra_cflags: str = None,
+ extra_lflags: str = None,
+ build_directory: str = None,
+ ):
+ includes = includes or []
+ defines = defines or {}
+ libraries = libraries or {}
+ cflags = cflags or "-fPIC -O3"
+ try:
+ cuda_home = find_cuda()
+ except:
+ cuda_home = None
+ if cuda_home is not None:
+ cuda_include = Path(cuda_home) / "include"
+ includes += [str(cuda_include)]
+ cuda_lib_path = Path(cuda_home) / "lib64"
+ cuda_libs = {str(cuda_lib_path): ["cublas", "cudart"]}
+ libraries = {**libraries, **cuda_libs}
+ super().__init__(
+ sources,
+ target,
+ std,
+ includes,
+ defines,
+ cflags,
+ link=True,
+ libraries=libraries,
+ lflags=lflags,
+ extra_cflags=extra_cflags,
+ extra_lflags=extra_lflags,
+ build_directory=build_directory,
+ )
+
+
+class NodeState(Enum):
+ Evaled = "evaled"
+ Normal = "normal"
+ Error = "error"
+
+
+class Node:
+ def __init__(self, name=None):
+ self.name = name
+ self.prev = []
+ self.next = []
+ self.state = NodeState.Normal
+
+ def __call__(self, *nodes):
+ for node in nodes:
+ self.prev.append(node)
+ node.next.append(self)
+ return self
+
+ def _eval(self, *args, **kw):
+ return True
+
+ def eval(self, *args, **kw):
+ for p in self.prev:
+ if not p.eval(*args, **kw):
+ self.state = NodeState.Error
+ return False
+ if self.state == NodeState.Normal:
+ if self._eval(*args, **kw):
+ self.state = NodeState.Evaled
+ else:
+ self.state = NodeState.Error
+ return True
+ return True
+
+ def reset(self):
+ self.state = NodeState.Normal
+ self.prev = []
+ self.next = []
+ for node in self.prev:
+ node.reset()
+
+
+class TargetNode(Node):
+ def __init__(self, srcs, hdrs, deps, copts, name=None):
+ super().__init__(name)
+ self.srcs = srcs
+ self.hdrs = hdrs
+ self.deps = deps
+ self.copts = copts
+
+ def _eval(self, executor):
+ pass
+
+
+def compile_func(cmd, code_folder, compiler):
+ if not isinstance(cmd, (Link, Nvcc)):
+ shell = cmd.shell(compiler=compiler)
+ else:
+ shell = cmd.shell()
+ print(shell)
+ cwd = None
+ if code_folder is not None:
+ cwd = str(code_folder)
+ ret = subprocess.run(shell, shell=True, cwd=cwd)
+ if ret.returncode != 0:
+ raise RuntimeError("compile failed with retcode", ret.returncode)
+ return ret
+
+
+def compile_libraries(cmds, code_folder=None, compiler: str = None, num_workers=-1):
+ if num_workers == -1:
+ num_workers = min(len(cmds), multiprocessing.cpu_count())
+ # for cmd in cmds:
+ # print(cmd.shell())
+ if num_workers == 0:
+ rets = map(
+ partial(compile_func, code_folder=code_folder, compiler=compiler), cmds
+ )
+ else:
+ with ProcessPoolExecutor(num_workers) as pool:
+ func = partial(compile_func, code_folder=code_folder, compiler=compiler)
+ rets = pool.map(func, cmds)
+
+ if any([r.returncode != 0 for r in rets]):
+ cmds.clear()
+ return False
+ cmds.clear()
+ return True
+
+
+def out(path):
+ return Path(path).parent / (Path(path).stem + ".o")
diff --git a/det3d/utils/buildtools/pybind11_build.py b/det3d/utils/buildtools/pybind11_build.py
new file mode 100644
index 0000000..606ebae
--- /dev/null
+++ b/det3d/utils/buildtools/pybind11_build.py
@@ -0,0 +1,128 @@
+import shutil
+import subprocess
+import tempfile
+from pathlib import Path
+
+from det3d.utils.find import find_cuda_device_arch
+from det3d.utils.loader import import_file
+
+from .command import CUDALink, Gpp, Nvcc, compile_libraries, out
+
+
+class Pybind11Link(Gpp):
+ def __init__(
+ self,
+ sources,
+ target,
+ std="c++11",
+ includes: list = None,
+ defines: dict = None,
+ cflags: str = None,
+ libraries: dict = None,
+ lflags: str = None,
+ extra_cflags: str = None,
+ extra_lflags: str = None,
+ build_directory: str = None,
+ ):
+ pb11_includes = (
+ subprocess.check_output("python3 -m pybind11 --includes", shell=True)
+ .decode("utf8")
+ .strip("\n")
+ )
+ cflags = cflags or "-fPIC -O3 "
+ cflags += pb11_includes
+ super().__init__(
+ sources,
+ target,
+ std,
+ includes,
+ defines,
+ cflags,
+ link=True,
+ libraries=libraries,
+ lflags=lflags,
+ extra_cflags=extra_cflags,
+ extra_lflags=extra_lflags,
+ build_directory=build_directory,
+ )
+
+
+class Pybind11CUDALink(CUDALink):
+ def __init__(
+ self,
+ sources,
+ target,
+ std="c++11",
+ includes: list = None,
+ defines: dict = None,
+ cflags: str = None,
+ libraries: dict = None,
+ lflags: str = None,
+ extra_cflags: str = None,
+ extra_lflags: str = None,
+ build_directory: str = None,
+ ):
+ pb11_includes = (
+ subprocess.check_output("python3 -m pybind11 --includes", shell=True)
+ .decode("utf8")
+ .strip("\n")
+ )
+ cflags = cflags or "-fPIC -O3 "
+ cflags += pb11_includes
+ super().__init__(
+ sources,
+ target,
+ std,
+ includes,
+ defines,
+ cflags,
+ libraries=libraries,
+ lflags=lflags,
+ extra_cflags=extra_cflags,
+ extra_lflags=extra_lflags,
+ build_directory=build_directory,
+ )
+
+
+def load_pb11(
+ sources,
+ target,
+ cwd=".",
+ cuda=False,
+ arch=None,
+ num_workers=4,
+ includes: list = None,
+ build_directory=None,
+ compiler="g++",
+):
+ cmd_groups = []
+ cmds = []
+ outs = []
+ main_sources = []
+ if arch is None:
+ arch = find_cuda_device_arch()
+
+ for s in sources:
+ s = str(s)
+ if ".cu" in s or ".cu.cc" in s:
+ assert cuda is True, "cuda must be true if contain cuda file"
+ cmds.append(Nvcc(s, out(s), arch))
+ outs.append(out(s))
+ else:
+ main_sources.append(s)
+
+ if cuda is True and arch is None:
+ raise ValueError("you must specify arch if sources contains" " cuda files")
+ cmd_groups.append(cmds)
+ if cuda:
+ cmd_groups.append(
+ [Pybind11CUDALink(outs + main_sources, target, includes=includes)]
+ )
+ else:
+ cmd_groups.append(
+ [Pybind11Link(outs + main_sources, target, includes=includes)]
+ )
+ for cmds in cmd_groups:
+ compile_libraries(cmds, cwd, num_workers=num_workers, compiler=compiler)
+
+ return import_file(target, add_to_sys=False, disable_warning=True)
diff --git a/det3d/utils/check.py b/det3d/utils/check.py
new file mode 100644
index 0000000..16a4137
--- /dev/null
+++ b/det3d/utils/check.py
@@ -0,0 +1,17 @@
+import numpy as np
+
+
+def is_array_like(x):
+ return isinstance(x, (list, tuple, np.ndarray))
+
+
+def shape_mergeable(x, expected_shape):
+ mergeable = True
+ if is_array_like(x) and is_array_like(expected_shape):
+ x = np.array(x)
+ if len(x.shape) == len(expected_shape):
+ for s, s_ex in zip(x.shape, expected_shape):
+ if s_ex is not None and s != s_ex:
+ mergeable = False
+ break
+ return mergeable
diff --git a/det3d/utils/checkpoint.py b/det3d/utils/checkpoint.py
new file mode 100644
index 0000000..3cf8501
--- /dev/null
+++ b/det3d/utils/checkpoint.py
@@ -0,0 +1,325 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import json
+import logging
+import os
+from collections import OrderedDict
+from pathlib import Path
+
+import torch
+from tensorboardX import SummaryWriter
+
+
+def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""):
+ for k, v in json_dict.items():
+ if isinstance(v, dict):
+ _flat_nested_json_dict(v, flatted, sep, start + sep + str(k))
+ else:
+ flatted[start + sep + str(k)] = v
+
+
+def flat_nested_json_dict(json_dict, sep=".") -> dict:
+ """flat a nested json-like dict. this function make shadow copy.
+ """
+ flatted = {}
+ for k, v in json_dict.items():
+ if isinstance(v, dict):
+ _flat_nested_json_dict(v, flatted, sep, str(k))
+ else:
+ flatted[str(k)] = v
+ return flatted
+
+
+def metric_to_str(metrics, sep="."):
+ flatted_metrics = flat_nested_json_dict(metrics, sep)
+ metrics_str_list = []
+ for k, v in flatted_metrics.items():
+ if isinstance(v, float):
+ metrics_str_list.append(f"{k}={v:.4}")
+ elif isinstance(v, (list, tuple)):
+ if v and isinstance(v[0], float):
+ v_str = ", ".join([f"{e:.4}" for e in v])
+ metrics_str_list.append(f"{k}=[{v_str}]")
+ else:
+ metrics_str_list.append(f"{k}={v}")
+ else:
+ metrics_str_list.append(f"{k}={v}")
+ return ", ".join(metrics_str_list)
+
+
+def align_and_update_state_dicts(model_state_dict, loaded_state_dict, logger=None):
+ """
+ Strategy: suppose that the models that we will create will have prefixes appended
+ to each of its keys, for example due to an extra level of nesting that the original
+ pre-trained weights from ImageNet won't contain. For example, model.state_dict()
+ might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains
+ res2.conv1.weight. We thus want to match both parameters together.
+ For that, we look for each model weight, look among all loaded keys if there is one
+ that is a suffix of the current weight name, and use it if that's the case.
+ If multiple matches exist, take the one with longest size
+ of the corresponding name. For example, for the same model as before, the pretrained
+ weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case,
+ we want to match backbone[0].body.conv1.weight to conv1.weight, and
+ backbone[0].body.res2.conv1.weight to res2.conv1.weight.
+ """
+ current_keys = sorted(list(model_state_dict.keys()))
+ loaded_keys = sorted(list(loaded_state_dict.keys()))
+ # get a matrix of string matches, where each (i, j) entry correspond to the size of the
+ # loaded_key string, if it matches
+ match_matrix = [
+ len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys
+ ]
+ match_matrix = torch.as_tensor(match_matrix).view(
+ len(current_keys), len(loaded_keys)
+ )
+ max_match_size, idxs = match_matrix.max(1)
+ # remove indices that correspond to no-match
+ idxs[max_match_size == 0] = -1
+
+ # used for logging
+ max_size = max([len(key) for key in current_keys]) if current_keys else 1
+ max_size_loaded = max([len(key) for key in loaded_keys]) if loaded_keys else 1
+ log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
+ if logger is None:
+ logger = logging.getLogger(__name__)
+ for idx_new, idx_old in enumerate(idxs.tolist()):
+ if idx_old == -1:
+ continue
+ key = current_keys[idx_new]
+ key_old = loaded_keys[idx_old]
+ model_state_dict[key] = loaded_state_dict[key_old]
+ logger.info(
+ log_str_template.format(
+ key,
+ max_size,
+ key_old,
+ max_size_loaded,
+ tuple(loaded_state_dict[key_old].shape),
+ )
+ )
+
+
+def strip_prefix_if_present(state_dict, prefix):
+ keys = sorted(state_dict.keys())
+ if not all(key.startswith(prefix) for key in keys):
+ return state_dict
+ stripped_state_dict = OrderedDict()
+ for key, value in state_dict.items():
+ stripped_state_dict[key.replace(prefix, "")] = value
+ return stripped_state_dict
+
+
+def load_state_dict(model, loaded_state_dict, logger=None):
+ model_state_dict = model.state_dict()
+ # if the state_dict comes from a model that was wrapped in a
+ # DataParallel or DistributedDataParallel during serialization,
+ # remove the "module" prefix before performing the matching
+ loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
+ align_and_update_state_dicts(model_state_dict, loaded_state_dict, logger=logger)
+
+ # use strict loading
+ model.load_state_dict(model_state_dict)
+
+
+def finetune_load_state_dict(model, loaded_state_dict, logger=None):
+ model_state_dict = model.state_dict()
+ # if the state_dict comes from a model that was wrapped in a
+ # DataParallel or DistributedDataParallel during serialization,
+ # remove the "module" prefix before performing the matching
+ loaded_state_dict = strip_prefix_if_present(loaded_state_dict, prefix="module.")
+ loaded_state_dict = {
+ k: v for k, v in loaded_state_dict.items() if not k.startswith("rpn.tasks")
+ }
+ align_and_update_state_dicts(model_state_dict, loaded_state_dict, logger=logger)
+
+ # use strict loading
+ model.load_state_dict(model_state_dict)
+
+
+class Checkpointer(object):
+ def __init__(
+ self,
+ model,
+ optimizer=None,
+ scheduler=None,
+ save_dir="",
+ ckpt_path=None,
+ save_to_disk=None,
+ logger=None,
+ ):
+ self.model = model
+ self.optimizer = optimizer
+ self.scheduler = scheduler
+ self.pretrained_path = ckpt_path # whether pretrained
+ self.finetune = False
+ self.save_dir = save_dir
+ self.save_to_disk = save_to_disk
+ if logger is None:
+ logger = logging.getLogger(__name__)
+ self.logger = logger
+
+ def save(self, name, **kwargs):
+ self.logger.info(name)
+ if not self.save_dir:
+ return
+
+ if not self.save_to_disk:
+ return
+
+ data = {}
+ data["model"] = self.model.state_dict()
+ if self.optimizer is not None:
+ data["optimizer"] = self.optimizer.state_dict()
+ if self.scheduler is not None:
+ print(dir(self.scheduler))
+ data["scheduler"] = self.scheduler.state_dict()
+ data.update(kwargs)
+
+ save_file = os.path.join(self.save_dir, "{}.pth".format(name))
+ self.logger.info("Saving checkpoint to {}".format(save_file))
+ torch.save(data, save_file)
+ self.tag_last_checkpoint(save_file)
+
+ def load(self, f=None):
+ if f is not None:
+ f = self.get_checkpoint_file(f)
+ elif self.has_checkpoint(self.save_dir):
+ # override argument with existing checkpoint
+ f = self.get_checkpoint_file(self.save_dir)
+
+ if not f:
+ # no checkpoint could be found
+ self.logger.info("No checkpoint found. Initializing model from scratch")
+ return {}
+ self.logger.info("Loading checkpoint from {}".format(f))
+ checkpoint = self._load_file(f)
+ self._load_model(checkpoint)
+ if "optimizer" in checkpoint and self.optimizer:
+ self.logger.info("Loading optimizer from {}".format(f))
+ self.optimizer.load_state_dict(checkpoint.pop("optimizer"))
+ if "scheduler" in checkpoint and self.scheduler:
+ self.logger.info("Loading scheduler from {}".format(f))
+ self.scheduler.load_state_dict(checkpoint.pop("scheduler"))
+
+ # return any further checkpoint data
+ return checkpoint
+
+ def finetune_load(self, ckpt_path=None, f=None):
+ if ckpt_path is not None:
+ self.pretrained_path = ckpt_path
+ self.finetune = True
+ f = self.get_checkpoint_file(ckpt_path)
+ assert f is not None, "Finetune should provide a valid ckpt path"
+ self.logger.info("Loading pretrained model from {}".format(f))
+ checkpoint = self._load_file(f)
+ self._load_model(checkpoint)
+
+ def has_checkpoint(self, save_dir):
+ save_file = os.path.join(save_dir, "last_checkpoint")
+ return os.path.exists(save_file)
+
+ def get_checkpoint_file(self, save_dir):
+ save_file = os.path.join(save_dir, "last_checkpoint")
+ try:
+ with open(save_file, "r") as f:
+ last_saved = f.read()
+ last_saved = last_saved.strip()
+ except IOError:
+ # if file doesn't exist, maybe because it has just been
+ # deleted by a separate process
+ last_saved = ""
+ return last_saved
+
+ def tag_last_checkpoint(self, last_filename):
+ save_file = os.path.join(self.save_dir, "last_checkpoint")
+ with open(save_file, "w") as f:
+ f.write(last_filename)
+
+ def _load_file(self, f):
+ return torch.load(f, map_location=torch.device("cpu"))
+
+ def _load_model(self, checkpoint):
+ if self.finetune:
+ finetune_load_state_dict(
+ self.model, checkpoint.pop("model"), logger=self.logger
+ )
+ else:
+ load_state_dict(self.model, checkpoint.pop("model"), logger=self.logger)
+
+
+class det3dCheckpointer(Checkpointer):
+ def __init__(
+ self,
+ # cfg,
+ model,
+ optimizer=None,
+ scheduler=None,
+ save_dir="",
+ save_to_disk=None,
+ logger=None,
+ ):
+ super(det3dCheckpointer, self).__init__(
+ model, optimizer, scheduler, save_dir, save_to_disk, logger
+ )
+ # self.cfg = cfg.clone()
+ # self.writer = Writer(save_dir)
+ self.logger = logger
+
+ def _load_file(self, f):
+ # load native detectron.pytorch checkpoint
+ loaded = super(det3dCheckpointer, self)._load_file(f)
+ if "model" not in loaded:
+ loaded = dict(model=loaded)
+ return loaded
+
+
+class Writer:
+ def __init__(self, save_dir):
+ self.save_dir = Path(save_dir)
+ self.log_mjson_file = None
+ self.summary_writter = None
+ self.metrics = []
+ self._text_current_gstep = -1
+ self._tb_texts = []
+
+ def open(self):
+ save_dir = self.save_dir
+ assert save_dir.exists()
+ summary_dir = save_dir / "summary"
+ summary_dir.mkdir(parents=True, exist_ok=True)
+ self.summary_writter = SummaryWriter(str(summary_dir))
+ return self
+
+ def close(self):
+ assert self.summary_writter is not None
+ tb_json_path = str(self.save_dir / "tensorboard_scalars.json")
+ self.summary_writter.export_scalars_to_json(tb_json_path)
+ self.summary_writter.close()
+ self.summary_writter = None
+
+ def log_text(self, text, step, tag="regular log"):
+ """This function only add text to log.txt and tensorboard texts
+ """
+ if step > self._text_current_gstep and self._text_current_gstep != -1:
+ total_text = "\n".join(self._tb_texts)
+ self.summary_writter.add_text(tag, total_text, global_step=step)
+ self._tb_texts = []
+ self._text_current_gstep = step
+ else:
+ self._tb_texts.append(text)
+
+ if self._text_current_gstep == -1:
+ self._text_current_gstep = step
+
+ def log_metrics(self, metrics: dict, step):
+ flatted_summarys = flat_nested_json_dict(metrics, "/")
+ for k, v in flatted_summarys.items():
+ if isinstance(v, (list, tuple)):
+ if any([isinstance(e, str) for e in v]):
+ continue
+ v_dict = {str(i): e for i, e in enumerate(v)}
+ for k1, v1 in v_dict.items():
+ self.summary_writter.add_scalar(k + "/" + k1, v1, step)
+ else:
+ if isinstance(v, str):
+ continue
+ self.summary_writter.add_scalar(k, v, step)
diff --git a/det3d/utils/config_tool.py b/det3d/utils/config_tool.py
new file mode 100644
index 0000000..96c6d1f
--- /dev/null
+++ b/det3d/utils/config_tool.py
@@ -0,0 +1,53 @@
+# This file contains some config modification function.
+# some functions should be only used for KITTI dataset.
+
+from pathlib import Path
+
+import numpy as np
+from google.protobuf import text_format
+
+
+def change_detection_range(model_config, new_range):
+ assert len(new_range) == 4, "you must provide a list such as [-50, -50, 50, 50]"
+ old_pc_range = list(model_config.voxel_generator.point_cloud_range)
+ old_pc_range[:2] = new_range[:2]
+ old_pc_range[3:5] = new_range[2:]
+ model_config.voxel_generator.point_cloud_range[:] = old_pc_range
+ for anchor_generator in model_config.target_assigner.anchor_generators:
+ a_type = anchor_generator.WhichOneof("anchor_generator")
+ if a_type == "anchor_generator_range":
+ a_cfg = anchor_generator.anchor_generator_range
+ old_a_range = list(a_cfg.anchor_ranges)
+ old_a_range[:2] = new_range[:2]
+ old_a_range[3:5] = new_range[2:]
+ a_cfg.anchor_ranges[:] = old_a_range
+ elif a_type == "anchor_generator_stride":
+ a_cfg = anchor_generator.anchor_generator_stride
+ old_offset = list(a_cfg.offsets)
+ stride = list(a_cfg.strides)
+ old_offset[0] = new_range[0] + stride[0] / 2
+ old_offset[1] = new_range[1] + stride[1] / 2
+ a_cfg.offsets[:] = old_offset
+ else:
+ raise ValueError("unknown")
+ old_post_range = list(model_config.post_center_limit_range)
+ old_post_range[:2] = new_range[:2]
+ old_post_range[3:5] = new_range[2:]
+ model_config.post_center_limit_range[:] = old_post_range
+
+
+def get_downsample_factor(model_config):
+ try:
+ neck_cfg = model_config["neck"]
+ except:
+ model_config = model_config['first_stage_cfg']
+ neck_cfg = model_config['neck']
+ downsample_factor = np.prod(neck_cfg.get("ds_layer_strides", [1]))
+ if len(neck_cfg.get("us_layer_strides", [])) > 0:
+ downsample_factor /= neck_cfg.get("us_layer_strides", [])[-1]
+
+ backbone_cfg = model_config['backbone']
+ downsample_factor *= backbone_cfg["ds_factor"]
+ downsample_factor = int(downsample_factor)
+ assert downsample_factor > 0
+ return downsample_factor
diff --git a/det3d/utils/dist/collect_env.py b/det3d/utils/dist/collect_env.py
new file mode 100644
index 0000000..2d0641d
--- /dev/null
+++ b/det3d/utils/dist/collect_env.py
@@ -0,0 +1,14 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import PIL
+
+from torch.utils.collect_env import get_pretty_env_info
+
+
+def get_pil_version():
+ return "\n Pillow ({})".format(PIL.__version__)
+
+
+def collect_env_info():
+ env_str = get_pretty_env_info()
+ env_str += get_pil_version()
+ return env_str
diff --git a/det3d/utils/dist/dist_common.py b/det3d/utils/dist/dist_common.py
new file mode 100644
index 0000000..46d7c55
--- /dev/null
+++ b/det3d/utils/dist/dist_common.py
@@ -0,0 +1,117 @@
+"""
+This file contains primitives for multi-gpu communication.
+This is useful when doing distributed training.
+"""
+
+import pickle
+import time
+
+import torch
+import torch.distributed as dist
+
+
+def get_world_size():
+ if not dist.is_available():
+ return 1
+ if not dist.is_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def get_rank():
+ if not dist.is_available():
+ return 0
+ if not dist.is_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def synchronize():
+ """
+ Helper function to synchronize (barrier) among all processes when
+ using distributed training
+ """
+ if not dist.is_available():
+ return
+ if not dist.is_initialized():
+ return
+ world_size = dist.get_world_size()
+ if world_size == 1:
+ return
+ dist.barrier()
+
+
+def all_gather(data):
+ """
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
+ Args:
+ data: any picklable object
+ Returns:
+ list[data]: list of data gathered from each rank
+ """
+ world_size = get_world_size()
+ if world_size == 1:
+ return [data]
+
+ # serialized to a Tensor
+ buffer = pickle.dumps(data)
+ storage = torch.ByteStorage.from_buffer(buffer)
+ tensor = torch.ByteTensor(storage).to("cuda")
+
+ # obtain Tensor size of each rank
+ local_size = torch.IntTensor([tensor.numel()]).to("cuda")
+ size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)]
+ dist.all_gather(size_list, local_size)
+ size_list = [int(size.item()) for size in size_list]
+ max_size = max(size_list)
+
+ # receiving Tensor from all ranks
+ # we pad the tensor because torch all_gather does not support
+ # gathering tensors of different shapes
+ tensor_list = []
+ for _ in size_list:
+ tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))
+ if local_size != max_size:
+ padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
+ tensor = torch.cat((tensor, padding), dim=0)
+ dist.all_gather(tensor_list, tensor)
+
+ data_list = []
+ for size, tensor in zip(size_list, tensor_list):
+ buffer = tensor.cpu().numpy().tobytes()[:size]
+ data_list.append(pickle.loads(buffer))
+
+ return data_list
+
+
+def reduce_dict(input_dict, average=True):
+ """
+ Args:
+ input_dict (dict): all the values will be reduced
+ average (bool): whether to do average or sum
+ Reduce the values in the dictionary from all processes so that process with rank
+ 0 has the averaged results. Returns a dict with the same fields as
+ input_dict, after reduction.
+ """
+ world_size = get_world_size()
+ if world_size < 2:
+ return input_dict
+ with torch.no_grad():
+ names = []
+ values = []
+ # sort the keys so that they are consistent across processes
+ for k in sorted(input_dict.keys()):
+ names.append(k)
+ values.append(input_dict[k])
+ values = torch.stack(values, dim=0)
+ dist.reduce(values, dst=0)
+ if dist.get_rank() == 0 and average:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ values /= world_size
+ reduced_dict = {k: v for k, v in zip(names, values)}
+ return reduced_dict
diff --git a/det3d/utils/dist/logger.py b/det3d/utils/dist/logger.py
new file mode 100644
index 0000000..159fffa
--- /dev/null
+++ b/det3d/utils/dist/logger.py
@@ -0,0 +1,26 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import logging
+import os
+import sys
+
+
+def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
+ logger = logging.getLogger(name)
+ logger.setLevel(logging.DEBUG)
+ # don't log results for the non-master process
+ if distributed_rank > 0:
+ return logger
+
+ ch = logging.StreamHandler(stream=sys.stdout)
+ ch.setLevel(logging.DEBUG)
+ formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
+ ch.setFormatter(formatter)
+ logger.addHandler(ch)
+
+ if save_dir:
+ fh = logging.FileHandler(os.path.join(save_dir, filename))
+ fh.setLevel(logging.DEBUG)
+ fh.setFormatter(formatter)
+ logger.addHandler(fh)
+
+ return logger
diff --git a/det3d/utils/find.py b/det3d/utils/find.py
new file mode 100644
index 0000000..3397cff
--- /dev/null
+++ b/det3d/utils/find.py
@@ -0,0 +1,214 @@
+import glob
+import json
+import os
+import subprocess
+import sys
+import tempfile
+from pathlib import Path
+
+import fire
+
+
+def _get_info_from_anaconda_info(info, split=":"):
+ info = info.strip("\n").replace(" ", "")
+ info_dict = {}
+ latest_key = ""
+ for line in info.splitlines():
+ if split in line:
+ pair = line.split(split)
+ info_dict[pair[0]] = pair[1]
+ latest_key = pair[0]
+ else:
+ if not isinstance(info_dict[latest_key], list):
+ info_dict[latest_key] = [info_dict[latest_key]]
+ info_dict[latest_key].append(line)
+ return info_dict
+
+
+def find_anaconda():
+ # try find in default path
+ path = Path.home() / "anaconda3"
+ if path.exists():
+ return path
+ # try conda in cmd
+ try:
+ info = subprocess.check_output("conda info", shell=True).decode("utf-8")
+ info_dict = _get_info_from_anaconda_info(info)
+ return info_dict["activeenvlocation"]
+ except subprocess.CalledProcessError:
+ raise RuntimeError("find anadonda failed")
+
+
+def find_cuda():
+ """Finds the CUDA install path."""
+ # Guess #1
+ cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH")
+ if cuda_home is None:
+ # Guess #2
+ if sys.platform == "win32":
+ cuda_homes = glob.glob(
+ "C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v*.*"
+ )
+ if len(cuda_homes) == 0:
+ cuda_home = ""
+ else:
+ cuda_home = cuda_homes[0]
+ else:
+ cuda_home = "/usr/local/cuda"
+ if not os.path.exists(cuda_home):
+ # Guess #3
+ try:
+ which = "where" if sys.platform == "win32" else "which"
+ nvcc = subprocess.check_output([which, "nvcc"]).decode().rstrip("\r\n")
+ cuda_home = os.path.dirname(os.path.dirname(nvcc))
+ except Exception:
+ cuda_home = None
+ if cuda_home is None:
+ raise RuntimeError(
+ "No CUDA runtime is found, using CUDA_HOME='{}'".format(cuda_home)
+ )
+ return cuda_home
+
+
+def find_cuda_device_arch():
+ if sys.platform == "win32":
+ # TODO: add windows support
+ return None
+ cuda_home = find_cuda()
+ if cuda_home is None:
+ return None
+ cuda_home = Path(cuda_home)
+ try:
+ device_query_path = cuda_home / "extras/demo_suite/deviceQuery"
+ if not device_query_path.exists():
+ source = """
+ #include
+ #include
+ int main(){
+ int nDevices;
+ cudaGetDeviceCount(&nDevices);
+ for (int i = 0; i < nDevices; i++) {
+ cudaDeviceProp prop;
+ cudaGetDeviceProperties(&prop, i);
+ std::cout << prop.major << "." << prop.minor << std::endl;
+ }
+ return 0;
+ }
+ """
+ with tempfile.NamedTemporaryFile("w", suffix=".cc") as f:
+ f_path = Path(f.name)
+ f.write(source)
+ f.flush()
+ try:
+ # TODO: add windows support
+ cmd = (
+ f"g++ {f.name} -o {f_path.stem}"
+ f" -I{cuda_home / 'include'} -L{cuda_home / 'lib64'} -lcudart"
+ )
+ print(cmd)
+ subprocess.check_output(cmd, shell=True, cwd=f_path.parent)
+ cmd = f"./{f_path.stem}"
+ arches = (
+ subprocess.check_output(cmd, shell=True, cwd=f_path.parent)
+ .decode()
+ .rstrip("\r\n")
+ .split("\n")
+ )
+ if len(arches) < 1:
+ return None
+ arch = arches[0]
+ except Exception:
+ return None
+ else:
+ cmd = f"{str(device_query_path)} | grep 'CUDA Capability'"
+ arch = (
+ subprocess.check_output(cmd, shell=True)
+ .decode()
+ .rstrip("\r\n")
+ .split(" ")[-1]
+ )
+ # assert len(arch) == 2
+ arch_list = [int(s) for s in arch.split(".")]
+ arch_int = arch_list[0] * 10 + arch_list[1]
+ find_work_arch = False
+ while arch_int > 10:
+ try:
+ res = subprocess.check_output(
+ "nvcc -arch=sm_{}".format(arch_int),
+ shell=True,
+ stderr=subprocess.STDOUT,
+ )
+ except subprocess.CalledProcessError as e:
+ if "No input files specified" in e.output.decode():
+ find_work_arch = True
+ break
+ elif (
+ "is not defined for option 'gpu-architecture'" in e.output.decode()
+ ):
+ arch_int -= 1
+ else:
+ raise RuntimeError("unknown error")
+ if find_work_arch:
+ arch = f"sm_{arch_int}"
+ else:
+ arch = None
+
+ except Exception:
+ arch = None
+ return arch
+
+
+def get_gpu_memory_usage():
+ if sys.platform == "win32":
+ # TODO: add windows support
+ return None
+ cuda_home = find_cuda()
+ if cuda_home is None:
+ return None
+ cuda_home = Path(cuda_home)
+ source = """
+ #include
+ #include
+ int main(){
+ int nDevices;
+ cudaGetDeviceCount(&nDevices);
+ size_t free_m, total_m;
+ // output json format.
+ std::cout << "[";
+ for (int i = 0; i < nDevices; i++) {
+ cudaSetDevice(i);
+ cudaMemGetInfo(&free_m, &total_m);
+ std::cout << "[" << free_m << "," << total_m << "]";
+ if (i != nDevices - 1)
+ std::cout << "," << std::endl;
+ }
+ std::cout << "]" << std::endl;
+ return 0;
+ }
+ """
+ with tempfile.NamedTemporaryFile("w", suffix=".cc") as f:
+ f_path = Path(f.name)
+ f.write(source)
+ f.flush()
+ try:
+ # TODO: add windows support
+ cmd = (
+ f"g++ {f.name} -o {f_path.stem} -std=c++11"
+ f" -I{cuda_home / 'include'} -L{cuda_home / 'lib64'} -lcudart"
+ )
+ print(cmd)
+ subprocess.check_output(cmd, shell=True, cwd=f_path.parent)
+ cmd = f"./{f_path.stem}"
+ usages = subprocess.check_output(
+ cmd, shell=True, cwd=f_path.parent
+ ).decode()
+ usages = json.loads(usages)
+ return usages
+ except Exception:
+ return None
+ return None
+
+
+if __name__ == "__main__":
+ print(find_cuda_device_arch())
+ # fire.Fire()
diff --git a/det3d/utils/flops_counter.py b/det3d/utils/flops_counter.py
new file mode 100644
index 0000000..4fc0927
--- /dev/null
+++ b/det3d/utils/flops_counter.py
@@ -0,0 +1,446 @@
+# Modified from flops-counter.pytorch by Vladislav Sovrasov
+# original repo: https://github.com/sovrasov/flops-counter.pytorch
+
+# MIT License
+
+# Copyright (c) 2018 Vladislav Sovrasov
+
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+import sys
+
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
+from torch.nn.modules.pooling import (
+ _AdaptiveAvgPoolNd,
+ _AdaptiveMaxPoolNd,
+ _AvgPoolNd,
+ _MaxPoolNd,
+)
+
+CONV_TYPES = (_ConvNd,)
+DECONV_TYPES = (_ConvTransposeMixin,)
+LINEAR_TYPES = (nn.Linear,)
+POOLING_TYPES = (_AvgPoolNd, _MaxPoolNd, _AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd)
+RELU_TYPES = (nn.ReLU, nn.PReLU, nn.ELU, nn.LeakyReLU, nn.ReLU6)
+BN_TYPES = (_BatchNorm,)
+UPSAMPLE_TYPES = (nn.Upsample,)
+
+SUPPORTED_TYPES = (
+ CONV_TYPES
+ + DECONV_TYPES
+ + LINEAR_TYPES
+ + POOLING_TYPES
+ + RELU_TYPES
+ + BN_TYPES
+ + UPSAMPLE_TYPES
+)
+
+
+def get_model_complexity_info(
+ model,
+ input_res,
+ print_per_layer_stat=True,
+ as_strings=True,
+ input_constructor=None,
+ ost=sys.stdout,
+):
+ assert type(input_res) is tuple
+ assert len(input_res) >= 2
+ flops_model = add_flops_counting_methods(model)
+ flops_model.eval().start_flops_count()
+ if input_constructor:
+ input = input_constructor(input_res)
+ _ = flops_model(**input)
+ else:
+ batch = torch.ones(()).new_empty(
+ (1, *input_res),
+ dtype=next(flops_model.parameters()).dtype,
+ device=next(flops_model.parameters()).device,
+ )
+ flops_model(batch)
+
+ if print_per_layer_stat:
+ print_model_with_flops(flops_model, ost=ost)
+ flops_count = flops_model.compute_average_flops_cost()
+ params_count = get_model_parameters_number(flops_model)
+ flops_model.stop_flops_count()
+
+ if as_strings:
+ return flops_to_string(flops_count), params_to_string(params_count)
+
+ return flops_count, params_count
+
+
+def flops_to_string(flops, units="GMac", precision=2):
+ if units is None:
+ if flops // 10 ** 9 > 0:
+ return str(round(flops / 10.0 ** 9, precision)) + " GMac"
+ elif flops // 10 ** 6 > 0:
+ return str(round(flops / 10.0 ** 6, precision)) + " MMac"
+ elif flops // 10 ** 3 > 0:
+ return str(round(flops / 10.0 ** 3, precision)) + " KMac"
+ else:
+ return str(flops) + " Mac"
+ else:
+ if units == "GMac":
+ return str(round(flops / 10.0 ** 9, precision)) + " " + units
+ elif units == "MMac":
+ return str(round(flops / 10.0 ** 6, precision)) + " " + units
+ elif units == "KMac":
+ return str(round(flops / 10.0 ** 3, precision)) + " " + units
+ else:
+ return str(flops) + " Mac"
+
+
+def params_to_string(params_num):
+ """converting number to string
+ :param float params_num: number
+ :returns str: number
+ >>> params_to_string(1e9)
+ '1000.0 M'
+ >>> params_to_string(2e5)
+ '200.0 k'
+ >>> params_to_string(3e-9)
+ '3e-09'
+ """
+ if params_num // 10 ** 6 > 0:
+ return str(round(params_num / 10 ** 6, 2)) + " M"
+ elif params_num // 10 ** 3:
+ return str(round(params_num / 10 ** 3, 2)) + " k"
+ else:
+ return str(params_num)
+
+
+def print_model_with_flops(model, units="GMac", precision=3, ost=sys.stdout):
+ total_flops = model.compute_average_flops_cost()
+
+ def accumulate_flops(self):
+ if is_supported_instance(self):
+ return self.__flops__ / model.__batch_counter__
+ else:
+ sum = 0
+ for m in self.children():
+ sum += m.accumulate_flops()
+ return sum
+
+ def flops_repr(self):
+ accumulated_flops_cost = self.accumulate_flops()
+ return ", ".join(
+ [
+ flops_to_string(
+ accumulated_flops_cost, units=units, precision=precision
+ ),
+ "{:.3%} MACs".format(accumulated_flops_cost / total_flops),
+ self.original_extra_repr(),
+ ]
+ )
+
+ def add_extra_repr(m):
+ m.accumulate_flops = accumulate_flops.__get__(m)
+ flops_extra_repr = flops_repr.__get__(m)
+ if m.extra_repr != flops_extra_repr:
+ m.original_extra_repr = m.extra_repr
+ m.extra_repr = flops_extra_repr
+ assert m.extra_repr != m.original_extra_repr
+
+ def del_extra_repr(m):
+ if hasattr(m, "original_extra_repr"):
+ m.extra_repr = m.original_extra_repr
+ del m.original_extra_repr
+ if hasattr(m, "accumulate_flops"):
+ del m.accumulate_flops
+
+ model.apply(add_extra_repr)
+ print(model, file=ost)
+ model.apply(del_extra_repr)
+
+
+def get_model_parameters_number(model):
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ return params_num
+
+
+def add_flops_counting_methods(net_main_module):
+ # adding additional methods to the existing module object,
+ # this is done this way so that each function has access to self object
+ net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
+ net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
+ net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
+ net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(
+ net_main_module
+ )
+
+ net_main_module.reset_flops_count()
+
+ # Adding variables necessary for masked flops computation
+ net_main_module.apply(add_flops_mask_variable_or_reset)
+
+ return net_main_module
+
+
+def compute_average_flops_cost(self):
+ """
+ A method that will be available after add_flops_counting_methods() is
+ called on a desired net object.
+ Returns current mean flops consumption per image.
+ """
+
+ batches_count = self.__batch_counter__
+ flops_sum = 0
+ for module in self.modules():
+ if is_supported_instance(module):
+ flops_sum += module.__flops__
+
+ return flops_sum / batches_count
+
+
+def start_flops_count(self):
+ """
+ A method that will be available after add_flops_counting_methods() is
+ called on a desired net object.
+ Activates the computation of mean flops consumption per image.
+ Call it before you run the network.
+ """
+ add_batch_counter_hook_function(self)
+ self.apply(add_flops_counter_hook_function)
+
+
+def stop_flops_count(self):
+ """
+ A method that will be available after add_flops_counting_methods() is
+ called on a desired net object.
+ Stops computing the mean flops consumption per image.
+ Call whenever you want to pause the computation.
+ """
+ remove_batch_counter_hook_function(self)
+ self.apply(remove_flops_counter_hook_function)
+
+
+def reset_flops_count(self):
+ """
+ A method that will be available after add_flops_counting_methods() is
+ called on a desired net object.
+ Resets statistics computed so far.
+ """
+ add_batch_counter_variables_or_reset(self)
+ self.apply(add_flops_counter_variable_or_reset)
+
+
+def add_flops_mask(module, mask):
+ def add_flops_mask_func(module):
+ if isinstance(module, torch.nn.Conv2d):
+ module.__mask__ = mask
+
+ module.apply(add_flops_mask_func)
+
+
+def remove_flops_mask(module):
+ module.apply(add_flops_mask_variable_or_reset)
+
+
+def is_supported_instance(module):
+ if isinstance(module, SUPPORTED_TYPES):
+ return True
+ else:
+ return False
+
+
+def empty_flops_counter_hook(module, input, output):
+ module.__flops__ += 0
+
+
+def upsample_flops_counter_hook(module, input, output):
+ output_size = output[0]
+ batch_size = output_size.shape[0]
+ output_elements_count = batch_size
+ for val in output_size.shape[1:]:
+ output_elements_count *= val
+ module.__flops__ += int(output_elements_count)
+
+
+def relu_flops_counter_hook(module, input, output):
+ active_elements_count = output.numel()
+ module.__flops__ += int(active_elements_count)
+
+
+def linear_flops_counter_hook(module, input, output):
+ input = input[0]
+ batch_size = input.shape[0]
+ module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
+
+
+def pool_flops_counter_hook(module, input, output):
+ input = input[0]
+ module.__flops__ += int(np.prod(input.shape))
+
+
+def bn_flops_counter_hook(module, input, output):
+ module.affine
+ input = input[0]
+
+ batch_flops = np.prod(input.shape)
+ if module.affine:
+ batch_flops *= 2
+ module.__flops__ += int(batch_flops)
+
+
+def deconv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ input_height, input_width = input.shape[2:]
+
+ kernel_height, kernel_width = conv_module.kernel_size
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = (
+ kernel_height * kernel_width * in_channels * filters_per_channel
+ )
+
+ active_elements_count = batch_size * input_height * input_width
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+ bias_flops = 0
+ if conv_module.bias is not None:
+ output_height, output_width = output.shape[2:]
+ bias_flops = out_channels * batch_size * output_height * output_height
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def conv_flops_counter_hook(conv_module, input, output):
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+
+ batch_size = input.shape[0]
+ output_dims = list(output.shape[2:])
+
+ kernel_dims = list(conv_module.kernel_size)
+ in_channels = conv_module.in_channels
+ out_channels = conv_module.out_channels
+ groups = conv_module.groups
+
+ filters_per_channel = out_channels // groups
+ conv_per_position_flops = np.prod(kernel_dims) * in_channels * filters_per_channel
+
+ active_elements_count = batch_size * np.prod(output_dims)
+
+ if conv_module.__mask__ is not None:
+ # (b, 1, h, w)
+ output_height, output_width = output.shape[2:]
+ flops_mask = conv_module.__mask__.expand(
+ batch_size, 1, output_height, output_width
+ )
+ active_elements_count = flops_mask.sum()
+
+ overall_conv_flops = conv_per_position_flops * active_elements_count
+
+ bias_flops = 0
+
+ if conv_module.bias is not None:
+
+ bias_flops = out_channels * active_elements_count
+
+ overall_flops = overall_conv_flops + bias_flops
+
+ conv_module.__flops__ += int(overall_flops)
+
+
+def batch_counter_hook(module, input, output):
+ batch_size = 1
+ if len(input) > 0:
+ # Can have multiple inputs, getting the first one
+ input = input[0]
+ batch_size = len(input)
+ else:
+ print(
+ "Warning! No positional inputs found for a module, "
+ "assuming batch size is 1."
+ )
+ module.__batch_counter__ += batch_size
+
+
+def add_batch_counter_variables_or_reset(module):
+
+ module.__batch_counter__ = 0
+
+
+def add_batch_counter_hook_function(module):
+ if hasattr(module, "__batch_counter_handle__"):
+ return
+
+ handle = module.register_forward_hook(batch_counter_hook)
+ module.__batch_counter_handle__ = handle
+
+
+def remove_batch_counter_hook_function(module):
+ if hasattr(module, "__batch_counter_handle__"):
+ module.__batch_counter_handle__.remove()
+ del module.__batch_counter_handle__
+
+
+def add_flops_counter_variable_or_reset(module):
+ if is_supported_instance(module):
+ module.__flops__ = 0
+
+
+def add_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, "__flops_handle__"):
+ return
+
+ if isinstance(module, CONV_TYPES):
+ handle = module.register_forward_hook(conv_flops_counter_hook)
+ elif isinstance(module, RELU_TYPES):
+ handle = module.register_forward_hook(relu_flops_counter_hook)
+ elif isinstance(module, LINEAR_TYPES):
+ handle = module.register_forward_hook(linear_flops_counter_hook)
+ elif isinstance(module, POOLING_TYPES):
+ handle = module.register_forward_hook(pool_flops_counter_hook)
+ elif isinstance(module, BN_TYPES):
+ handle = module.register_forward_hook(bn_flops_counter_hook)
+ elif isinstance(module, UPSAMPLE_TYPES):
+ handle = module.register_forward_hook(upsample_flops_counter_hook)
+ elif isinstance(module, DECONV_TYPES):
+ handle = module.register_forward_hook(deconv_flops_counter_hook)
+ else:
+ handle = module.register_forward_hook(empty_flops_counter_hook)
+ module.__flops_handle__ = handle
+
+
+def remove_flops_counter_hook_function(module):
+ if is_supported_instance(module):
+ if hasattr(module, "__flops_handle__"):
+ module.__flops_handle__.remove()
+ del module.__flops_handle__
+
+
+# --- Masked flops counting
+# Also being run in the initialization
+def add_flops_mask_variable_or_reset(module):
+ if is_supported_instance(module):
+ module.__mask__ = None
diff --git a/det3d/utils/imports.py b/det3d/utils/imports.py
new file mode 100644
index 0000000..50dc637
--- /dev/null
+++ b/det3d/utils/imports.py
@@ -0,0 +1,24 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
+import torch
+
+if torch._six.PY3:
+ import importlib
+ import importlib.util
+ import sys
+
+ # from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path?utm_medium=organic&utm_source=google_rich_qa&utm_campaign=google_rich_qa
+ def import_file(module_name, file_path, make_importable=False):
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ if make_importable:
+ sys.modules[module_name] = module
+ return module
+
+
+else:
+ import imp
+
+ def import_file(module_name, file_path, make_importable=None):
+ module = imp.load_source(module_name, file_path)
+ return module
diff --git a/det3d/utils/loader.py b/det3d/utils/loader.py
new file mode 100644
index 0000000..4a360bf
--- /dev/null
+++ b/det3d/utils/loader.py
@@ -0,0 +1,79 @@
+import importlib
+import logging
+import os
+import sys
+from pathlib import Path
+
+logger = logging.getLogger("det3d.utils.loader")
+
+CUSTOM_LOADED_MODULES = {}
+
+
+def _get_possible_module_path(paths):
+ ret = []
+ for p in paths:
+ p = Path(p)
+ for path in p.glob("*"):
+ if path.suffix in ["py", ".so"] or (path.is_dir()):
+ if path.stem.isidentifier():
+ ret.append(path)
+ return ret
+
+
+def _get_regular_import_name(path, module_paths):
+ path = Path(path)
+ for mp in module_paths:
+ mp = Path(mp)
+ if mp == path:
+ return path.stem
+ try:
+ relative_path = path.relative_to(Path(mp))
+ parts = list((relative_path.parent / relative_path.stem).parts)
+ module_name = ".".join([mp.stem] + parts)
+ return module_name
+ except Exception:
+ pass
+ return None
+
+
+def import_file(path, name: str = None, add_to_sys=True, disable_warning=False):
+ global CUSTOM_LOADED_MODULES
+ path = Path(path)
+ module_name = path.stem
+ try:
+ user_paths = os.environ["PYTHONPATH"].split(os.pathsep)
+ except KeyError:
+ user_paths = []
+ possible_paths = _get_possible_module_path(user_paths)
+ model_import_name = _get_regular_import_name(path, possible_paths)
+ if model_import_name is not None:
+ return import_name(model_import_name)
+ if name is not None:
+ module_name = name
+ spec = importlib.util.spec_from_file_location(module_name, path)
+ module = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(module)
+ if not disable_warning:
+ logger.warning(
+ (
+ f"Failed to perform regular import for file {path}. "
+ "this means this file isn't in any folder in PYTHONPATH "
+ "or don't have __init__.py in that project. "
+ "directly file import may fail and some reflecting features are "
+ "disabled even if import succeed. please add your project to PYTHONPATH "
+ "or add __init__.py to ensure this file can be regularly imported. "
+ )
+ )
+
+ if add_to_sys: # this will enable find objects defined in a file.
+ # avoid replace system modules.
+ if module_name in sys.modules and module_name not in CUSTOM_LOADED_MODULES:
+ raise ValueError(f"{module_name} exists in system.")
+ CUSTOM_LOADED_MODULES[module_name] = module
+ sys.modules[module_name] = module
+ return module
+
+
+def import_name(name, package=None):
+ module = importlib.import_module(name, package)
+ return module
diff --git a/det3d/utils/print_utils.py b/det3d/utils/print_utils.py
new file mode 100644
index 0000000..3916da8
--- /dev/null
+++ b/det3d/utils/print_utils.py
@@ -0,0 +1,35 @@
+def _flat_nested_json_dict(json_dict, flatted, sep=".", start=""):
+ for k, v in json_dict.items():
+ if isinstance(v, dict):
+ _flat_nested_json_dict(v, flatted, sep, start + sep + str(k))
+ else:
+ flatted[start + sep + str(k)] = v
+
+
+def flat_nested_json_dict(json_dict, sep=".") -> dict:
+ """flat a nested json-like dict. this function make shadow copy.
+ """
+ flatted = {}
+ for k, v in json_dict.items():
+ if isinstance(v, dict):
+ _flat_nested_json_dict(v, flatted, sep, str(k))
+ else:
+ flatted[str(k)] = v
+ return flatted
+
+
+def metric_to_str(metrics, sep="."):
+ flatted_metrics = flat_nested_json_dict(metrics, sep)
+ metrics_str_list = []
+ for k, v in flatted_metrics.items():
+ if isinstance(v, float):
+ metrics_str_list.append(f"{k}={v:.4}")
+ elif isinstance(v, (list, tuple)):
+ if v and isinstance(v[0], float):
+ v_str = ", ".join([f"{e:.4}" for e in v])
+ metrics_str_list.append(f"{k}=[{v_str}]")
+ else:
+ metrics_str_list.append(f"{k}={v}")
+ else:
+ metrics_str_list.append(f"{k}={v}")
+ return ", ".join(metrics_str_list)
diff --git a/det3d/utils/registry.py b/det3d/utils/registry.py
new file mode 100644
index 0000000..e51fd3a
--- /dev/null
+++ b/det3d/utils/registry.py
@@ -0,0 +1,78 @@
+import inspect
+
+from det3d import torchie
+
+
+class Registry(object):
+ def __init__(self, name):
+ self._name = name
+ self._module_dict = dict()
+
+ def __repr__(self):
+ format_str = self.__class__.__name__ + "(name={}, items={})".format(
+ self._name, list(self._module_dict.keys())
+ )
+ return format_str
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def module_dict(self):
+ return self._module_dict
+
+ def get(self, key):
+ return self._module_dict.get(key, None)
+
+ def _register_module(self, module_class):
+ """Register a module.
+ Args:
+ module (:obj:`nn.Module`): Module to be registered.
+ """
+ if not inspect.isclass(module_class):
+ raise TypeError(
+ "module must be a class, but got {}".format(type(module_class))
+ )
+ module_name = module_class.__name__
+ if module_name in self._module_dict:
+ raise KeyError(
+ "{} is already registered in {}".format(module_name, self.name)
+ )
+ self._module_dict[module_name] = module_class
+
+ def register_module(self, cls):
+ self._register_module(cls)
+ return cls
+
+
+def build_from_cfg(cfg, registry, default_args=None):
+ """Build a module from config dict.
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+ registry (:obj:`Registry`): The registry to search the type from.
+ default_args (dict, optional): Default initialization arguments.
+ Returns:
+ obj: The constructed object.
+ """
+ assert isinstance(cfg, dict) and "type" in cfg
+ assert isinstance(default_args, dict) or default_args is None
+ args = cfg.copy()
+ obj_type = args.pop("type")
+ if torchie.is_str(obj_type):
+ obj_cls = registry.get(obj_type)
+ if obj_cls is None:
+ raise KeyError(
+ "{} is not in the {} registry".format(obj_type, registry.name)
+ )
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ "type must be a str or valid type, but got {}".format(type(obj_type))
+ )
+ if default_args is not None:
+ for name, value in default_args.items():
+ args.setdefault(name, value)
+
+ return obj_cls(**args)
diff --git a/det3d/utils/utils.py b/det3d/utils/utils.py
new file mode 100644
index 0000000..9bc0a03
--- /dev/null
+++ b/det3d/utils/utils.py
@@ -0,0 +1,42 @@
+import numpy as np
+import torch
+
+
+def example_to_device(
+ example, dtype=torch.float32, device=None, non_blocking=True
+) -> dict:
+ device = device or torch.device("cuda:0")
+ example_torch = {}
+ float_names = ["voxels", "bev_map"]
+ for k, v in example.items():
+ if k in ["anchors", "reg_targets", "reg_weights", "labels", "anchors_mask"]:
+ res = []
+ for kk, vv in v.items():
+ vv = [vvv.unsqueeze_(0) for vvv in vv]
+ res.append(torch.cat(vv, dim=0).cuda(device, non_blocking=non_blocking))
+ example_torch[k] = res
+ elif k in [
+ "voxels",
+ "bev_map",
+ "coordinates",
+ "num_points",
+ "points",
+ "num_voxels",
+ ]:
+ # slow when directly provide fp32 data with dtype=torch.half
+ example_torch[k] = v.cuda(device, non_blocking=non_blocking)
+ elif k == "calib":
+ calib = {}
+ for k1, v1 in v.items():
+ calib[k1] = v1.cuda(device, non_blocking=non_blocking)
+ example_torch[k] = calib
+ else:
+ example_torch[k] = v
+
+ return example_torch
+
+
+def _worker_init_fn(worker_id):
+ time_seed = np.array(time.time(), dtype=np.int32)
+ np.random.seed(time_seed + worker_id)
+ print(f"WORKER {worker_id} seed:", np.random.get_state()[1][0])
diff --git a/docs/INSTALL.md b/docs/INSTALL.md
new file mode 100644
index 0000000..7853b74
--- /dev/null
+++ b/docs/INSTALL.md
@@ -0,0 +1,27 @@
+## Installation
+Modified from [CenterPoint](https://github.com/tianweiy/CenterPoint)
+
+Our experiments are tested on the following environments:
+
+- Python: 3.9.12
+- PyTorch: 1.9.1
+- CUDA: 11.1
+- We use spconv 1.2.1 in our experiment.
+
+### Installation
+
+```bash
+# basic python libraries
+conda create --name centerformer python=3.9
+conda activate centerformer
+pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
+git clone [this repo]
+cd centerformer
+pip install -r requirements.txt
+sh setup.sh
+
+# add CenterFormer to PYTHONPATH by adding the following line to ~/.bashrc (change the path accordingly)
+export PYTHONPATH="${PYTHONPATH}:PATH_TO_CENTERFORMER"
+```
+
+Most of the libaraies are the same as [CenterPoint](https://github.com/tianweiy/CenterPoint) except for the transformer part. If you run into any issues, you can also refer to their detailed instructions and search from the issues in their repo.
\ No newline at end of file
diff --git a/docs/NOTICE b/docs/NOTICE
new file mode 100644
index 0000000..5a642fa
--- /dev/null
+++ b/docs/NOTICE
@@ -0,0 +1,273 @@
+Portions of this software are derived from det3d(https://github.com/poodarchu/Det3D/tree/56402d4761a5b73acd23080f537599b0888cce07).
+
+==============================================================================
+det3d licence
+==============================================================================
+
+MIT License
+
+Copyright (c) 2019 朱本金
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Portions of this software are derived from second.
+
+==============================================================================
+second license
+==============================================================================
+
+MIT License
+
+Copyright (c) 2018
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Portions of this software are derived from CenterTrack.
+
+==============================================================================
+CenterTrack license
+==============================================================================
+
+MIT License
+
+Copyright (c) 2020 Xingyi Zhou
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Portions of this software are derived from CenterNet.
+
+MIT License
+
+==============================================================================
+CenterNet license
+==============================================================================
+
+Copyright (c) 2019 Xingyi Zhou
+All rights reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Portions of this software are derived from nuscenes-devkit.
+
+==============================================================================
+nuscenes-devkit licence
+==============================================================================
+
+Copyright 2019 Aptiv
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+Portions of this software are derived from mmdetection.
+
+==============================================================================
+mmdetection licence
+==============================================================================
+
+Copyright 2018-2019 Open-MMLab.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+Portions of this software are derived from mmcv.
+
+==============================================================================
+mmcv licence
+==============================================================================
+
+Copyright 2018-2020 Open-MMLab.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
+Portions of this software are derived from PCDet.
+
+==============================================================================
+PCDet licence
+==============================================================================
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+Portions of this software are derived from maskrcnn-benchmark.
+
+==============================================================================
+maskrcnn-benchmark licence
+==============================================================================
+
+MIT License
+
+Copyright (c) 2018 Facebook
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+Portions of this software are derived from pillar-od.
+
+==============================================================================
+pillar-od licence
+==============================================================================
+
+MIT License
+
+Copyright (c) Massachusetts Institute of Technology and its affiliates.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+==============================================================================
+centerpoint licence
+==============================================================================
+
+MIT License
+
+Copyright (c) 2020-2021 Tianwei Yin and Xingyi Zhou
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
\ No newline at end of file
diff --git a/docs/WAYMO.md b/docs/WAYMO.md
new file mode 100644
index 0000000..52d9a3e
--- /dev/null
+++ b/docs/WAYMO.md
@@ -0,0 +1,121 @@
+## Getting Started with CenterFormer on Waymo
+
+### Prerequisite
+
+- Follow [INSTALL.md](INSTALL.md) to install all required libraries.
+- Tensorflow
+- Waymo-open-dataset devkit
+
+```bash
+conda activate centerformer
+pip install waymo-open-dataset-tf-2-6-0==1.4.3
+```
+
+### Prepare data
+
+#### Download data and organise as follows
+
+```
+# For Waymo Dataset
+└── WAYMO_DATASET_ROOT
+ ├── tfrecord_training
+ ├── tfrecord_validation
+ ├── tfrecord_testing
+```
+
+Convert the tfrecord data to pickle files.
+
+```bash
+# train set
+CUDA_VISIBLE_DEVICES=-1 python det3d/datasets/waymo/waymo_converter.py --record_path 'WAYMO_DATASET_ROOT/tfrecord_training/*.tfrecord' --root_path 'WAYMO_DATASET_ROOT/train/'
+
+# validation set
+CUDA_VISIBLE_DEVICES=-1 python det3d/datasets/waymo/waymo_converter.py --record_path 'WAYMO_DATASET_ROOT/tfrecord_validation/*.tfrecord' --root_path 'WAYMO_DATASET_ROOT/val/'
+
+# testing set
+CUDA_VISIBLE_DEVICES=-1 python det3d/datasets/waymo/waymo_converter.py --record_path 'WAYMO_DATASET_ROOT/tfrecord_testing/*.tfrecord' --root_path 'WAYMO_DATASET_ROOT/test/'
+```
+
+Create a symlink to the dataset root
+```bash
+mkdir data && cd data
+ln -s WAYMO_DATASET_ROOT Waymo
+```
+Remember to change the WAYMO_DATASET_ROOT to the actual path in your system.
+
+
+#### Create info files
+
+```bash
+# One Sweep Infos
+python tools/create_data.py waymo_data_prep --root_path=data/Waymo --split train --nsweeps=1
+
+python tools/create_data.py waymo_data_prep --root_path=data/Waymo --split val --nsweeps=1
+
+python tools/create_data.py waymo_data_prep --root_path=data/Waymo --split test --nsweeps=1
+
+# Two Sweep Infos
+python tools/create_data.py waymo_data_prep --root_path=data/Waymo --split train --nsweeps=2
+
+python tools/create_data.py waymo_data_prep --root_path=data/Waymo --split val --nsweeps=2
+
+python tools/create_data.py waymo_data_prep --root_path=data/Waymo --split test --nsweeps=2
+
+# More Sweep Infos etc.
+```
+
+In the end, the data and info files should be organized as follows
+
+```
+└── CenterFormer
+ └── data
+ └── Waymo
+ ├── tfrecord_training
+ ├── tfrecord_validation
+ ├── train <-- all training frames and annotations
+ ├── val <-- all validation frames and annotations
+ ├── test <-- all testing frames and annotations
+ ├── infos_train_01sweeps_filter_zero_gt.pkl
+ ├── infos_train_02sweeps_filter_zero_gt.pkl
+ ├── infos_val_01sweeps_filter_zero_gt.pkl
+ ├── infos_val_02sweeps_filter_zero_gt.pkl
+ ├── infos_test_01sweeps_filter_zero_gt.pkl
+ ├── infos_test_02sweeps_filter_zero_gt.pkl
+ ├── ...
+```
+
+### Train & Evaluate in Command Line
+
+Use the following command to start a distributed training using 4 GPUs. The models and logs will be saved to ```work_dirs/CONFIG_NAME```.
+
+```bash
+python -m torch.distributed.launch --nproc_per_node=4 ./tools/train.py CONFIG_PATH
+```
+
+For distributed testing with 4 gpus,
+
+```bash
+python -m torch.distributed.launch --nproc_per_node=4 ./tools/dist_test.py CONFIG_PATH --work_dir work_dirs/CONFIG_NAME --checkpoint work_dirs/CONFIG_NAME/latest.pth
+```
+
+For testing with one gpu and see the inference time,
+
+```bash
+python ./tools/dist_test.py CONFIG_PATH --work_dir work_dirs/CONFIG_NAME --checkpoint work_dirs/CONFIG_NAME/latest.pth --speed_test
+```
+
+This will generate a `my_preds.bin` file in the work_dir. You can create submission to Waymo server using waymo-open-dataset code by following the instructions [here](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md).
+
+If you want to do local evaluation (e.g. for a subset), generate the gt prediction bin files using the script below and follow the waymo instructions [here](https://github.com/waymo-research/waymo-open-dataset/blob/master/docs/quick_start.md).
+
+```bash
+python det3d/datasets/waymo/waymo_common.py --info_path data/Waymo/infos_val_01sweeps_filter_zero_gt.pkl --result_path data/Waymo/ --gt
+```
+
+### Test Set
+
+Add the ```--testset``` flag to the end.
+
+```bash
+python ./tools/dist_test.py CONFIG_PATH --work_dir work_dirs/CONFIG_NAME --checkpoint work_dirs/CONFIG_NAME/latest.pth --testset
+```
diff --git a/docs/teaser.png b/docs/teaser.png
new file mode 100644
index 0000000..9488bdb
Binary files /dev/null and b/docs/teaser.png differ
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..454fbb9
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,31 @@
+numba
+matplotlib
+fire
+protobuf
+opencv-python
+opencv-contrib-python
+pybind11
+easydict
+open3d-python
+terminaltables
+pytest-runner
+addict
+pycocotools
+imagecorruptions
+objgraph
+cachetools
+descartes
+jupyter
+matplotlib
+motmetrics<=1.1.3
+numpy
+pandas>=0.24
+Pillow<=6.2.1 # Latest Pillow is incompatible with current torchvision, https://github.com/pytorch/vision/issues/1712
+pyquaternion>=0.9.5
+scikit-learn
+Shapely
+tqdm
+pyyaml
+requests
+nuscenes-devkit==1.0.5
+einops
diff --git a/setup.sh b/setup.sh
new file mode 100644
index 0000000..0635cbb
--- /dev/null
+++ b/setup.sh
@@ -0,0 +1,14 @@
+# (Optional) DCN is supported in CenterPoint, but not used in CenterFormer
+# cd det3d/ops/dcn
+# python setup.py build_ext --inplace
+
+# cd .. && cd iou3d_nms
+# python setup.py build_ext --inplace
+
+cd det3d/ops/iou3d_nms
+python setup.py build_ext --inplace
+
+cd ../.. && cd models/ops/
+python setup.py build install
+# unit test (should see all checking is True)
+python test.py
diff --git a/tools/compare_model.py b/tools/compare_model.py
new file mode 100644
index 0000000..b73af09
--- /dev/null
+++ b/tools/compare_model.py
@@ -0,0 +1,34 @@
+import torch
+import numpy as np
+import sys
+
+def compare_models(model_1, model_2):
+ models_differ = 0
+
+ print(model_1["state_dict"]['neck.transformer_layer.layers.1.2.fn.net.0.weight'])
+
+ print(model_2["state_dict"]['neck.transformer_layer.layers.1.2.fn.net.0.weight'])
+
+ # for key_item_1, key_item_2 in zip(model_1["state_dict"].items(), model_2["state_dict"].items()):
+ # if torch.equal(key_item_1[1], key_item_2[1]):
+ # # print('match found at', key_item_1[0])
+ # pass
+ # else:
+ # models_differ += 1
+ # if (key_item_1[0] == key_item_2[0]):
+ # print('Mismtach found at', key_item_1[0])
+ # # print(key_item_1[1])
+ # # print(key_item_2[1])
+ # # raise Exception
+ # else:
+ # raise Exception
+ # if models_differ == 0:
+ # print('Models match perfectly! :)')
+
+model1_path='/mnt/truenas/scratch/zixiang.zhou/code/CenterPoint/work_dirs/waymo_centerpoint_voxelnet_transforemer_SCA_multiscale_CBAM_test_6epoch/pre_epoch_11.pth'
+model2_path='/mnt/truenas/scratch/zixiang.zhou/code/CenterPoint/work_dirs/waymo_centerpoint_voxelnet_transforemer_SCA_multiscale_CBAM_test_6epoch/pre_epoch_19.pth'
+
+checkpoint1 = torch.load(model1_path, map_location="cpu")
+checkpoint2 = torch.load(model2_path, map_location="cpu")
+
+compare_models(checkpoint1,checkpoint2)
diff --git a/tools/convert_voxelnet.py b/tools/convert_voxelnet.py
new file mode 100644
index 0000000..9b56fbc
--- /dev/null
+++ b/tools/convert_voxelnet.py
@@ -0,0 +1,186 @@
+import argparse
+import copy
+from io import UnsupportedOperation
+import json
+import os
+import sys
+import os.path as osp
+from collections import OrderedDict
+
+try:
+ import apex
+except:
+ print("No APEX!")
+import numpy as np
+import torch
+import yaml
+from det3d import __version__, torchie
+from det3d.datasets import build_dataloader, build_dataset
+from det3d.models import build_detector
+from det3d.torchie import Config
+from det3d.torchie.apis import (
+ batch_processor,
+ build_optimizer,
+ get_root_logger,
+ init_dist,
+ set_random_seed,
+ train_detector,
+)
+from det3d.torchie.trainer import get_dist_info, load_checkpoint
+from det3d.torchie.trainer.utils import all_gather, synchronize
+from torch.nn.parallel import DistributedDataParallel
+import pickle
+import time
+
+def convert_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict into a module
+ """
+ unexpected_keys = []
+ shape_mismatch_pairs = []
+
+ own_state = module.state_dict()
+ for name, param in state_dict.items():
+ # a hacky fixed to load a new voxelnet
+ if name not in own_state:
+ if name[:20] == 'backbone.middle_conv':
+ index = int(name[20:].split('.')[1])
+
+ if index in [0, 1, 2]:
+ new_name = 'backbone.conv_input.{}.{}'.format(str(index), name[23:])
+ elif index in [3, 4]:
+ new_name = 'backbone.conv1.{}.{}'.format(str(index-3), name[23:])
+ elif index in [5, 6, 7, 8, 9]:
+ new_name = 'backbone.conv2.{}.{}'.format(str(index-5), name[23:])
+ elif index in [10, 11, 12, 13, 14]:
+ new_name = 'backbone.conv3.{}.{}'.format(str(index-10), name[24:])
+ elif index in [15, 16, 17, 18, 19]:
+ new_name = 'backbone.conv4.{}.{}'.format(str(index-15), name[24:])
+ elif index in [20, 21, 22]:
+ new_name = 'backbone.extra_conv.{}.{}'.format(str(index-20), name[24:])
+ else:
+ raise NotImplementedError(index)
+
+ if param.size() != own_state[new_name].size():
+ shape_mismatch_pairs.append([name, own_state[name].size(), param.size()])
+ continue
+
+ own_state[new_name].copy_(param)
+ print("load {}'s param from {}".format(new_name, name))
+ continue
+
+ unexpected_keys.append(name)
+ continue
+ if isinstance(param, torch.nn.Parameter):
+ # backwards compatibility for serialized parameters
+ param = param.data
+ if param.size() != own_state[name].size():
+ shape_mismatch_pairs.append([name, own_state[name].size(), param.size()])
+ continue
+ own_state[name].copy_(param)
+
+ all_missing_keys = set(own_state.keys()) - set(state_dict.keys())
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [key for key in all_missing_keys if "num_batches_tracked" not in key]
+
+ err_msg = []
+ if unexpected_keys:
+ err_msg.append(
+ "unexpected key in source state_dict: {}\n".format(
+ ", ".join(unexpected_keys)
+ )
+ )
+ if missing_keys:
+ err_msg.append(
+ "missing keys in source state_dict: {}\n".format(", ".join(missing_keys))
+ )
+ if shape_mismatch_pairs:
+ mismatch_info = "these keys have mismatched shape:\n"
+ header = ["key", "expected shape", "loaded shape"]
+ table_data = [header] + shape_mismatch_pairs
+ table = AsciiTable(table_data)
+ err_msg.append(mismatch_info + table.table)
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(0, "The model and loaded state dict do not match exactly\n")
+ err_msg = "\n".join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train a detector")
+ parser.add_argument("config", help="train config file path")
+ parser.add_argument("--work_dir", help="the dir to save logs and models")
+ parser.add_argument(
+ "--checkpoint", help="the dir to checkpoint which the model read from"
+ )
+ args = parser.parse_args()
+
+ return args
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+
+
+def save_checkpoint(model, filename, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError("meta must be a dict or None, but got {}".format(type(meta)))
+
+ torchie.mkdir_or_exist(osp.dirname(filename))
+ if hasattr(model, "module"):
+ model = model.module
+
+ checkpoint = {"meta": meta, "state_dict": weights_to_cpu(model.state_dict())}
+
+ torch.save(checkpoint, filename)
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ # update configs according to CLI args
+ if args.work_dir is not None:
+ cfg.work_dir = args.work_dir
+
+ model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+
+ checkpoint = torch.load(args.checkpoint, map_location='cpu')
+ state_dict = checkpoint['state_dict']
+
+ if list(state_dict.keys())[0].startswith("module."):
+ state_dict = {k[7:]: v for k, v in checkpoint["state_dict"].items()}
+
+ convert_state_dict(model, state_dict)
+
+ save_checkpoint(model, osp.join(args.work_dir, 'voxelnet_converted.pth'))
+
+main()
\ No newline at end of file
diff --git a/tools/create_data.py b/tools/create_data.py
new file mode 100644
index 0000000..dffad3f
--- /dev/null
+++ b/tools/create_data.py
@@ -0,0 +1,33 @@
+import copy
+from pathlib import Path
+import pickle
+
+import fire, os
+
+from det3d.datasets.nuscenes import nusc_common as nu_ds
+from det3d.datasets.utils.create_gt_database import create_groundtruth_database
+from det3d.datasets.waymo import waymo_common as waymo_ds
+
+def nuscenes_data_prep(root_path, version, nsweeps=10, filter_zero=True):
+ nu_ds.create_nuscenes_infos(root_path, version=version, nsweeps=nsweeps, filter_zero=filter_zero)
+ create_groundtruth_database(
+ "NUSC",
+ root_path,
+ Path(root_path) / "infos_train_{:02d}sweeps_withvelo_filter_{}.pkl".format(nsweeps, filter_zero),
+ nsweeps=nsweeps,
+ )
+
+def waymo_data_prep(root_path, split, nsweeps=1):
+ waymo_ds.create_waymo_infos(root_path, split=split, nsweeps=nsweeps)
+ if split == 'train':
+ create_groundtruth_database(
+ "WAYMO",
+ root_path,
+ Path(root_path) / "infos_train_{:02d}sweeps_filter_zero_gt.pkl".format(nsweeps),
+ used_classes=['VEHICLE', 'CYCLIST', 'PEDESTRIAN'],
+ nsweeps=nsweeps
+ )
+
+
+if __name__ == "__main__":
+ fire.Fire()
diff --git a/tools/demo.py b/tools/demo.py
new file mode 100644
index 0000000..319758e
--- /dev/null
+++ b/tools/demo.py
@@ -0,0 +1,150 @@
+import argparse
+import copy
+import json
+import os
+import sys
+
+try:
+ import apex
+except:
+ print("No APEX!")
+import numpy as np
+import torch
+import yaml
+from det3d import torchie
+from det3d.datasets import build_dataloader, build_dataset
+from det3d.models import build_detector
+from det3d.torchie import Config
+from det3d.torchie.apis import (
+ batch_processor,
+ build_optimizer,
+ get_root_logger,
+ init_dist,
+ set_random_seed,
+ train_detector,
+)
+from det3d.torchie.trainer import load_checkpoint
+import pickle
+import time
+from matplotlib import pyplot as plt
+from det3d.torchie.parallel import collate, collate_kitti
+from torch.utils.data import DataLoader
+import matplotlib.cm as cm
+import subprocess
+import cv2
+from tools.demo_utils import visual, visual_attention
+from collections import defaultdict
+
+def convert_box(info):
+ boxes = info["gt_boxes"].astype(np.float32)
+ names = info["gt_names"]
+
+ assert len(boxes) == len(names)
+
+ detection = {}
+
+ detection['box3d_lidar'] = boxes
+
+ # dummy value
+ detection['label_preds'] = np.zeros(len(boxes))
+ detection['scores'] = np.ones(len(boxes))
+
+ return detection
+
+def convert_box_waymo(info):
+ boxes = info.astype(np.float32)[0]
+
+ detection = {}
+
+ detection['box3d_lidar'] = boxes[:,:7]
+
+ # dummy value
+ detection['label_preds'] = np.zeros(len(boxes))
+ detection['scores'] = np.ones(len(boxes))
+
+ return detection
+
+
+def main():
+ cfg = Config.fromfile('/mnt/truenas/scratch/zixiang.zhou/code/CenterPoint/configs/waymo/voxelnet/waymo_centerpoint_voxelnet_transforemer_SCA_multiscale_CBAM_36epoch.py')
+
+ model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+
+ dataset = build_dataset(cfg.data.val)
+
+ torch.manual_seed(1)
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=1,
+ sampler=None,
+ shuffle=True,
+ num_workers=8,
+ collate_fn=collate_kitti,
+ pin_memory=False,
+ )
+
+ checkpoint = load_checkpoint(model, 'work_dirs/waymo_centerpoint_voxelnet_transforemer_SCA_multiscale_CBAM_36epoch/epoch_36.pth', map_location="cpu")
+ model.eval()
+
+ model = model.cuda()
+
+ cpu_device = torch.device("cpu")
+
+ points_list = []
+ neighbor_points_list = []
+ gt_annos = []
+ detections = []
+
+ for i, data_batch in enumerate(data_loader):
+ if i>10:
+ break
+ if 'gt_boxes_and_cls' in data_batch:
+ gt_annos.append(convert_box_waymo(data_batch['gt_boxes_and_cls'].cpu().numpy()))
+
+ points = data_batch['points'][0][:,0:3].cpu().numpy()
+ with torch.no_grad():
+ outputs = batch_processor(
+ model, data_batch, train_mode=False, local_rank=0,
+ )
+ for output in outputs:
+ for k, v in output.items():
+ if k not in [
+ "metadata",
+ ]:
+ output[k] = v.to(cpu_device)
+ detections.append(output)
+
+ points_list.append(points.T)
+
+
+ print('Done model inference. Please wait a minute, the matplotlib is a little slow...')
+
+ for i in range(len(points_list)):
+ visual(points_list[i], gt_annos[i], detections[i], i)
+ print("Rendered Image {}".format(i))
+
+ image_folder = 'demo'
+ video_name = 'video.avi'
+
+ images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
+ images.sort()
+ frame = cv2.imread(os.path.join(image_folder, images[0]))
+ height, width, layers = frame.shape
+
+ video = cv2.VideoWriter(video_name, 0, 1, (width,height))
+ cv2_images = []
+
+ for image in images:
+ cv2_images.append(cv2.imread(os.path.join(image_folder, image)))
+
+ for img in cv2_images:
+ video.write(img)
+
+ cv2.destroyAllWindows()
+ video.release()
+
+ print("Successfully save video in the main folder")
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/demo_utils.py b/tools/demo_utils.py
new file mode 100644
index 0000000..140763b
--- /dev/null
+++ b/tools/demo_utils.py
@@ -0,0 +1,374 @@
+"""The following code is takend from the nuscenes-devkit"""
+
+import copy
+import os.path as osp
+import struct
+from abc import ABC, abstractmethod
+from functools import reduce
+from typing import Tuple, List, Dict
+
+import cv2
+import numpy as np
+from matplotlib.axes import Axes
+from sklearn.metrics.pairwise import cosine_similarity
+from pyquaternion import Quaternion
+from matplotlib import pyplot as plt
+
+
+def view_points(points: np.ndarray, view: np.ndarray, normalize: bool) -> np.ndarray:
+ """
+ This is a helper class that maps 3d points to a 2d plane. It can be used to implement both perspective and
+ orthographic projections. It first applies the dot product between the points and the view. By convention,
+ the view should be such that the data is projected onto the first 2 axis. It then optionally applies a
+ normalization along the third dimension.
+
+ For a perspective projection the view should be a 3x3 camera matrix, and normalize=True
+ For an orthographic projection with translation the view is a 3x4 matrix and normalize=False
+ For an orthographic projection without translation the view is a 3x3 matrix (optionally 3x4 with last columns
+ all zeros) and normalize=False
+
+ :param points: Matrix of points, where each point (x, y, z) is along each column.
+ :param view: . Defines an arbitrary projection (n <= 4).
+ The projection should be such that the corners are projected onto the first 2 axis.
+ :param normalize: Whether to normalize the remaining coordinate (along the third axis).
+ :return: . Mapped point. If normalize=False, the third coordinate is the height.
+ """
+
+ assert view.shape[0] <= 4
+ assert view.shape[1] <= 4
+ assert points.shape[0] == 3
+
+ viewpad = np.eye(4)
+ viewpad[:view.shape[0], :view.shape[1]] = view
+
+ nbr_points = points.shape[1]
+
+ # Do operation in homogenous coordinates.
+ points = np.concatenate((points, np.ones((1, nbr_points))))
+ points = np.dot(viewpad, points)
+ points = points[:3, :]
+
+ if normalize:
+ points = points / points[2:3, :].repeat(3, 0).reshape(3, nbr_points)
+
+ return points
+
+def _second_det_to_nusc_box(detection):
+ box3d = detection["box3d_lidar"]
+ scores = detection["scores"]
+ labels = detection["label_preds"]
+ box3d[:, -1] = -box3d[:, -1] - np.pi / 2
+ box_list = []
+ for i in range(box3d.shape[0]):
+ quat = Quaternion(axis=[0, 0, 1], radians=box3d[i, -1])
+ velocity = (*box3d[i, 6:8], 0.0)
+ box = Box(
+ list(box3d[i, :3]),
+ list(box3d[i, 3:6]),
+ quat,
+ label=labels[i],
+ score=scores[i],
+ velocity=velocity,
+ )
+ box_list.append(box)
+ return box_list
+
+
+class Box:
+ """ Simple data class representing a 3d box including, label, score and velocity. """
+
+ def __init__(self,
+ center: List[float],
+ size: List[float],
+ orientation: Quaternion,
+ label: int = np.nan,
+ score: float = np.nan,
+ velocity: Tuple = (np.nan, np.nan, np.nan),
+ name: str = None,
+ token: str = None):
+ """
+ :param center: Center of box given as x, y, z.
+ :param size: Size of box in width, length, height.
+ :param orientation: Box orientation.
+ :param label: Integer label, optional.
+ :param score: Classification score, optional.
+ :param velocity: Box velocity in x, y, z direction.
+ :param name: Box name, optional. Can be used e.g. for denote category name.
+ :param token: Unique string identifier from DB.
+ """
+ # print(center.shape)
+ assert not np.any(np.isnan(center))
+ assert not np.any(np.isnan(size))
+ assert len(center) == 3
+ assert len(size) == 3
+ assert type(orientation) == Quaternion
+
+ self.center = np.array(center)
+ self.wlh = np.array(size)
+ self.orientation = orientation
+ self.label = int(label) if not np.isnan(label) else label
+ self.score = float(score) if not np.isnan(score) else score
+ self.velocity = np.array(velocity)
+ self.name = name
+ self.token = token
+
+ def __eq__(self, other):
+ center = np.allclose(self.center, other.center)
+ wlh = np.allclose(self.wlh, other.wlh)
+ orientation = np.allclose(self.orientation.elements, other.orientation.elements)
+ label = (self.label == other.label) or (np.isnan(self.label) and np.isnan(other.label))
+ score = (self.score == other.score) or (np.isnan(self.score) and np.isnan(other.score))
+ vel = (np.allclose(self.velocity, other.velocity) or
+ (np.all(np.isnan(self.velocity)) and np.all(np.isnan(other.velocity))))
+
+ return center and wlh and orientation and label and score and vel
+
+ def __repr__(self):
+ repr_str = 'label: {}, score: {:.2f}, xyz: [{:.2f}, {:.2f}, {:.2f}], wlh: [{:.2f}, {:.2f}, {:.2f}], ' \
+ 'rot axis: [{:.2f}, {:.2f}, {:.2f}], ang(degrees): {:.2f}, ang(rad): {:.2f}, ' \
+ 'vel: {:.2f}, {:.2f}, {:.2f}, name: {}, token: {}'
+
+ return repr_str.format(self.label, self.score, self.center[0], self.center[1], self.center[2], self.wlh[0],
+ self.wlh[1], self.wlh[2], self.orientation.axis[0], self.orientation.axis[1],
+ self.orientation.axis[2], self.orientation.degrees, self.orientation.radians,
+ self.velocity[0], self.velocity[1], self.velocity[2], self.name, self.token)
+
+ @property
+ def rotation_matrix(self) -> np.ndarray:
+ """
+ Return a rotation matrix.
+ :return: . The box's rotation matrix.
+ """
+ return self.orientation.rotation_matrix
+
+ def translate(self, x: np.ndarray) -> None:
+ """
+ Applies a translation.
+ :param x: . Translation in x, y, z direction.
+ """
+ self.center += x
+
+ def rotate(self, quaternion: Quaternion) -> None:
+ """
+ Rotates box.
+ :param quaternion: Rotation to apply.
+ """
+ self.center = np.dot(quaternion.rotation_matrix, self.center)
+ self.orientation = quaternion * self.orientation
+ self.velocity = np.dot(quaternion.rotation_matrix, self.velocity)
+
+ def corners(self, wlh_factor: float = 1.0) -> np.ndarray:
+ """
+ Returns the bounding box corners.
+ :param wlh_factor: Multiply w, l, h by a factor to scale the box.
+ :return: . First four corners are the ones facing forward.
+ The last four are the ones facing backwards.
+ """
+ w, l, h = self.wlh * wlh_factor
+
+ # 3D bounding box corners. (Convention: x points forward, y to the left, z up.)
+ x_corners = l / 2 * np.array([1, 1, 1, 1, -1, -1, -1, -1])
+ y_corners = w / 2 * np.array([1, -1, -1, 1, 1, -1, -1, 1])
+ z_corners = h / 2 * np.array([1, 1, -1, -1, 1, 1, -1, -1])
+ corners = np.vstack((x_corners, y_corners, z_corners))
+
+ # Rotate
+ corners = np.dot(self.orientation.rotation_matrix, corners)
+
+ # Translate
+ x, y, z = self.center
+ corners[0, :] = corners[0, :] + x
+ corners[1, :] = corners[1, :] + y
+ corners[2, :] = corners[2, :] + z
+
+ return corners
+
+ def bottom_corners(self) -> np.ndarray:
+ """
+ Returns the four bottom corners.
+ :return: . Bottom corners. First two face forward, last two face backwards.
+ """
+ return self.corners()[:, [2, 3, 7, 6]]
+
+ def render(self,
+ axis: Axes,
+ view: np.ndarray = np.eye(3),
+ normalize: bool = False,
+ colors: Tuple = ('b', 'r', 'k'),
+ linewidth: float = 2) -> None:
+ """
+ Renders the box in the provided Matplotlib axis.
+ :param axis: Axis onto which the box should be drawn.
+ :param view: . Define a projection in needed (e.g. for drawing projection in an image).
+ :param normalize: Whether to normalize the remaining coordinate.
+ :param colors: (: 3). Valid Matplotlib colors ( or normalized RGB tuple) for front,
+ back and sides.
+ :param linewidth: Width in pixel of the box sides.
+ """
+ corners = view_points(self.corners(), view, normalize=normalize)[:2, :]
+
+ def draw_rect(selected_corners, color):
+ prev = selected_corners[-1]
+ for corner in selected_corners:
+ axis.plot([prev[0], corner[0]], [prev[1], corner[1]], color=color, linewidth=linewidth)
+ prev = corner
+
+ # Draw the sides
+ for i in range(4):
+ axis.plot([corners.T[i][0], corners.T[i + 4][0]],
+ [corners.T[i][1], corners.T[i + 4][1]],
+ color=colors[2], linewidth=linewidth)
+
+ # Draw front (first 4 corners) and rear (last 4 corners) rectangles(3d)/lines(2d)
+ draw_rect(corners.T[:4], colors[0])
+ draw_rect(corners.T[4:], colors[1])
+
+ # Draw line indicating the front
+ center_bottom_forward = np.mean(corners.T[2:4], axis=0)
+ center_bottom = np.mean(corners.T[[2, 3, 7, 6]], axis=0)
+ axis.plot([center_bottom[0], center_bottom_forward[0]],
+ [center_bottom[1], center_bottom_forward[1]],
+ color=colors[0], linewidth=linewidth)
+
+ def render_cv2(self,
+ im: np.ndarray,
+ view: np.ndarray = np.eye(3),
+ normalize: bool = False,
+ colors: Tuple = ((0, 0, 255), (255, 0, 0), (155, 155, 155)),
+ linewidth: int = 2) -> None:
+ """
+ Renders box using OpenCV2.
+ :param im: . Image array. Channels are in BGR order.
+ :param view: . Define a projection if needed (e.g. for drawing projection in an image).
+ :param normalize: Whether to normalize the remaining coordinate.
+ :param colors: ((R, G, B), (R, G, B), (R, G, B)). Colors for front, side & rear.
+ :param linewidth: Linewidth for plot.
+ """
+ corners = view_points(self.corners(), view, normalize=normalize)[:2, :]
+
+ def draw_rect(selected_corners, color):
+ prev = selected_corners[-1]
+ for corner in selected_corners:
+ cv2.line(im,
+ (int(prev[0]), int(prev[1])),
+ (int(corner[0]), int(corner[1])),
+ color, linewidth)
+ prev = corner
+
+ # Draw the sides
+ for i in range(4):
+ cv2.line(im,
+ (int(corners.T[i][0]), int(corners.T[i][1])),
+ (int(corners.T[i + 4][0]), int(corners.T[i + 4][1])),
+ colors[2][::-1], linewidth)
+
+ # Draw front (first 4 corners) and rear (last 4 corners) rectangles(3d)/lines(2d)
+ draw_rect(corners.T[:4], colors[0][::-1])
+ draw_rect(corners.T[4:], colors[1][::-1])
+
+ # Draw line indicating the front
+ center_bottom_forward = np.mean(corners.T[2:4], axis=0)
+ center_bottom = np.mean(corners.T[[2, 3, 7, 6]], axis=0)
+ cv2.line(im,
+ (int(center_bottom[0]), int(center_bottom[1])),
+ (int(center_bottom_forward[0]), int(center_bottom_forward[1])),
+ colors[0][::-1], linewidth)
+
+ def copy(self) -> 'Box':
+ """
+ Create a copy of self.
+ :return: A copy.
+ """
+ return copy.deepcopy(self)
+
+
+def visual(points, gt_anno, det, i, eval_range=50, conf_th=0.5, neighbor_points=None):
+ _, ax = plt.subplots(1, 1, figsize=(9, 9), dpi=200)
+ points = remove_close(points, radius=3)
+ points = view_points(points[:3, :], np.eye(4), normalize=False)
+
+ dists = np.sqrt(np.sum(points[:2, :] ** 2, axis=0))
+ colors = np.minimum(1, dists / eval_range)
+ ax.scatter(points[0, :], points[1, :], c=colors, s=0.2)
+
+ cls_color = ['b','g','m']
+ boxes_est = _second_det_to_nusc_box(det)
+ # Show GT boxes.
+ if gt_anno is not None:
+ boxes_gt = _second_det_to_nusc_box(gt_anno)
+ for box in boxes_gt:
+ box.render(ax, view=np.eye(4), colors=('r', 'r', 'r'), linewidth=0.1)
+
+ print(len(boxes_est))
+
+ # Show EST boxes.
+ for box in boxes_est:
+ if box.score >= conf_th:
+ box.render(ax, view=np.eye(4), colors=(cls_color[box.label], cls_color[box.label], cls_color[box.label]), linewidth=1)
+ elif box.score >= 0.3:
+ box.render(ax, view=np.eye(4), colors=('y', 'y', 'y'), linewidth=1)
+
+
+ axes_limit = eval_range + 3 # Slightly bigger to include boxes that extend beyond the range.
+ ax.set_xlim(-axes_limit, axes_limit)
+ ax.set_ylim(-axes_limit, axes_limit)
+ plt.axis('off')
+
+ plt.savefig("demo/file%02d.png" % i)
+ plt.close()
+
+def visual_attention(points, gt_anno, det, i, eval_range=80, conf_th=0.5, neighbor_points=None):
+ _, ax = plt.subplots(2, 4, figsize=(24, 12), dpi=200)
+ points = remove_close(points, radius=3)
+ points = view_points(points[:3, :], np.eye(4), normalize=False)
+
+ dists = np.sqrt(np.sum(points[:2, :] ** 2, axis=0))
+ colors = np.minimum(1, dists / eval_range)
+
+ det['sampled_points_loc'] = det['sampled_points_loc'].numpy()
+ det['attention'] = det['attention'].numpy()
+
+ boxes_gt = _second_det_to_nusc_box(gt_anno)
+ boxes_est = _second_det_to_nusc_box(det)
+
+ for a in range(2):
+ for b in range(4):
+
+ ax[a,b].scatter(points[0, :], points[1, :], c=colors, s=0.2)
+
+ for pts,feat,box in zip(det['sampled_points_loc'],det['attention'],boxes_est):
+ if box.score >= conf_th:
+ ax[a,b].scatter(pts[:, 0], pts[:, 1], c=feat[a,b], norm = plt.Normalize(vmin=0, vmax=0.5),cmap='Blues', s=0.4)
+ # ax.scatter(pts[:, 0], pts[:, 1], c='g', s=0.4)
+
+ # Show GT boxes.
+ for box in boxes_gt:
+ box.render(ax[a,b], view=np.eye(4), colors=('r', 'r', 'r'), linewidth=0.3)
+
+ # Show EST boxes.
+ for box in boxes_est:
+ if box.score >= conf_th:
+ box.render(ax[a,b], view=np.eye(4), colors=('b', 'b', 'b'), linewidth=0.3)
+ # elif box.score >= 0.3:
+ # box.render(ax, view=np.eye(4), colors=('y', 'y', 'y'), linewidth=0.3)
+
+
+ axes_limit = eval_range + 3 # Slightly bigger to include boxes that extend beyond the range.
+ ax[a,b].set_xlim(-axes_limit, axes_limit)
+ ax[a,b].set_ylim(-axes_limit, axes_limit)
+ # plt.axis('off')
+
+ plt.savefig("demo/file%02d.png" % i)
+ plt.close()
+
+
+def remove_close(points, radius: float) -> None:
+ """
+ Removes point too close within a certain radius from origin.
+ :param radius: Radius below which points are removed.
+ """
+ x_filt = np.abs(points[0, :]) < radius
+ y_filt = np.abs(points[1, :]) < radius
+ not_close = np.logical_not(np.logical_and(x_filt, y_filt))
+ points = points[:, not_close]
+ return points
diff --git a/tools/dist_test.py b/tools/dist_test.py
new file mode 100644
index 0000000..7c29ba4
--- /dev/null
+++ b/tools/dist_test.py
@@ -0,0 +1,224 @@
+import argparse
+import copy
+import json
+import os
+import sys
+
+# try:
+# import apex
+# except:
+# print("No APEX!")
+import numpy as np
+import torch
+import yaml
+from det3d import torchie
+from det3d.datasets import build_dataloader, build_dataset
+from det3d.models import build_detector
+from det3d.torchie import Config
+from det3d.torchie.apis import (
+ batch_processor,
+ build_optimizer,
+ get_root_logger,
+ init_dist,
+ set_random_seed,
+ train_detector,
+)
+from det3d.torchie.trainer import get_dist_info, load_checkpoint
+from det3d.torchie.trainer.utils import all_gather, synchronize
+from torch.nn.parallel import DistributedDataParallel
+import pickle
+import time
+
+def save_pred(pred, root):
+ with open(os.path.join(root, "prediction.pkl"), "wb") as f:
+ pickle.dump(pred, f)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train a detector")
+ parser.add_argument("config", help="train config file path")
+ parser.add_argument("--work_dir", required=True, help="the dir to save logs and models")
+ parser.add_argument(
+ "--checkpoint", help="the dir to checkpoint which the model read from"
+ )
+ parser.add_argument(
+ "--txt_result",
+ type=bool,
+ default=False,
+ help="whether to save results to standard KITTI format of txt type",
+ )
+ parser.add_argument(
+ "--gpus",
+ type=int,
+ default=1,
+ help="number of gpus to use " "(only applicable to non-distributed training)",
+ )
+ parser.add_argument(
+ "--launcher",
+ choices=["none", "pytorch", "slurm", "mpi"],
+ default="none",
+ help="job launcher",
+ )
+ parser.add_argument("--speed_test", action="store_true")
+ parser.add_argument("--local_rank", type=int, default=0)
+ parser.add_argument("--testset", action="store_true")
+
+ args = parser.parse_args()
+ if "LOCAL_RANK" not in os.environ:
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
+
+ return args
+
+
+def main():
+
+ # torch.manual_seed(0)
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ # np.random.seed(0)
+
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ cfg.local_rank = args.local_rank
+
+ # update configs according to CLI args
+ if args.work_dir is not None:
+ cfg.work_dir = args.work_dir
+
+ distributed = False
+ if "WORLD_SIZE" in os.environ:
+ distributed = int(os.environ["WORLD_SIZE"]) > 1
+
+ if distributed:
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
+
+ cfg.gpus = torch.distributed.get_world_size()
+ else:
+ cfg.gpus = args.gpus
+
+ # init logger before other steps
+ logger = get_root_logger(cfg.log_level)
+ logger.info("Distributed testing: {}".format(distributed))
+ logger.info(f"torch.backends.cudnn.benchmark: {torch.backends.cudnn.benchmark}")
+
+ # change center number in model config based on test_cfg
+ if 'obj_num' in cfg.test_cfg:
+ cfg.model['neck']['obj_num'] = cfg.test_cfg['obj_num']
+ print('Use center number {} in inference'.format(cfg.model['neck']['obj_num']))
+ if 'score_threshold' in cfg.test_cfg:
+ cfg.model['neck']['score_threshold'] = cfg.test_cfg['score_threshold']
+ print('Use heatmap score threshold {} in inference'.format(cfg.model['neck']['score_threshold']))
+
+ model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+
+ if args.testset:
+ print("Use Test Set")
+ dataset = build_dataset(cfg.data.test)
+ else:
+ print("Use Val Set")
+ dataset = build_dataset(cfg.data.val)
+
+ data_loader = build_dataloader(
+ dataset,
+ batch_size=cfg.data.samples_per_gpu if not args.speed_test else 1,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False,
+ )
+
+ checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu")
+
+ # put model on gpus
+ if distributed:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = DistributedDataParallel(
+ model.cuda(cfg.local_rank),
+ device_ids=[cfg.local_rank],
+ output_device=cfg.local_rank,
+ # broadcast_buffers=False,
+ find_unused_parameters=True,
+ )
+ else:
+ # model = fuse_bn_recursively(model)
+ model = model.cuda()
+
+ pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ print('parameter size:', pytorch_total_params)
+
+ model.eval()
+ mode = "val"
+
+ logger.info(f"work dir: {args.work_dir}")
+ if cfg.local_rank == 0:
+ prog_bar = torchie.ProgressBar(len(data_loader.dataset) // cfg.gpus)
+
+ detections = {}
+ cpu_device = torch.device("cpu")
+
+ start = time.time()
+
+ start = 100
+ end = 100+int(len(dataset) * 1 /50)
+
+ time_start = 0
+ time_end = 0
+
+ for i, data_batch in enumerate(data_loader):
+ if i == start:
+ torch.cuda.synchronize()
+ time_start = time.time()
+
+ if i == end:
+ torch.cuda.synchronize()
+ time_end = time.time()
+
+ with torch.no_grad():
+ outputs = batch_processor(
+ model, data_batch, train_mode=False, local_rank=args.local_rank,
+ )
+ for output in outputs:
+ token = output["metadata"]["token"]
+ for k, v in output.items():
+ if k not in [
+ "metadata",
+ ]:
+ output[k] = v.to(cpu_device)
+ detections.update(
+ {token: output,}
+ )
+ if args.local_rank == 0:
+ prog_bar.update()
+
+ synchronize()
+
+ # torch.cuda.empty_cache()
+ all_predictions = all_gather(detections)
+
+ print("\n Total time per frame: ", (time_end - time_start) / (end - start))
+ # print("\n Total time per frame: ", (time_end - time_start) / (i+1))
+
+ if args.local_rank != 0:
+ return
+
+ predictions = {}
+ for p in all_predictions:
+ predictions.update(p)
+
+ if not os.path.exists(args.work_dir):
+ os.makedirs(args.work_dir)
+
+ save_pred(predictions, args.work_dir)
+
+ result_dict, _ = dataset.evaluation(copy.deepcopy(predictions), output_dir=args.work_dir, testset=args.testset)
+
+ if result_dict is not None:
+ for k, v in result_dict["results"].items():
+ print(f"Evaluation {k}: {v}")
+
+ if args.txt_result:
+ assert False, "No longer support kitti"
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/multi_sweep_inference.py b/tools/multi_sweep_inference.py
new file mode 100644
index 0000000..63b9900
--- /dev/null
+++ b/tools/multi_sweep_inference.py
@@ -0,0 +1,412 @@
+
+import rospy
+import ros_numpy
+import numpy as np
+import copy
+import json
+import os
+import sys
+import torch
+import yaml
+import time
+
+from std_msgs.msg import Header
+import sensor_msgs.point_cloud2 as pc2
+from nav_msgs.msg import Odometry
+from sensor_msgs.msg import PointCloud2, PointField
+from jsk_recognition_msgs.msg import BoundingBox, BoundingBoxArray
+from pyquaternion import Quaternion
+
+from det3d import __version__, torchie
+from det3d.models import build_detector
+from det3d.torchie import Config
+from det3d.core.input.voxel_generator import VoxelGenerator
+
+import cupy as cp
+from collections import deque
+from copy import deepcopy
+from functools import reduce
+
+
+def yaw2quaternion(yaw: float) -> Quaternion:
+ return Quaternion(axis=[0, 0, 1], radians=yaw)
+
+def transform_matrix(translation: np.ndarray = np.array([0, 0, 0]),
+ rotation: Quaternion = Quaternion([1, 0, 0, 0]),
+ inverse: bool = False) -> np.ndarray:
+ """
+ Convert pose to transformation matrix.
+ :param translation: . Translation in x, y, z.
+ :param rotation: Rotation in quaternions (w ri rj rk).
+ :param inverse: Whether to compute inverse transform matrix.
+ :return: . Transformation matrix.
+ """
+ tm = np.eye(4)
+ if inverse:
+ rot_inv = rotation.rotation_matrix.T
+ trans = np.transpose(-np.array(translation))
+ tm[:3, :3] = rot_inv
+ tm[:3, 3] = rot_inv.dot(trans)
+ else:
+ tm[:3, :3] = rotation.rotation_matrix
+ tm[:3, 3] = np.transpose(np.array(translation))
+ return tm
+
+
+def get_annotations_indices(types, thresh, label_preds, scores):
+ indexs = []
+ annotation_indices = []
+ for i in range(label_preds.shape[0]):
+ if label_preds[i] == types:
+ indexs.append(i)
+ for index in indexs:
+ if scores[index] >= thresh:
+ annotation_indices.append(index)
+ return annotation_indices
+
+
+def remove_low_score_nu(image_anno, thresh):
+ img_filtered_annotations = {}
+ label_preds_ = image_anno["label_preds"].detach().cpu().numpy()
+ scores_ = image_anno["scores"].detach().cpu().numpy()
+
+ car_indices = get_annotations_indices(0, 0.4, label_preds_, scores_)
+ truck_indices = get_annotations_indices(1, 0.4, label_preds_, scores_)
+ construction_vehicle_indices = get_annotations_indices(
+ 2, 0.4, label_preds_, scores_)
+ bus_indices = get_annotations_indices(3, 0.3, label_preds_, scores_)
+ trailer_indices = get_annotations_indices(4, 0.4, label_preds_, scores_)
+ barrier_indices = get_annotations_indices(5, 0.4, label_preds_, scores_)
+ motorcycle_indices = get_annotations_indices(
+ 6, 0.15, label_preds_, scores_)
+ bicycle_indices = get_annotations_indices(7, 0.15, label_preds_, scores_)
+ pedestrain_indices = get_annotations_indices(
+ 8, 0.12, label_preds_, scores_)
+ traffic_cone_indices = get_annotations_indices(
+ 9, 0.1, label_preds_, scores_)
+
+ for key in image_anno.keys():
+ if key == 'metadata':
+ continue
+ img_filtered_annotations[key] = (
+ image_anno[key][car_indices +
+ pedestrain_indices +
+ bicycle_indices +
+ bus_indices +
+ construction_vehicle_indices +
+ traffic_cone_indices +
+ trailer_indices +
+ barrier_indices +
+ truck_indices
+ ])
+
+ return img_filtered_annotations
+
+
+class Processor_ROS:
+ def __init__(self, config_path, model_path):
+ self.points = None
+ self.config_path = config_path
+ self.model_path = model_path
+ self.device = None
+ self.net = None
+ self.voxel_generator = None
+
+ self.lidar_deque = deque(maxlen=5)
+ self.current_frame = {
+ "lidar_stamp": None,
+ "lidar_seq": None,
+ "points": None,
+ "odom_seq": None,
+ "odom_stamp": None,
+ "translation": None,
+ "rotation": None
+ }
+ self.pc_list = deque(maxlen=5)
+ self.inputs = None
+
+ def initialize(self):
+ self.read_config()
+
+ def read_config(self):
+ config_path = self.config_path
+ cfg = Config.fromfile(self.config_path)
+ self.device = torch.device(
+ "cuda" if torch.cuda.is_available() else "cpu")
+ self.net = build_detector(
+ cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+ self.net.load_state_dict(torch.load(self.model_path)["state_dict"])
+ self.net = self.net.to(self.device).eval()
+
+ self.range = cfg.voxel_generator.range
+ self.voxel_size = cfg.voxel_generator.voxel_size
+ self.max_points_in_voxel = cfg.voxel_generator.max_points_in_voxel
+ self.max_voxel_num = cfg.voxel_generator.max_voxel_num
+ self.voxel_generator = VoxelGenerator(
+ voxel_size=self.voxel_size,
+ point_cloud_range=self.range,
+ max_num_points=self.max_points_in_voxel,
+ max_voxels=self.max_voxel_num[1],
+ )
+ # nuscenes dataset
+ lidar2imu_t = np.array([0.985793, 0.0, 1.84019])
+ lidar2imu_r = Quaternion([0.706749235, -0.01530099378, 0.0173974518, -0.7070846])
+
+ ## UDI dataset
+ # lidar2imu_t = np.array([1.50, 0., 1.42])
+ # lidar2imu_r = Quaternion([1., 0., 0., 0.])
+ self.lidar2imu = transform_matrix(lidar2imu_t, lidar2imu_r, inverse=True)
+ self.imu2lidar = transform_matrix(lidar2imu_t, lidar2imu_r, inverse=False)
+
+ def run(self):
+
+ # print(f"input points shape: {points.shape}")
+ # num_features = 5
+ # self.points = points.reshape([-1, num_features])
+
+ voxels, coords, num_points = self.voxel_generator.generate(self.points)
+ num_voxels = np.array([voxels.shape[0]], dtype=np.int64)
+ grid_size = self.voxel_generator.grid_size
+ coords = np.pad(coords, ((0, 0), (1, 0)),
+ mode='constant', constant_values=0)
+
+ voxels = torch.tensor(voxels, dtype=torch.float32, device=self.device)
+ coords = torch.tensor(coords, dtype=torch.int32, device=self.device)
+ num_points = torch.tensor(
+ num_points, dtype=torch.int32, device=self.device)
+ num_voxels = torch.tensor(
+ num_voxels, dtype=torch.int32, device=self.device)
+ # grid_size = torch.tensor(grid_size, dtype=torch.float32, device=self.device)
+
+ # t = time.time()
+ self.inputs = dict(
+ voxels=voxels,
+ num_points=num_points,
+ num_voxels=num_voxels,
+ coordinates=coords,
+ shape=[grid_size] # simulate a batch of one example
+ )
+ torch.cuda.synchronize()
+ t = time.time()
+
+ with torch.no_grad():
+ outputs = self.net(self.inputs, return_loss=False)[0]
+
+ torch.cuda.synchronize()
+ print(" network predict time cost:", time.time() - t)
+
+ outputs = remove_low_score_nu(outputs, 0.45)
+
+ boxes_lidar = outputs["box3d_lidar"].detach().cpu().numpy()
+ print(" predict boxes:", boxes_lidar.shape)
+
+ scores = outputs["scores"].detach().cpu().numpy()
+ types = outputs["label_preds"].detach().cpu().numpy()
+
+ boxes_lidar[:, -1] = -boxes_lidar[:, -1] - np.pi / 2
+
+ return scores, boxes_lidar, types
+
+ def get_lidar_data(self, input_points: dict):
+ print("get one frame lidar data.")
+ self.current_frame["lidar_stamp"] = input_points['stamp']
+ self.current_frame["lidar_seq"] = input_points['seq']
+ self.current_frame["points"] = input_points['points'].T
+ self.lidar_deque.append(deepcopy(self.current_frame))
+ if len(self.lidar_deque) == 5:
+
+ ref_from_car = self.imu2lidar
+ car_from_global = transform_matrix(self.lidar_deque[-1]['translation'], self.lidar_deque[-1]['rotation'], inverse=True)
+
+ ref_from_car_gpu = cp.asarray(ref_from_car)
+ car_from_global_gpu = cp.asarray(car_from_global)
+
+ for i in range(len(self.lidar_deque) - 1):
+ last_pc = self.lidar_deque[i]['points']
+ last_pc_gpu = cp.asarray(last_pc)
+
+ global_from_car = transform_matrix(self.lidar_deque[i]['translation'], self.lidar_deque[i]['rotation'], inverse=False)
+ car_from_current = self.lidar2imu
+ global_from_car_gpu = cp.asarray(global_from_car)
+ car_from_current_gpu = cp.asarray(car_from_current)
+
+ transform = reduce(
+ cp.dot,
+ [ref_from_car_gpu, car_from_global_gpu, global_from_car_gpu, car_from_current_gpu],
+ )
+ # tmp_1 = cp.dot(global_from_car_gpu, car_from_current_gpu)
+ # tmp_2 = cp.dot(car_from_global_gpu, tmp_1)
+ # transform = cp.dot(ref_from_car_gpu, tmp_2)
+
+ last_pc_gpu = cp.vstack((last_pc_gpu[:3, :], cp.ones(last_pc_gpu.shape[1])))
+ last_pc_gpu = cp.dot(transform, last_pc_gpu)
+
+ self.pc_list.append(last_pc_gpu[:3, :])
+
+ current_pc = self.lidar_deque[-1]['points']
+ current_pc_gpu = cp.asarray(current_pc)
+ self.pc_list.append(current_pc_gpu[:3,:])
+
+ all_pc = np.zeros((5, 0), dtype=float)
+ for i in range(len(self.pc_list)):
+ tmp_pc = cp.vstack((self.pc_list[i], cp.zeros((2, self.pc_list[i].shape[1]))))
+ tmp_pc = cp.asnumpy(tmp_pc)
+ ref_timestamp = self.lidar_deque[-1]['lidar_stamp'].to_sec()
+ timestamp = self.lidar_deque[i]['lidar_stamp'].to_sec()
+ tmp_pc[3, ...] = self.lidar_deque[i]['points'][3, ...]
+ tmp_pc[4, ...] = ref_timestamp - timestamp
+ all_pc = np.hstack((all_pc, tmp_pc))
+
+ all_pc = all_pc.T
+ print(f" concate pointcloud shape: {all_pc.shape}")
+
+ self.points = all_pc
+ sync_cloud = xyz_array_to_pointcloud2(all_pc[:, :3], stamp=self.lidar_deque[-1]["lidar_stamp"], frame_id="lidar_top")
+ pub_sync_cloud.publish(sync_cloud)
+ return True
+
+ def get_odom_data(self, input_odom):
+
+ self.current_frame["odom_stamp"] = input_odom.header.stamp
+ self.current_frame["odom_seq"] = input_odom.header.seq
+ x_t = input_odom.pose.pose.position.x
+ y_t = input_odom.pose.pose.position.y
+ z_t = input_odom.pose.pose.position.z
+ self.current_frame["translation"] = np.array([x_t, y_t, z_t])
+ x_r = input_odom.pose.pose.orientation.x
+ y_r = input_odom.pose.pose.orientation.y
+ z_r = input_odom.pose.pose.orientation.z
+ w_r = input_odom.pose.pose.orientation.w
+ self.current_frame["rotation"] = Quaternion([w_r, x_r, y_r, z_r])
+
+
+def get_xyz_points(cloud_array, remove_nans=True, dtype=np.float):
+ '''
+ '''
+ if remove_nans:
+ mask = np.isfinite(cloud_array['x']) & np.isfinite(
+ cloud_array['y']) & np.isfinite(cloud_array['z'])
+ cloud_array = cloud_array[mask]
+
+ points = np.zeros(cloud_array.shape + (5,), dtype=dtype)
+ points[..., 0] = cloud_array['x']
+ points[..., 1] = cloud_array['y']
+ points[..., 2] = cloud_array['z']
+ points[..., 3] = cloud_array['intensity']
+ return points
+
+
+def xyz_array_to_pointcloud2(points_sum, stamp=None, frame_id=None):
+ '''
+ Create a sensor_msgs.PointCloud2 from an array of points.
+ '''
+ msg = PointCloud2()
+ if stamp:
+ msg.header.stamp = stamp
+ if frame_id:
+ msg.header.frame_id = frame_id
+ msg.height = 1
+ msg.width = points_sum.shape[0]
+ msg.fields = [
+ PointField('x', 0, PointField.FLOAT32, 1),
+ PointField('y', 4, PointField.FLOAT32, 1),
+ PointField('z', 8, PointField.FLOAT32, 1)
+ # PointField('i', 12, PointField.FLOAT32, 1)
+ ]
+ msg.is_bigendian = False
+ msg.point_step = 12
+ msg.row_step = points_sum.shape[0]
+ msg.is_dense = int(np.isfinite(points_sum).all())
+ msg.data = np.asarray(points_sum, np.float32).tostring()
+ # msg.data = points_sum.astype(np.float32).tobytes()
+ return msg
+
+
+def rslidar_callback(msg):
+ # t_t = time.time()
+ arr_bbox = BoundingBoxArray()
+ msg_cloud = ros_numpy.point_cloud2.pointcloud2_to_array(msg)
+ np_p = get_xyz_points(msg_cloud, True)
+
+ print(" ")
+ seq = msg.header.seq
+ stamp = msg.header.stamp
+ input_points = {
+ 'stamp': stamp,
+ 'seq': seq,
+ 'points': np_p
+ }
+ if(proc_1.get_lidar_data(input_points)):
+ scores, dt_box_lidar, types = proc_1.run()
+
+ if scores.size != 0:
+ for i in range(scores.size):
+ bbox = BoundingBox()
+ bbox.header.frame_id = msg.header.frame_id
+ bbox.header.stamp = rospy.Time.now()
+ q = yaw2quaternion(float(dt_box_lidar[i][8]))
+ bbox.pose.orientation.x = q[1]
+ bbox.pose.orientation.y = q[2]
+ bbox.pose.orientation.z = q[3]
+ bbox.pose.orientation.w = q[0]
+ bbox.pose.position.x = float(dt_box_lidar[i][0])
+ bbox.pose.position.y = float(dt_box_lidar[i][1])
+ bbox.pose.position.z = float(dt_box_lidar[i][2])
+ bbox.dimensions.x = float(dt_box_lidar[i][4])
+ bbox.dimensions.y = float(dt_box_lidar[i][3])
+ bbox.dimensions.z = float(dt_box_lidar[i][5])
+ bbox.value = scores[i]
+ bbox.label = int(types[i])
+ arr_bbox.boxes.append(bbox)
+ # print("total callback time: ", time.time() - t_t)
+ arr_bbox.header.frame_id = msg.header.frame_id
+ arr_bbox.header.stamp = msg.header.stamp
+ if len(arr_bbox.boxes) is not 0:
+ pub_arr_bbox.publish(arr_bbox)
+ arr_bbox.boxes = []
+ else:
+ arr_bbox.boxes = []
+ pub_arr_bbox.publish(arr_bbox)
+
+
+def odom_callback(msg):
+ '''
+ get odom data
+ '''
+ proc_1.get_odom_data(msg)
+
+
+if __name__ == "__main__":
+
+ global proc
+ # CenterPoint
+ config_path = 'configs/centerpoint/nusc_centerpoint_pp_02voxel_circle_nms_demo.py'
+ model_path = 'models/last.pth'
+
+ proc_1 = Processor_ROS(config_path, model_path)
+
+ proc_1.initialize()
+
+ rospy.init_node('centerpoint_ros_node')
+ sub_lidar_topic = ["/velodyne_points",
+ "/top/rslidar_points",
+ "/points_raw",
+ "/aligned/point_cloud",
+ "/merged_cloud",
+ "/lidar_top",
+ "/roi_pclouds"]
+ sub_lidar = rospy.Subscriber(
+ sub_lidar_topic[5], PointCloud2, rslidar_callback, queue_size=1, buff_size=2**24)
+
+ sub_odom_topic = ["/golfcar/odom",
+ "/aligned/odometry",
+ "/odom"]
+
+ sub_odom = rospy.Subscriber(
+ sub_odom_topic[2], Odometry, odom_callback, queue_size=10, buff_size=2**10, tcp_nodelay=True)
+
+ pub_arr_bbox = rospy.Publisher("pp_boxes", BoundingBoxArray, queue_size=1)
+ pub_sync_cloud = rospy.Publisher("sync_5sweeps_cloud", PointCloud2, queue_size=1)
+
+ print("[+] CenterPoint ros_node has started!")
+ rospy.spin()
diff --git a/tools/nms_better.py b/tools/nms_better.py
new file mode 100644
index 0000000..bc2b3bc
--- /dev/null
+++ b/tools/nms_better.py
@@ -0,0 +1,196 @@
+import argparse
+import copy
+import json
+import os
+import sys
+
+import numpy as np
+import pickle
+from pathlib import Path
+from pyquaternion import Quaternion
+from nuscenes.utils.data_classes import LidarPointCloud, Box, RadarPointCloud
+from nuscenes import NuScenes
+from nuscenes.utils.geometry_utils import BoxVisibility, transform_matrix
+from nuscenes.utils.geometry_utils import points_in_box
+from functools import reduce
+from tqdm import tqdm
+from det3d.core import box_torch_ops
+from collections import defaultdict
+import torch
+import glob
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Ensemble Models")
+ parser.add_argument("ensemble_dir", help="path to a dir that contains all prediction file")
+ parser.add_argument("--output_path", help="the path to save ensemble output")
+ parser.add_argument("--data_root", type=str, default="data/nuScenes/v1.0-trainval")
+
+ args = parser.parse_args()
+
+ return args
+
+
+def get_sample_data(pred):
+ box_list = []
+ score_list = []
+ pred = pred.copy()
+
+ for item in pred:
+ box = Box(item['translation'], item['size'], Quaternion(item['rotation']),
+ name=item['detection_name'])
+ score_list.append(item['detection_score'])
+ box_list.append(box)
+
+ top_boxes = reorganize_boxes(box_list)
+ top_scores = np.array(score_list).reshape(-1)
+
+ return top_boxes, top_scores
+
+def reorganize_boxes(box_lidar_nusc):
+ rots = []
+ centers = []
+ wlhs = []
+ for i, box_lidar in enumerate(box_lidar_nusc):
+ v = np.dot(box_lidar.rotation_matrix, np.array([1, 0, 0]))
+ rot = np.arctan2(v[1], v[0])
+
+ rots.append(-rot- np.pi / 2)
+ centers.append(box_lidar.center)
+ wlhs.append(box_lidar.wlh)
+
+ rots = np.asarray(rots)
+ centers = np.asarray(centers)
+ wlhs = np.asarray(wlhs)
+ gt_boxes_lidar = np.concatenate([centers.reshape(-1,3), wlhs.reshape(-1,3), rots[..., np.newaxis].reshape(-1,1) ], axis=1)
+
+ return gt_boxes_lidar
+
+def reorganize_pred_by_class(pred):
+ ret_dicts = defaultdict(list)
+ for item in pred:
+ ret_dicts[item['detection_name']].append(item)
+
+ return ret_dicts
+
+def concatenate_list(lists):
+ ret = []
+ for l in lists:
+ ret += l
+
+ return ret
+
+ENS_CLASS = ['car', 'truck', 'bus', 'construction_vehicle', 'bicycle']
+SMALL_CLASS = ['pedestrian', 'barrier', 'traffic_cone', 'motorcycle']
+LARGE_CLASS = ['trailer']
+ALL_CLASS = ['car', 'truck', 'bus', 'construction_vehicle', 'bicycle', 'pedestrian', 'barrier', 'traffic_cone', 'motorcycle', 'trailer']
+
+def filter_pred_by_class(preds, small=False, large=False):
+ ret_dict = {}
+ for token, pred in preds.items():
+ filtered = []
+
+ for item in pred:
+ assert item['detection_name'] in ALL_CLASS
+
+ if small:
+ if item['detection_name'] not in LARGE_CLASS:
+ filtered.append(item)
+ elif large:
+ if item['detection_name'] not in SMALL_CLASS:
+ filtered.append(item)
+
+ ret_dict[token] = filtered
+
+ return ret_dict
+
+def get_pred(path):
+ with open(path, 'rb') as f:
+ pred=pickle.load(f)
+
+ return pred
+
+def main():
+ args = parse_args()
+
+ pred_paths = glob.glob(os.path.join(args.ensemble_dir, '*.pkl'))
+ print(pred_paths)
+
+ preds = []
+ for path in pred_paths:
+ preds.append(get_pred(path))
+
+ merged_predictions = {}
+ for token in preds[0].keys():
+ annos = [pred[token] for pred in preds]
+
+ merged_predictions[token] = concatenate_list(annos)
+
+ predictions = merged_predictions
+
+ print("Finish Merging")
+
+ nusc_annos = {
+ "results": {},
+ "meta": None,
+ }
+
+ for sample_token, prediction in tqdm(predictions.items()):
+ annos = []
+
+ # reorganize pred by class
+ pred_dicts = reorganize_pred_by_class(prediction)
+
+ for name, pred in pred_dicts.items():
+ # in global coordinate
+ top_boxes, top_scores = get_sample_data(pred)
+
+ with torch.no_grad():
+ top_boxes_tensor = torch.from_numpy(top_boxes)
+ boxes_for_nms = top_boxes_tensor[:, [0, 1, 2, 4, 3, 5, -1]]
+ boxes_for_nms[:, -1] = boxes_for_nms[:, -1] + np.pi /2
+ top_scores_tensor = torch.from_numpy(top_scores)
+
+ selected = box_torch_ops.rotate_nms(boxes_for_nms, top_scores_tensor,
+ pre_max_size=None,
+ post_max_size=50,
+ iou_threshold=0.2,
+ ).numpy()
+
+ pred = [pred[s] for s in selected]
+
+ annos.extend(pred)
+
+ nusc_annos["results"].update({sample_token: annos})
+
+ nusc_annos["meta"] = {
+ "use_camera": False,
+ "use_lidar": True,
+ "use_radar": True,
+ "use_map": False,
+ "use_external": False,
+ }
+
+ res_dir = os.path.join(args.work_dir)
+ if not os.path.exists(res_dir):
+ os.makedirs(res_dir)
+
+ with open(os.path.join(args.work_dir, 'result.json'), "w") as f:
+ json.dump(nusc_annos, f)
+
+ from nuscenes.eval.detection.config import config_factory
+ from nuscenes.eval.detection.evaluate import NuScenesEval
+ nusc = NuScenes(version="v1.0-trainval", dataroot=args.data_root, verbose=True)
+ cfg = config_factory("cvpr_2019")
+ nusc_eval = NuScenesEval(
+ nusc,
+ config=cfg,
+ result_path=os.path.join(args.work_dir, 'result.json'),
+ eval_set='val',
+ output_dir=args.work_dir,
+ verbose=True,
+ )
+ metrics_summary = nusc_eval.main(plot_examples=0,)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/nusc_tracking/__init__.py b/tools/nusc_tracking/__init__.py
new file mode 100644
index 0000000..7c28433
--- /dev/null
+++ b/tools/nusc_tracking/__init__.py
@@ -0,0 +1,3 @@
+from .pub_tracker import PubTracker
+
+__all__ = ["PubTracker"]
\ No newline at end of file
diff --git a/tools/nusc_tracking/pub_test.py b/tools/nusc_tracking/pub_test.py
new file mode 100644
index 0000000..4234ae6
--- /dev/null
+++ b/tools/nusc_tracking/pub_test.py
@@ -0,0 +1,192 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import json
+import numpy as np
+import time
+import copy
+import argparse
+import copy
+import json
+import os
+import numpy as np
+from pub_tracker import PubTracker as Tracker
+from nuscenes import NuScenes
+import json
+import time
+from nuscenes.utils import splits
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Tracking Evaluation")
+ parser.add_argument("--work_dir", help="the dir to save logs and tracking results")
+ parser.add_argument(
+ "--checkpoint", help="the dir to checkpoint which the model read from"
+ )
+ parser.add_argument("--hungarian", action='store_true')
+ parser.add_argument("--root", type=str, default="data/nuScenes")
+ parser.add_argument("--version", type=str, default='v1.0-trainval')
+ parser.add_argument("--max_age", type=int, default=3)
+
+ args = parser.parse_args()
+
+ return args
+
+
+def save_first_frame():
+ args = parse_args()
+ nusc = NuScenes(version=args.version, dataroot=args.root, verbose=True)
+ if args.version == 'v1.0-trainval':
+ scenes = splits.val
+ elif args.version == 'v1.0-test':
+ scenes = splits.test
+ else:
+ raise ValueError("unknown")
+
+ frames = []
+ for sample in nusc.sample:
+ scene_name = nusc.get("scene", sample['scene_token'])['name']
+ if scene_name not in scenes:
+ continue
+
+ timestamp = sample["timestamp"] * 1e-6
+ token = sample["token"]
+ frame = {}
+ frame['token'] = token
+ frame['timestamp'] = timestamp
+
+ # start of a sequence
+ if sample['prev'] == '':
+ frame['first'] = True
+ else:
+ frame['first'] = False
+ frames.append(frame)
+
+ del nusc
+
+ res_dir = os.path.join(args.work_dir)
+ if not os.path.exists(res_dir):
+ os.makedirs(res_dir)
+
+ with open(os.path.join(args.work_dir, 'frames_meta.json'), "w") as f:
+ json.dump({'frames': frames}, f)
+
+
+def main():
+ args = parse_args()
+ print('Deploy OK')
+
+ tracker = Tracker(max_age=args.max_age, hungarian=args.hungarian)
+
+ with open(args.checkpoint, 'rb') as f:
+ predictions=json.load(f)['results']
+
+ with open(os.path.join(args.work_dir, 'frames_meta.json'), 'rb') as f:
+ frames=json.load(f)['frames']
+
+ nusc_annos = {
+ "results": {},
+ "meta": None,
+ }
+ size = len(frames)
+
+ print("Begin Tracking\n")
+ start = time.time()
+ for i in range(size):
+ token = frames[i]['token']
+
+ # reset tracking after one video sequence
+ if frames[i]['first']:
+ # use this for sanity check to ensure your token order is correct
+ # print("reset ", i)
+ tracker.reset()
+ last_time_stamp = frames[i]['timestamp']
+
+ time_lag = (frames[i]['timestamp'] - last_time_stamp)
+ last_time_stamp = frames[i]['timestamp']
+
+ preds = predictions[token]
+
+ outputs = tracker.step_centertrack(preds, time_lag)
+ annos = []
+
+ for item in outputs:
+ if item['active'] == 0:
+ continue
+ nusc_anno = {
+ "sample_token": token,
+ "translation": item['translation'],
+ "size": item['size'],
+ "rotation": item['rotation'],
+ "velocity": item['velocity'],
+ "tracking_id": str(item['tracking_id']),
+ "tracking_name": item['detection_name'],
+ "tracking_score": item['detection_score'],
+ }
+ annos.append(nusc_anno)
+ nusc_annos["results"].update({token: annos})
+
+
+ end = time.time()
+
+ second = (end-start)
+
+ speed=size / second
+ print("The speed is {} FPS".format(speed))
+
+ nusc_annos["meta"] = {
+ "use_camera": False,
+ "use_lidar": True,
+ "use_radar": False,
+ "use_map": False,
+ "use_external": False,
+ }
+
+ res_dir = os.path.join(args.work_dir)
+ if not os.path.exists(res_dir):
+ os.makedirs(res_dir)
+
+ with open(os.path.join(args.work_dir, 'tracking_result.json'), "w") as f:
+ json.dump(nusc_annos, f)
+ return speed
+
+def eval_tracking():
+ args = parse_args()
+ eval(os.path.join(args.work_dir, 'tracking_result.json'),
+ "val",
+ args.work_dir,
+ args.root
+ )
+
+def eval(res_path, eval_set="val", output_dir=None, root_path=None):
+ from nuscenes.eval.tracking.evaluate import TrackingEval
+ from nuscenes.eval.common.config import config_factory as track_configs
+
+
+ cfg = track_configs("tracking_nips_2019")
+ nusc_eval = TrackingEval(
+ config=cfg,
+ result_path=res_path,
+ eval_set=eval_set,
+ output_dir=output_dir,
+ verbose=True,
+ nusc_version="v1.0-trainval",
+ nusc_dataroot=root_path,
+ )
+ metrics_summary = nusc_eval.main()
+
+
+def test_time():
+ speeds = []
+ for i in range(3):
+ speeds.append(main())
+
+ print("Speed is {} FPS".format( max(speeds) ))
+
+if __name__ == '__main__':
+ save_first_frame()
+ main()
+ # test_time()
+ eval_tracking()
diff --git a/tools/nusc_tracking/pub_tracker.py b/tools/nusc_tracking/pub_tracker.py
new file mode 100644
index 0000000..3dad488
--- /dev/null
+++ b/tools/nusc_tracking/pub_tracker.py
@@ -0,0 +1,154 @@
+import numpy as np
+import copy
+from track_utils import greedy_assignment
+import copy
+import importlib
+import sys
+
+NUSCENES_TRACKING_NAMES = [
+ 'bicycle',
+ 'bus',
+ 'car',
+ 'motorcycle',
+ 'pedestrian',
+ 'trailer',
+ 'truck'
+]
+
+
+# 99.9 percentile of the l2 velocity error distribution (per clss / 0.5 second)
+# This is an earlier statistcs and I didn't spend much time tuning it.
+# Tune this for your model should provide some considerable AMOTA improvement
+NUSCENE_CLS_VELOCITY_ERROR = {
+ 'car':4,
+ 'truck':4,
+ 'bus':5.5,
+ 'trailer':3,
+ 'pedestrian':1,
+ 'motorcycle':13,
+ 'bicycle':3,
+}
+
+
+
+class PubTracker(object):
+ def __init__(self, hungarian=False, max_age=0):
+ self.hungarian = hungarian
+ self.max_age = max_age
+
+ print("Use hungarian: {}".format(hungarian))
+
+ self.NUSCENE_CLS_VELOCITY_ERROR = NUSCENE_CLS_VELOCITY_ERROR
+
+ self.reset()
+
+ def reset(self):
+ self.id_count = 0
+ self.tracks = []
+
+ def step_centertrack(self, results, time_lag):
+ if len(results) == 0:
+ self.tracks = []
+ return []
+ else:
+ temp = []
+ for det in results:
+ # filter out classes not evaluated for tracking
+ if det['detection_name'] not in NUSCENES_TRACKING_NAMES:
+ continue
+
+ det['ct'] = np.array(det['translation'][:2])
+ det['tracking'] = np.array(det['velocity'][:2]) * -1 * time_lag
+ det['label_preds'] = NUSCENES_TRACKING_NAMES.index(det['detection_name'])
+ temp.append(det)
+
+ results = temp
+
+ N = len(results)
+ M = len(self.tracks)
+
+ # N X 2
+ if 'tracking' in results[0]:
+ dets = np.array(
+ [ det['ct'] + det['tracking'].astype(np.float32)
+ for det in results], np.float32)
+ else:
+ dets = np.array(
+ [det['ct'] for det in results], np.float32)
+
+ item_cat = np.array([item['label_preds'] for item in results], np.int32) # N
+ track_cat = np.array([track['label_preds'] for track in self.tracks], np.int32) # M
+
+ max_diff = np.array([self.NUSCENE_CLS_VELOCITY_ERROR[box['detection_name']] for box in results], np.float32)
+
+ tracks = np.array(
+ [pre_det['ct'] for pre_det in self.tracks], np.float32) # M x 2
+
+ if len(tracks) > 0: # NOT FIRST FRAME
+ dist = (((tracks.reshape(1, -1, 2) - \
+ dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
+ dist = np.sqrt(dist) # absolute distance in meter
+
+ invalid = ((dist > max_diff.reshape(N, 1)) + \
+ (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
+
+ dist = dist + invalid * 1e18
+ if self.hungarian:
+ dist[dist > 1e18] = 1e18
+ matched_indices = linear_assignment(copy.deepcopy(dist))
+ else:
+ matched_indices = greedy_assignment(copy.deepcopy(dist))
+ else: # first few frame
+ assert M == 0
+ matched_indices = np.array([], np.int32).reshape(-1, 2)
+
+ unmatched_dets = [d for d in range(dets.shape[0]) \
+ if not (d in matched_indices[:, 0])]
+
+ unmatched_tracks = [d for d in range(tracks.shape[0]) \
+ if not (d in matched_indices[:, 1])]
+
+ if self.hungarian:
+ matches = []
+ for m in matched_indices:
+ if dist[m[0], m[1]] > 1e16:
+ unmatched_dets.append(m[0])
+ else:
+ matches.append(m)
+ matches = np.array(matches).reshape(-1, 2)
+ else:
+ matches = matched_indices
+
+ ret = []
+ for m in matches:
+ track = results[m[0]]
+ track['tracking_id'] = self.tracks[m[1]]['tracking_id']
+ track['age'] = 1
+ track['active'] = self.tracks[m[1]]['active'] + 1
+ ret.append(track)
+
+ for i in unmatched_dets:
+ track = results[i]
+ self.id_count += 1
+ track['tracking_id'] = self.id_count
+ track['age'] = 1
+ track['active'] = 1
+ ret.append(track)
+
+ # still store unmatched tracks if its age doesn't exceed max_age, however, we shouldn't output
+ # the object in current frame
+ for i in unmatched_tracks:
+ track = self.tracks[i]
+ if track['age'] < self.max_age:
+ track['age'] += 1
+ track['active'] = 0
+ ct = track['ct']
+
+ # movement in the last second
+ if 'tracking' in track:
+ offset = track['tracking'] * -1 # move forward
+ track['ct'] = ct + offset
+ ret.append(track)
+
+ self.tracks = ret
+ return ret
diff --git a/tools/nusc_tracking/track_utils.py b/tools/nusc_tracking/track_utils.py
new file mode 100644
index 0000000..288ca8d
--- /dev/null
+++ b/tools/nusc_tracking/track_utils.py
@@ -0,0 +1,12 @@
+import numpy as np
+
+def greedy_assignment(dist):
+ matched_indices = []
+ if dist.shape[1] == 0:
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
+ for i in range(dist.shape[0]):
+ j = dist[i].argmin()
+ if dist[i][j] < 1e16:
+ dist[:, j] = 1e18
+ matched_indices.append([i, j])
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
diff --git a/tools/simple_inference_waymo.py b/tools/simple_inference_waymo.py
new file mode 100644
index 0000000..2ecc614
--- /dev/null
+++ b/tools/simple_inference_waymo.py
@@ -0,0 +1,161 @@
+# modified from the single_inference.py by @muzi2045
+from spconv.utils import VoxelGenerator as VoxelGenerator
+from det3d.datasets.pipelines.loading import read_single_waymo
+from det3d.datasets.pipelines.loading import get_obj
+from det3d.torchie.trainer import load_checkpoint
+from det3d.models import build_detector
+from det3d.torchie import Config
+from tqdm import tqdm
+import numpy as np
+import pickle
+import open3d as o3d
+import argparse
+import torch
+import time
+import os
+
+voxel_generator = None
+model = None
+device = None
+
+def initialize_model(args):
+ global model, voxel_generator
+ cfg = Config.fromfile(args.config)
+ model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+ if args.checkpoint is not None:
+ load_checkpoint(model, args.checkpoint, map_location="cpu")
+ # print(model)
+ if args.fp16:
+ print("cast model to fp16")
+ model = model.half()
+
+ model = model.cuda()
+ model.eval()
+
+ global device
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+ range = cfg.voxel_generator.range
+ voxel_size = cfg.voxel_generator.voxel_size
+ max_points_in_voxel = cfg.voxel_generator.max_points_in_voxel
+ max_voxel_num = cfg.voxel_generator.max_voxel_num[1]
+ voxel_generator = VoxelGenerator(
+ voxel_size=voxel_size,
+ point_cloud_range=range,
+ max_num_points=max_points_in_voxel,
+ max_voxels=max_voxel_num
+ )
+ return model
+
+def voxelization(points, voxel_generator):
+ voxel_output = voxel_generator.generate(points)
+ voxels, coords, num_points = \
+ voxel_output['voxels'], voxel_output['coordinates'], voxel_output['num_points_per_voxel']
+
+ return voxels, coords, num_points
+
+def _process_inputs(points, fp16):
+ voxels, coords, num_points = voxel_generator.generate(points)
+ num_voxels = np.array([voxels.shape[0]], dtype=np.int32)
+ grid_size = voxel_generator.grid_size
+ coords = np.pad(coords, ((0, 0), (1, 0)), mode='constant', constant_values = 0)
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ voxels = torch.tensor(voxels, dtype=torch.float32, device=device)
+ coords = torch.tensor(coords, dtype=torch.int32, device=device)
+ num_points = torch.tensor(num_points, dtype=torch.int32, device=device)
+ num_voxels = torch.tensor(num_voxels, dtype=torch.int32, device=device)
+
+ if fp16:
+ voxels = voxels.half()
+
+ inputs = dict(
+ voxels = voxels,
+ num_points = num_points,
+ num_voxels = num_voxels,
+ coordinates = coords,
+ shape = [grid_size]
+ )
+
+ return inputs
+
+def run_model(points, fp16=False):
+ with torch.no_grad():
+ data_dict = _process_inputs(points, fp16)
+ outputs = model(data_dict, return_loss=False)[0]
+
+ return {'boxes': outputs['box3d_lidar'].cpu().numpy(),
+ 'scores': outputs['scores'].cpu().numpy(),
+ 'classes': outputs['label_preds'].cpu().numpy()}
+
+def process_example(points, fp16=False):
+ output = run_model(points, fp16)
+
+ assert len(output) == 3
+ assert set(output.keys()) == set(('boxes', 'scores', 'classes'))
+ num_objs = output['boxes'].shape[0]
+ assert output['scores'].shape[0] == num_objs
+ assert output['classes'].shape[0] == num_objs
+
+ return output
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="CenterPoint")
+ parser.add_argument("config", help="path to config file")
+ parser.add_argument(
+ "--checkpoint", help="the path to checkpoint which the model read from", default=None, type=str
+ )
+ parser.add_argument('--input_data_dir', type=str, required=True)
+ parser.add_argument('--output_dir', type=str, required=True)
+ parser.add_argument('--fp16', action='store_true')
+ parser.add_argument('--threshold', default=0.5)
+ parser.add_argument('--visual', action='store_true')
+ parser.add_argument("--online", action='store_true')
+ parser.add_argument('--num_frame', default=-1, type=int)
+ args = parser.parse_args()
+
+ print("Please prepare your point cloud in waymo format and save it as a pickle dict with points key into the {}".format(args.input_data_dir))
+ print("One point cloud should be saved in one pickle file.")
+ print("Download and save the pretrained model at {}".format(args.checkpoint))
+
+ # Run any user-specified initialization code for their submission.
+ model = initialize_model(args)
+
+ latencies = []
+ visual_dicts = []
+ pred_dicts = {}
+ counter = 0
+ for frame_name in tqdm(sorted(os.listdir(args.input_data_dir))):
+ if counter == args.num_frame:
+ break
+ else:
+ counter += 1
+
+ pc_name = os.path.join(args.input_data_dir, frame_name)
+ points = pickle.load(open(pc_name, 'rb'))['points']
+ # points = read_single_waymo(get_obj(pc_name))
+
+ detections = process_example(points, args.fp16)
+
+ if args.visual and args.online:
+ pcd = o3d.geometry.PointCloud()
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points[:, :3])
+
+ visual = [pcd]
+ num_dets = detections['scores'].shape[0]
+ visual += plot_boxes(detections, args.threshold)
+
+ o3d.visualization.draw_geometries(visual)
+ elif args.visual:
+ visual_dicts.append({'points': points, 'detections': detections})
+
+ pred_dicts.update({frame_name: detections})
+
+ if args.visual:
+ with open(os.path.join(args.output_dir, 'visualization.pkl'), 'wb') as f:
+ pickle.dump(visual_dicts, f)
+
+ with open(os.path.join(args.output_dir, 'detections.pkl'), 'wb') as f:
+ pickle.dump(pred_dicts, f)
diff --git a/tools/single_infernece.py b/tools/single_infernece.py
new file mode 100644
index 0000000..669bbee
--- /dev/null
+++ b/tools/single_infernece.py
@@ -0,0 +1,253 @@
+
+import rospy
+import ros_numpy
+import numpy as np
+import copy
+import json
+import os
+import sys
+import torch
+import time
+
+from std_msgs.msg import Header
+import sensor_msgs.point_cloud2 as pc2
+from sensor_msgs.msg import PointCloud2, PointField
+from jsk_recognition_msgs.msg import BoundingBox, BoundingBoxArray
+from pyquaternion import Quaternion
+
+from det3d import __version__, torchie
+from det3d.models import build_detector
+from det3d.torchie import Config
+from det3d.core.input.voxel_generator import VoxelGenerator
+
+def yaw2quaternion(yaw: float) -> Quaternion:
+ return Quaternion(axis=[0,0,1], radians=yaw)
+
+def get_annotations_indices(types, thresh, label_preds, scores):
+ indexs = []
+ annotation_indices = []
+ for i in range(label_preds.shape[0]):
+ if label_preds[i] == types:
+ indexs.append(i)
+ for index in indexs:
+ if scores[index] >= thresh:
+ annotation_indices.append(index)
+ return annotation_indices
+
+
+def remove_low_score_nu(image_anno, thresh):
+ img_filtered_annotations = {}
+ label_preds_ = image_anno["label_preds"].detach().cpu().numpy()
+ scores_ = image_anno["scores"].detach().cpu().numpy()
+
+ car_indices = get_annotations_indices(0, 0.4, label_preds_, scores_)
+ truck_indices = get_annotations_indices(1, 0.4, label_preds_, scores_)
+ construction_vehicle_indices = get_annotations_indices(2, 0.4, label_preds_, scores_)
+ bus_indices = get_annotations_indices(3, 0.3, label_preds_, scores_)
+ trailer_indices = get_annotations_indices(4, 0.4, label_preds_, scores_)
+ barrier_indices = get_annotations_indices(5, 0.4, label_preds_, scores_)
+ motorcycle_indices = get_annotations_indices(6, 0.15, label_preds_, scores_)
+ bicycle_indices = get_annotations_indices(7, 0.15, label_preds_, scores_)
+ pedestrain_indices = get_annotations_indices(8, 0.1, label_preds_, scores_)
+ traffic_cone_indices = get_annotations_indices(9, 0.1, label_preds_, scores_)
+
+ for key in image_anno.keys():
+ if key == 'metadata':
+ continue
+ img_filtered_annotations[key] = (
+ image_anno[key][car_indices +
+ pedestrain_indices +
+ bicycle_indices +
+ bus_indices +
+ construction_vehicle_indices +
+ traffic_cone_indices +
+ trailer_indices +
+ barrier_indices +
+ truck_indices
+ ])
+
+ return img_filtered_annotations
+
+
+class Processor_ROS:
+ def __init__(self, config_path, model_path):
+ self.points = None
+ self.config_path = config_path
+ self.model_path = model_path
+ self.device = None
+ self.net = None
+ self.voxel_generator = None
+ self.inputs = None
+
+ def initialize(self):
+ self.read_config()
+
+ def read_config(self):
+ config_path = self.config_path
+ cfg = Config.fromfile(self.config_path)
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ self.net = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg)
+ self.net.load_state_dict(torch.load(self.model_path)["state_dict"])
+ self.net = self.net.to(self.device).eval()
+
+ self.range = cfg.voxel_generator.range
+ self.voxel_size = cfg.voxel_generator.voxel_size
+ self.max_points_in_voxel = cfg.voxel_generator.max_points_in_voxel
+ self.max_voxel_num = cfg.voxel_generator.max_voxel_num
+ self.voxel_generator = VoxelGenerator(
+ voxel_size=self.voxel_size,
+ point_cloud_range=self.range,
+ max_num_points=self.max_points_in_voxel,
+ max_voxels=self.max_voxel_num,
+ )
+
+ def run(self, points):
+ t_t = time.time()
+ print(f"input points shape: {points.shape}")
+ num_features = 5
+ self.points = points.reshape([-1, num_features])
+ self.points[:, 4] = 0 # timestamp value
+
+ voxels, coords, num_points = self.voxel_generator.generate(self.points)
+ num_voxels = np.array([voxels.shape[0]], dtype=np.int64)
+ grid_size = self.voxel_generator.grid_size
+ coords = np.pad(coords, ((0, 0), (1, 0)), mode='constant', constant_values = 0)
+
+ voxels = torch.tensor(voxels, dtype=torch.float32, device=self.device)
+ coords = torch.tensor(coords, dtype=torch.int32, device=self.device)
+ num_points = torch.tensor(num_points, dtype=torch.int32, device=self.device)
+ num_voxels = torch.tensor(num_voxels, dtype=torch.int32, device=self.device)
+
+ self.inputs = dict(
+ voxels = voxels,
+ num_points = num_points,
+ num_voxels = num_voxels,
+ coordinates = coords,
+ shape = [grid_size]
+ )
+ torch.cuda.synchronize()
+ t = time.time()
+
+ with torch.no_grad():
+ outputs = self.net(self.inputs, return_loss=False)[0]
+
+ # print(f"output: {outputs}")
+
+ torch.cuda.synchronize()
+ print(" network predict time cost:", time.time() - t)
+
+ outputs = remove_low_score_nu(outputs, 0.45)
+
+ boxes_lidar = outputs["box3d_lidar"].detach().cpu().numpy()
+ print(" predict boxes:", boxes_lidar.shape)
+
+ scores = outputs["scores"].detach().cpu().numpy()
+ types = outputs["label_preds"].detach().cpu().numpy()
+
+ boxes_lidar[:, -1] = -boxes_lidar[:, -1] - np.pi / 2
+
+ print(f" total cost time: {time.time() - t_t}")
+
+ return scores, boxes_lidar, types
+
+def get_xyz_points(cloud_array, remove_nans=True, dtype=np.float):
+ '''
+ '''
+ if remove_nans:
+ mask = np.isfinite(cloud_array['x']) & np.isfinite(cloud_array['y']) & np.isfinite(cloud_array['z'])
+ cloud_array = cloud_array[mask]
+
+ points = np.zeros(cloud_array.shape + (5,), dtype=dtype)
+ points[...,0] = cloud_array['x']
+ points[...,1] = cloud_array['y']
+ points[...,2] = cloud_array['z']
+ return points
+
+def xyz_array_to_pointcloud2(points_sum, stamp=None, frame_id=None):
+ '''
+ Create a sensor_msgs.PointCloud2 from an array of points.
+ '''
+ msg = PointCloud2()
+ if stamp:
+ msg.header.stamp = stamp
+ if frame_id:
+ msg.header.frame_id = frame_id
+ msg.height = 1
+ msg.width = points_sum.shape[0]
+ msg.fields = [
+ PointField('x', 0, PointField.FLOAT32, 1),
+ PointField('y', 4, PointField.FLOAT32, 1),
+ PointField('z', 8, PointField.FLOAT32, 1)
+ # PointField('i', 12, PointField.FLOAT32, 1)
+ ]
+ msg.is_bigendian = False
+ msg.point_step = 12
+ msg.row_step = points_sum.shape[0]
+ msg.is_dense = int(np.isfinite(points_sum).all())
+ msg.data = np.asarray(points_sum, np.float32).tostring()
+ return msg
+
+def rslidar_callback(msg):
+ t_t = time.time()
+ arr_bbox = BoundingBoxArray()
+
+ msg_cloud = ros_numpy.point_cloud2.pointcloud2_to_array(msg)
+ np_p = get_xyz_points(msg_cloud, True)
+ print(" ")
+ scores, dt_box_lidar, types = proc_1.run(np_p)
+
+ if scores.size != 0:
+ for i in range(scores.size):
+ bbox = BoundingBox()
+ bbox.header.frame_id = msg.header.frame_id
+ bbox.header.stamp = rospy.Time.now()
+ q = yaw2quaternion(float(dt_box_lidar[i][8]))
+ bbox.pose.orientation.x = q[1]
+ bbox.pose.orientation.y = q[2]
+ bbox.pose.orientation.z = q[3]
+ bbox.pose.orientation.w = q[0]
+ bbox.pose.position.x = float(dt_box_lidar[i][0])
+ bbox.pose.position.y = float(dt_box_lidar[i][1])
+ bbox.pose.position.z = float(dt_box_lidar[i][2])
+ bbox.dimensions.x = float(dt_box_lidar[i][4])
+ bbox.dimensions.y = float(dt_box_lidar[i][3])
+ bbox.dimensions.z = float(dt_box_lidar[i][5])
+ bbox.value = scores[i]
+ bbox.label = int(types[i])
+ arr_bbox.boxes.append(bbox)
+ print("total callback time: ", time.time() - t_t)
+ arr_bbox.header.frame_id = msg.header.frame_id
+ arr_bbox.header.stamp = msg.header.stamp
+ if len(arr_bbox.boxes) is not 0:
+ pub_arr_bbox.publish(arr_bbox)
+ arr_bbox.boxes = []
+ else:
+ arr_bbox.boxes = []
+ pub_arr_bbox.publish(arr_bbox)
+
+if __name__ == "__main__":
+
+ global proc
+ ## CenterPoint
+ config_path = 'configs/centerpoint/nusc_centerpoint_pp_02voxel_circle_nms_demo.py'
+ model_path = 'models/last.pth'
+
+ proc_1 = Processor_ROS(config_path, model_path)
+
+ proc_1.initialize()
+
+ rospy.init_node('centerpoint_ros_node')
+ sub_lidar_topic = [ "/velodyne_points",
+ "/top/rslidar_points",
+ "/points_raw",
+ "/lidar_protector/merged_cloud",
+ "/merged_cloud",
+ "/lidar_top",
+ "/roi_pclouds"]
+
+ sub_ = rospy.Subscriber(sub_lidar_topic[5], PointCloud2, rslidar_callback, queue_size=1, buff_size=2**24)
+
+ pub_arr_bbox = rospy.Publisher("pp_boxes", BoundingBoxArray, queue_size=1)
+
+ print("[+] CenterPoint ros_node has started!")
+ rospy.spin()
\ No newline at end of file
diff --git a/tools/train.py b/tools/train.py
new file mode 100644
index 0000000..24dc811
--- /dev/null
+++ b/tools/train.py
@@ -0,0 +1,137 @@
+import argparse
+import json
+import os
+import sys
+
+from numba.core.errors import NumbaDeprecationWarning, NumbaPendingDeprecationWarning, NumbaWarning
+import warnings
+warnings.simplefilter('ignore', category=NumbaDeprecationWarning)
+warnings.simplefilter('ignore', category=NumbaWarning)
+
+import numpy as np
+import torch
+import yaml
+from det3d.datasets import build_dataset
+from det3d.models import build_detector
+from det3d.torchie import Config
+from det3d.torchie.apis import (
+ build_optimizer,
+ get_root_logger,
+ init_dist,
+ set_random_seed,
+ train_detector,
+)
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Train a detector")
+ parser.add_argument("config", help="train config file path")
+ parser.add_argument("--work_dir", help="the dir to save logs and models")
+ parser.add_argument("--resume_from", help="the checkpoint file to resume from")
+ parser.add_argument(
+ "--validate",
+ action="store_true",
+ help="whether to evaluate the checkpoint during training",
+ )
+ parser.add_argument(
+ "--gpus",
+ type=int,
+ default=1,
+ help="number of gpus to use " "(only applicable to non-distributed training)",
+ )
+ parser.add_argument("--seed", type=int, default=None, help="random seed")
+ parser.add_argument(
+ "--launcher",
+ choices=["none", "pytorch", "slurm", "mpi"],
+ default="none",
+ help="job launcher",
+ )
+ parser.add_argument("--local_rank", type=int, default=0)
+ parser.add_argument(
+ "--autoscale-lr",
+ action="store_true",
+ help="automatically scale lr with the number of gpus",
+ )
+ args = parser.parse_args()
+ if "LOCAL_RANK" not in os.environ:
+ os.environ["LOCAL_RANK"] = str(args.local_rank)
+
+ return args
+
+
+def main():
+
+ torch.manual_seed(10)
+ # torch.backends.cudnn.deterministic = True
+ # torch.backends.cudnn.benchmark = False
+ np.random.seed(10)
+
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ cfg.local_rank = args.local_rank
+
+ # update configs according to CLI args
+ if args.work_dir is not None:
+ cfg.work_dir = args.work_dir
+ if args.resume_from is not None:
+ cfg.resume_from = args.resume_from
+
+ distributed = False
+ if "WORLD_SIZE" in os.environ:
+ distributed = int(os.environ["WORLD_SIZE"]) > 1
+
+ if distributed:
+ torch.cuda.set_device(args.local_rank)
+ torch.distributed.init_process_group(backend="nccl", init_method="env://")
+
+ cfg.gpus = torch.distributed.get_world_size()
+
+ if args.autoscale_lr:
+ cfg.lr_config.lr_max = cfg.lr_config.lr_max * cfg.gpus
+
+ # init logger before other steps
+ logger = get_root_logger(cfg.log_level)
+ logger.info("Distributed training: {}".format(distributed))
+ logger.info(f"torch.backends.cudnn.benchmark: {torch.backends.cudnn.benchmark}")
+
+ if args.local_rank == 0:
+ # copy important files to backup
+ backup_dir = os.path.join(cfg.work_dir, "det3d")
+ os.makedirs(backup_dir, exist_ok=True)
+ # os.system("cp -r * %s/" % backup_dir)
+ # logger.info(f"Backup source files to {cfg.work_dir}/det3d")
+
+ # set random seeds
+ if args.seed is not None:
+ logger.info("Set random seed to {}".format(args.seed))
+ set_random_seed(args.seed)
+
+ model = build_detector(cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
+
+ datasets = [build_dataset(cfg.data.train)]
+
+ if len(cfg.workflow) == 2:
+ datasets.append(build_dataset(cfg.data.val))
+
+ if cfg.checkpoint_config is not None:
+ # save det3d version, config file content and class names in
+ # checkpoints as meta data
+ cfg.checkpoint_config.meta = dict(
+ config=cfg.text, CLASSES=datasets[0].CLASSES
+ )
+
+ # add an attribute for visualization convenience
+ model.CLASSES = datasets[0].CLASSES
+ train_detector(
+ model,
+ datasets,
+ cfg,
+ distributed=distributed,
+ validate=args.validate,
+ logger=logger,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/visual.py b/tools/visual.py
new file mode 100644
index 0000000..7c88dc1
--- /dev/null
+++ b/tools/visual.py
@@ -0,0 +1,71 @@
+from det3d.core.bbox.box_np_ops import center_to_corner_box3d
+import open3d as o3d
+import argparse
+import pickle
+
+def label2color(label):
+ colors = [[204/255, 0, 0], [52/255, 101/255, 164/255],
+ [245/255, 121/255, 0], [115/255, 210/255, 22/255]]
+
+ return colors[label]
+
+def corners_to_lines(qs, color=[204/255, 0, 0]):
+ """ Draw 3d bounding box in image
+ qs: (8,3) array of vertices for the 3d box in following order:
+ 7 -------- 4
+ /| /|
+ 6 -------- 5 .
+ | | | |
+ . 3 -------- 0
+ |/ |/
+ 2 -------- 1
+ """
+ idx = [(1,0), (5,4), (2,3), (6,7), (1,2), (5,6), (0,3), (4,7), (1,5), (0,4), (2,6), (3,7)]
+ cl = [color for i in range(12)]
+
+ line_set = o3d.geometry.LineSet(
+ points=o3d.utility.Vector3dVector(qs),
+ lines=o3d.utility.Vector2iVector(idx),
+ )
+ line_set.colors = o3d.utility.Vector3dVector(cl)
+
+ return line_set
+
+def plot_boxes(boxes, score_thresh):
+ visuals =[]
+ num_det = boxes['scores'].shape[0]
+ for i in range(num_det):
+ score = boxes['scores'][i]
+ if score < score_thresh:
+ continue
+
+ box = boxes['boxes'][i:i+1]
+ label = boxes['classes'][i]
+ corner = center_to_corner_box3d(box[:, :3], box[:, 3:6], box[:, -1])[0].tolist()
+ color = label2color(label)
+ visuals.append(corners_to_lines(corner, color))
+ return visuals
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="CenterPoint")
+ parser.add_argument('--path', help='path to visualization file', type=str)
+ parser.add_argument('--thresh', help='visualization threshold', type=float, default=0.3)
+ args = parser.parse_args()
+
+ with open(args.path, 'rb') as f:
+ data_dicts = pickle.load(f)
+
+ for data in data_dicts:
+ points = data['points']
+ detections = data['detections']
+
+ pcd = o3d.geometry.PointCloud()
+ pcd = o3d.geometry.PointCloud()
+ pcd.points = o3d.utility.Vector3dVector(points[:, :3])
+
+ visual = [pcd]
+ num_dets = detections['scores'].shape[0]
+ visual += plot_boxes(detections, args.thresh)
+
+ o3d.visualization.draw_geometries(visual)
diff --git a/tools/waymo_tracking/__init__.py b/tools/waymo_tracking/__init__.py
new file mode 100644
index 0000000..114554f
--- /dev/null
+++ b/tools/waymo_tracking/__init__.py
@@ -0,0 +1,3 @@
+from .tracker import PubTracker
+
+__all__ = ["PubTracker"]
\ No newline at end of file
diff --git a/tools/waymo_tracking/line_search.py b/tools/waymo_tracking/line_search.py
new file mode 100644
index 0000000..7396607
--- /dev/null
+++ b/tools/waymo_tracking/line_search.py
@@ -0,0 +1,35 @@
+import os
+import numpy as np
+
+scores = {
+ 0: np.arange(0.4, 0.8, 0.02),
+ 1: np.arange(0.4, 0.8, 0.02),
+ 2: np.arange(0.4, 0.8, 0.02)
+}
+
+dists = {
+ 0: np.arange(0.4, 0.8, 0.04),
+ 1: np.arange(0.1, 0.5, 0.04),
+ 2: np.arange(0.3, 0.7, 0.04)
+}
+
+for label in range(3):
+ score_list = scores[label]
+ dist_list = dists[label]
+
+ for score in score_list:
+ for dist in dist_list:
+ work_dir = "waymo_track/label_{}_score_{}_max_age_{}_dist_{}".format(label, score, dist)
+
+ cmd=("python tools/waymo_tracking/test.py " +
+ "--checkpoint /home/tianweiy/base/work_dirs/waymo_centerpoint_voxelnet_two_sweeps_3x_with_velo/prediction.pkl"
+ "--work_dir {}".format(work_dir) +
+ " --info_path data/Waymo/infos_val_02sweeps_filter_zero_gt.pkl" +
+ "--vehicle {} --pedestrian {} --cyclist".format(dist, dist, dist) +
+ "--score_thresh {}".format(score) +
+ "--name {}".format(label) +
+ "> {}/stats.txt ".format(work_dir)
+ )[0]
+
+ print(cmd)
+ os.system(cmd)
diff --git a/tools/waymo_tracking/test.py b/tools/waymo_tracking/test.py
new file mode 100644
index 0000000..db14135
--- /dev/null
+++ b/tools/waymo_tracking/test.py
@@ -0,0 +1,263 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+import json
+import numpy as np
+import time
+import copy
+import argparse
+import copy
+import json
+import os
+import numpy as np
+from tools.waymo_tracking.tracker import PubTracker as Tracker
+from tqdm import tqdm
+import json
+import time
+from nuscenes.utils.geometry_utils import transform_matrix
+import pickle
+from pyquaternion import Quaternion
+from det3d.datasets.waymo.waymo_common import _create_pd_detection
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="Tracking Evaluation")
+ parser.add_argument("--work_dir", help="the dir to save logs and tracking results")
+ parser.add_argument(
+ "--checkpoint", help="the dir to prediction file"
+ )
+ parser.add_argument(
+ "--info_path", type=str
+ )
+ parser.add_argument("--max_age", type=int, default=3)
+ parser.add_argument("--vehicle", type=float, default=0.8)
+ parser.add_argument("--pedestrian", type=float, default=0.4)
+ parser.add_argument("--cyclist", type=float, default=0.6)
+ parser.add_argument("--score_thresh", type=float, default=0.75)
+
+ args = parser.parse_args()
+
+ return args
+
+def get_obj(path):
+ with open(path, 'rb') as f:
+ obj = pickle.load(f)
+ return obj
+
+def veh_pos_to_transform(veh_pos):
+ "convert vehicle pose to two transformation matrix"
+ rotation = veh_pos[:3, :3]
+ tran = veh_pos[:3, 3]
+
+ global_from_car = transform_matrix(
+ tran, Quaternion(matrix=rotation), inverse=False
+ )
+
+ car_from_global = transform_matrix(
+ tran, Quaternion(matrix=rotation), inverse=True
+ )
+
+ return global_from_car, car_from_global
+
+def reorganize_info(infos):
+ new_info = {}
+
+ for info in infos:
+ token = info['token']
+ new_info[token] = info
+
+ return new_info
+
+def main():
+ args = parse_args()
+ print('Deploy OK')
+
+ max_dist = {
+ 'VEHICLE': args.vehicle,
+ 'PEDESTRIAN': args.pedestrian,
+ 'CYCLIST': args.cyclist
+ }
+
+ tracker = Tracker(max_age=args.max_age, max_dist=max_dist, score_thresh=args.score_thresh)
+
+ with open(args.checkpoint, 'rb') as f:
+ predictions=pickle.load(f)
+
+ with open(args.info_path, 'rb') as f:
+ infos=pickle.load(f)
+ infos = reorganize_info(infos)
+
+ global_preds, detection_results = convert_detection_to_global_box(predictions, infos)
+ size = len(global_preds)
+
+ print("Begin Tracking {} frames\n".format(size))
+
+ predictions = {}
+
+ for i in tqdm(range(size)):
+ pred = global_preds[i]
+ token = pred['token']
+
+ # reset tracking after one video sequence
+ if pred['frame_id'] == 0:
+ tracker.reset()
+ last_time_stamp = pred['timestamp']
+
+ time_lag = (pred['timestamp'] - last_time_stamp)
+ last_time_stamp = pred['timestamp']
+
+ current_det = pred['global_boxs']
+
+ outputs = tracker.step_centertrack(current_det, time_lag)
+ tracking_ids = []
+ box_ids = []
+
+ for item in outputs:
+ if item['active'] == 0:
+ continue
+
+ box_ids.append(item['box_id'])
+ tracking_ids.append(item['tracking_id'])
+
+ # now reorder
+ detection = detection_results[token]
+
+ remained_box_ids = np.array(box_ids)
+
+ track_result = {}
+
+ # store box id
+ track_result['tracking_ids']= np.array(tracking_ids)
+
+ # store box parameter
+ track_result['box3d_lidar'] = detection['box3d_lidar'][remained_box_ids]
+
+ # store box label
+ track_result['label_preds'] = detection['label_preds'][remained_box_ids]
+
+ # store box score
+ track_result['scores'] = detection['scores'][remained_box_ids]
+
+ predictions[token] = track_result
+
+ os.makedirs(args.work_dir, exist_ok=True)
+ # save prediction files to args.work_dir
+ _create_pd_detection(predictions, infos, args.work_dir, tracking=True)
+
+ result_path = os.path.join(args.work_dir, 'tracking_pred.bin')
+ gt_path = os.path.join(args.work_dir, '../gt_preds.bin')
+
+ print("Use Waymo devkit or online server to evaluate the result")
+ print("After building the devkit, you can use the following command")
+ print("waymo-open-dataset/bazel-bin/waymo_open_dataset/metrics/tools/compute_tracking_metrics_main \
+ {} {} ".format(result_path, gt_path))
+
+ # os.system("waymo_open_dataset/metrics/tools/compute_tracking_metrics_main \
+ # {} {} ".format(result_path, gt_path))
+
+def transform_box(box, pose):
+ """Transforms 3d upright boxes from one frame to another.
+ Args:
+ box: [..., N, 7] boxes.
+ from_frame_pose: [...,4, 4] origin frame poses.
+ to_frame_pose: [...,4, 4] target frame poses.
+ Returns:
+ Transformed boxes of shape [..., N, 7] with the same type as box.
+ """
+ transform = pose
+ heading = box[..., -1] + np.arctan2(transform[..., 1, 0], transform[..., 0,
+ 0])
+ center = np.einsum('...ij,...nj->...ni', transform[..., 0:3, 0:3],
+ box[..., 0:3]) + np.expand_dims(
+ transform[..., 0:3, 3], axis=-2)
+
+ velocity = box[..., [6, 7]]
+
+ velocity = np.concatenate([velocity, np.zeros((velocity.shape[0], 1))], axis=-1) # add z velocity
+
+ velocity = np.einsum('...ij,...nj->...ni', transform[..., 0:3, 0:3],
+ velocity)[..., [0, 1]] # remove z axis
+
+ return np.concatenate([center, box[..., 3:6], velocity, heading[..., np.newaxis]], axis=-1)
+
+def label_to_name(label):
+ if label == 0:
+ return "VEHICLE"
+ elif label == 1 :
+ return "PEDESTRIAN"
+ elif label == 2:
+ return "CYCLIST"
+ else:
+ raise NotImplemented()
+
+def sort_detections(detections):
+ indices = []
+
+ for det in detections:
+ f = det['token']
+ seq_id = int(f.split("_")[1])
+ frame_id= int(f.split("_")[3][:-4])
+
+ idx = seq_id * 1000 + frame_id
+ indices.append(idx)
+
+ rank = list(np.argsort(np.array(indices)))
+
+ detections = [detections[r] for r in rank]
+
+ return detections
+
+def convert_detection_to_global_box(detections, infos):
+ ret_list = []
+
+ detection_results = {} # copy.deepcopy(detections)
+
+ for token in tqdm(infos.keys()):
+ detection = detections[token]
+ detection_results[token] = copy.deepcopy(detection)
+
+ info = infos[token]
+ # pose = get_transform(info)
+ anno_path = info['anno_path']
+ ref_obj = get_obj(anno_path)
+ pose = np.reshape(ref_obj['veh_to_global'], [4, 4])
+
+ box3d = detection["box3d_lidar"].detach().clone().cpu().numpy()
+ labels = detection["label_preds"].detach().clone().cpu().numpy()
+ scores = detection['scores'].detach().clone().cpu().numpy()
+ box3d[:, -1] = -box3d[:, -1] - np.pi / 2
+ box3d[:, [3, 4]] = box3d[:, [4, 3]]
+
+ box3d = transform_box(box3d, pose)
+
+ frame_id = token.split('_')[3][:-4]
+
+ num_box = len(box3d)
+
+ anno_list =[]
+ for i in range(num_box):
+ anno = {
+ 'translation': box3d[i, :3],
+ 'velocity': box3d[i, [6, 7]],
+ 'detection_name': label_to_name(labels[i]),
+ 'score': scores[i],
+ 'box_id': i
+ }
+
+ anno_list.append(anno)
+
+ ret_list.append({
+ 'token': token,
+ 'frame_id':int(frame_id),
+ 'global_boxs': anno_list,
+ 'timestamp': info['timestamp']
+ })
+
+ sorted_ret_list = sort_detections(ret_list)
+
+ return sorted_ret_list, detection_results
+
+if __name__ == '__main__':
+ main()
diff --git a/tools/waymo_tracking/tracker.py b/tools/waymo_tracking/tracker.py
new file mode 100644
index 0000000..b94f583
--- /dev/null
+++ b/tools/waymo_tracking/tracker.py
@@ -0,0 +1,136 @@
+import numpy as np
+import copy
+import copy
+import importlib
+import sys
+
+import numpy as np
+
+def greedy_assignment(dist):
+ matched_indices = []
+ if dist.shape[1] == 0:
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
+ for i in range(dist.shape[0]):
+ j = dist[i].argmin()
+ if dist[i][j] < 1e16:
+ dist[:, j] = 1e18
+ matched_indices.append([i, j])
+ return np.array(matched_indices, np.int32).reshape(-1, 2)
+
+
+WAYMO_TRACKING_NAMES = [
+ 'VEHICLE',
+ 'PEDESTRIAN',
+ 'CYCLIST'
+]
+
+class PubTracker(object):
+ def __init__(self, max_age=0, max_dist={}, score_thresh=0.1):
+ self.max_age = max_age
+
+ self.WAYMO_CLS_VELOCITY_ERROR = max_dist
+
+ self.WAYMO_TRACKING_NAMES = WAYMO_TRACKING_NAMES
+ self.score_thresh = score_thresh
+
+ self.reset()
+
+ def reset(self):
+ self.id_count = 0
+ self.tracks = []
+
+ def step_centertrack(self, results, time_lag):
+ if len(results) == 0:
+ self.tracks = []
+ return []
+ else:
+ temp = []
+ for det in results:
+ # filter out classes not evaluated for tracking
+ if det['detection_name'] not in self.WAYMO_TRACKING_NAMES:
+ print("filter {}".format(det['detection_name']))
+ continue
+
+ det['ct'] = np.array(det['translation'][:2])
+ det['tracking'] = np.array(det['velocity'][:2]) * -1 * time_lag
+ det['label_preds'] = self.WAYMO_TRACKING_NAMES.index(det['detection_name'])
+ temp.append(det)
+
+ results = temp
+
+ N = len(results)
+ M = len(self.tracks)
+
+ # N X 2
+ if 'tracking' in results[0]:
+ dets = np.array(
+ [ det['ct'] + det['tracking'].astype(np.float32)
+ for det in results], np.float32)
+ else:
+ dets = np.array(
+ [det['ct'] for det in results], np.float32)
+
+ item_cat = np.array([item['label_preds'] for item in results], np.int32) # N
+ track_cat = np.array([track['label_preds'] for track in self.tracks], np.int32) # M
+
+ max_diff = np.array([self.WAYMO_CLS_VELOCITY_ERROR[box['detection_name']] for box in results], np.float32)
+
+ tracks = np.array(
+ [pre_det['ct'] for pre_det in self.tracks], np.float32) # M x 2
+
+ if len(tracks) > 0: # NOT FIRST FRAME
+ dist = (((tracks.reshape(1, -1, 2) - \
+ dets.reshape(-1, 1, 2)) ** 2).sum(axis=2)) # N x M
+ dist = np.sqrt(dist) # absolute distance in meter
+
+ invalid = ((dist > max_diff.reshape(N, 1)) + \
+ (item_cat.reshape(N, 1) != track_cat.reshape(1, M))) > 0
+
+ dist = dist + invalid * 1e18
+ matched_indices = greedy_assignment(copy.deepcopy(dist))
+ else: # first few frame
+ assert M == 0
+ matched_indices = np.array([], np.int32).reshape(-1, 2)
+
+ unmatched_dets = [d for d in range(dets.shape[0]) \
+ if not (d in matched_indices[:, 0])]
+
+ unmatched_tracks = [d for d in range(tracks.shape[0]) \
+ if not (d in matched_indices[:, 1])]
+
+ matches = matched_indices
+
+ ret = []
+ for m in matches:
+ track = results[m[0]]
+ track['tracking_id'] = self.tracks[m[1]]['tracking_id']
+ track['age'] = 1
+ track['active'] = self.tracks[m[1]]['active'] + 1
+ ret.append(track)
+
+ for i in unmatched_dets:
+ track = results[i]
+ if track['score'] > self.score_thresh:
+ self.id_count += 1
+ track['tracking_id'] = self.id_count
+ track['age'] = 1
+ track['active'] = 1
+ ret.append(track)
+
+ # still store unmatched tracks if its age doesn't exceed max_age, however, we shouldn't output
+ # the object in current frame
+ for i in unmatched_tracks:
+ track = self.tracks[i]
+ if track['age'] < self.max_age:
+ track['age'] += 1
+ track['active'] = 0
+ ct = track['ct']
+
+ # movement in the last second
+ if 'tracking' in track:
+ offset = track['tracking'] * -1 # move forward
+ track['ct'] = ct + offset
+ ret.append(track)
+
+ self.tracks = ret
+ return ret