Skip to content

Improved pytorch implementation of RandLA (https://arxiv.org/abs/1911.11236) with easier transferability and reproductibility

License

Notifications You must be signed in to change notification settings

SC-shendazt/RandLA-pytorch

 
 

Repository files navigation

RandLA-Net-pytorch

This repository contains the implementation of RandLA-Net (CVPR 2020 Oral) in PyTorch.

Updates:

  • We extend the model with to train with one synthetic dataset, SynLiDAR
  • We replaced the weighted cross entropy in the oroginal implementation with focal loss to alliverate the influence of class imbalance
  • Add frequency weighted mIoU as indicator for the training & validation
  • Fix some minor bugs in the original implementation
  • We improve the mIoU on the validation set from 53.1% on validation set to 55.1%.
  • This is a good starting point & backbone choice for those who plan to start their research on point clouds segmentation.

Previous:

  • support SemanticKITTI dataset now. (Welcome everyone to develop together and raise PR)
  • We place our pretrain-model in pretrain_model/checkpoint.tar directory.

Performance

Results on Validation Set (seq 08)

  • Compare with original implementation
Model mIoU
Original Tensorflow 0.531
Our Pytorch Implementation 0.551
  • Per class mIoU
mIoU car bicycle motorcycle truck other-vehicle person bicyclist motorcyclist road parking sidewalk other-ground building fence vegetation trunk terrain pole traffic-sign
55.1 0.939 0.092 0.347 0.659 0.453 0.548 0.707 0.000 0.920 0.401 0.784 0.006 0.886 0.520 0.855 0.627 0.747 0.568 0.403

A. Environment Setup

  1. Click this webpage and use conda to install pytorch>=1.4 (Be aware of the cuda version when installation)

  2. Install python packages

pip install -r requirements.txt
  1. Compile C++ Wrappers
sh compile_op.sh

B. Prepare Data

Download the Semantic KITTI dataset, and preprocess the data:

python data_prepare_semantickitti.py --src_path path/to/sequences --dst_path destination/for/preprocessed/sequences

Note:

  • Please change the dataset path in the data_prepare_semantickitti.py with your own path.
  • Data preprocessing code will convert the label to 0-19 index

C. Training & Testing

  1. Training
python3 train_SemanticKITTI.py <args>
Options:
--backbone           select the backbone to be used: choices=['randla', 'baflac', 'baaf']
--checkpoint_path    path to pretrained models(if any), otherwise train from start
--log_dir            Name of the log dir, the file will be in logs/ suffixed with start time
--max_epoch          max epoch for the model to run, default 80
--batch_size         training batch size, default 6 (indicated in oroginal implementation), modify to full utilize the GPU/s
--val_batch_size     batch size for validation, default 30
--num_workers        number of workers for I/O
--focal              whether to use focal loss or not, default True
--focal_gamma        gamma for focal loss, default 2

  1. Testing
python3 test_SemanticKITTI.py <args>
Options:
--backbone           select the backbone to be used: choices=['randla', 'baflac', 'baaf']
--infer_type         all: infer all points in specified sequence, sub: subsamples in specified sequence
--checkpoint_path    required. path to the model to test
--test_id            sequence id to test
--index_to_label     whether to convert .npy(label 0-19) back .label(original labels)

Note: if the flag --index_to_label is set, output predictions will be ".label" files (label figure) which can be visualized; Otherwise, they will be ".npy" (0-19 index) files which is used to evaluated afterward.

D. Visualization & Evaluation

  1. Visualization
python3 visualize_SemanticKITTI.py <args>

options:

--dataset path to dataset for visualization

--config dataset config, default utils/semantic-kitti.yaml

--sequence sequence to visualize

--predictions location for predictions
  1. Evaluation
  • Example Evaluation code
python3 evaluate_SemanticKITTI.py --dataset /tmp2/tsunghan/PCL_Seg_data/sequences_0.06/ \
    --predictions runs/supervised/predictions/ --sequences 8

Acknowledgement

About

Improved pytorch implementation of RandLA (https://arxiv.org/abs/1911.11236) with easier transferability and reproductibility

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • C++ 52.3%
  • Python 45.9%
  • Cython 1.4%
  • Other 0.4%