forked from open-mmlab/mmdetection3d
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Add TR3D detector to projects (open-mmlab#2274)
* first tr3d commit * all tr3d files added * all tr3d is ok * fix comments * fix config imports and readme * fix comments * update links in readme * fix lint
- Loading branch information
Showing
12 changed files
with
1,309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
# TR3D: Towards Real-Time Indoor 3D Object Detection | ||
|
||
> [TR3D: Towards Real-Time Indoor 3D Object Detection](https://arxiv.org/abs/2302.02858) | ||
## Abstract | ||
|
||
Recently, sparse 3D convolutions have changed 3D object detection. Performing on par with the voting-based approaches, 3D CNNs are memory-efficient and scale to large scenes better. However, there is still room for improvement. With a conscious, practice-oriented approach to problem-solving, we analyze the performance of such methods and localize the weaknesses. Applying modifications that resolve the found issues one by one, we end up with TR3D: a fast fully-convolutional 3D object detection model trained end-to-end, that achieves state-of-the-art results on the standard benchmarks, ScanNet v2, SUN RGB-D, and S3DIS. Moreover, to take advantage of both point cloud and RGB inputs, we introduce an early fusion of 2D and 3D features. We employ our fusion module to make conventional 3D object detection methods multimodal and demonstrate an impressive boost in performance. Our model with early feature fusion, which we refer to as TR3D+FF, outperforms existing 3D object detection approaches on the SUN RGB-D dataset. Overall, besides being accurate, both TR3D and TR3D+FF models are lightweight, memory-efficient, and fast, thereby marking another milestone on the way toward real-time 3D object detection. | ||
|
||
<div align="center"> | ||
<img src="https://user-images.githubusercontent.com/6030962/219644780-646516ec-a6c1-4ec5-9b8c-63bbc9702d05.png" width="800"/> | ||
</div> | ||
|
||
## Usage | ||
|
||
Training and inference in this project were tested with `mmdet3d==1.1.0rc3`. | ||
|
||
### Training commands | ||
|
||
In MMDet3D's root directory, run the following command to train the model: | ||
|
||
```bash | ||
python tools/train.py projects/TR3D/configs/tr3d_1xb16_scannet-3d-18class.py | ||
``` | ||
|
||
### Testing commands | ||
|
||
In MMDet3D's root directory, run the following command to test the model: | ||
|
||
```bash | ||
python tools/test.py projects/TR3D/configs/tr3d_1xb16_scannet-3d-18class.py ${CHECKPOINT_PATH} | ||
``` | ||
|
||
## Results and models | ||
|
||
### ScanNet | ||
|
||
| Backbone | Mem (GB) | Inf time (fps) | AP@0.25 | AP@0.5 | Download | | ||
| :--------------------------------------------------------: | :------: | :------------: | :---------: | :---------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| [MinkResNet34](./configs/tr3d_1xb16_scannet-3d-18class.py) | 8.6 | 23.7 | 72.9 (72.0) | 59.3 (57.4) | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_scannet-3d-18class/tr3d_1xb16_scannet-3d-18class.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_scannet-3d-18class/tr3d_1xb16_scannet-3d-18class.log.json) | | ||
|
||
### SUN RGB-D | ||
|
||
| Backbone | Mem (GB) | Inf time (fps) | AP@0.25 | AP@0.5 | Download | | ||
| :--------------------------------------------------------: | :------: | :------------: | :---------: | :---------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| [MinkResNet34](./configs/tr3d_1xb16_sunrgbd-3d-10class.py) | 3.8 | 27.5 | 67.1 (66.3) | 50.4 (49.6) | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_sunrgbd-3d-10class/tr3d_1xb16_sunrgbd-3d-10class.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_sunrgbd-3d-10class/tr3d_1xb16_sunrgbd-3d-10class.log.json) | | ||
|
||
### S3DIS | ||
|
||
| Backbone | Mem (GB) | Inf time (fps) | AP@0.25 | AP@0.5 | Download | | ||
| :-----------------------------------------------------: | :------: | :------------: | :---------: | :---------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | | ||
| [MinkResNet34](./configs/tr3d_1xb16_s3dis-3d-5class.py) | 15.2 | 21.0 | 74.5 (72.1) | 51.7 (47.6) | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_s3dis-3d-5class/tr3d_1xb16_s3dis-3d-5class.pth) \| [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/tr3d/tr3d_1xb16_s3dis-3d-5class/tr3d_1xb16_s3dis-3d-5class.log.json) | | ||
|
||
**Note** | ||
|
||
- We report the results across 5 train runs followed by 5 test runs. Median values are in round brackets. | ||
- Inference time is given for a single NVidia GeForce RTX 4090 GPU. | ||
|
||
## Citation | ||
|
||
```latex | ||
@article{rukhovich2023tr3d, | ||
title={TR3D: Towards Real-Time Indoor 3D Object Detection}, | ||
author={Rukhovich, Danila and Vorontsova, Anna and Konushin, Anton}, | ||
journal={arXiv preprint arXiv:2302.02858}, | ||
year={2023} | ||
} | ||
``` | ||
|
||
## Checklist | ||
|
||
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`. | ||
|
||
- [x] Finish the code | ||
|
||
- [x] Basic docstrings & proper citation | ||
|
||
- [x] Test-time correctness | ||
|
||
- [x] A full README | ||
|
||
- [x] Milestone 2: Indicates a successful model implementation. | ||
|
||
- [x] Training-time correctness | ||
|
||
- [ ] Milestone 3: Good to be a part of our core package! | ||
|
||
- [x] Type hints and docstrings | ||
|
||
- [ ] Unit tests | ||
|
||
- [ ] Code polishing | ||
|
||
- [ ] Metafile.yml | ||
|
||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure. | ||
|
||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
_base_ = ['mmdet3d::_base_/default_runtime.py'] | ||
custom_imports = dict(imports=['projects.TR3D.tr3d']) | ||
|
||
model = dict( | ||
type='MinkSingleStage3DDetector', | ||
data_preprocessor=dict(type='Det3DDataPreprocessor'), | ||
backbone=dict( | ||
type='TR3DMinkResNet', | ||
in_channels=3, | ||
depth=34, | ||
norm='batch', | ||
num_planes=(64, 128, 128, 128)), | ||
neck=dict( | ||
type='TR3DNeck', in_channels=(64, 128, 128, 128), out_channels=128), | ||
bbox_head=dict( | ||
type='TR3DHead', | ||
in_channels=128, | ||
voxel_size=0.01, | ||
pts_center_threshold=6, | ||
num_reg_outs=6), | ||
train_cfg=dict(), | ||
test_cfg=dict(nms_pre=1000, iou_thr=0.5, score_thr=0.01)) | ||
|
||
optim_wrapper = dict( | ||
type='OptimWrapper', | ||
optimizer=dict(type='AdamW', lr=0.001, weight_decay=0.0001), | ||
clip_grad=dict(max_norm=10, norm_type=2)) | ||
|
||
# learning rate | ||
param_scheduler = dict( | ||
type='MultiStepLR', | ||
begin=0, | ||
end=12, | ||
by_epoch=True, | ||
milestones=[8, 11], | ||
gamma=0.1) | ||
|
||
custom_hooks = [dict(type='EmptyCacheHook', after_iter=True)] | ||
|
||
# training schedule for 1x | ||
train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=12, val_interval=1) | ||
val_cfg = dict(type='ValLoop') | ||
test_cfg = dict(type='TestLoop') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
_base_ = ['./tr3d.py', 'mmdet3d::_base_/datasets/s3dis-3d.py'] | ||
custom_imports = dict(imports=['projects.TR3D.tr3d']) | ||
|
||
dataset_type = 'S3DISDataset' | ||
data_root = 'data/s3dis/' | ||
metainfo = dict(classes=('table', 'chair', 'sofa', 'bookcase', 'board')) | ||
train_area = [1, 2, 3, 4, 6] | ||
|
||
model = dict(bbox_head=dict(label2level=[1, 0, 1, 1, 0])) | ||
|
||
train_pipeline = [ | ||
dict( | ||
type='LoadPointsFromFile', | ||
coord_type='DEPTH', | ||
shift_height=False, | ||
use_color=True, | ||
load_dim=6, | ||
use_dim=[0, 1, 2, 3, 4, 5]), | ||
dict(type='LoadAnnotations3D'), | ||
dict(type='PointSample', num_points=100000), | ||
dict( | ||
type='RandomFlip3D', | ||
sync_2d=False, | ||
flip_ratio_bev_horizontal=0.5, | ||
flip_ratio_bev_vertical=0.5), | ||
dict( | ||
type='GlobalRotScaleTrans', | ||
rot_range=[0, 0], | ||
scale_ratio_range=[0.95, 1.05], | ||
translation_std=[0.1, 0.1, 0.1], | ||
shift_height=False), | ||
dict(type='NormalizePointsColor', color_mean=None), | ||
dict( | ||
type='Pack3DDetInputs', | ||
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
dataset=dict(datasets=[ | ||
dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file=f's3dis_infos_Area_{i}.pkl', | ||
pipeline=train_pipeline, | ||
filter_empty_gt=False, | ||
metainfo=metainfo, | ||
box_type_3d='Depth') for i in train_area | ||
]))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
_base_ = ['./tr3d.py', 'mmdet3d::_base_/datasets/scannet-3d.py'] | ||
custom_imports = dict(imports=['projects.TR3D.tr3d']) | ||
|
||
model = dict( | ||
bbox_head=dict( | ||
label2level=[0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0])) | ||
|
||
train_pipeline = [ | ||
dict( | ||
type='LoadPointsFromFile', | ||
coord_type='DEPTH', | ||
shift_height=False, | ||
use_color=True, | ||
load_dim=6, | ||
use_dim=[0, 1, 2, 3, 4, 5]), | ||
dict(type='LoadAnnotations3D'), | ||
dict(type='GlobalAlignment', rotation_axis=2), | ||
# We do not sample 100k points for ScanNet, as very few scenes have | ||
# significantly more then 100k points. So we sample 33 to 100% of them. | ||
dict(type='TR3DPointSample', num_points=0.33), | ||
dict( | ||
type='RandomFlip3D', | ||
sync_2d=False, | ||
flip_ratio_bev_horizontal=0.5, | ||
flip_ratio_bev_vertical=0.5), | ||
dict( | ||
type='GlobalRotScaleTrans', | ||
rot_range=[-0.02, 0.02], | ||
scale_ratio_range=[0.9, 1.1], | ||
translation_std=[0.1, 0.1, 0.1], | ||
shift_height=False), | ||
dict(type='NormalizePointsColor', color_mean=None), | ||
dict( | ||
type='Pack3DDetInputs', | ||
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) | ||
] | ||
test_pipeline = [ | ||
dict( | ||
type='LoadPointsFromFile', | ||
coord_type='DEPTH', | ||
shift_height=False, | ||
use_color=True, | ||
load_dim=6, | ||
use_dim=[0, 1, 2, 3, 4, 5]), | ||
dict(type='GlobalAlignment', rotation_axis=2), | ||
dict( | ||
type='MultiScaleFlipAug3D', | ||
img_scale=(1333, 800), | ||
pts_scale_ratio=1, | ||
flip=False, | ||
transforms=[ | ||
# We do not sample 100k points for ScanNet, as very few scenes have | ||
# significantly more then 100k points. So it doesn't affect | ||
# inference time and we can accept all points. | ||
# dict(type='PointSample', num_points=100000), | ||
dict(type='NormalizePointsColor', color_mean=None), | ||
]), | ||
dict(type='Pack3DDetInputs', keys=['points']) | ||
] | ||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
type='RepeatDataset', | ||
times=15, | ||
dataset=dict(pipeline=train_pipeline, filter_empty_gt=False))) | ||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
_base_ = ['./tr3d.py', 'mmdet3d::_base_/datasets/sunrgbd-3d.py'] | ||
custom_imports = dict(imports=['projects.TR3D.tr3d']) | ||
|
||
model = dict( | ||
bbox_head=dict( | ||
num_reg_outs=8, | ||
label2level=[1, 1, 1, 0, 0, 1, 0, 0, 1, 0], | ||
bbox_loss=dict( | ||
type='TR3DRotatedIoU3DLoss', mode='diou', reduction='none'))) | ||
|
||
train_pipeline = [ | ||
dict( | ||
type='LoadPointsFromFile', | ||
coord_type='DEPTH', | ||
shift_height=False, | ||
use_color=True, | ||
load_dim=6, | ||
use_dim=[0, 1, 2, 3, 4, 5]), | ||
dict(type='LoadAnnotations3D'), | ||
dict(type='PointSample', num_points=100000), | ||
dict( | ||
type='RandomFlip3D', | ||
sync_2d=False, | ||
flip_ratio_bev_horizontal=0.5, | ||
flip_ratio_bev_vertical=0), | ||
dict( | ||
type='GlobalRotScaleTrans', | ||
rot_range=[-0.523599, 0.523599], | ||
scale_ratio_range=[.85, 1.15], | ||
translation_std=[.1, .1, .1], | ||
shift_height=False), | ||
dict( | ||
type='Pack3DDetInputs', | ||
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) | ||
] | ||
test_pipeline = [ | ||
dict( | ||
type='LoadPointsFromFile', | ||
coord_type='DEPTH', | ||
shift_height=False, | ||
use_color=True, | ||
load_dim=6, | ||
use_dim=[0, 1, 2, 3, 4, 5]), | ||
dict( | ||
type='MultiScaleFlipAug3D', | ||
img_scale=(1333, 800), | ||
pts_scale_ratio=1, | ||
flip=False, | ||
transforms=[ | ||
dict(type='PointSample', num_points=100000), | ||
]), | ||
dict(type='Pack3DDetInputs', keys=['points']) | ||
] | ||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
type='RepeatDataset', | ||
times=5, | ||
dataset=dict(pipeline=train_pipeline, filter_empty_gt=False))) | ||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .axis_aligned_iou_loss import TR3DAxisAlignedIoULoss | ||
from .mink_resnet import TR3DMinkResNet | ||
from .rotated_iou_loss import TR3DRotatedIoU3DLoss | ||
from .tr3d_head import TR3DHead | ||
from .tr3d_neck import TR3DNeck | ||
from .transforms_3d import TR3DPointSample | ||
|
||
__all__ = [ | ||
'TR3DAxisAlignedIoULoss', 'TR3DMinkResNet', 'TR3DRotatedIoU3DLoss', | ||
'TR3DHead', 'TR3DNeck', 'TR3DPointSample' | ||
] |
Oops, something went wrong.