diff --git a/README.md b/README.md index 0c9bed8..bf0e525 100644 --- a/README.md +++ b/README.md @@ -1,244 +1,337 @@ -# DeepLab with PyTorch - -This is an unofficial **PyTorch** implementation to train **DeepLab v2** model (ResNet backbone) [[1](##references)] on **COCO-Stuff** dataset [[2](##references)]. -DeepLab is one of the CNN architectures for semantic image segmentation. -COCO-Stuff is a semantic segmentation dataset, which includes 164k images annotated with 171 thing/stuff classes (+ unlabeled). - -This repository aims to reproduce the official score of DeepLab v2 on COCO-Stuff datasets. -The model can be trained both on the latest [COCO-Stuff 164k](https://github.com/nightrome/cocostuff) and [COCO-Stuff 10k](https://github.com/nightrome/cocostuff10k), *without building the official implementation in Caffe*. -[Pre-trained models are provided](#models----omit-in-toc). -ResNet-based DeepLab v3/v3+ are also included, although they are not tested. -```torch.hub``` is supported. - -- [Setup](#setup) -- [Training](#training) -- [Evaluation](#evaluation) -- [Performance](#performance) -- [Demo](#demo) -- [Misc](#misc) -- [References](#references) +# DeepLab with PyTorch -## Setup +This is an unofficial **PyTorch** implementation of **DeepLab v2** [[1](##references)] with a **ResNet** backbone. **COCO-Stuff** dataset [[2](##references)] and **PASCAL VOC** dataset [[3]()] are supported. The initial weights (`.caffemodel`) officially provided by the authors are can be converted/used without building the Caffe API. DeepLab v3/v3+ models with the identical backbone are also included (although not tested). [```torch.hub``` is supported](#torchhub). -### Requirements +## Performance -For anaconda users: +Pretrained models are provided. Note that the 2D interpolation ways are different from the original, which leads to a bit better results. + +### COCO-Stuff + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Train setEval setCRF?CodePixel
Accuracy
Mean
Accuracy
Mean IoUFreqW IoU
+ 10k train
+ (Model) +
10k valOriginal [2]65.145.534.450.4
Ours65.845.734.851.2
Ours67.146.435.652.5
+ 164k train
+ (Model) +
10k valOurs68.455.644.255.1
Ours69.255.945.055.9
164k valOurs66.851.239.151.5
Ours67.651.539.752.3
+ +† Images and labels are pre-warped to square-shape 513x513 + +### PASCAL VOC 2012 + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Train setEval setCRF?CodePixel
Accuracy
Mean
Accuracy
Mean IoUFreqW IoU
+ trainaug
+ (Model) +
valOriginal [3]--76.35-
Ours94.6486.5076.6590.41
Original [3]--77.69-
Ours95.0486.6477.9391.06
-```sh -# Please modify CUDA option according to your environment -conda env create --file config/conda_env.yaml -``` +## Setup -* python 2.7+/3.6+ -* [pytorch](https://pytorch.org/) 0.4.1+ -* [torchvision](https://pytorch.org/) -* [torchnet](https://github.com/pytorch/tnt) -* [pydensecrf](https://github.com/lucasb-eyer/pydensecrf) -* [tensorflow](https://www.tensorflow.org/install/) (for tensorboard) -* [tensorboardX](https://github.com/lanpa/tensorboard-pytorch) 1.0+ -* opencv 3.0.0+ -* tqdm -* click -* addict -* h5py -* scipy -* matplotlib -* yaml -* joblib - -### Datasets - -COCO-Stuff 164k is the latest version and recommended. - -
-COCO-Stuff 164k (click to show the structure) -
-├── images
-│   ├── train2017
-│   │   ├── 000000000009.jpg
-│   │   └── ...
-│   └── val2017
-│       ├── 000000000139.jpg
-│       └── ...
-└── annotations
-    ├── train2017
-    │   ├── 000000000009.png
-    │   └── ...
-    └── val2017
-        ├── 000000000139.png
-        └── ...
-
-
-
- -1. Run the script below to download the dataset (20GB+). +### Requirements -```sh -./scripts/setup_cocostuff164k.sh -``` +* Python 2.7+/3.6+ +* Anaconda environement -2. Set the path to the dataset in ```config/cocostuff164k.yaml```. +Then setup from `conda_env.yaml`. Please modify cuda option as needed (default: `cudatoolkit=10.0`) -```yaml -DATASET: cocostuff164k -ROOT: # <- Write here -... +```console +$ conda env create -f configs/conda_env.yaml +$ conda activate deeplab-pytorch ``` -
-COCO-Stuff 10k (click to show the structure) -
-├── images
-│   ├── COCO_train2014_000000000077.jpg
-│   └── ...
-├── annotations
-│   ├── COCO_train2014_000000000077.mat
-│   └── ...
-└── imageLists
-    ├── all.txt
-    ├── test.txt
-    └── train.txt
-
-
-
- -1. Run the script below to download the dataset (2GB). - -```sh -./scripts/setup_cocostuff10k.sh -``` +### Datasets -2. Set the path to the dataset in ```config/cocostuff10k.yaml```. +Setup instruction is provided in each link. -```yaml -DATASET: cocostuff10k -ROOT: # <- Write here -... -``` +* [COCO-Stuff 10k/164k](data/datasets/cocostuff/README.md) +* [PASCAL VOC 2012](data/datasets/voc12/README.md) -### Initial parameters +### Initial weights 1. Run the script below to download caffemodel pre-trained on ImageNet and 91-class COCO (1GB+). ```sh -./scripts/setup_caffemodels.sh +$ bash scripts/setup_caffemodels.sh ``` -2. Convert the caffemodel to pytorch compatible. No need to build the official DeepLab! +2. Convert the caffemodel to pytorch compatible. No need to build the official Caffe API! ```sh -# This generates deeplabv2_resnet101_COCO_init.pth from init.caffemodel -python convert.py --dataset init +# This generates "deeplabv2_resnet101_COCO_init.pth" from "init.caffemodel" +$ python convert.py --dataset coco ``` -You can also convert an included ```train2_iter_20000.caffemodel``` for PASCAL VOC 2012 dataset. See [here](config/README.md#voc12yaml). ## Training -Training, evaluation, and some demos are all through the [```.yaml``` configuration files](config/README.md). +Please see [```./scripts/train_eval.sh```](scripts/train_eval.sh) for example usage. -```sh -# Train DeepLab v2 on COCO-Stuff 164k -python main.py train --config config/cocostuff164k.yaml ``` +Usage: main.py train [OPTIONS] -```sh -# Monitor a cross-entropy loss -tensorboard --logdir runs -``` + Training DeepLab by v2 protocol -Default settings: +Options: + -c, --config-path FILENAME Dataset configuration file in YAML [required] + --cuda / --cpu Enable CUDA if available [default: --cuda] + --help Show this message and exit. +``` -- All the GPUs visible to the process are used. Please specify the scope with ```CUDA_VISIBLE_DEVICES=```. -- Stochastic gradient descent (SGD) is used with momentum of 0.9 and initial learning rate of 2.5e-4. Polynomial learning rate decay is employed; the learning rate is multiplied by ```(1-iter/iter_max)**power``` at every 10 iterations. -- Weights are updated 20k iterations for COCO-Stuff 10k and 100k iterations for COCO-Stuff 164k, with a mini-batch of 10. The batch is not processed at once due to high occupancy of video memories, instead, gradients of small batches are aggregated, and weight updating is performed at the end (```batch_size * iter_size = 10```). -- Input images are initially warped to 513x513, randomly re-scaled by factors ranging from 0.5 to 1.5, zero-padded if needed, and randomly cropped to 321x321 so that the input size is fixed during training (see the example below). -- The label indices range from 0 to 181 and the model outputs a 182-dim categorical distribution, but only [171 classes](https://github.com/nightrome/cocostuff/blob/master/labels.md) are supervised with COCO-Stuff. -- Loss is defined as a sum of responses from multi-scale inputs (1x, 0.75x, 0.5x) and element-wise max across the scales. The "unlabeled" class (index -1) is ignored in the loss computation. -- Moving average loss (```average_loss``` in Caffe) can be monitored in TensorBoard. -- GPU memory usage is approx. 11.2 GB with the default setting (tested on the single Titan X). You can reduce it with a small ```batch_size```. +To monitor a loss, lr values, and gpu usage: -Processed image vs. label examples: +```sh +$ tensorboard --logdir data/logs +``` -![Data](docs/data.png) +Common settings: + +- **Model**: DeepLab v2 with ResNet-101 backbone. Dilated rates of ASPP are (6, 12, 18, 24). Output stride is 8. +- **Multi-GPU**: All the GPUs visible to the process are used. Please specify the scope with +```CUDA_VISIBLE_DEVICES=```. +- **Multi-scale loss**: Loss is defined as a sum of responses from multi-scale inputs (1x, 0.75x, 0.5x) and +element-wise max across the scales. The *unlabeled* class is ignored in the loss computation. +- **Gradient accumulation**: The mini-batch of 10 samples is not processed at once due to the high occupancy of GPU +memories. Instead, gradients of small batches of 5 samples are accumulated for 2 iterations, and weight updating is +performed at the end (```batch_size * iter_size = 10```). GPU memory usage is approx. 11.2 GB with the default setting +(tested on the single Titan X). You can reduce it with a small ```batch_size```. +- **Learning rate**: Stochastic gradient descent (SGD) is used with momentum of 0.9 and initial learning rate of +2.5e-4. Polynomial learning rate decay is employed; the learning rate is multiplied by ```(1-iter/iter_max)**power``` +at every 10 iterations. +- **Monitoring**: Moving average loss (```average_loss``` in Caffe) can be monitored in TensorBoard. + +[COCO-Stuff 164k](config/cocostuff164k.yaml): +- **#Iterations**: Updated 100k iterations. +- **#Classes**: The label indices range from 0 to 181 and the model outputs a 182-dim categorical distribution, but +only [171 classes](https://github.com/nightrome/cocostuff/blob/master/labels.md) are supervised with COCO-Stuff. Index 255 is an unlabeled class to be ignored. +- **Preprocessing**: (1) Input images are randomly re-scaled by factors ranging +from 0.5 to 1.5, (2) padded if needed, and (3) randomly cropped to 321x321. + +[COCO-Stuff 10k](config/cocostuff10k.yaml): +- **#Iterations**: Updated 20k iterations. +- **#Classes**: Same as the 164k version above. +- **Preprocessing**: (1) Input images are initially warped to 513x513 squares, (2) randomly re-scaled by factors ranging from +0.5 to 1.5, (3) padded if needed, and (4) randomly cropped to 321x321 so that the input size is fixed during training. + +[PASCAL VOC 2012](config/voc12.yaml): +- **#Iterations**: Updated 20k iterations. +- **#Classes**: 20 foreground objects + 1 background. Index 255 is an unlabeled class to be ignored. +- **Preprocessing**: (1) Input images are randomly re-scaled by factors ranging from 0.5 to 1.5, (2) padded if needed, and (3) randomly cropped +to 321x321. + +Processed image vs. label examples in COCO-Stuff: + +![Data](docs/datasets/cocostuff.png) ## Evaluation -```sh -# Evaluate the final model on COCO-Stuff 164k validation set -python main.py test --config config/cocostuff164k.yaml \ - --model-path data/models/deeplab_resnet101/cocostuff164k/checkpoint_final.pth +To compute scores in: +* Pixel accuracy +* Mean accuracy +* Mean IoU +* Frequency weighted IoU + ``` +Usage: main.py test [OPTIONS] -You can run CRF post-processing with a option ```--crf```. See ```--help``` for more details. + Evaluation on validation set -## Performance +Options: + -c, --config-path FILENAME Dataset configuration file in YAML [required] + -m, --model-path PATH PyTorch model to be loaded [required] + --cuda / --cpu Enable CUDA if available [default: --cuda] + --help Show this message and exit. +``` -### Validation scores +To perform CRF post-processing: -|   | Train set | Eval set | CRF? | Pixel
Accuracy | Mean
Accuracy | Mean IoU | FreqW IoU | -| :---------------------------------------------------------------- | :-----------: | :---------: | :----: | :---------------: | :--------------: | :--------: | :--------: | -| [**Official (Caffe)**](https://github.com/nightrome/cocostuff10k) | **10k train** | **10k val** | **No** | **65.1 %** | **45.5 %** | **34.4 %** | **50.4 %** | -| **This repo** | **10k train** | **10k val** | **No** | **65.3 %** | **45.3 %** | **34.4 %** | **50.5 %** | -| This repo | 10k train | 10k val | Yes | 66.7 % | 45.9 % | 35.5 % | 51.9 % | -| This repo | 164k train | 10k val | No | 67.6 % | 54.9 % | 43.2 % | 53.9 % | -| This repo | 164k train | 10k val | Yes | 68.7 % | 55.3 % | 44.4 % | 55.1 % | -| This repo | 164k train | 164k val | No | 65.7 % | 49.7 % | 37.6 % | 50.0 % | -| This repo | 164k train | 164k val | Yes | 66.8 % | 50.1 % | 38.5 % | 51.1 % | +``` +Usage: main.py crf [OPTIONS] -### Models + CRF post-processing on pre-computed logits -* [Models trained on COCO-Stuff 10k/164k (*.pth)](https://drive.google.com/drive/folders/1m3wyXvvWy-IvGmdFS_dsQCRXhFNhek8_?usp=sharing) -* [Scores (*.json)](https://drive.google.com/drive/folders/1PouglnlwsyHTwdSo_d55WgMgdnxbxmE6?usp=sharing) +Options: + -c, --config-path FILENAME Dataset configuration file in YAML [required] + -j, --n-jobs INTEGER Number of parallel jobs [default: # cores] + --help Show this message and exit. +``` ## Demo -### From an image +| COCO-Stuff 164k | COCO-Stuff 10k | PASCAL VOC 2012 | Pretrained COCO | +| :----------------------------------: | :---------------------------------: | :--------------------------: | :------------------------------: | +| ![](docs/examples/cocostuff164k.png) | ![](docs/examples/cocostuff10k.png) | ![](docs/examples/voc12.png) | ![](docs/examples/coco_init.png) | -```bash -python demo.py single --config config/cocostuff164k.yaml \ - --model-path \ - --image-path \ - --crf +### Single image + +``` +Usage: demo.py single [OPTIONS] + + Inference from a single image + +Options: + -c, --config-path FILENAME Dataset configuration file in YAML [required] + -m, --model-path PATH PyTorch model to be loaded [required] + -i, --image-path PATH Image to be processed [required] + --cuda / --cpu Enable CUDA if available [default: --cuda] + --crf CRF post-processing [default: False] + --help Show this message and exit. ``` -### From a webcam +### Webcam A class of mouseovered pixel is shown in terminal. -```bash -python demo.py live --config config/cocostuff164k.yaml \ - --model-path \ - --camera-id \ - --crf +``` +Usage: demo.py live [OPTIONS] + + Inference from camera stream + +Options: + -c, --config-path FILENAME Dataset configuration file in YAML [required] + -m, --model-path PATH PyTorch model to be loaded [required] + --cuda / --cpu Enable CUDA if available [default: --cuda] + --crf CRF post-processing [default: False] + --camera-id INTEGER Device ID [default: 0] + --help Show this message and exit. ``` -### torch.hub +### torch.hub + Model setup with 3 lines. ```python import torch.hub - model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=182) -model.load_state_dict(torch.load("cocostuff164k_iter100k.pth")) +model.load_state_dict(torch.load("deeplabv2_resnet101_msc-cocostuff164k-100000.pth")) ``` ## Misc -### Image processing +### Difference with Caffe version -Default setting warps an image and a label to square-shape as the official code does. -To preserve aspect ratio in the image preprocessing, please modify ```.yaml``` as follows: +* While the official code employs 1/16 bilinear interpolation (```Interp``` layer) for downsampling a label for only 0.5x input, this codebase does for both 0.5x and 0.75x inputs with nearest interpolation (```PIL.Image.resize```, [related issue](https://github.com/kazuto1011/deeplab-pytorch/issues/51)). +* Bilinear interpolation on images and logits is performed with the ```align_corners=False```. -```yaml -BATCH_SIZE: - TEST: 1 -WARP_IMAGE: False -``` +### Training batch normalization -### Training batch normalization -This codebase only supports DeepLab v2 training which freezes batch normalization layers, although v3/v3+ protocols require training them. If training their parameters as well in your projects, please install [the extra library](https://hangzhang.org/PyTorch-Encoding/) below. +This codebase only supports DeepLab v2 training which freezes batch normalization layers, although +v3/v3+ protocols require training them. If training their parameters on multiple GPUs as well in your projects, please +install [the extra library](https://hangzhang.org/PyTorch-Encoding/) below. ```bash pip install torch-encoding @@ -256,8 +349,17 @@ except: ## References -1. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A. L. Yuille. DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs. *IEEE TPAMI*, 2018.
-[Project](http://liangchiehchen.com/projects/DeepLab.html) / [Code](https://bitbucket.org/aquariusjay/deeplab-public-ver2) / [arXiv paper](https://arxiv.org/abs/1606.00915) +1. L.-C. Chen, G. Papandreou, I. Kokkinos, K. Murphy, A. L. Yuille. DeepLab: Semantic Image +Segmentation with Deep Convolutional Nets, Atrous Convolution, and Fully Connected CRFs. *IEEE TPAMI*, +2018.
+[Project](http://liangchiehchen.com/projects/DeepLab.html) / +[Code](https://bitbucket.org/aquariusjay/deeplab-public-ver2) / [arXiv +paper](https://arxiv.org/abs/1606.00915) 2. H. Caesar, J. Uijlings, V. Ferrari. COCO-Stuff: Thing and Stuff Classes in Context. In *CVPR*, 2018.
-[Project](https://github.com/nightrome/cocostuff) / [Code](https://github.com/nightrome/cocostuff) / [arXiv paper](https://arxiv.org/abs/1612.03716) \ No newline at end of file +[Project](https://github.com/nightrome/cocostuff) / [arXiv paper](https://arxiv.org/abs/1612.03716) + +1. M. Everingham, L. Van Gool, C. K. I. Williams, J. Winn, A. Zisserman. The PASCAL Visual Object +Classes (VOC) Challenge. *IJCV*, 2010.
+[Project](http://host.robots.ox.ac.uk/pascal/VOC) / +[Paper](http://host.robots.ox.ac.uk/pascal/VOC/pubs/everingham10.pdf) \ No newline at end of file diff --git a/config/README.md b/config/README.md deleted file mode 100644 index 59f2ef3..0000000 --- a/config/README.md +++ /dev/null @@ -1,11 +0,0 @@ -## ```cocostuff10k.yaml``` - -COCO-Stuff 10k. See the main [```README.md```](../README.md#default-settings). - -## ```cocostuff164k.yaml``` - -COCO-Stuff 164k. See the main [```README.md```](../README.md#default-settings). - -## ```voc12.yaml``` - -PASCAL VOC2012. If you want to try the official voc12 caffemodel, please convert with the ```convert.py --dataset voc12``` and use this configuration file. \ No newline at end of file diff --git a/config/conda_env.yaml b/config/conda_env.yaml deleted file mode 100644 index 826a12f..0000000 --- a/config/conda_env.yaml +++ /dev/null @@ -1,25 +0,0 @@ -name: deeplab-pytorch -channels: - - pytorch - - conda-forge - - defaults -dependencies: - - pytorch - - torchvision - # - cudatoolkit=10.0 - - h5py - - scipy - - matplotlib - - pyyaml - - click - - tqdm - # - clang - # - clangxx - - pydensecrf - - pip: - - torchnet==0.0.2 - - opencv-python==3.4.1.15 - - tensorboardX==1.2 - - tensorflow==1.9.0 - - addict - - joblib \ No newline at end of file diff --git a/configs/coco.yaml b/configs/coco.yaml new file mode 100644 index 0000000..b358bd6 --- /dev/null +++ b/configs/coco.yaml @@ -0,0 +1,58 @@ +EXP: + ID: coco + OUTPUT_DIR: data + +DATASET: + NAME: coco + ROOT: + LABELS: ./data/datasets/coco/labels.txt + N_CLASSES: 91 + IGNORE_LABEL: + SCALES: + SPLIT: + TRAIN: + VAL: + TEST: + +DATALOADER: + NUM_WORKERS: 0 + +IMAGE: + MEAN: + R: 122.675 + G: 116.669 + B: 104.008 + SIZE: + BASE: + TRAIN: + TEST: 513 + +MODEL: + NAME: DeepLabV1_ResNet101 + N_BLOCKS: [3, 4, 23, 3] + ATROUS_RATES: + INIT_MODEL: + +SOLVER: + BATCH_SIZE: + TRAIN: 5 + TEST: 1 + ITER_MAX: 100000 + ITER_SIZE: 2 + ITER_SAVE: 5000 + ITER_TB: 100 + LR_DECAY: 10 + LR: 2.5e-4 + MOMENTUM: 0.9 + OPTIMIZER: sgd + POLY_POWER: 0.9 + WEIGHT_DECAY: 5.0e-4 + AVERAGE_LOSS: 20 + +CRF: + ITER_MAX: 10 + POS_W: 3 + POS_XY_STD: 1 + BI_W: 4 + BI_XY_STD: 67 + BI_RGB_STD: 3 diff --git a/config/cocostuff10k.yaml b/configs/cocostuff10k.yaml similarity index 67% rename from config/cocostuff10k.yaml rename to configs/cocostuff10k.yaml index c6686e6..d798908 100644 --- a/config/cocostuff10k.yaml +++ b/configs/cocostuff10k.yaml @@ -1,10 +1,13 @@ +EXP: + ID: cocostuff10k + OUTPUT_DIR: data + DATASET: NAME: cocostuff10k ROOT: /media/kazuto1011/Extra/cocostuff/cocostuff-10k-v1.1 - LABELS: ./data/datasets/cocostuff/labels_2.txt - N_CLASSES: 182 - IGNORE_LABEL: -1 - WARP_IMAGE: True + LABELS: ./data/datasets/cocostuff/labels.txt + N_CLASSES: 182 + IGNORE_LABEL: 255 SCALES: [0.5, 0.75, 1.0, 1.25, 1.5] SPLIT: TRAIN: train @@ -20,17 +23,15 @@ IMAGE: G: 116.669 B: 104.008 SIZE: - TRAIN: - BASE: 513 - CROP: 321 + BASE: + TRAIN: 321 TEST: 513 MODEL: NAME: DeepLabV2_ResNet101_MSC N_BLOCKS: [3, 4, 23, 3] ATROUS_RATES: [6, 12, 18, 24] - INIT_MODEL: ./data/models/deeplab_resnet101/coco_init/deeplabv2_resnet101_COCO_init.pth - SAVE_DIR: ./data/models/deeplab_resnet101/cocostuff10k + INIT_MODEL: data/models/coco/deeplabv1_resnet101/caffemodel/deeplabv1_resnet101-coco.pth SOLVER: BATCH_SIZE: @@ -46,7 +47,6 @@ SOLVER: OPTIMIZER: sgd POLY_POWER: 0.9 WEIGHT_DECAY: 5.0e-4 - LOG_DIR: runs/cocostuff10k AVERAGE_LOSS: 20 CRF: @@ -55,4 +55,4 @@ CRF: POS_XY_STD: 1 BI_W: 4 BI_XY_STD: 67 - BI_RGB_STD: 3 \ No newline at end of file + BI_RGB_STD: 3 diff --git a/config/cocostuff164k.yaml b/configs/cocostuff164k.yaml similarity index 70% rename from config/cocostuff164k.yaml rename to configs/cocostuff164k.yaml index 7e4a788..7e6fb45 100644 --- a/config/cocostuff164k.yaml +++ b/configs/cocostuff164k.yaml @@ -1,10 +1,13 @@ +EXP: + ID: cocostuff164k + OUTPUT_DIR: data + DATASET: NAME: cocostuff164k ROOT: /media/kazuto1011/Extra/cocostuff/cocostuff-164k - LABELS: ./data/datasets/cocostuff/labels_2.txt + LABELS: ./data/datasets/cocostuff/labels.txt N_CLASSES: 182 IGNORE_LABEL: 255 - WARP_IMAGE: True SCALES: [0.5, 0.75, 1.0, 1.25, 1.5] SPLIT: TRAIN: train2017 @@ -20,22 +23,20 @@ IMAGE: G: 116.669 B: 104.008 SIZE: - TRAIN: - BASE: 513 - CROP: 321 + BASE: # None + TRAIN: 321 TEST: 513 MODEL: NAME: DeepLabV2_ResNet101_MSC N_BLOCKS: [3, 4, 23, 3] ATROUS_RATES: [6, 12, 18, 24] - INIT_MODEL: ./data/models/deeplab_resnet101/coco_init/deeplabv2_resnet101_COCO_init.pth - SAVE_DIR: ./data/models/deeplab_resnet101/cocostuff164k + INIT_MODEL: data/models/coco/deeplabv1_resnet101/caffemodel/deeplabv1_resnet101-coco.pth SOLVER: BATCH_SIZE: TRAIN: 5 - TEST: 5 + TEST: 1 ITER_MAX: 100000 ITER_SIZE: 2 ITER_SAVE: 5000 @@ -46,7 +47,6 @@ SOLVER: OPTIMIZER: sgd POLY_POWER: 0.9 WEIGHT_DECAY: 5.0e-4 - LOG_DIR: runs/cocostuff164k AVERAGE_LOSS: 20 CRF: @@ -55,4 +55,4 @@ CRF: POS_XY_STD: 1 BI_W: 4 BI_XY_STD: 67 - BI_RGB_STD: 3 \ No newline at end of file + BI_RGB_STD: 3 diff --git a/configs/conda_env.yaml b/configs/conda_env.yaml new file mode 100644 index 0000000..fe23345 --- /dev/null +++ b/configs/conda_env.yaml @@ -0,0 +1,26 @@ +name: deeplab-pytorch +channels: + - pytorch + - conda-forge + - defaults +dependencies: + # - clang # For MacOS + # - clangxx # For MacOS + - click + - cudatoolkit=10.0 + - cython + - matplotlib + - pytorch + - pyyaml + - scipy + - torchvision + - tqdm + - pip: + - addict + - black + - joblib + - opencv-python + - pydensecrf + - tensorboardX + - tensorflow + - torchnet diff --git a/config/voc12.yaml b/configs/voc12.yaml similarity index 67% rename from config/voc12.yaml rename to configs/voc12.yaml index 310cd7f..8dcef1b 100644 --- a/config/voc12.yaml +++ b/configs/voc12.yaml @@ -1,15 +1,18 @@ +EXP: + ID: voc12 + OUTPUT_DIR: data + DATASET: - NAME: voc12 - ROOT: + NAME: vocaug + ROOT: /media/kazuto1011/Extra/VOCdevkit LABELS: ./data/datasets/voc12/labels.txt N_CLASSES: 21 IGNORE_LABEL: 255 - WARP_IMAGE: False SCALES: [0.5, 0.75, 1.0, 1.25, 1.5] SPLIT: - TRAIN: - VAL: - TEST: + TRAIN: train_aug + VAL: val + TEST: test DATALOADER: NUM_WORKERS: 0 @@ -20,22 +23,20 @@ IMAGE: G: 116.669 B: 104.008 SIZE: - TRAIN: - BASE: 513 - CROP: 321 + BASE: # None + TRAIN: 321 TEST: 513 MODEL: NAME: DeepLabV2_ResNet101_MSC N_BLOCKS: [3, 4, 23, 3] ATROUS_RATES: [6, 12, 18, 24] - INIT_MODEL: ./data/models/deeplab_resnet101/coco_init/deeplabv2_resnet101_COCO_init.pth - SAVE_DIR: ./data/models/deeplab_resnet101/voc12 + INIT_MODEL: data/models/coco/deeplabv1_resnet101/caffemodel/deeplabv1_resnet101-coco.pth SOLVER: BATCH_SIZE: TRAIN: 5 - TEST: 5 + TEST: 1 ITER_MAX: 20000 ITER_SIZE: 2 ITER_SAVE: 5000 @@ -46,7 +47,6 @@ SOLVER: OPTIMIZER: sgd POLY_POWER: 0.9 WEIGHT_DECAY: 5.0e-4 - LOG_DIR: runs/voc12 AVERAGE_LOSS: 20 CRF: @@ -55,4 +55,4 @@ CRF: POS_XY_STD: 1 BI_W: 4 BI_XY_STD: 67 - BI_RGB_STD: 3 \ No newline at end of file + BI_RGB_STD: 3 diff --git a/convert.py b/convert.py index 31ae01c..7812fec 100644 --- a/convert.py +++ b/convert.py @@ -8,14 +8,16 @@ from __future__ import absolute_import, division, print_function import re -from collections import OrderedDict +import traceback +from collections import Counter, OrderedDict import click import numpy as np import torch +from addict import Dict from libs import caffe_pb2 -from libs.models import DeepLabV2_ResNet101_MSC +from libs.models import DeepLabV1_ResNet101, DeepLabV2_ResNet101_MSC def parse_caffemodel(model_path): @@ -25,18 +27,24 @@ def parse_caffemodel(model_path): # Check trainable layers print( - *set([(layer.type, len(layer.blobs)) for layer in caffemodel.layer]), sep="\n" + *Counter( + [(layer.type, len(layer.blobs)) for layer in caffemodel.layer] + ).most_common(), + sep="\n", ) params = OrderedDict() previous_layer_type = None for layer in caffemodel.layer: - print("{} ({}): {}".format(layer.name, layer.type, len(layer.blobs))) - # Skip the shared branch if "res075" in layer.name or "res05" in layer.name: continue + print( + "\033[34m[Caffe]\033[00m", + "{} ({}): {}".format(layer.name, layer.type, len(layer.blobs)), + ) + # Convolution or Dilated Convolution if "Convolution" in layer.type: params[layer.name] = {} @@ -58,6 +66,12 @@ def parse_caffemodel(model_path): params[layer.name]["dilation"] = layer.convolution_param.dilation[0] else: params[layer.name]["dilation"] = 1 + # Fully-connected + elif "InnerProduct" in layer.type: + params[layer.name] = {} + params[layer.name]["weight"] = list(layer.blobs[0].data) + if len(layer.blobs) == 2: + params[layer.name]["bias"] = list(layer.blobs[1].data) # Batch Normalization elif "BatchNorm" in layer.type: params[layer.name] = {} @@ -71,12 +85,18 @@ def parse_caffemodel(model_path): params[layer.name][ "momentum" ] = layer.batch_norm_param.moving_average_fraction + params[layer.name]["num_batches_tracked"] = np.array(0) batch_norm_layer = layer.name # Scale elif "Scale" in layer.type: assert previous_layer_type == "BatchNorm" params[batch_norm_layer]["weight"] = list(layer.blobs[0].data) params[batch_norm_layer]["bias"] = list(layer.blobs[1].data) + elif "Pooling" in layer.type: + params[layer.name] = {} + params[layer.name]["kernel_size"] = layer.pooling_param.kernel_size + params[layer.name]["stride"] = layer.pooling_param.stride + params[layer.name]["padding"] = layer.pooling_param.pad previous_layer_type = layer.type @@ -84,9 +104,9 @@ def parse_caffemodel(model_path): # Hard coded translater -def translate_layer_name(source): +def translate_layer_name(source, target="base"): def layer_block_branch(source, target): - target += ".layer{}".format(source[0][0]) + target += "layer{}".format(source[0][0]) if len(source[0][1:]) == 1: block = {"a": 1, "b": 2, "c": 3}.get(source[0][1:]) else: @@ -104,12 +124,19 @@ def layer_block_branch(source, target): return target source = source.split("_") - target = "base" - if "conv1" in source[0]: - target += ".layer1.conv1.conv" + if "pool" in source[0]: + target += "layer1.pool" + elif "fc" in source[0]: + if len(source) == 3: + stage = source[2] + target += "aspp.{}".format(stage) + else: + target += "fc" + elif "conv1" in source[0]: + target += "layer1.conv1.conv" elif "conv1" in source[1]: - target += ".layer1.conv1.bn" + target += "layer1.conv1.bn" elif "res" in source[0]: source[0] = source[0].replace("res", "") target = layer_block_branch(source, target) @@ -118,92 +145,108 @@ def layer_block_branch(source, target): source[0] = source[0].replace("bn", "") target = layer_block_branch(source, target) target += ".bn" - elif "fc" in source[0]: - # Skip if coco_init - if len(source) == 3: - stage = source[2] - target += ".aspp.{}".format(stage) return target @click.command() -@click.option("-d", "--dataset", required=True, type=click.Choice(["voc12", "init"])) +@click.option( + "-d", + "--dataset", + type=click.Choice(["voc12", "coco"]), + required=True, + help="Caffemodel", +) def main(dataset): + """ + Convert caffemodels to pytorch models + """ + WHITELIST = ["kernel_size", "stride", "padding", "dilation", "eps", "momentum"] - CONFIG = { - "voc12": { - "path_caffe_model": "data/models/deeplab_resnet101/voc12/train2_iter_20000.caffemodel", - "path_pytorch_model": "data/models/deeplab_resnet101/voc12/deeplabv2_resnet101_VOC2012.pth", - "n_classes": 21, - }, - "init": { - "path_caffe_model": "data/models/deeplab_resnet101/coco_init/init.caffemodel", - "path_pytorch_model": "data/models/deeplab_resnet101/coco_init/deeplabv2_resnet101_COCO_init.pth", - "n_classes": 91, - }, - }.get(dataset) - - params = parse_caffemodel(CONFIG["path_caffe_model"]) - - model = DeepLabV2_ResNet101_MSC(n_classes=CONFIG["n_classes"]) + CONFIG = Dict( + { + "voc12": { + # For loading the provided VOC 2012 caffemodel + "PATH_CAFFE_MODEL": "data/models/voc12/deeplabv2_resnet101_msc/caffemodel/train2_iter_20000.caffemodel", + "PATH_PYTORCH_MODEL": "data/models/voc12/deeplabv2_resnet101_msc/caffemodel/deeplabv2_resnet101_msc-vocaug.pth", + "N_CLASSES": 21, + "MODEL": "DeepLabV2_ResNet101_MSC", + "HEAD": "base.", + }, + "coco": { + # For loading the provided initial weights pre-trained on COCO + "PATH_CAFFE_MODEL": "data/models/coco/deeplabv1_resnet101/caffemodel/init.caffemodel", + "PATH_PYTORCH_MODEL": "data/models/coco/deeplabv1_resnet101/caffemodel/deeplabv1_resnet101-coco.pth", + "N_CLASSES": 91, + "MODEL": "DeepLabV1_ResNet101", + "HEAD": "", + }, + }.get(dataset) + ) + + params = parse_caffemodel(CONFIG.PATH_CAFFE_MODEL) + + model = eval(CONFIG.MODEL)(n_classes=CONFIG.N_CLASSES) model.eval() - own_state = model.state_dict() + reference_state_dict = model.state_dict() rel_tol = 1e-7 - state_dict = OrderedDict() - for layer_name, layer_dict in params.items(): - for param_name, values in layer_dict.items(): - if param_name in WHITELIST and dataset != "coco_init" and dataset != "init": - attribute = translate_layer_name(layer_name) - attribute = eval("model." + attribute + "." + param_name) - if isinstance(attribute, tuple): - assert ( - attribute[0] == values - ), "Inconsistent values: {}@{}, {}@{}".format( - attribute[0], - translate_layer_name(layer_name) + "." + param_name, - values, - layer_name, + converted_state_dict = OrderedDict() + for caffe_layer, caffe_layer_dict in params.items(): + for param_name, caffe_values in caffe_layer_dict.items(): + pytorch_layer = translate_layer_name(caffe_layer, CONFIG.HEAD) + if pytorch_layer: + pytorch_param = pytorch_layer + "." + param_name + + # Parameter check + if param_name in WHITELIST: + pytorch_values = eval("model." + pytorch_param) + if isinstance(pytorch_values, tuple): + assert ( + pytorch_values[0] == caffe_values + ), "Inconsistent values: {} @{} (Caffe), {} @{} (PyTorch)".format( + caffe_values, + caffe_layer + "/" + param_name, + pytorch_values, + pytorch_param, + ) + else: + assert ( + abs(pytorch_values - caffe_values) < rel_tol + ), "Inconsistent values: {} @{} (Caffe), {} @{} (PyTorch)".format( + caffe_values, + caffe_layer + "/" + param_name, + pytorch_values, + pytorch_param, + ) + print( + "\033[34m[Passed!]\033[00m", + (caffe_layer + "/" + param_name).ljust(35), + "->", + pytorch_param, + ) + continue + + # Weight conversion + if pytorch_param in reference_state_dict: + caffe_values = torch.tensor(caffe_values) + caffe_values = caffe_values.view_as( + reference_state_dict[pytorch_param] ) - else: - assert ( - abs(attribute - values) < rel_tol - ), "Inconsistent values: {}@{}, {}@{}".format( - attribute, - translate_layer_name(layer_name) + "." + param_name, - values, - layer_name, + converted_state_dict[pytorch_param] = caffe_values + print( + "\033[32m[Copied!]\033[00m", + (caffe_layer + "/" + param_name).ljust(35), + "->", + pytorch_param, ) - print( - layer_name.ljust(20), - "->", - param_name, - attribute, - values, - ": Checked!", - ) - continue - param_name = translate_layer_name(layer_name) + "." + param_name - if param_name in own_state: - values = torch.FloatTensor(values) - values = values.view_as(own_state[param_name]) - state_dict[param_name] = values - print(layer_name.ljust(20), "->", param_name, ": Copied!") - - try: - print("\033[32mVerify the converted model\033[00m") - model.load_state_dict(state_dict) - except: - import traceback - - traceback.print_exc() - print("\033[32mVerify with ignoring ASPP (strict=False)\033[00m") - model.load_state_dict(state_dict, strict=False) - - print("Saving to", CONFIG["path_pytorch_model"]) - torch.save(state_dict, CONFIG["path_pytorch_model"]) + + print("\033[32mVerify the converted model\033[00m") + model.load_state_dict(converted_state_dict) + + print('Saving to "{}"'.format(CONFIG.PATH_PYTORCH_MODEL)) + torch.save(converted_state_dict, CONFIG.PATH_PYTORCH_MODEL) if __name__ == "__main__": diff --git a/data/models/.gitignore b/data/.gitignore similarity index 56% rename from data/models/.gitignore rename to data/.gitignore index e679bd5..c73f4b7 100644 --- a/data/models/.gitignore +++ b/data/.gitignore @@ -1,4 +1,6 @@ *.pth *.pth.tar *.caffemodel -*.npy \ No newline at end of file +*.npy +*.prototxt +*.zip \ No newline at end of file diff --git a/data/datasets/coco/labels.txt b/data/datasets/coco/labels.txt new file mode 100644 index 0000000..83604e0 --- /dev/null +++ b/data/datasets/coco/labels.txt @@ -0,0 +1,91 @@ +0 background +1 person +2 bicycle +3 car +4 motorcycle +5 airplane +6 bus +7 train +8 truck +9 boat +10 traffic light +11 fire hydrant +12 street sign +13 stop sign +14 parking meter +15 bench +16 bird +17 cat +18 dog +19 horse +20 sheep +21 cow +22 elephant +23 bear +24 zebra +25 giraffe +26 hat +27 backpack +28 umbrella +29 shoe +30 eye glasses +31 handbag +32 tie +33 suitcase +34 frisbee +35 skis +36 snowboard +37 sports ball +38 kite +39 baseball bat +40 baseball glove +41 skateboard +42 surfboard +43 tennis racket +44 bottle +45 plate +46 wine glass +47 cup +48 fork +49 knife +50 spoon +51 bowl +52 banana +53 apple +54 sandwich +55 orange +56 broccoli +57 carrot +58 hot dog +59 pizza +60 donut +61 cake +62 chair +63 couch +64 potted plant +65 bed +66 mirror +67 dining table +68 window +69 desk +70 toilet +71 door +72 tv +73 laptop +74 mouse +75 remote +76 keyboard +77 cell phone +78 microwave +79 oven +80 toaster +81 sink +82 refrigerator +83 blender +84 book +85 clock +86 vase +87 scissors +88 teddy bear +89 hair drier +90 toothbrush \ No newline at end of file diff --git a/data/datasets/cocostuff/README.md b/data/datasets/cocostuff/README.md new file mode 100644 index 0000000..f748afc --- /dev/null +++ b/data/datasets/cocostuff/README.md @@ -0,0 +1,76 @@ +# COCO-Stuff + +This is an instruction for setting up COCO-Stuff dataset. +COCO-Stuff 164k is the latest version and recommended. + +![](../../../docs/datasets/cocostuff.png) + +## COCO-Stuff 164k + +### Setup + +1. Run the script below to download the dataset (20GB+). + +```sh +$ bash ./scripts/setup_cocostuff164k.sh [PATH TO DOWNLOAD] +``` + +2. Set the path to the dataset in ```configs/cocostuff164k.yaml```. + +```yaml +DATASET: cocostuff164k + ROOT: # <- Write here +... +``` + +### Dataset structure + +``` +├── images +│ ├── train2017 +│ │ ├── 000000000009.jpg +│ │ └── ... +│ └── val2017 +│ ├── 000000000139.jpg +│ └── ... +└── annotations + ├── train2017 + │ ├── 000000000009.png + │ └── ... + └── val2017 + ├── 000000000139.png + └── ... +``` + +## COCO-Stuff 10k + +### Setup + +1. Run the script below to download the dataset (2GB). + +```sh +$ bash ./scripts/setup_cocostuff10k.sh [PATH TO DOWNLOAD] +``` + +2. Set the path to the dataset in ```configs/cocostuff10k.yaml```. + +```yaml +DATASET: cocostuff10k + ROOT: # <- Write here +... +``` + +### Dataset structure + +``` +├── images +│ ├── COCO_train2014_000000000077.jpg +│ └── ... +├── annotations +│ ├── COCO_train2014_000000000077.mat +│ └── ... +└── imageLists + ├── all.txt + ├── test.txt + └── train.txt +``` diff --git a/data/datasets/cocostuff/labels.txt b/data/datasets/cocostuff/labels.txt index c5d66fd..cf9f484 100644 --- a/data/datasets/cocostuff/labels.txt +++ b/data/datasets/cocostuff/labels.txt @@ -1,183 +1,182 @@ -0 unlabeled -1 person -2 bicycle -3 car -4 motorcycle -5 airplane -6 bus -7 train -8 truck -9 boat -10 traffic light -11 fire hydrant -12 street sign -13 stop sign -14 parking meter -15 bench -16 bird -17 cat -18 dog -19 horse -20 sheep -21 cow -22 elephant -23 bear -24 zebra -25 giraffe -26 hat -27 backpack -28 umbrella -29 shoe -30 eye glasses -31 handbag -32 tie -33 suitcase -34 frisbee -35 skis -36 snowboard -37 sports ball -38 kite -39 baseball bat -40 baseball glove -41 skateboard -42 surfboard -43 tennis racket -44 bottle -45 plate -46 wine glass -47 cup -48 fork -49 knife -50 spoon -51 bowl -52 banana -53 apple -54 sandwich -55 orange -56 broccoli -57 carrot -58 hot dog -59 pizza -60 donut -61 cake -62 chair -63 couch -64 potted plant -65 bed -66 mirror -67 dining table -68 window -69 desk -70 toilet -71 door -72 tv -73 laptop -74 mouse -75 remote -76 keyboard -77 cell phone -78 microwave -79 oven -80 toaster -81 sink -82 refrigerator -83 blender -84 book -85 clock -86 vase -87 scissors -88 teddy bear -89 hair drier -90 toothbrush -91 hair brush -92 banner -93 blanket -94 branch -95 bridge -96 building-other -97 bush -98 cabinet -99 cage -100 cardboard -101 carpet -102 ceiling-other -103 ceiling-tile -104 cloth -105 clothes -106 clouds -107 counter -108 cupboard -109 curtain -110 desk-stuff -111 dirt -112 door-stuff -113 fence -114 floor-marble -115 floor-other -116 floor-stone -117 floor-tile -118 floor-wood -119 flower -120 fog -121 food-other -122 fruit -123 furniture-other -124 grass -125 gravel -126 ground-other -127 hill -128 house -129 leaves -130 light -131 mat -132 metal -133 mirror-stuff -134 moss -135 mountain -136 mud -137 napkin -138 net -139 paper -140 pavement -141 pillow -142 plant-other -143 plastic -144 platform -145 playingfield -146 railing -147 railroad -148 river -149 road -150 rock -151 roof -152 rug -153 salad -154 sand -155 sea -156 shelf -157 sky-other -158 skyscraper -159 snow -160 solid-other -161 stairs -162 stone -163 straw -164 structural-other -165 table -166 tent -167 textile-other -168 towel -169 tree -170 vegetable -171 wall-brick -172 wall-concrete -173 wall-other -174 wall-panel -175 wall-stone -176 wall-tile -177 wall-wood -178 water-other -179 waterdrops -180 window-blind -181 window-other -182 wood \ No newline at end of file +0 person +1 bicycle +2 car +3 motorcycle +4 airplane +5 bus +6 train +7 truck +8 boat +9 traffic light +10 fire hydrant +11 street sign +12 stop sign +13 parking meter +14 bench +15 bird +16 cat +17 dog +18 horse +19 sheep +20 cow +21 elephant +22 bear +23 zebra +24 giraffe +25 hat +26 backpack +27 umbrella +28 shoe +29 eye glasses +30 handbag +31 tie +32 suitcase +33 frisbee +34 skis +35 snowboard +36 sports ball +37 kite +38 baseball bat +39 baseball glove +40 skateboard +41 surfboard +42 tennis racket +43 bottle +44 plate +45 wine glass +46 cup +47 fork +48 knife +49 spoon +50 bowl +51 banana +52 apple +53 sandwich +54 orange +55 broccoli +56 carrot +57 hot dog +58 pizza +59 donut +60 cake +61 chair +62 couch +63 potted plant +64 bed +65 mirror +66 dining table +67 window +68 desk +69 toilet +70 door +71 tv +72 laptop +73 mouse +74 remote +75 keyboard +76 cell phone +77 microwave +78 oven +79 toaster +80 sink +81 refrigerator +82 blender +83 book +84 clock +85 vase +86 scissors +87 teddy bear +88 hair drier +89 toothbrush +90 hair brush +91 banner +92 blanket +93 branch +94 bridge +95 building-other +96 bush +97 cabinet +98 cage +99 cardboard +100 carpet +101 ceiling-other +102 ceiling-tile +103 cloth +104 clothes +105 clouds +106 counter +107 cupboard +108 curtain +109 desk-stuff +110 dirt +111 door-stuff +112 fence +113 floor-marble +114 floor-other +115 floor-stone +116 floor-tile +117 floor-wood +118 flower +119 fog +120 food-other +121 fruit +122 furniture-other +123 grass +124 gravel +125 ground-other +126 hill +127 house +128 leaves +129 light +130 mat +131 metal +132 mirror-stuff +133 moss +134 mountain +135 mud +136 napkin +137 net +138 paper +139 pavement +140 pillow +141 plant-other +142 plastic +143 platform +144 playingfield +145 railing +146 railroad +147 river +148 road +149 rock +150 roof +151 rug +152 salad +153 sand +154 sea +155 shelf +156 sky-other +157 skyscraper +158 snow +159 solid-other +160 stairs +161 stone +162 straw +163 structural-other +164 table +165 tent +166 textile-other +167 towel +168 tree +169 vegetable +170 wall-brick +171 wall-concrete +172 wall-other +173 wall-panel +174 wall-stone +175 wall-tile +176 wall-wood +177 water-other +178 waterdrops +179 window-blind +180 window-other +181 wood \ No newline at end of file diff --git a/data/datasets/cocostuff/labels_2.txt b/data/datasets/cocostuff/labels_2.txt deleted file mode 100644 index cf9f484..0000000 --- a/data/datasets/cocostuff/labels_2.txt +++ /dev/null @@ -1,182 +0,0 @@ -0 person -1 bicycle -2 car -3 motorcycle -4 airplane -5 bus -6 train -7 truck -8 boat -9 traffic light -10 fire hydrant -11 street sign -12 stop sign -13 parking meter -14 bench -15 bird -16 cat -17 dog -18 horse -19 sheep -20 cow -21 elephant -22 bear -23 zebra -24 giraffe -25 hat -26 backpack -27 umbrella -28 shoe -29 eye glasses -30 handbag -31 tie -32 suitcase -33 frisbee -34 skis -35 snowboard -36 sports ball -37 kite -38 baseball bat -39 baseball glove -40 skateboard -41 surfboard -42 tennis racket -43 bottle -44 plate -45 wine glass -46 cup -47 fork -48 knife -49 spoon -50 bowl -51 banana -52 apple -53 sandwich -54 orange -55 broccoli -56 carrot -57 hot dog -58 pizza -59 donut -60 cake -61 chair -62 couch -63 potted plant -64 bed -65 mirror -66 dining table -67 window -68 desk -69 toilet -70 door -71 tv -72 laptop -73 mouse -74 remote -75 keyboard -76 cell phone -77 microwave -78 oven -79 toaster -80 sink -81 refrigerator -82 blender -83 book -84 clock -85 vase -86 scissors -87 teddy bear -88 hair drier -89 toothbrush -90 hair brush -91 banner -92 blanket -93 branch -94 bridge -95 building-other -96 bush -97 cabinet -98 cage -99 cardboard -100 carpet -101 ceiling-other -102 ceiling-tile -103 cloth -104 clothes -105 clouds -106 counter -107 cupboard -108 curtain -109 desk-stuff -110 dirt -111 door-stuff -112 fence -113 floor-marble -114 floor-other -115 floor-stone -116 floor-tile -117 floor-wood -118 flower -119 fog -120 food-other -121 fruit -122 furniture-other -123 grass -124 gravel -125 ground-other -126 hill -127 house -128 leaves -129 light -130 mat -131 metal -132 mirror-stuff -133 moss -134 mountain -135 mud -136 napkin -137 net -138 paper -139 pavement -140 pillow -141 plant-other -142 plastic -143 platform -144 playingfield -145 railing -146 railroad -147 river -148 road -149 rock -150 roof -151 rug -152 salad -153 sand -154 sea -155 shelf -156 sky-other -157 skyscraper -158 snow -159 solid-other -160 stairs -161 stone -162 straw -163 structural-other -164 table -165 tent -166 textile-other -167 towel -168 tree -169 vegetable -170 wall-brick -171 wall-concrete -172 wall-other -173 wall-panel -174 wall-stone -175 wall-tile -176 wall-wood -177 water-other -178 waterdrops -179 window-blind -180 window-other -181 wood \ No newline at end of file diff --git a/data/datasets/voc12/README.md b/data/datasets/voc12/README.md new file mode 100644 index 0000000..ec44ea5 --- /dev/null +++ b/data/datasets/voc12/README.md @@ -0,0 +1,60 @@ +# PASCAL VOC 2012 + +This is an instruction for setting up PASCAL VOC dataset. + +![](../../../docs/datasets/voc12.png) + +1. Download PASCAL VOC 2012. + +```sh +$ bash scripts/setup_voc12.sh [PATH TO DOWNLOAD] +``` + +``` +/VOCdevkit +└── VOC2012 + ├── Annotations + ├── ImageSets + │ └── Segmentation + ├── JPEGImages + ├── SegmentationObject + └── SegmentationClass +``` + +2. Add SBD augmentated training data as `SegmentationClassAug`. + + +* Convert by yourself ([here](https://github.com/shelhamer/fcn.berkeleyvision.org/tree/master/data/pascal)). +* Or download pre-converted files ([here](https://github.com/DrSleep/tensorflow-deeplab-resnet#evaluation)). + +3. Download official image sets as `ImageSets/SegmentationAug`. + +* From https://ucla.app.box.com/s/rd9z2xvwsfpksi7mi08i2xqrj7ab4keb/file/55053033642 + +```sh +/VOCdevkit +└── VOC2012 + ├── Annotations + ├── ImageSets + │ ├── Segmentation + │ └── SegmentationAug # ADDED!! + │ ├── test.txt + │ ├── train_aug.txt + │ ├── train.txt + │ ├── trainval_aug.txt + │ ├── trainval.txt + │ └── val.txt + ├── JPEGImages + ├── SegmentationObject + ├── SegmentationClass + └── SegmentationClassAug # ADDED!! + └── 2007_000032.png +``` + +1. Set the path to the dataset in ```configs/voc12.yaml```. + +```yaml +DATASET: voc12 + ROOT: # <- Write here +... +``` diff --git a/data/models/deeplab_resnet101/coco_init/.gitkeep b/data/models/coco/deeplabv1_resnet101/caffemodel/.gitkeep similarity index 100% rename from data/models/deeplab_resnet101/coco_init/.gitkeep rename to data/models/coco/deeplabv1_resnet101/caffemodel/.gitkeep diff --git a/data/models/deeplab_resnet101/cocostuff164k/.gitkeep b/data/models/deeplab_resnet101/cocostuff164k/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/data/models/deeplab_resnet101/init/.gitkeep b/data/models/deeplab_resnet101/init/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/data/models/deeplab_resnet101/voc12/.gitkeep b/data/models/deeplab_resnet101/voc12/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/data/models/deeplab_resnet101/cocostuff10k/.gitkeep b/data/models/voc12/deeplabv2_resnet101_msc/caffemodel/.gitkeep similarity index 100% rename from data/models/deeplab_resnet101/cocostuff10k/.gitkeep rename to data/models/voc12/deeplabv2_resnet101_msc/caffemodel/.gitkeep diff --git a/demo.py b/demo.py index ec4d573..efe8fb6 100644 --- a/demo.py +++ b/demo.py @@ -19,7 +19,7 @@ import yaml from addict import Dict -from libs.models import DeepLabV2_ResNet101_MSC +from libs.models import * from libs.utils import DenseCRF @@ -43,15 +43,6 @@ def get_classtable(CONFIG): return classes -def setup_model(model_path, n_classes, device): - model = DeepLabV2_ResNet101_MSC(n_classes=n_classes) - state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) - model.load_state_dict(state_dict) - model.eval() - model.to(device) - return model - - def setup_postprocessor(CONFIG): # CRF post-processor postprocessor = DenseCRF( @@ -89,19 +80,18 @@ def preprocessing(image, device, CONFIG): def inference(model, image, raw_image=None, postprocessor=None): - B, C, H, W = image.shape + _, _, H, W = image.shape # Image -> Probability map logits = model(image) - logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=True) - probs = F.softmax(logits, dim=1) - probs = probs.data.cpu().numpy()[0] + logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=False) + probs = F.softmax(logits, dim=1)[0] + probs = probs.cpu().numpy() # Refine the prob map with CRF if postprocessor and raw_image is not None: probs = postprocessor(raw_image, probs) - # Pixel-wise argmax labelmap = np.argmax(probs, axis=0) return labelmap @@ -110,26 +100,59 @@ def inference(model, image, raw_image=None, postprocessor=None): @click.group() @click.pass_context def main(ctx): + """ + Demo with a trained model + """ + print("Mode:", ctx.invoked_subcommand) -@main.command(help="Inference from a single image") -@click.option("-c", "--config", type=str, required=True, help="yaml") -@click.option("-i", "--image-path", type=str, required=True) -@click.option("-m", "--model-path", type=str, required=True, help="pth") -@click.option("--cuda/--no-cuda", default=True, help="Switch GPU/CPU") -@click.option("--crf", is_flag=True, help="CRF post processing") -def single(config, image_path, model_path, cuda, crf): - # Disable autograd globally - torch.set_grad_enabled(False) +@main.command() +@click.option( + "-c", + "--config-path", + type=click.File(), + required=True, + help="Dataset configuration file in YAML", +) +@click.option( + "-m", + "--model-path", + type=click.Path(exists=True), + required=True, + help="PyTorch model to be loaded", +) +@click.option( + "-i", + "--image-path", + type=click.Path(exists=True), + required=True, + help="Image to be processed", +) +@click.option( + "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]" +) +@click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing") +def single(config_path, model_path, image_path, cuda, crf): + """ + Inference from a single image + """ # Setup + CONFIG = Dict(yaml.load(config_path)) device = get_device(cuda) - CONFIG = Dict(yaml.load(open(config))) + torch.set_grad_enabled(False) + classes = get_classtable(CONFIG) - model = setup_model(model_path, CONFIG.DATASET.N_CLASSES, device) postprocessor = setup_postprocessor(CONFIG) if crf else None + model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES) + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(state_dict) + model.eval() + model.to(device) + print("Model:", CONFIG.MODEL.NAME) + # Inference image = cv2.imread(image_path, cv2.IMREAD_COLOR) image, raw_image = preprocessing(image, device, CONFIG) @@ -158,25 +181,47 @@ def single(config, image_path, model_path, cuda, crf): plt.show() -@main.command(help="Inference from camera stream") -@click.option("-c", "--config", type=str, required=True, help="yaml") -@click.option("-m", "--model-path", type=str, required=True, help="pth") -@click.option("--cuda/--no-cuda", default=True, help="Switch GPU/CPU") -@click.option("--crf", is_flag=True, help="CRF post processing") -@click.option("--camera-id", type=int, default=0) -def live(config, model_path, cuda, crf, camera_id): - # Disable autograd globally - torch.set_grad_enabled(False) - # Auto-tune cuDNN - torch.backends.cudnn.benchmark = True +@main.command() +@click.option( + "-c", + "--config-path", + type=click.File(), + required=True, + help="Dataset configuration file in YAML", +) +@click.option( + "-m", + "--model-path", + type=click.Path(exists=True), + required=True, + help="PyTorch model to be loaded", +) +@click.option( + "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]" +) +@click.option("--crf", is_flag=True, show_default=True, help="CRF post-processing") +@click.option("--camera-id", type=int, default=0, show_default=True, help="Device ID") +def live(config_path, model_path, cuda, crf, camera_id): + """ + Inference from camera stream + """ # Setup + CONFIG = Dict(yaml.load(config_path)) device = get_device(cuda) - CONFIG = Dict(yaml.load(open(config))) + torch.set_grad_enabled(False) + torch.backends.cudnn.benchmark = True + classes = get_classtable(CONFIG) - model = setup_model(model_path, CONFIG.DATASET.N_CLASSES, device) postprocessor = setup_postprocessor(CONFIG) if crf else None + model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES) + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(state_dict) + model.eval() + model.to(device) + print("Model:", CONFIG.MODEL.NAME) + # UVC camera stream cap = cv2.VideoCapture(camera_id) cap.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*"YUYV")) @@ -197,7 +242,7 @@ def mouse_event(event, x, y, flags, labelmap): cv2.namedWindow(window_name, cv2.WINDOW_AUTOSIZE) while True: - ret, frame = cap.read() + _, frame = cap.read() image, raw_image = preprocessing(frame, device, CONFIG) labelmap = inference(model, image, raw_image, postprocessor) colormap = colorize(labelmap) diff --git a/docs/datasets/cocostuff.png b/docs/datasets/cocostuff.png new file mode 100644 index 0000000..74c8a04 Binary files /dev/null and b/docs/datasets/cocostuff.png differ diff --git a/docs/data.png b/docs/datasets/voc12.png similarity index 66% rename from docs/data.png rename to docs/datasets/voc12.png index 8bc51c8..290b4a6 100644 Binary files a/docs/data.png and b/docs/datasets/voc12.png differ diff --git a/docs/examples/coco_init.png b/docs/examples/coco_init.png new file mode 100644 index 0000000..e01e3ad Binary files /dev/null and b/docs/examples/coco_init.png differ diff --git a/docs/examples/cocostuff10k.png b/docs/examples/cocostuff10k.png new file mode 100644 index 0000000..ca4d71c Binary files /dev/null and b/docs/examples/cocostuff10k.png differ diff --git a/docs/examples/cocostuff164k.png b/docs/examples/cocostuff164k.png new file mode 100644 index 0000000..d5f1c27 Binary files /dev/null and b/docs/examples/cocostuff164k.png differ diff --git a/docs/examples/voc12.png b/docs/examples/voc12.png new file mode 100644 index 0000000..8f90aed Binary files /dev/null and b/docs/examples/voc12.png differ diff --git a/hubconf.py b/hubconf.py index e38bff0..94f496f 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,6 +5,8 @@ # URL: https://kazuto1011.github.io # Date: 20 December 2018 +from __future__ import print_function + def deeplabv2_resnet101(**kwargs): """ @@ -12,6 +14,12 @@ def deeplabv2_resnet101(**kwargs): n_classes (int): the number of classes """ + if kwargs["pretrained"]: + raise NotImplementedError( + "Please download from " + "https://github.com/kazuto1011/deeplab-pytorch/tree/master#pretrained-models" + ) + from libs.models.deeplabv2 import DeepLabV2 from libs.models.msc import MSC diff --git a/libs/datasets/__init__.py b/libs/datasets/__init__.py index 7aae0d2..314a751 100644 --- a/libs/datasets/__init__.py +++ b/libs/datasets/__init__.py @@ -1,6 +1,11 @@ -from __future__ import absolute_import +from .voc import VOC, VOCAug from .cocostuff import CocoStuff10k, CocoStuff164k def get_dataset(name): - return {"cocostuff10k": CocoStuff10k, "cocostuff164k": CocoStuff164k}[name] + return { + "cocostuff10k": CocoStuff10k, + "cocostuff164k": CocoStuff164k, + "voc": VOC, + "vocaug": VOCAug, + }[name] diff --git a/libs/datasets/base.py b/libs/datasets/base.py new file mode 100644 index 0000000..efb6a8e --- /dev/null +++ b/libs/datasets/base.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# coding: utf-8 +# +# Author: Kazuto Nakashima +# URL: http://kazuto1011.github.io +# Created: 2017-10-30 + +import random + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils import data + + +class _BaseDataset(data.Dataset): + """ + Base dataset class + """ + + def __init__( + self, + root, + split, + ignore_label, + mean_bgr, + augment=True, + base_size=None, + crop_size=321, + scales=(1.0), + flip=True, + ): + self.root = root + self.split = split + self.ignore_label = ignore_label + self.mean_bgr = np.array(mean_bgr) + self.augment = augment + self.base_size = base_size + self.crop_size = crop_size + self.scales = scales + self.flip = flip + self.files = [] + self._set_files() + + cv2.setNumThreads(0) + + def _set_files(self): + """ + Create a file path/image id list. + """ + raise NotImplementedError() + + def _load_data(self, image_id): + """ + Load the image and label in numpy.ndarray + """ + raise NotImplementedError() + + def _augmentation(self, image, label): + # Scaling + h, w = label.shape + if self.base_size: + if h > w: + h, w = (self.base_size, int(self.base_size * w / h)) + else: + h, w = (int(self.base_size * h / w), self.base_size) + scale_factor = random.choice(self.scales) + h, w = (int(h * scale_factor), int(w * scale_factor)) + image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR) + label = Image.fromarray(label).resize((w, h), resample=Image.NEAREST) + label = np.asarray(label, dtype=np.int64) + + # Padding to fit for crop_size + h, w = label.shape + pad_h = max(self.crop_size - h, 0) + pad_w = max(self.crop_size - w, 0) + pad_kwargs = { + "top": 0, + "bottom": pad_h, + "left": 0, + "right": pad_w, + "borderType": cv2.BORDER_CONSTANT, + } + if pad_h > 0 or pad_w > 0: + image = cv2.copyMakeBorder(image, value=self.mean_bgr, **pad_kwargs) + label = cv2.copyMakeBorder(label, value=self.ignore_label, **pad_kwargs) + + # Cropping + h, w = label.shape + start_h = random.randint(0, h - self.crop_size) + start_w = random.randint(0, w - self.crop_size) + end_h = start_h + self.crop_size + end_w = start_w + self.crop_size + image = image[start_h:end_h, start_w:end_w] + label = label[start_h:end_h, start_w:end_w] + + if self.flip: + # Random flipping + if random.random() < 0.5: + image = np.fliplr(image).copy() # HWC + label = np.fliplr(label).copy() # HW + return image, label + + def __getitem__(self, index): + image_id, image, label = self._load_data(index) + if self.augment: + image, label = self._augmentation(image, label) + # Mean subtraction + image -= self.mean_bgr + # HWC -> CHW + image = image.transpose(2, 0, 1) + return image_id, image.astype(np.float32), label.astype(np.int64) + + def __len__(self): + return len(self.files) + + def __repr__(self): + fmt_str = "Dataset: " + self.__class__.__name__ + "\n" + fmt_str += " # data: {}\n".format(self.__len__()) + fmt_str += " Split: {}\n".format(self.split) + fmt_str += " Root: {}".format(self.root) + return fmt_str diff --git a/libs/datasets/cocostuff.py b/libs/datasets/cocostuff.py index a7dd313..d132d48 100644 --- a/libs/datasets/cocostuff.py +++ b/libs/datasets/cocostuff.py @@ -5,151 +5,26 @@ # URL: http://kazuto1011.github.io # Created: 2017-10-30 -from __future__ import print_function +from __future__ import absolute_import, print_function -import glob import os.path as osp -import random -from collections import defaultdict from glob import glob import cv2 -import h5py import numpy as np import scipy.io as sio import torch +from PIL import Image from torch.utils import data -from tqdm import tqdm +from .base import _BaseDataset -class _CocoStuff(data.Dataset): - """COCO-Stuff base class""" - def __init__( - self, - root, - split="train", - base_size=513, - crop_size=321, - mean=(104.008, 116.669, 122.675), - scale=(0.5, 0.75, 1.0, 1.25, 1.5), - warp=True, - flip=True, - preload=False, - ): - self.root = root - self.split = split - self.base_size = base_size - self.crop_size = crop_size - self.mean = np.array(mean) - self.scale = scale - self.warp = warp - self.flip = flip - self.preload = preload - - self.files = [] - self.images = [] - self.labels = [] - self.ignore_label = None - - self._set_files() - - if self.preload: - self._preload_data() - - cv2.setNumThreads(0) - - def _set_files(self): - raise NotImplementedError() - - def _transform(self, image, label): - # Mean subtraction - image -= self.mean - # Pre-scaling - if self.warp: - base_size = (self.base_size,) * 2 - else: - raw_h, raw_w = label.shape - if raw_h > raw_w: - base_size = (int(self.base_size * raw_w / raw_h), self.base_size) - else: - base_size = (self.base_size, int(self.base_size * raw_h / raw_w)) - image = cv2.resize(image, base_size, interpolation=cv2.INTER_LINEAR) - label = cv2.resize(label, base_size, interpolation=cv2.INTER_NEAREST) - if self.scale is not None: - # Scaling - scale_factor = random.choice(self.scale) - scale_kwargs = {"dsize": None, "fx": scale_factor, "fy": scale_factor} - image = cv2.resize(image, interpolation=cv2.INTER_LINEAR, **scale_kwargs) - label = cv2.resize(label, interpolation=cv2.INTER_NEAREST, **scale_kwargs) - scale_h, scale_w = label.shape - # Padding - pad_h = max(max(base_size[1], self.crop_size) - scale_h, 0) - pad_w = max(max(base_size[0], self.crop_size) - scale_w, 0) - pad_kwargs = { - "top": 0, - "bottom": pad_h, - "left": 0, - "right": pad_w, - "borderType": cv2.BORDER_CONSTANT, - } - if pad_h > 0 or pad_w > 0: - image = cv2.copyMakeBorder(image, value=(0.0, 0.0, 0.0), **pad_kwargs) - label = cv2.copyMakeBorder(label, value=self.ignore_label, **pad_kwargs) - # Random cropping - base_h, base_w = label.shape - start_h = random.randint(0, base_h - self.crop_size) - start_w = random.randint(0, base_w - self.crop_size) - end_h = start_h + self.crop_size - end_w = start_w + self.crop_size - image = image[start_h:end_h, start_w:end_w] - label = label[start_h:end_h, start_w:end_w] - if self.flip: - # Random flipping - if random.random() < 0.5: - image = np.fliplr(image).copy() # HWC - label = np.fliplr(label).copy() # HW - # HWC -> CHW - image = image.transpose(2, 0, 1) - return image, label - - def _load_data(self, image_id): - raise NotImplementedError() - - def _preload_data(self): - for image_id in tqdm( - self.files, desc="Preloading...", leave=False, dynamic_ncols=True - ): - image, label = self._load_data(image_id) - self.images.append(image) - self.labels.append(label) - - def __getitem__(self, index): - if self.preload: - image, label = self.images[index], self.labels[index] - else: - image_id = self.files[index] - image, label = self._load_data(image_id) - image, label = self._transform(image, label) - return image.astype(np.float32), label.astype(np.int64) - - def __len__(self): - return len(self.files) - - def __repr__(self): - fmt_str = "Dataset " + self.__class__.__name__ + "\n" - fmt_str += " Number of datapoints: {}\n".format(self.__len__()) - fmt_str += " Split: {}\n".format(self.split) - fmt_str += " Root Location: {}\n".format(self.root) - return fmt_str - - -class CocoStuff10k(_CocoStuff): +class CocoStuff10k(_BaseDataset): """COCO-Stuff 10k dataset""" - def __init__(self, version="1.1", **kwargs): - self.version = version - self.ignore_label = -1 + def __init__(self, warp_image=True, **kwargs): + self.warp_image = warp_image super(CocoStuff10k, self).__init__(**kwargs) def _set_files(self): @@ -162,32 +37,28 @@ def _set_files(self): else: raise ValueError("Invalid split name: {}".format(self.split)) - def _load_data(self, image_id): + def _load_data(self, index): # Set paths + image_id = self.files[index] image_path = osp.join(self.root, "images", image_id + ".jpg") label_path = osp.join(self.root, "annotations", image_id + ".mat") - # Load an image + # Load an image and label image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) - # Load a label map - if self.version == "1.1": - label = sio.loadmat(label_path)["S"].astype(np.int64) - label -= 1 # unlabeled (0 -> -1) - elif self.version == "1.0": - label = np.array(h5py.File(label_path, "r")["S"], dtype=np.int64) - label = label.transpose(1, 0) - label -= 2 # unlabeled (1 -> -1) - else: - raise NotImplementedError( - "1.0 or 1.1 expected, but got: {}".format(self.version) - ) - return image, label - - -class CocoStuff164k(_CocoStuff): + label = sio.loadmat(label_path)["S"] + label -= 1 # unlabeled (0 -> -1) + label[label == -1] = 255 + # Warping: this is just for reproducing the official scores on GitHub + if self.warp_image: + image = cv2.resize(image, (513, 513), interpolation=cv2.INTER_LINEAR) + label = Image.fromarray(label).resize((513, 513), resample=Image.NEAREST) + label = np.asarray(label) + return image_id, image, label + + +class CocoStuff164k(_BaseDataset): """COCO-Stuff 164k dataset""" def __init__(self, **kwargs): - self.ignore_label = 255 super(CocoStuff164k, self).__init__(**kwargs) def _set_files(self): @@ -202,14 +73,15 @@ def _set_files(self): else: raise ValueError("Invalid split name: {}".format(self.split)) - def _load_data(self, image_id): + def _load_data(self, index): # Set paths + image_id = self.files[index] image_path = osp.join(self.root, "images", self.split, image_id + ".jpg") label_path = osp.join(self.root, "annotations", self.split, image_id + ".png") - # Load an image + # Load an image and label image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) - label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE).astype(np.int64) - return image, label + label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE) + return image_id, image, label def get_parent_class(value, dictionary): @@ -231,30 +103,32 @@ def get_parent_class(value, dictionary): import matplotlib.pyplot as plt import matplotlib.cm as cm import torchvision - from torchvision.utils import make_grid import yaml + from torchvision.utils import make_grid + from tqdm import tqdm kwargs = {"nrow": 10, "padding": 50} batch_size = 100 - dataset_root = "/media/kazuto1011/Extra/cocostuff/cocostuff-164k" - dataset = CocoStuff164k(root=dataset_root, split="train2017") - + dataset = CocoStuff164k( + root="/media/kazuto1011/Extra/cocostuff/cocostuff-164k", + split="train2017", + ignore_label=255, + mean_bgr=(104.008, 116.669, 122.675), + augment=True, + crop_size=321, + scales=(0.5, 0.75, 1.0, 1.25, 1.5), + flip=True, + ) print(dataset) loader = data.DataLoader(dataset, batch_size=batch_size) - for i, (images, labels) in tqdm( + for i, (image_ids, images, labels) in tqdm( enumerate(loader), total=np.ceil(len(dataset) / batch_size), leave=False ): - if i == 0: - mean = ( - torch.tensor((104.008, 116.669, 122.675)) - .unsqueeze(0) - .unsqueeze(2) - .unsqueeze(3) - ) + mean = torch.tensor((104.008, 116.669, 122.675))[None, :, None, None] images += mean.expand_as(images) image = make_grid(images, pad_value=-1, **kwargs).numpy() image = np.transpose(image, (1, 2, 0)) @@ -271,7 +145,7 @@ def get_parent_class(value, dictionary): label = label.astype(np.uint8) tiled_images = np.hstack((image, label)) - # cv2.imwrite("./docs/data.png", tiled_images) + # cv2.imwrite("./docs/datasets/cocostuff.png", tiled_images) plt.imshow(np.dstack((tiled_images[..., 2::-1], tiled_images[..., 3]))) plt.show() break diff --git a/libs/datasets/voc.py b/libs/datasets/voc.py new file mode 100644 index 0000000..80b03f4 --- /dev/null +++ b/libs/datasets/voc.py @@ -0,0 +1,140 @@ +#!/usr/bin/env python +# coding: utf-8 +# +# Author: Kazuto Nakashima +# URL: https://kazuto1011.github.io +# Date: 08 February 2019 + +from __future__ import absolute_import, print_function + +import os.path as osp + +import cv2 +import numpy as np +import torch +from PIL import Image +from torch.utils import data + +from .base import _BaseDataset + + +class VOC(_BaseDataset): + """ + PASCAL VOC Segmentation dataset + """ + + def __init__(self, year=2012, **kwargs): + self.year = year + super(VOC, self).__init__(**kwargs) + + def _set_files(self): + self.root = osp.join(self.root, "VOC{}".format(self.year)) + self.image_dir = osp.join(self.root, "JPEGImages") + self.label_dir = osp.join(self.root, "SegmentationClass") + + if self.split in ["train", "trainval", "val", "test"]: + file_list = osp.join( + self.root, "ImageSets/Segmentation", self.split + ".txt" + ) + file_list = tuple(open(file_list, "r")) + file_list = [id_.rstrip() for id_ in file_list] + self.files = file_list + else: + raise ValueError("Invalid split name: {}".format(self.split)) + + def _load_data(self, image_id): + # Set paths + image_path = osp.join(self.root, self.image_dir, image_id + ".jpg") + label_path = osp.join(self.root, self.label_dir, image_id + ".png") + # Load an image + image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) + label = np.asarray(Image.open(label_path), dtype=np.int32) + return image_id, image, label + + +class VOCAug(_BaseDataset): + """ + PASCAL VOC Segmentation dataset with extra annotations + """ + + def __init__(self, year=2012, **kwargs): + self.year = year + super(VOCAug, self).__init__(**kwargs) + + def _set_files(self): + self.root = osp.join(self.root, "VOC{}".format(self.year)) + + if self.split in ["train", "train_aug", "trainval", "trainval_aug", "val"]: + file_list = osp.join( + self.root, "ImageSets/SegmentationAug", self.split + ".txt" + ) + file_list = tuple(open(file_list, "r")) + file_list = [id_.rstrip().split(" ") for id_ in file_list] + self.files, self.labels = list(zip(*file_list)) + else: + raise ValueError("Invalid split name: {}".format(self.split)) + + def _load_data(self, index): + # Set paths + image_id = self.files[index].split("/")[-1].split(".")[0] + image_path = osp.join(self.root, self.files[index][1:]) + label_path = osp.join(self.root, self.labels[index][1:]) + # Load an image + image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) + label = np.asarray(Image.open(label_path), dtype=np.int32) + return image_id, image, label + + +if __name__ == "__main__": + import matplotlib + import matplotlib.pyplot as plt + import matplotlib.cm as cm + import torchvision + import yaml + from torchvision.utils import make_grid + from tqdm import tqdm + + kwargs = {"nrow": 10, "padding": 50} + batch_size = 100 + + dataset = VOCAug( + root="/media/kazuto1011/Extra/VOCdevkit", + split="train_aug", + ignore_label=255, + mean_bgr=(104.008, 116.669, 122.675), + year=2012, + augment=True, + base_size=None, + crop_size=513, + scales=(0.5, 0.75, 1.0, 1.25, 1.5), + flip=True, + ) + print(dataset) + + loader = data.DataLoader(dataset, batch_size=batch_size) + + for i, (image_ids, images, labels) in tqdm( + enumerate(loader), total=np.ceil(len(dataset) / batch_size), leave=False + ): + if i == 0: + mean = torch.tensor((104.008, 116.669, 122.675))[None, :, None, None] + images += mean.expand_as(images) + image = make_grid(images, pad_value=-1, **kwargs).numpy() + image = np.transpose(image, (1, 2, 0)) + mask = np.zeros(image.shape[:2]) + mask[(image != -1)[..., 0]] = 255 + image = np.dstack((image, mask)).astype(np.uint8) + + labels = labels[:, np.newaxis, ...] + label = make_grid(labels, pad_value=255, **kwargs).numpy() + label_ = np.transpose(label, (1, 2, 0))[..., 0].astype(np.float32) + label = cm.jet_r(label_ / 21.0) * 255 + mask = np.zeros(label.shape[:2]) + label[..., 3][(label_ == 255)] = 0 + label = label.astype(np.uint8) + + tiled_images = np.hstack((image, label)) + # cv2.imwrite("./docs/datasets/voc12.png", tiled_images) + plt.imshow(np.dstack((tiled_images[..., 2::-1], tiled_images[..., 3]))) + plt.show() + break diff --git a/libs/models/__init__.py b/libs/models/__init__.py index 62a0640..1e67048 100644 --- a/libs/models/__init__.py +++ b/libs/models/__init__.py @@ -1,25 +1,33 @@ from __future__ import absolute_import from .resnet import * +from .deeplabv1 import * from .deeplabv2 import * from .deeplabv3 import * from .deeplabv3plus import * from .msc import * -def init_weights(model): - for m in model.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight) - if m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1) - if m.bias is not None: - nn.init.constant_(m.bias, 0) +def init_weights(module): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.Linear): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + elif isinstance(module, nn.BatchNorm2d): + nn.init.constant_(module.weight, 1) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + +def ResNet101(n_classes): + return ResNet(n_classes=n_classes, n_blocks=[3, 4, 23, 3]) + + +def DeepLabV1_ResNet101(n_classes): + return DeepLabV1(n_classes=n_classes, n_blocks=[3, 4, 23, 3]) def DeepLabV2_ResNet101_MSC(n_classes): @@ -40,7 +48,7 @@ def DeepLabV2S_ResNet101_MSC(n_classes): ) -def DeepLabV3_ResNet101_MSC(n_classes, output_stride): +def DeepLabV3_ResNet101_MSC(n_classes, output_stride=16): if output_stride == 16: atrous_rates = [6, 12, 18] elif output_stride == 8: @@ -48,19 +56,22 @@ def DeepLabV3_ResNet101_MSC(n_classes, output_stride): else: NotImplementedError - return MSC( - base=DeepLabV3( - n_classes=n_classes, - n_blocks=[3, 4, 23, 3], - atrous_rates=atrous_rates, - multi_grids=[1, 2, 4], - output_stride=output_stride, - ), - scales=[0.5, 0.75], + base = DeepLabV3( + n_classes=n_classes, + n_blocks=[3, 4, 23, 3], + atrous_rates=atrous_rates, + multi_grids=[1, 2, 4], + output_stride=output_stride, ) + for name, module in base.named_modules(): + if ".bn" in name: + module.momentum = 0.9997 -def DeepLabV3Plus_ResNet101_MSC(n_classes, output_stride): + return MSC(base=base, scales=[0.5, 0.75]) + + +def DeepLabV3Plus_ResNet101_MSC(n_classes, output_stride=16): if output_stride == 16: atrous_rates = [6, 12, 18] elif output_stride == 8: @@ -68,13 +79,16 @@ def DeepLabV3Plus_ResNet101_MSC(n_classes, output_stride): else: NotImplementedError - return MSC( - base=DeepLabV3Plus( - n_classes=n_classes, - n_blocks=[3, 4, 23, 3], - atrous_rates=atrous_rates, - multi_grids=[1, 2, 4], - output_stride=output_stride, - ), - scales=[0.5, 0.75], + base = DeepLabV3Plus( + n_classes=n_classes, + n_blocks=[3, 4, 23, 3], + atrous_rates=atrous_rates, + multi_grids=[1, 2, 4], + output_stride=output_stride, ) + + for name, module in base.named_modules(): + if ".bn" in name: + module.momentum = 0.9997 + + return MSC(base=base, scales=[0.5, 0.75]) diff --git a/libs/models/deeplabv1.py b/libs/models/deeplabv1.py new file mode 100644 index 0000000..ac5e673 --- /dev/null +++ b/libs/models/deeplabv1.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# coding: utf-8 +# +# Author: Kazuto Nakashima +# URL: https://kazuto1011.github.io +# Date: 19 February 2019 + +from __future__ import absolute_import, print_function + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .resnet import _ResLayer, _Stem + + +class DeepLabV1(nn.Sequential): + """ + DeepLab v1: Dilated ResNet + 1x1 Conv + Note that this is just a container for loading the pretrained COCO model and not mentioned as "v1" in papers. + """ + + def __init__(self, n_classes, n_blocks): + super(DeepLabV1, self).__init__() + self.add_module("layer1", _Stem()) + self.add_module("layer2", _ResLayer(n_blocks[0], 64, 64, 256, 1, 1)) + self.add_module("layer3", _ResLayer(n_blocks[1], 256, 128, 512, 2, 1)) + self.add_module("layer4", _ResLayer(n_blocks[2], 512, 256, 1024, 1, 2)) + self.add_module("layer5", _ResLayer(n_blocks[3], 1024, 512, 2048, 1, 4)) + self.add_module("fc", nn.Conv2d(2048, n_classes, 1)) + + +if __name__ == "__main__": + model = DeepLabV1(n_classes=21, n_blocks=[3, 4, 23, 3]) + model.eval() + image = torch.randn(1, 3, 513, 513) + + print(model) + print("input:", image.shape) + print("output:", model(image).shape) diff --git a/libs/models/deeplabv2.py b/libs/models/deeplabv2.py index bc8fc82..c2e3729 100644 --- a/libs/models/deeplabv2.py +++ b/libs/models/deeplabv2.py @@ -5,7 +5,7 @@ # URL: http://kazuto1011.github.io # Created: 2017-11-19 -from collections import OrderedDict +from __future__ import absolute_import, print_function import torch import torch.nn as nn @@ -15,11 +15,13 @@ class _ASPP(nn.Module): - """Atrous Spatial Pyramid Pooling""" + """ + Atrous spatial pyramid pooling (ASPP) + """ def __init__(self, in_ch, out_ch, rates): super(_ASPP, self).__init__() - for i, rate in enumerate(zip(rates)): + for i, rate in enumerate(rates): self.add_module( "c{}".format(i), nn.Conv2d(in_ch, out_ch, 3, 1, padding=rate, dilation=rate, bias=True), @@ -34,7 +36,10 @@ def forward(self, x): class DeepLabV2(nn.Sequential): - """DeepLab v2 (OS=8)""" + """ + DeepLab v2: Dilated ResNet + ASPP + Output stride is fixed at 8 + """ def __init__(self, n_classes, n_blocks, atrous_rates): super(DeepLabV2, self).__init__() diff --git a/libs/models/deeplabv3.py b/libs/models/deeplabv3.py index 3c2575d..2d1f821 100644 --- a/libs/models/deeplabv3.py +++ b/libs/models/deeplabv3.py @@ -5,6 +5,8 @@ # URL: http://kazuto1011.github.io # Created: 2018-03-26 +from __future__ import absolute_import, print_function + from collections import OrderedDict import torch @@ -15,13 +17,15 @@ class _ASPP(nn.Module): - """Atrous Spatial Pyramid Pooling with image pool""" + """ + Atrous spatial pyramid pooling with image-level feature + """ def __init__(self, n_in, n_out, rates): super(_ASPP, self).__init__() self.stages = nn.Module() self.stages.add_module("c0", _ConvBnReLU(n_in, n_out, 1, 1, 0, 1)) - for i, rate in enumerate(zip(rates)): + for i, rate in enumerate(rates): self.stages.add_module( "c{}".format(i + 1), _ConvBnReLU(n_in, n_out, 3, 1, padding=rate, dilation=rate), @@ -37,7 +41,7 @@ def __init__(self, n_in, n_out, rates): def forward(self, x): h = self.imagepool(x) - h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=True)] + h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)] for stage in self.stages.children(): h += [stage(x)] h = torch.cat(h, dim=1) @@ -45,7 +49,9 @@ def forward(self, x): class DeepLabV3(nn.Sequential): - """DeepLab v3""" + """ + DeepLab v3: Dilated ResNet with multi-grid + improved ASPP + """ def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride): super(DeepLabV3, self).__init__() diff --git a/libs/models/deeplabv3plus.py b/libs/models/deeplabv3plus.py index 592c837..71f095a 100644 --- a/libs/models/deeplabv3plus.py +++ b/libs/models/deeplabv3plus.py @@ -5,6 +5,8 @@ # URL: http://kazuto1011.github.io # Created: 2018-03-26 +from __future__ import absolute_import, print_function + from collections import OrderedDict import torch @@ -16,7 +18,9 @@ class DeepLabV3Plus(nn.Module): - """DeepLab v3+""" + """ + DeepLab v3+: Dilated ResNet with multi-grid + improved ASPP + decoder + """ def __init__(self, n_classes, n_blocks, atrous_rates, multi_grids, output_stride): super(DeepLabV3Plus, self).__init__() @@ -60,10 +64,10 @@ def forward(self, x): h = self.layer5(h) h = self.aspp(h) h = self.fc1(h) - h = F.interpolate(h, size=h_.shape[2:], mode="bilinear", align_corners=True) + h = F.interpolate(h, size=h_.shape[2:], mode="bilinear", align_corners=False) h = torch.cat((h, h_), dim=1) h = self.fc2(h) - h = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=True) + h = F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False) return h @@ -73,7 +77,7 @@ def forward(self, x): n_blocks=[3, 4, 23, 3], atrous_rates=[6, 12, 18], multi_grids=[1, 2, 4], - output_stride=8, + output_stride=16, ) model.eval() image = torch.randn(1, 3, 513, 513) diff --git a/libs/models/msc.py b/libs/models/msc.py index ed863aa..c51d798 100644 --- a/libs/models/msc.py +++ b/libs/models/msc.py @@ -11,25 +11,30 @@ class MSC(nn.Module): - """Multi-scale inputs""" + """ + Multi-scale inputs + """ - def __init__(self, base, scales=[0.5, 0.75]): + def __init__(self, base, scales=None): super(MSC, self).__init__() self.base = base - self.scales = scales + if scales: + self.scales = scales + else: + self.scales = [0.5, 0.75] def forward(self, x): # Original logits = self.base(x) + _, _, H, W = logits.shape interp = lambda l: F.interpolate( - l, size=logits.shape[2:], mode="bilinear", align_corners=True + l, size=(H, W), mode="bilinear", align_corners=False ) # Scaled logits_pyramid = [] for p in self.scales: - size = [int(s * p) for s in x.shape[2:]] - h = F.interpolate(x, size=size, mode="bilinear", align_corners=True) + h = F.interpolate(x, scale_factor=p, mode="bilinear", align_corners=False) logits_pyramid.append(self.base(h)) # Pixel-wise max diff --git a/libs/models/resnet.py b/libs/models/resnet.py index 62e8b63..24d5f69 100644 --- a/libs/models/resnet.py +++ b/libs/models/resnet.py @@ -5,6 +5,8 @@ # URL: http://kazuto1011.github.io # Created: 2017-11-19 +from __future__ import absolute_import, print_function + from collections import OrderedDict import torch @@ -20,6 +22,10 @@ class _ConvBnReLU(nn.Sequential): + """ + Cascade of 2D convolution, batch norm, and ReLU. + """ + BATCH_NORM = _BATCH_NORM def __init__( @@ -39,7 +45,9 @@ def __init__( class _Bottleneck(nn.Module): - """Bottleneck Unit""" + """ + Bottleneck block of MSRA ResNet. + """ def __init__(self, in_ch, mid_ch, out_ch, stride, dilation, downsample): super(_Bottleneck, self).__init__() @@ -61,7 +69,9 @@ def forward(self, x): class _ResLayer(nn.Sequential): - """Residual blocks""" + """ + Residual layer with multi grids + """ def __init__( self, n_layers, in_ch, mid_ch, out_ch, stride, dilation, multi_grids=None @@ -88,10 +98,43 @@ def __init__( class _Stem(nn.Sequential): """ - The 1st Residual Layer + The 1st Residual Layer. + Note that the max pooling is different from both MSRA and FAIR ResNet. """ def __init__(self): super(_Stem, self).__init__() self.add_module("conv1", _ConvBnReLU(3, 64, 7, 2, 3, 1)) self.add_module("pool", nn.MaxPool2d(3, 2, 1, ceil_mode=True)) + + +class ResNet(nn.Module): + def __init__(self, n_classes, n_blocks): + super(ResNet, self).__init__() + self.add_module("layer1", _Stem()) + self.add_module("layer2", _ResLayer(n_blocks[0], 64, 64, 256, 1, 1)) + self.add_module("layer3", _ResLayer(n_blocks[1], 256, 128, 512, 2, 1)) + self.add_module("layer4", _ResLayer(n_blocks[2], 512, 256, 1024, 2, 1)) + self.add_module("layer5", _ResLayer(n_blocks[3], 1024, 512, 2048, 2, 1)) + self.add_module("pool5", nn.AdaptiveAvgPool2d(1)) + self.add_module("fc", nn.Linear(2048, n_classes)) + + def forward(self, x): + h = self.layer1(x) + h = self.layer2(h) + h = self.layer3(h) + h = self.layer4(h) + h = self.layer5(h) + h = self.pool5(h) + h = self.fc(h.view(h.size(0), -1)) + return h + + +if __name__ == "__main__": + model = ResNet(n_classes=1000, n_blocks=[3, 4, 23, 3]) + model.eval() + image = torch.randn(1, 3, 224, 224) + + print(model) + print("input:", image.shape) + print("output:", model(image).shape) diff --git a/main.py b/main.py index a15d112..709b076 100644 --- a/main.py +++ b/main.py @@ -5,11 +5,11 @@ # URL: https://kazuto1011.github.io # Date: 07 January 2019 - from __future__ import absolute_import, division, print_function import json -import os.path as osp +import multiprocessing +import os import click import joblib @@ -19,39 +19,33 @@ import torch.nn.functional as F import yaml from addict import Dict +from PIL import Image from tensorboardX import SummaryWriter from torchnet.meter import MovingAverageValueMeter from tqdm import tqdm from libs.datasets import get_dataset -from libs.models import DeepLabV2_ResNet101_MSC +from libs.models import * from libs.utils import DenseCRF, PolynomialLR, scores +def makedirs(dirs): + if not os.path.exists(dirs): + os.makedirs(dirs) + + def get_device(cuda): cuda = cuda and torch.cuda.is_available() device = torch.device("cuda" if cuda else "cpu") if cuda: - current_device = torch.cuda.current_device() - print("Device:", torch.cuda.get_device_name(current_device)) + print("Device:") + for i in range(torch.cuda.device_count()): + print(" {}:".format(i), torch.cuda.get_device_name(i)) else: print("Device: CPU") return device -def setup_model(model_path, n_classes, train=True): - model = DeepLabV2_ResNet101_MSC(n_classes=n_classes) - state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) - if train: - model.load_state_dict(state_dict, strict=False) # to skip ASPP - model = nn.DataParallel(model) - else: - model.load_state_dict(state_dict) - model = nn.DataParallel(model) - model.eval() - return model - - def get_params(model, key): # For Dilated FCN if key == "1x": @@ -74,41 +68,65 @@ def get_params(model, key): yield m[1].bias -def resize_labels(labels, shape): - labels = labels.unsqueeze(1).float() # Add channel axis - labels = F.interpolate(labels, shape, mode="nearest") - labels = labels.squeeze(1).long() - return labels +def resize_labels(labels, size): + """ + Downsample labels for 0.5x and 0.75x logits by nearest interpolation. + Other nearest methods result in misaligned labels. + -> F.interpolate(labels, shape, mode='nearest') + -> cv2.resize(labels, shape, interpolation=cv2.INTER_NEAREST) + """ + new_labels = [] + for label in labels: + label = label.float().numpy() + label = Image.fromarray(label).resize(size, resample=Image.NEAREST) + new_labels.append(np.asarray(label)) + new_labels = torch.LongTensor(new_labels) + return new_labels @click.group() @click.pass_context def main(ctx): + """ + Training and evaluation + """ print("Mode:", ctx.invoked_subcommand) @main.command() -@click.option("-c", "--config", type=str, required=True, help="yaml") -@click.option("--cuda/--no-cuda", default=True, help="Switch GPU/CPU") -def train(config, cuda): - # Auto-tune cuDNN - torch.backends.cudnn.benchmark = True +@click.option( + "-c", + "--config-path", + type=click.File(), + required=True, + help="Dataset configuration file in YAML", +) +@click.option( + "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]" +) +def train(config_path, cuda): + """ + Training DeepLab by v2 protocol + """ # Configuration + CONFIG = Dict(yaml.load(config_path)) device = get_device(cuda) - CONFIG = Dict(yaml.load(open(config))) + torch.backends.cudnn.benchmark = True - # Dataset 10k or 164k + # Dataset dataset = get_dataset(CONFIG.DATASET.NAME)( root=CONFIG.DATASET.ROOT, split=CONFIG.DATASET.SPLIT.TRAIN, - base_size=CONFIG.IMAGE.SIZE.TRAIN.BASE, - crop_size=CONFIG.IMAGE.SIZE.TRAIN.CROP, - mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), - warp=CONFIG.DATASET.WARP_IMAGE, - scale=CONFIG.DATASET.SCALES, + ignore_label=CONFIG.DATASET.IGNORE_LABEL, + mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), + augment=True, + base_size=CONFIG.IMAGE.SIZE.BASE, + crop_size=CONFIG.IMAGE.SIZE.TRAIN, + scales=CONFIG.DATASET.SCALES, flip=True, ) + print(dataset) # DataLoader loader = torch.utils.data.DataLoader( @@ -119,10 +137,27 @@ def train(config, cuda): ) loader_iter = iter(loader) - # Model - model = setup_model(CONFIG.MODEL.INIT_MODEL, CONFIG.DATASET.N_CLASSES, train=True) + # Model check + print("Model:", CONFIG.MODEL.NAME) + assert ( + CONFIG.MODEL.NAME == "DeepLabV2_ResNet101_MSC" + ), 'Currently support only "DeepLabV2_ResNet101_MSC"' + + # Model setup + model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES) + state_dict = torch.load(CONFIG.MODEL.INIT_MODEL) + print(" Init:", CONFIG.MODEL.INIT_MODEL) + for m in model.base.state_dict().keys(): + if m not in state_dict.keys(): + print(" Skip init:", m) + model.base.load_state_dict(state_dict, strict=False) # to skip ASPP + model = nn.DataParallel(model) model.to(device) + # Loss definition + criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL) + criterion.to(device) + # Optimizer optimizer = torch.optim.SGD( # cf lr_mult and decay_mult in train.prototxt @@ -154,14 +189,21 @@ def train(config, cuda): power=CONFIG.SOLVER.POLY_POWER, ) - # Loss definition - criterion = nn.CrossEntropyLoss(ignore_index=CONFIG.DATASET.IGNORE_LABEL) - criterion.to(device) - - # TensorBoard logger - writer = SummaryWriter(CONFIG.SOLVER.LOG_DIR) + # Setup loss logger + writer = SummaryWriter(os.path.join(CONFIG.EXP.OUTPUT_DIR, "logs", CONFIG.EXP.ID)) average_loss = MovingAverageValueMeter(CONFIG.SOLVER.AVERAGE_LOSS) + # Path to save models + checkpoint_dir = os.path.join( + CONFIG.EXP.OUTPUT_DIR, + "models", + CONFIG.EXP.ID, + CONFIG.MODEL.NAME.lower(), + CONFIG.DATASET.SPLIT.TRAIN, + ) + makedirs(checkpoint_dir) + print("Checkpoint dst:", checkpoint_dir) + # Freeze the batch norm pre-trained on COCO model.train() model.module.base.freeze_bn() @@ -169,7 +211,6 @@ def train(config, cuda): for iteration in tqdm( range(1, CONFIG.SOLVER.ITER_MAX + 1), total=CONFIG.SOLVER.ITER_MAX, - leave=False, dynamic_ncols=True, ): @@ -179,26 +220,23 @@ def train(config, cuda): loss = 0 for _ in range(CONFIG.SOLVER.ITER_SIZE): try: - images, labels = next(loader_iter) + _, images, labels = next(loader_iter) except: loader_iter = iter(loader) - images, labels = next(loader_iter) - - images = images.to(device) - labels = labels.to(device) + _, images, labels = next(loader_iter) # Propagate forward - logits = model(images) + logits = model(images.to(device)) # Loss iter_loss = 0 for logit in logits: # Resize labels for {100%, 75%, 50%, Max} logits _, _, H, W = logit.shape - labels_ = resize_labels(labels, shape=(H, W)) - iter_loss += criterion(logit, labels_) + labels_ = resize_labels(labels, size=(H, W)) + iter_loss += criterion(logit, labels_.to(device)) - # Backpropagate (just compute gradients wrt the loss) + # Propagate backward (just compute gradients wrt the loss) iter_loss /= CONFIG.SOLVER.ITER_SIZE iter_loss.backward() @@ -216,9 +254,16 @@ def train(config, cuda): if iteration % CONFIG.SOLVER.ITER_TB == 0: writer.add_scalar("loss/train", average_loss.value()[0], iteration) for i, o in enumerate(optimizer.param_groups): - writer.add_scalar("lr/group{}".format(i), o["lr"], iteration) - if False: # This produces a large log file - for name, param in model.named_parameters(): + writer.add_scalar("lr/group_{}".format(i), o["lr"], iteration) + for i in range(torch.cuda.device_count()): + writer.add_scalar( + "gpu/device_{}/memory_cached".format(i), + torch.cuda.memory_cached(i) / 1024 ** 3, + iteration, + ) + + if False: + for name, param in model.module.base.named_parameters(): name = name.replace(".", "/") # Weight/gradient distribution writer.add_histogram(name, param, iteration, bins="auto") @@ -231,50 +276,51 @@ def train(config, cuda): if iteration % CONFIG.SOLVER.ITER_SAVE == 0: torch.save( model.module.state_dict(), - osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_{}.pth".format(iteration)), + os.path.join(checkpoint_dir, "checkpoint_{}.pth".format(iteration)), ) - # To verify progress separately - torch.save( - model.module.state_dict(), - osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_current.pth"), - ) - torch.save( - model.module.state_dict(), - osp.join(CONFIG.MODEL.SAVE_DIR, "checkpoint_final.pth"), + model.module.state_dict(), os.path.join(checkpoint_dir, "checkpoint_final.pth") ) @main.command() -@click.option("-c", "--config", type=str, required=True, help="yaml") -@click.option("-m", "--model-path", type=str, required=True, help="pth") -@click.option("--cuda/--no-cuda", default=True, help="Switch GPU/CPU") -@click.option("--crf", is_flag=True, help="CRF post processing") -def test(config, model_path, cuda, crf): - # Disable autograd globally - torch.set_grad_enabled(False) +@click.option( + "-c", + "--config-path", + type=click.File(), + required=True, + help="Dataset configuration file in YAML", +) +@click.option( + "-m", + "--model-path", + type=click.Path(exists=True), + required=True, + help="PyTorch model to be loaded", +) +@click.option( + "--cuda/--cpu", default=True, help="Enable CUDA if available [default: --cuda]" +) +def test(config_path, model_path, cuda): + """ + Evaluation on validation set + """ - # Setup + # Configuration + CONFIG = Dict(yaml.load(config_path)) device = get_device(cuda) - CONFIG = Dict(yaml.load(open(config))) - - # If the image size never change, - if CONFIG.DATASET.WARP_IMAGE: - # Auto-tune cuDNN - torch.backends.cudnn.benchmark = True + torch.set_grad_enabled(False) - # Dataset 10k or 164k + # Dataset dataset = get_dataset(CONFIG.DATASET.NAME)( root=CONFIG.DATASET.ROOT, split=CONFIG.DATASET.SPLIT.VAL, - base_size=CONFIG.IMAGE.SIZE.TEST, - crop_size=None, - mean=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), - warp=CONFIG.DATASET.WARP_IMAGE, - scale=None, - flip=False, + ignore_label=CONFIG.DATASET.IGNORE_LABEL, + mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), + augment=False, ) + print(dataset) # DataLoader loader = torch.utils.data.DataLoader( @@ -285,9 +331,106 @@ def test(config, model_path, cuda, crf): ) # Model - model = setup_model(model_path, CONFIG.DATASET.N_CLASSES, train=False) + model = eval(CONFIG.MODEL.NAME)(n_classes=CONFIG.DATASET.N_CLASSES) + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + model.load_state_dict(state_dict) + model = nn.DataParallel(model) + model.eval() model.to(device) + # Path to save logits + logit_dir = os.path.join( + CONFIG.EXP.OUTPUT_DIR, + "features", + CONFIG.EXP.ID, + CONFIG.MODEL.NAME.lower(), + CONFIG.DATASET.SPLIT.VAL, + "logit", + ) + makedirs(logit_dir) + print("Logit dst:", logit_dir) + + # Path to save scores + save_dir = os.path.join( + CONFIG.EXP.OUTPUT_DIR, + "scores", + CONFIG.EXP.ID, + CONFIG.MODEL.NAME.lower(), + CONFIG.DATASET.SPLIT.VAL, + ) + makedirs(save_dir) + save_path = os.path.join(save_dir, "scores.json") + print("Score dst:", save_path) + + preds, gts = [], [] + for image_ids, images, gt_labels in tqdm( + loader, total=len(loader), dynamic_ncols=True + ): + # Image + images = images.to(device) + + # Forward propagation + logits = model(images) + + # Save on disk for CRF post-processing + for image_id, logit in zip(image_ids, logits): + filename = os.path.join(logit_dir, image_id + ".npy") + np.save(filename, logit.cpu().numpy()) + + # Pixel-wise labeling + _, H, W = gt_labels.shape + logits = F.interpolate( + logits, size=(H, W), mode="bilinear", align_corners=False + ) + probs = F.softmax(logits, dim=1) + labels = torch.argmax(probs, dim=1) + + preds += list(labels.cpu().numpy()) + gts += list(gt_labels.numpy()) + + # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU + score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES) + + with open(save_path, "w") as f: + json.dump(score, f, indent=4, sort_keys=True) + + +@main.command() +@click.option( + "-c", + "--config-path", + type=click.File(), + required=True, + help="Dataset configuration file in YAML", +) +@click.option( + "-j", + "--n-jobs", + type=int, + default=multiprocessing.cpu_count(), + show_default=True, + help="Number of parallel jobs", +) +def crf(config_path, n_jobs): + """ + CRF post-processing on pre-computed logits + """ + + # Configuration + CONFIG = Dict(yaml.load(config_path)) + torch.set_grad_enabled(False) + print("# jobs:", n_jobs) + + # Dataset + dataset = get_dataset(CONFIG.DATASET.NAME)( + root=CONFIG.DATASET.ROOT, + split=CONFIG.DATASET.SPLIT.VAL, + ignore_label=CONFIG.DATASET.IGNORE_LABEL, + mean_bgr=(CONFIG.IMAGE.MEAN.B, CONFIG.IMAGE.MEAN.G, CONFIG.IMAGE.MEAN.R), + augment=False, + ) + print(dataset) + # CRF post-processor postprocessor = DenseCRF( iter_max=CONFIG.CRF.ITER_MAX, @@ -298,37 +441,61 @@ def test(config, model_path, cuda, crf): bi_w=CONFIG.CRF.BI_W, ) - preds, gts = [], [] - for images, labels in tqdm( - loader, total=len(loader), leave=False, dynamic_ncols=True - ): - # Image - images = images.to(device) - _, H, W = labels.shape + # Path to logit files + logit_dir = os.path.join( + CONFIG.EXP.OUTPUT_DIR, + "features", + CONFIG.EXP.ID, + CONFIG.MODEL.NAME.lower(), + CONFIG.DATASET.SPLIT.VAL, + "logit", + ) + print("Logit src:", logit_dir) + if not os.path.isdir(logit_dir): + print("Logit not found, run first: python main.py test [OPTIONS]") + quit() + + # Path to save scores + save_dir = os.path.join( + CONFIG.EXP.OUTPUT_DIR, + "scores", + CONFIG.EXP.ID, + CONFIG.MODEL.NAME.lower(), + CONFIG.DATASET.SPLIT.VAL, + ) + makedirs(save_dir) + save_path = os.path.join(save_dir, "scores_crf.json") + print("Score dst:", save_path) - # Forward propagation - logits = model(images) - logits = F.interpolate(logits, size=(H, W), mode="bilinear", align_corners=True) - probs = F.softmax(logits, dim=1) - probs = probs.data.cpu().numpy() - - # Postprocessing - if crf: - # images: (B,C,H,W) -> (B,H,W,C) - images = images.data.cpu().numpy().astype(np.uint8).transpose(0, 2, 3, 1) - probs = joblib.Parallel(n_jobs=-1)( - [joblib.delayed(postprocessor)(*pair) for pair in zip(images, probs)] - ) + # Process per sample + def process(i): + image_id, image, gt_label = dataset.__getitem__(i) + + filename = os.path.join(logit_dir, image_id + ".npy") + logit = np.load(filename) - labelmaps = np.argmax(probs, axis=1) + _, H, W = image.shape + logit = torch.FloatTensor(logit)[None, ...] + logit = F.interpolate(logit, size=(H, W), mode="bilinear", align_corners=False) + prob = F.softmax(logit, dim=1)[0].numpy() + + image = image.astype(np.uint8).transpose(1, 2, 0) + prob = postprocessor(image, prob) + label = np.argmax(prob, axis=0) + + return label, gt_label + + # CRF in multi-process + results = joblib.Parallel(n_jobs=n_jobs, verbose=10, pre_dispatch="all")( + [joblib.delayed(process)(i) for i in range(len(dataset))] + ) - preds += list(labelmaps) - gts += list(labels.numpy()) + preds, gts = zip(*results) # Pixel Accuracy, Mean Accuracy, Class IoU, Mean IoU, Freq Weighted IoU score = scores(gts, preds, n_class=CONFIG.DATASET.N_CLASSES) - with open(model_path.replace(".pth", ".json"), "w") as f: + with open(save_path, "w") as f: json.dump(score, f, indent=4, sort_keys=True) diff --git a/scripts/setup_caffemodels.sh b/scripts/setup_caffemodels.sh index aed2c05..06d95d8 100755 --- a/scripts/setup_caffemodels.sh +++ b/scripts/setup_caffemodels.sh @@ -1,21 +1,18 @@ #!/bin/bash # Download released caffemodels -wget -nc http://liangchiehchen.com/projects/released/deeplab_aspp_resnet101/prototxt_and_model.zip +wget -nc -P ./data http://liangchiehchen.com/projects/released/deeplab_aspp_resnet101/prototxt_and_model.zip -unzip -n prototxt_and_model.zip +unzip -n ./data/prototxt_and_model.zip -d ./data # Move caffemodels to data directories ## MSCOCO -mv init.caffemodel data/models/deeplab_resnet101/coco_init +mv ./data/init.caffemodel ./data/models/coco/deeplabv1_resnet101/caffemodel ## PASCAL VOC 2012 -mv train_iter_20000.caffemodel data/models/deeplab_resnet101/voc12 -mv train2_iter_20000.caffemodel data/models/deeplab_resnet101/voc12 - -# Remove *.prototxt -rm *.prototxt +mv ./data/train_iter_20000.caffemodel ./data/models/voc12/deeplabv2_resnet101_msc/caffemodel +mv ./data/train2_iter_20000.caffemodel ./data/models/voc12/deeplabv2_resnet101_msc/caffemodel echo =============================================================================================== echo "Next, try running script below:" -echo -e "\033[32m python convert.py --dataset coco_init \033[00m" +echo -e "\033[32m python convert.py --dataset coco \033[00m" echo =============================================================================================== \ No newline at end of file diff --git a/scripts/setup_voc12.sh b/scripts/setup_voc12.sh new file mode 100755 index 0000000..401b024 --- /dev/null +++ b/scripts/setup_voc12.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +DATASET_DIR=$1 + +# Download PASCAL VOC12 (2GB) +wget -nc -P $DATASET_DIR http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar + +# Extract images, annotations, etc. +tar -xvf $DATASET_DIR/VOCtrainval_11-May-2012.tar -C $DATASET_DIR \ No newline at end of file diff --git a/scripts/train_eval.sh b/scripts/train_eval.sh new file mode 100755 index 0000000..bde1b6d --- /dev/null +++ b/scripts/train_eval.sh @@ -0,0 +1,40 @@ +#!/bin/bash + +set -x + + +# 0. Choose from {voc12, cocostuff10k, cocostuff164k} +DATASET=voc12 + + +# 1. Train DeepLab v2 on ${DATASET} +python main.py train \ +-c configs/${DATASET}.yaml + +# Trained models are saved into +# data/models/${DATASET}/deeplabv2_resnet101_msc/*/checkpoint_5000.pth +# data/models/${DATASET}/deeplabv2_resnet101_msc/*/checkpoint_10000.pth +# data/models/${DATASET}/deeplabv2_resnet101_msc/*/checkpoint_15000.pth +# ... + +# Tensorboard logs are in data/logs. + + +# 2. Evaluate the model on val set +python main.py test \ +-c configs/${DATASET}.yaml \ +-m data/models/${DATASET}/deeplabv2_resnet101_msc/*/checkpoint_final.pth + +# Validation scores on 4 metrics are saved as +# data/scores/${DATASET}/deeplabv2_resnet101_msc/*/scores.json + +# Logits are saved into +# data/features/${DATASET}/deeplabv2_resnet101_msc/*/logit/... + + +# 3. Re-evaluate the model with CRF post-processing +python main.py crf \ +-c configs/${DATASET}.yaml + +# Scores with CRF on 4 metrics are saved as +# data/scores/${DATASET}/deeplabv2_resnet101_msc/*/scores_crf.json \ No newline at end of file