Skip to content
This repository was archived by the owner on Oct 13, 2021. It is now read-only.

Initial PointPillars Release #1

Merged
merged 5 commits into from
Dec 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 68 additions & 159 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,72 +1,86 @@
# SECOND for KITTI object detection
SECOND detector. Based on my unofficial implementation of VoxelNet with some improvements.
# PointPillars

ONLY support python 3.6+, pytorch 0.4.1+. Don't support pytorch 0.4.0. Tested in Ubuntu 16.04/18.04.
Welcome to PointPillars.

* Ubuntu 18.04 have speed problem in my environment and may can't build/usr SparseConvNet.
This repo demonstrates how to reproduce the results from
_PointPillars: Fast Encoders for Object Detection from Point Clouds_ on the
[KITTI dataset](http://www.cvlibs.net/datasets/kitti/) by making the minimum required changes from the preexisting
open source codebase [SECOND](https://github.com/traveller59/second.pytorch).

### Performance in KITTI validation set (50/50 split, people have problems, need to be tuned.)
This is not an official nuTonomy codebase, but it can be used to match the published PointPillars results.

```
Car AP@0.70, 0.70, 0.70:
bbox AP:90.80, 88.97, 87.52
bev AP:89.96, 86.69, 86.11
3d AP:87.43, 76.48, 74.66
aos AP:90.68, 88.39, 86.57
Car AP@0.70, 0.50, 0.50:
bbox AP:90.80, 88.97, 87.52
bev AP:90.85, 90.02, 89.36
3d AP:90.85, 89.86, 89.05
aos AP:90.68, 88.39, 86.57
```
![Example Results](https://raw.githubusercontent.com/nutonomy/second.pytorch/master/images/pointpillars_kitti_results.pdf)


## Getting Started

This is a fork of [SECOND for KITTI object detection](https://github.com/traveller59/second.pytorch) and the relevant
subset of the original README is reproduced here.

### Code Support

ONLY supports python 3.6+, pytorch 0.4.1+. Code has only been tested on Ubuntu 16.04/18.04.

## Install
### Install

### 1. Clone code
#### 1. Clone code

```bash
git clone https://github.com/traveller59/second.pytorch.git
cd ./second.pytorch/second
git clone https://github.com/nutonomy/second.pytorch.git
```

### 2. Install dependence python packages
#### 2. Install Python packages

It is recommend to use Anaconda package manager.
It is recommend to use the Anaconda package manager.

First, use Anaconda to configure as many packages as possible.
```bash
pip install shapely fire pybind11 tensorboardX protobuf scikit-image numba pillow
conda create -n pointpillars python=3.7 anaconda
source activate pointpillars
conda install shapely pybind11 protobuf scikit-image numba pillow
conda install pytorch torchvision -c pytorch
conda install google-sparsehash -c bioconda
```

If you don't have Anaconda:

Then use pip for the packages missing from Anaconda.
```bash
pip install numba
pip install --upgrade pip
pip install fire tensorboardX
```

Follow instructions in https://github.com/facebookresearch/SparseConvNet to install SparseConvNet.
Finally, install SparseConvNet. This is not required for PointPillars, but the general SECOND code base expects this
to be correctly configured.
```bash
git clone git@github.com:facebookresearch/SparseConvNet.git
cd SparseConvNet/
bash build.sh
# NOTE: if bash build.sh fails, try bash develop.sh instead
```

Install Boost geometry:
Additionally, you may need to install Boost geometry:

```bash
sudo apt-get install libboost-all-dev
```


### 3. Setup cuda for numba
#### 3. Setup cuda for numba

you need to add following environment variable for numba.cuda, you can add them to ~/.bashrc:
You need to add following environment variables for numba to ~/.bashrc:

```bash
export NUMBAPRO_CUDA_DRIVER=/usr/lib/x86_64-linux-gnu/libcuda.so
export NUMBAPRO_NVVM=/usr/local/cuda/nvvm/lib64/libnvvm.so
export NUMBAPRO_LIBDEVICE=/usr/local/cuda/nvvm/libdevice
```

### 4. add second.pytorch/ to PYTHONPATH
#### 4. PYTHONPATH

## Prepare dataset
Add second.pytorch/ to your PYTHONPATH.

* Dataset preparation
### Prepare dataset

#### 1. Dataset preparation

Download KITTI dataset and create some directories first:

Expand All @@ -85,27 +99,29 @@ Download KITTI dataset and create some directories first:
└── velodyne_reduced <-- empty directory
```

* Create kitti infos:
Note: PointPillar's protos use ```KITTI_DATASET_ROOT=/data/sets/kitti_second/```.

#### 2. Create kitti infos:

```bash
python create_data.py create_kitti_info_file --data_path=KITTI_DATASET_ROOT
```

* Create reduced point cloud:
#### 3. Create reduced point cloud:

```bash
python create_data.py create_reduced_point_cloud --data_path=KITTI_DATASET_ROOT
```

* Create groundtruth-database infos:
#### 4. Create groundtruth-database infos:

```bash
python create_data.py create_groundtruth_database --data_path=KITTI_DATASET_ROOT
```

* Modify config file
#### 5. Modify config file

There is some path need to be configured in config file:
The config file needs to be edited to point to the above datasets:

```bash
train_input_reader: {
Expand All @@ -125,135 +141,28 @@ eval_input_reader: {
}
```

## Usage

### train
### Train

```bash
python ./pytorch/train.py train --config_path=./configs/car.config --model_dir=/path/to/model_dir
cd ~/second.pytorch/second
python ./pytorch/train.py train --config_path=./configs/pointpillars/car/xyres_16.proto --model_dir=/path/to/model_dir
```

* Make sure "/path/to/model_dir" doesn't exist if you want to train new model. A new directory will be created if the model_dir doesn't exist, otherwise will read checkpoints in it.
* If you want to train a new model, make sure "/path/to/model_dir" doesn't exist.
* If "/path/to/model_dir" does exist, training will be resumed from the last checkpoint.
* Training only supports a single GPU.
* Training uses a batchsize=2 which should fit in memory on most standard GPUs.
* On a single 1080Ti, training xyres_16 requires approximately 20 hours for 160 epochs.

* training process use batchsize=3 as default for 1080Ti, you need to reduce batchsize if your GPU has less memory.

* Currently only support single GPU training, but train a model only needs 20 hours (165 epoch) in a single 1080Ti and only needs 40 epoch to reach 74 AP in car moderate 3D in Kitti validation dateset.
### Evaluate

### evaluate

```bash
python ./pytorch/train.py evaluate --config_path=./configs/car.config --model_dir=/path/to/model_dir
```

* detection result will saved as a result.pkl file in model_dir/eval_results/step_xxx or save as official KITTI label format if you use --pickle_result=False.

### pretrained model

Before using pretrained model, you need to modify some file in SparseConvNet because the pretrained model doesn't support SparseConvNet master:

* convolution.py
```Python
# self.weight = Parameter(torch.Tensor(
# self.filter_volume, nIn, nOut).normal_(
# 0,
# std))
self.weight = Parameter(torch.Tensor(
self.filter_volume * nIn, nOut).normal_(
0,
std))
# ...
# output.features = ConvolutionFunction.apply(
# input.features,
# self.weight,
output.features = ConvolutionFunction.apply(
input.features,
self.weight.view(self.filter_volume, self.nIn, self.nOut),
```

* submanifoldConvolution.py
```Python
# self.weight = Parameter(torch.Tensor(
# self.filter_volume, nIn, nOut).normal_(
# 0,
# std))
self.weight = Parameter(torch.Tensor(
self.filter_volume * nIn, nOut).normal_(
0,
std))
# ...
# output.features = SubmanifoldConvolutionFunction.apply(
# input.features,
# self.weight,
output.features = SubmanifoldConvolutionFunction.apply(
input.features,
self.weight.view(self.filter_volume, self.nIn, self.nOut),
```

You can download pretrained models in [google drive](https://drive.google.com/open?id=1eblyuILwbxkJXfIP5QlALW5N_x5xJZhL). The car model is corresponding to car.config, the car_tiny model is corresponding to car.tiny.config and the people model is corresponding to people.config.

## Docker

You can use a prebuilt docker for testing:
cd ~/second.pytorch/second/
python pytorch/train.py evaluate --config_path= configs/pointpillars/car/xyres_16.proto --model_dir=/path/to/model_dir
```
docker pull scrin/second-pytorch
```
Then run:
```
nvidia-docker run -it --rm -v /media/yy/960evo/datasets/:/root/data -v $HOME/pretrained_models:/root/model --ipc=host second-pytorch:latest
python ./pytorch/train.py evaluate --config_path=./configs/car.config --model_dir=/root/model/car
...
```

Currently there is a problem that training and evaluating in docker is very slow.

## Try Kitti Viewer Web

### Major step

1. run ```python ./kittiviewer/backend.py main --port=xxxx``` in your server/local.

2. run ```cd ./kittiviewer/frontend && python -m http.server``` to launch a local web server.

3. open your browser and enter your frontend url (e.g. http://127.0.0.1:8000, default]).

4. input backend url (e.g. http://127.0.0.1:16666)

5. input root path, info path and det path (optional)

6. click load, loadDet (optional), input image index in center bottom of screen and press Enter.

### Inference step

Firstly the load button must be clicked and load successfully.

1. input checkpointPath and configPath.

2. click buildNet.

3. click inference.

![GuidePic](https://raw.githubusercontent.com/traveller59/second.pytorch/master/images/viewerweb.png)



## Try Kitti Viewer (Deprecated)

You should use kitti viewer based on pyqt and pyqtgraph to check data before training.

run ```python ./kittiviewer/viewer.py```, check following picture to use kitti viewer:
![GuidePic](https://raw.githubusercontent.com/traveller59/second.pytorch/master/images/simpleguide.png)

## Concepts


* Kitti lidar box

A kitti lidar box is consist of 7 elements: [x, y, z, w, l, h, rz], see figure.

![Kitti Box Image](https://raw.githubusercontent.com/traveller59/second.pytorch/master/images/kittibox.png)

All training and inference code use kitti box format. So we need to convert other format to KITTI format before training.

* Kitti camera box

A kitti camera box is consist of 7 elements: [x, y, z, l, h, w, ry].
* Detection result will saved in model_dir/eval_results/step_xxx.
* By default, results are stored as a result.pkl file. To save as official KITTI label format use --pickle_result=False.
Binary file added images/pointpillars_kitti_results.pdf
Binary file not shown.
1 change: 1 addition & 0 deletions second/builder/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def build(input_reader_config,
gt_loc_noise_std=list(cfg.groundtruth_localization_noise_std),
global_rotation_noise=list(cfg.global_rotation_uniform_noise),
global_scaling_noise=list(cfg.global_scaling_uniform_noise),
global_loc_noise_std=(0.2, 0.2, 0.2),
global_random_rot_range=list(
cfg.global_random_rotation_range_per_object),
db_sampler=db_sampler,
Expand Down
3 changes: 3 additions & 0 deletions second/configs/pointpillars/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# PointPillars Configs

The configuration files in these directories can be used to reproduce the results published in PointPillars.
Loading