Pytorch Implementation of "Multi-Stage Partitioned Transformer for Efficient Image Deraining"
Images shot outdoors may capture rain, which can be troublesome to view the clean scene and significantly degrade their visual quality. Since rain scenes vary due to rain's density and wind directions, removing rain streaks from a rainy image is difficult. Thanks to the recent success of transformers in vision tasks, we propose a novel Multi-stage Partitioned Transformer (MPT) specifically for image deraining. MPT separates the attention module and multi-layer perceptron (MLP) to decompose the rain layer and the clean background from a rainy image. It utilizes the proposed global and local rain-aware attention mechanism to estimate the rain layer. In addition, we add atrous convolutions to MLP to aggregate contextualized background features to produce a clean background at multiple stages. MPT is a parameter-economical and computationally efficient deraining model that can effectively remove rain streaks from the input rainy image. Experimental results demonstrate that the proposed MPT performs favorably against state-of-the-art models in image deraining on benchmark datasets.
Performance comparison on the five test dataset in terms of deraining quality and model size (number of parameters in million). (click to expand)
- Rain100L: 200 training pairs and 100 test pairs [paper][dataset](2017 CVPR)
- Rain100H: 1254 training pairs and 100 test pairs [paper][dataset](2017 CVPR)
- Rain800: 700 training pairs and 98 test pairs (we drop 2 images from 100 test images as the images are too large)
- Rain1400(DDN-Data): 12600 training pairs and 1400 test pairs [paper][dataset] (2017 CVPR)
- Rain1200(DID-Data): 12000 training pairs and 1200 test pairs [paper][dataset] (2018 CVPR)
- For example on Rain100L: './data/Rain100L'
./data/Rain100L
+--- train
| +--- norain
| +--- rain
|
+--- test
| +--- norain
| +--- rain
- PSNR (Peak Signal-to-Noise Ratio) [paper] [matlab code]
- SSIM (Structural Similarity) [paper] [matlab code]
The implementation is modified from "RCDNet_simple"
git clone https://github.com/WENYICAT/MPT.git
cd MPT
conda create -n Stripformer python=3.8
source activate MPT
conda install pytorch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 cudatoolkit=11.4 -c pytorch -c conda-forge
pip install opencv-python tqdm ptflops glog scikit-image tensorboardX torchsummary
*taking training on Rain100L (200 training pairs) as an example, then unzip to ./data. the unzipped file is like:
data_path = r"./data/Rain100L/train/rain/rain-\*.png"
gt_path = r"./data/Rain100L/train/norain/norain-\*.png"
Note that if using other datasets, please change the file organization as this.
$ python -m torch.distributed.launch --nproc_per_node=2 --master_port=25911 train_main_syn_parallel.py --use_gpu="0,1" --batchSize=12 --resume=-1 --model_dir="./checkpoints/Rain100L/"
$ python -m torch.distributed.launch --nproc_per_node=1 --master_port=25911 test.py --use_gpu="0" --model_dir="./checkpoints/Rain100L/" --save_path="./results/Rain100L/"
The pre-trained are place it in ./weights/
, and modified the content is just like train_main_syn_parallel.py --resume=1