This is the official implementation of ACFlow.
refer to requirements.txt
.
download CelebA
, CIFAR10
, MNIST
and Omniglot
to your local workspace. You might need to change the path for each dataset in datasets
folder accordingly.
MNIST and CIFAR10 can be downloaded by torchvision
. Links for CelebA and Omniglot are provided here. Please cite their work if you use this repo.
You can train your own model by the scripts provided below. Or you can download our pretrained weights form here.
- Train with Gaussian base likelihood
python scripts/train.py --cfg_file=./exp/celeba/rnvp/params.json
- Train with autoregressive likelihood
python scripts/train_tan.py --cfg_file=./exp/celeba/tan/params.json
- Compute log likelihood on testset and compute the PSNR and PRD scores using samples.
python scripts/test.py --cfg_file=./exp/celeba/rnvp/params.json
NOTE: you can run this script for multiple times with different random seed to get mean score and standard deviation.
- Compute joint likelihood p(x).
python scripts/test_joint.py --cfg_file=./exp/celeba/rnvp/params.json
- Sample from arbitrary conditional distribution p(x_u | x_o) for multiple imputation.
python scripts/sample.py --cfg_file=./exp/celeba/rnvp/params.json
- Sample the 'Best Guess' single imputation.
python scripts/sample_single.py --cfg_file=./exp/celeba/rnvp/params.json
- Sample from joint distribution p(x).
python scripts/sample_joint.py --cfg_file=./exp/celeba/rnvp/params.json
- Gibbs sampling
python scripts/gibbs_sampling.py --cfg_file=./exp/celeba/rnvp/params.json
Sample the upper and lower half condition on the remaining half.
similar commands can be run. Config files are provided in exp/mnist
folder.
similar commands can be run. Config files are provided in exp/omniglot
folder.
similar commands can be run. Config files are provided in exp/cifar
folder.
Code for evaluating FID and PRD are adapted from their public implementations. Please cite their work if you use this repo.