FuseNet implementation in PyTorch
This is the PyTorch implementation for FuseNet, developed based on Pix2Pix code.
- Linux
- Python 3.7.0
- CPU or NVIDIA GPU + CUDA CuDNN
- Install PyTorch 0.4.1.post2 and dependencies from http://pytorch.org
- Clone this repo:
git clone https://github.com/MehmetAygun/fusenet-pytorch
cd fusenet-pytorch
pip install -r requirements.txt
- Download and untar the preprocessed sunrgbd dataset under
/datasets/sunrgbd
- Download the dataset and create the training set
cd datasets
sh download_nyuv2.sh
python create_training_set.py
- Download the
scannet_frames_25k
andscannet_frames_test
under/datasets/scannet/tasks/
- To view training errors and loss plots, set
--display_id 1
, runpython -m visdom.server
and click the URL http://localhost:8097 - Checkpoints are saved under
./checkpoints/sunrgbd/
python train.py --dataroot datasets/sunrgbd --dataset sunrgbd --name sunrgbd
python test.py --dataroot datasets/sunrgbd --dataset sunrgbd --name sunrgbd --epoch 400
python train.py --dataroot datasets/nyuv2 --dataset nyuv2 --name nyuv2
python test.py --dataroot datasets/nyuv2 --dataset nyuv2 --name nyuv2 --epoch 400
python train.py --dataroot datasets/scannet/tasks/scannet_frames_25k --dataset scannetv2 \
--name scannetv2
python test.py --dataroot datasets/scannet/tasks/scannet_frames_25k --dataset scannetv2 \
--name scannetv2 --epoch 380 --phase val
python test.py --dataroot datasets/scannet/tasks/scannet_frames_test --dataset scannetv2 \
--name scannetv2 --epoch 380 --phase test
- We use the training scheme defined in FuseNet
- Loss is weighted for SUNRGBD dataset
- Learning rate is set to 0.01 for NYUv2 dataset
- Results can be improved with a hyper-parameter search
- Results on the scannetv2-test (w/o class-weighted loss) can be found here
Dataset | FuseNet-SF5 (CAFFE) | FuseNet-SF5 | ||||
overall | mean | iou | overall | mean | iou | |
sunrgbd | 76.30 | 48.30 | 37.30 | 75.41 | 46.48 | 35.69 |
nyuv2 | 66.00 | 43.40 | 32.70 | 68.76 | 46.42 | 35.48 |
scannetv2-val | -- | -- | -- | 76.32 | 55.84 | 44.12 |
scannetv2-cls_weighted-val | -- | -- | -- | 76.26 | 55.74 | 44.40 |
scannetv2-test | avg iou | bathtub | bed | bookshelf | cabinet | chair | counter | curtain | desk | door | floor | other furniture | picture | refrigerator | shower curtain | sink | sofa | table | toilet | wall | window |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
no-cls_weighted | 52.1 | 59.1 | 68.2 | 22.0 | 48.8 | 27.9 | 34.4 | 61.0 | 46.1 | 47.5 | 91.0 | 29.3 | 44.7 | 51.2 | 39.7 | 61.8 | 56.7 | 45.2 | 73.4 | 78.2 | 56.6 |
cls_weighted | 53.5 | 57.0 | 68.1 | 18.2 | 51.2 | 29.0 | 43.1 | 65.9 | 50.4 | 49.5 | 90.3 | 30.8 | 42.8 | 52.3 | 36.5 | 67.6 | 62.1 | 47.0 | 76.2 | 77.9 | 54.1 |
@inproceedings{hazirbas16fusenet,
Title = {{FuseNet}: Incorporating Depth into Semantic Segmentation via Fusion-Based CNN Architecture},
Author = {Hazirbas, Caner and Ma, Lingni and Domokos, Csaba and Cremers, Daniel},
Booktitle = {Asian Conference on Computer Vision ({ACCV})},
Year = {2016},
Doi = {10.1007/978-3-319-54181-5_14},
Url = {https://github.com/tum-vision/fusenet}
}
Code is inspired by pytorch-CycleGAN-and-pix2pix.