- Python 3.6
- PyTorch 1.0 (The version MUST >=1.0)
- NVIDIA GPU + CUDA cuDNN
- Install PyTorch.
- Install python requirements:
pip3 install -r requirements.txt
We use Places2, CelebA datasets. To train a model on the full dataset, download datasets from official websites.
After downloading, run scripts/flist.py
to generate train, test and validation set file lists for images or masks. To generate the training set file lists on Places2 dataset run:
mkdir datasets
python3 ./scripts/flist.py --path [path_to_places2_train_set] --output ./datasets/places2_train.flist
python3 ./scripts/flist.py --path [path_to_places2_validation_set] --output ./datasets/places2_val.flist
python3 ./scripts/flist.py --path [path_to_places2_test_set] --output ./datasets/mask_test.flist
We alse provide the function for generate the file lists of CelebA by using the official partition file. To generate the train,val,test dataset file lists on celeba dataset run:
python3 ./scripts/flist.py --path [path_to_celeba_dataset] --celeba [path_to_celeba_partition_file]
Our model is trained on the irregular mask dataset provided by Liu et al.. You can download publically available Irregular Mask Dataset from their website.
Alternatively, you can download Quick Draw Irregular Mask Dataset by Karim Iskakov which is combination of 50 million strokes drawn by human hand.
We additionally provide the code for dividing the mask maps into to 4 class according to proportion of their corrupted region.
Please note that our Places2 models are trained by multi GPUs, so please load using multiple GPUS. Specially, replace 3th row of config.yml in your checkpoint with code:GPU: [0,1]
To train the model, create a config.yaml
file similar to the example config file and copy it under your checkpoints directory. Read the configuration guide for more information on model configuration.
To train the ISNet, set the MODE:1
in corresponding config file config.yaml
, and run command:
python3 train.py --checkpoints [path to checkpoints]
To test the model, create a config.yaml
file similar to the example config file and copy it under your checkpoints directory. Read the configuration guide for more information on model configuration.
Set the MODE:2
in corresponding config file config.yaml
, you can evaluate the test dataset which list file path is recorded in config.yaml
by this command:
python3 test.py --path ./checkpoints/Celeba
You can evaluate the test dataset using progressive reconstruction algorithm by this command:
python3 progressiv_test.py --path ./checkpoints/Celeba
You can test the model for some specific images and masks, you need to provide an input images and a binary masks. Please make sure that the resolution of mask is same as images To test the model:
python3 test.py \
--checkpoints [path to checkpoints] \
--input [path to input directory or file] \
--mask [path to masks directory or mask file] \
--output [path to the output directory]
We provide some test examples under ./examples
directory. Please download the pre-trained models and run:
python3 test.py \
--checkpoints ./checkpoints/places2
--input ./examples/places2/images
--mask ./examples/places2/masks
--output ./examples/places2/results
This script will inpaint all images in ./examples/places2/images
using their corresponding masks in ./examples/places2/mask
directory and saves the results in ./checkpoints/places2/results
directory.
The model configuration is stored in a config.yaml
file under your checkpoints directory. The following tables provide the documentation for all the options available in the configuration file:
Licensed under a Creative Commons Attribution-NonCommercial 4.0 International.
Except where otherwise noted, this content is published under a CC BY-NC license, which means that you can copy, remix, transform and build upon the content as long as you do not use the material for commercial purposes and give appropriate credit and provide a link to the license.
Lots of logic code and readme file comes from Edge-Connect, we sincerely thanks their contribution.
If you use this code for your research, please cite our paper None:
Option | Description |
---|---|
MODE | 1: train, 2: test, 3: eval, 4: progressive_inpainting |
MASK | 1: random block, 2: half, 3: external, 4: (external, random block), 5: (external, random block, half) 6: one to one image mask |
SEED | random number generator seed |
GPU | list of gpu ids, comma separated list e.g. [0,1] |
DEBUG | 0: no debug, 1: debugging mode |
VERBOSE | 0: no verbose, 1: output detailed statistics in the output console |
Option | Description |
---|---|
SAVEIMG | if save the image in test phase 1: save, 0: not save |
TRAIN_FLIST | text file containing training set files list |
VAL_FLIST | text file containing validation set files list |
TEST_FLIST | text file containing test set files list |
TRAIN_MASK_FLIST | text file containing training set masks files list (only with MASK=3, 4, 5) |
VAL_MASK_FLIST | text file containing validation set masks files list (only with MASK=3, 4, 5) |
TEST_MASK_FLIST | text file containing test set masks files list (only with MASK=3, 4, 5) |
Option | Default | Description |
---|---|---|
BLOCKS | 4 | set the number residual blocks in each stage |
LR | 0.0001 | learning rate |
D2G_LR | 0.1 | discriminator/generator learning rate ratio |
BETA1 | 0.0 | adam optimizer beta1 |
BETA2 | 0.9 | adam optimizer beta2 |
BATCH_SIZE | 8 | input batch size |
INPUT_SIZE | 256 | input image size for training. (0 for original size) |
MAX_ITERS | 2e6 | maximum number of iterations to train the model |
MAX_STEPS: | 5000 | maximum number of each epoch |
MAX_EPOCHES: | 100 | maximum number of epoches 100 |
L1_LOSS_WEIGHT | 1 | l1 loss weight |
FM_LOSS_WEIGHT | 10 | feature-matching loss weight |
STYLE_LOSS_WEIGHT | 1 | style loss weight |
CONTENT_LOSS_WEIGHT | 1 | perceptual loss weight |
INPAINT_ADV_LOSS_WEIGHT | 0.01 | adversarial loss weight |
GAN_LOSS | nsgan | nsgan: non-saturating gan, lsgan: least squares GAN, hinge: hinge loss GAN |
GAN_POOL_SIZE | 0 | fake images pool size |
SAVE_INTERVAL | 1000 | how many iterations to wait before saving model (0: never) |
EVAL_INTERVAL | 0 | how many iterations to wait before evaluating the model (0: never) |
SAMPLE_INTERVAL | 1000 | how many iterations to wait before saving sample (0: never) |
SAMPLE_SIZE | 12 | number of images to sample on each samling interval |
EVAL_INTERVAL | 3 | How many INTERVAL sample while valuation (0: never 36000 in places) |
LOG_INTERVAL | 10 | how many iterations to wait before logging training status (0: never) |
TEST_INTERVAL | 3657 | how many interval numbers to test |