Skip to content

Commit

Permalink
[Feature]Support NeRF-Det (#2732)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yanyirong authored Jan 4, 2024
1 parent 8deeb6e commit 2dad86c
Show file tree
Hide file tree
Showing 17 changed files with 4,449 additions and 0 deletions.
115 changes: 115 additions & 0 deletions projects/NeRF-Det/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection

> [NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection](https://arxiv.org/abs/2307.14620)
<!-- [ALGORITHM] -->

## Abstract

NeRF-Det is a novel method for indoor 3D detection with posed RGB images as input. Unlike existing indoor 3D detection methods that struggle to model scene geometry, NeRF-Det makes novel use of NeRF in an end-to-end manner to explicitly estimate 3D geometry, thereby improving 3D detection performance. Specifically, to avoid the significant extra latency associated with per-scene optimization of NeRF, NeRF-Det introduce sufficient geometry priors to enhance the generalizability of NeRF-MLP. Furthermore, it subtly connect the detection and NeRF branches through a shared MLP, enabling an efficient adaptation of NeRF to detection and yielding geometry-aware volumetric representations for 3D detection. NeRF-Det outperforms state-of-the-arts by 3.9 mAP and 3.1 mAP on the ScanNet and ARKITScenes benchmarks, respectively. The author provide extensive analysis to shed light on how NeRF-Det works. As a result of joint-training design, NeRF-Det is able to generalize well to unseen scenes for object detection, view synthesis, and depth estimation tasks without requiring per-scene optimization. Code will be available at https://github.com/facebookresearch/NeRF-Det

<div align=center>
<img src="https://chenfengxu714.github.io/nerfdet/static/images/method-cropped_1.png" width="800"/>
</div>

## Introduction

This directory contains the implementations of NeRF-Det (https://arxiv.org/abs/2307.14620). Our implementations are built on top of MMdetection3D. We have updated NeRF-Det to be compatible with latest mmdet3d version. The codebase and config files have all changed to adapt to the new mmdet3d version. All previous pretrained models are verified with the result listed below. However, newly trained models are yet to be uploaded.

<!-- Share any information you would like others to know. For example:
Author: @xxx.
This is an implementation of \[XXX\]. -->

## Dataset

The format of the scannet dataset in the latest version of mmdet3d only supports the lidar tasks. For NeRF-Det, we need to create the new format of ScanNet Dataset.

Please following the files in mmdet3d to prepare the raw data of ScanNet. After that, please use this command to generate the pkls used in nerfdet.

```bash
python projects/NeRF-Det/prepare_infos.py --root-path ./data/scannet --out-dir ./data/scannet
```

The new format of the pkl is organized as below:

- scannet_infos_train.pkl: The train data infos, the detailed info of each scan is as follows:
- info\['instances'\]:A list of dict contains all annotations, each dict contains all annotation information of single instance.For the i-th instance:
- info\['instances'\]\[i\]\['bbox_3d'\]: List of 6 numbers representing the axis_aligned in depth coordinate system, in (x,y,z,l,w,h) order.
- info\['instances'\]\[i\]\['bbox_label_3d'\]: The label of each 3d bounding boxes.
- info\['cam2img'\]: The intrinsic matrix.Every scene has one matrix.
- info\['lidar2cam'\]: The extrinsic matrixes.Every scene has 300 matrixes.
- info\['img_paths'\]: The paths of the 300 rgb pictures.
- info\['axis_align_matrix'\]: The align matrix.Every scene has one matrix.

After preparing your scannet dataset pkls,please change the paths in configs to fit your project.

## Train

In MMDet3D's root directory, run the following command to train the model:

```bash
python tools/train.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${WORK_DIR}
```

## Results and Models

### NeRF-Det

| Backbone | mAP@25 | mAP@50 | Log |
| :-------------------------------------------------------------: | :----: | :----: | :-------: |
| [NeRF-Det-R50](./configs/nerfdet_res50_2x_low_res.py) | 53.0 | 26.8 | [log](<>) |
| [NeRF-Det-R50\*](./configs/nerfdet_res50_2x_low_res_depth.py) | 52.2 | 28.5 | [log](<>) |
| [NeRF-Det-R101\*](./configs/nerfdet_res101_2x_low_res_depth.py) | 52.3 | 28.5 | [log](<>) |

(Here NeRF-Det-R50\* means this model uses depth information in the training step)

### Notes

- The values showed in the chart all represents the best mAP in the training.

- Since there is a lot of randomness in the behavior of the model, we conducted three experiments on each config and took the average. The mAP showed on the above chart are all average values.

- We also conducted the same experiments in the original code, the results are showed below.

| Backbone | mAP@25 | mAP@50 |
| :-------------: | :----: | :----: |
| NeRF-Det-R50 | 52.8 | 26.8 |
| NeRF-Det-R50\* | 52.4 | 27.5 |
| NeRF-Det-R101\* | 52.8 | 28.6 |

- Attention: Because of the randomness in the construction of the ScanNet dataset itself and the behavior of the model, the training results will fluctuate considerably. According to experimental results and experience, the experimental results will fluctuate by plus or minus 1.5 points.

## Evaluation using pretrained models

1. Download the pretrained checkpoints through the linkings in the above chart.

2. Testing

To test, use:

```bash
python tools/test.py projects/NeRF-Det/configs/nerfdet_res50_2x_low_res.py ${CHECKPOINT_PATH}
```

## Citation

<!-- You may remove this section if not applicable. -->

```latex
@inproceedings{
xu2023nerfdet,
title={NeRF-Det: Learning Geometry-Aware Volumetric Representation for Multi-View 3D Object Detection},
author={Xu, Chenfeng and Wu, Bichen and Hou, Ji and Tsai, Sam and Li, Ruilong and Wang, Jialiang and Zhan, Wei and He, Zijian and Vajda, Peter and Keutzer, Kurt and Tomizuka, Masayoshi},
booktitle={ICCV},
year={2023},
}
@inproceedings{
park2023time,
title={Time Will Tell: New Outlooks and A Baseline for Temporal Multi-View 3D Object Detection},
author={Jinhyung Park and Chenfeng Xu and Shijia Yang and Kurt Keutzer and Kris M. Kitani and Masayoshi Tomizuka and Wei Zhan},
booktitle={The Eleventh International Conference on Learning Representations },
year={2023},
url={https://openreview.net/forum?id=H3HcEJA2Um}
}
```
198 changes: 198 additions & 0 deletions projects/NeRF-Det/configs/nerfdet_res101_2x_low_res_depth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
_base_ = ['../../../configs/_base_/default_runtime.py']

custom_imports = dict(imports=['projects.NeRF-Det.nerfdet'])
prior_generator = dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-3.2, -3.2, -1.28, 3.2, 3.2, 1.28]],
rotations=[.0])

model = dict(
type='NerfDet',
data_preprocessor=dict(
type='NeRFDetDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_size_divisor=10),
backbone=dict(
type='mmdet.ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=False),
norm_eval=True,
init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet101'),
style='pytorch'),
neck=dict(
type='mmdet.FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
num_outs=4),
neck_3d=dict(
type='IndoorImVoxelNeck',
in_channels=256,
out_channels=128,
n_blocks=[1, 1, 1]),
bbox_head=dict(
type='NerfDetHead',
bbox_loss=dict(type='AxisAlignedIoULoss', loss_weight=1.0),
n_classes=18,
n_levels=3,
n_channels=128,
n_reg_outs=6,
pts_assign_threshold=27,
pts_center_threshold=18,
prior_generator=prior_generator),
prior_generator=prior_generator,
voxel_size=[.16, .16, .2],
n_voxels=[40, 40, 16],
aabb=([-2.7, -2.7, -0.78], [3.7, 3.7, 1.78]),
near_far_range=[0.2, 8.0],
N_samples=64,
N_rand=2048,
nerf_mode='image',
depth_supervise=True,
use_nerf_mask=True,
nerf_sample_view=20,
squeeze_scale=4,
nerf_density=True,
train_cfg=dict(),
test_cfg=dict(nms_pre=1000, iou_thr=.25, score_thr=.01))

dataset_type = 'MultiViewScanNetDataset'
data_root = 'data/scannet/'
class_names = [
'cabinet', 'bed', 'chair', 'sofa', 'table', 'door', 'window', 'bookshelf',
'picture', 'counter', 'desk', 'curtain', 'refrigerator', 'showercurtrain',
'toilet', 'sink', 'bathtub', 'garbagebin'
]
metainfo = dict(CLASSES=class_names)
file_client_args = dict(backend='disk')

input_modality = dict(
use_camera=True,
use_depth=True,
use_lidar=False,
use_neuralrecon_depth=False,
use_ray=True)
backend_args = None

train_collect_keys = [
'img', 'gt_bboxes_3d', 'gt_labels_3d', 'depth', 'lightpos', 'nerf_sizes',
'raydirs', 'gt_images', 'gt_depths', 'denorm_images'
]

test_collect_keys = [
'img',
'depth',
'lightpos',
'nerf_sizes',
'raydirs',
'gt_images',
'gt_depths',
'denorm_images',
]

train_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=48,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=10),
dict(type='RandomShiftOrigin', std=(.7, .7, .0)),
dict(type='PackNeRFDetInputs', keys=train_collect_keys)
]

test_pipeline = [
dict(type='LoadAnnotations3D'),
dict(
type='MultiViewPipeline',
n_images=101,
transforms=[
dict(type='LoadImageFromFile', file_client_args=file_client_args),
dict(type='Resize', scale=(320, 240), keep_ratio=True),
],
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
margin=10,
depth_range=[0.5, 5.5],
loading='random',
nerf_target_views=1),
dict(type='PackNeRFDetInputs', keys=test_collect_keys)
]

train_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=6,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_train_new.pkl',
pipeline=train_pipeline,
modality=input_modality,
test_mode=False,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo)))
val_dataloader = dict(
batch_size=1,
num_workers=5,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='scannet_infos_val_new.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
filter_empty_gt=True,
box_type_3d='Depth',
metainfo=metainfo))
test_dataloader = val_dataloader

val_evaluator = dict(type='IndoorMetric')
test_evaluator = val_evaluator

train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1)
test_cfg = dict()
val_cfg = dict()

optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='AdamW', lr=0.0002, weight_decay=0.0001),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}),
clip_grad=dict(max_norm=35., norm_type=2))
param_scheduler = [
dict(
type='MultiStepLR',
begin=0,
end=12,
by_epoch=True,
milestones=[8, 11],
gamma=0.1)
]

# hooks
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', interval=1, max_keep_ckpts=12))

# runtime
find_unused_parameters = True # only 1 of 4 FPN outputs is used
Loading

0 comments on commit 2dad86c

Please sign in to comment.