Code for the paper "Optimizing Hierarchical Image VAEs for Sample Quality". Hierarchical VAEs are an extension of regular VAEs which uses a sequence of learned normal distributions for the prior and posterior. Notable examples include NVAE and Very Deep VAE. We propose changes to these Hierarchical VAEs that help them generate better-looking samples, namely:
- controlling how much information is added in each latent variable layer
- Using a continuous Gaussian KL loss instead of a discrete (mixture of logistic distributions) loss.
- using a guided sampling strategy similar to classifier-free guidance in diffusion models
This release includes models for CIFAR-10 and ImageNet
First, clone our repository and change directory into it. Install the requirements via:
pip install -r requirements.txt
Then do :
pip install -e .
To sample from our pretrained models, you should first download them using the links from the Pretrained Models section below. In these examples, we assume you've downloaded the relevant model files into the directory "./models".
To create an
python scripts/evaluate.py --config "config/cifar10.py" --save_dir "./results" --checkpoint_path "./models/cifar10_ema_weights.pt" --n_samples 64 --nrow 8
To create a grid of ImageNet
python scripts/evaluate.py --config "config/imagenet32.py" --save_dir "./results" --checkpoint_path "./models/i32_ema_weights.pt" --superres_config "config/imagenet64.py" --superres_checkpoint_path "./models/i64_ema_weights.pt" --n_samples 36 --nrow 6 --mean_scale 1.5 --var_scale 5.0
To perform unguided sampling (ImageNet only), set --mean_scale 0.0
and --var_scale 0.0
. If not specified, the default guidance values are 2.0 and 4.0 respectively. For unconditional sampling, set --label 1000
. Alternatively, to generate images from a specific class, set label $LABEL_NUM
(see this website for the list of ImageNet class numbers.
To generate supperres_config
and --superres_checkpoint_path
arguments.
To create a .npz file of instead of a grid, e.g. for FID evaluation, add the argument --save_format "npz"
.
If you trained your own model with a different config, remember to set the correct model config via --config "config/my_new_config.py"
The instructions above are for sampling with PyTorch. Sampling with the JAX models is essentially the same, except:
- use
scripts/evaluate_jax.py
instead ofscripts/evaluate.py
- use the JAX checkpoints instead of the PyTorch ones (e.g.
cifar10_ema_weights_jax.p
instead ofcifar10_ema_weights.pt
)
Note: you will need to install JAX and flax, e.g. via pip install jax>=0.3.0 flax
The training configuration files are located within the config
folder - the config is divided into 4 parts: model architecture hyperparameters, dataset information, training details, and optimizer settings. We encourage you to look at the existing config files for more information, or if you want to change certain hyperparameters.
Training in PyTorch (CIFAR-10 dataset):
python scripts/train.py --config config/cifar10.py --global_dir "./training_results"
This will save the training checkpoints and logs to the folder ./training_results.
We include model checkpoints for our CIFAR-10, ImageNet
PyTorch checkpoints:
- CIFAR-10 : cifar10_ema_weights.pt
- ImageNet
$32^2$ : i32_ema_weights.pt - ImageNet
$32^2 \rightarrow 64^2$ : i64_ema_weights.pt
JAX checkpoints:
- CIFAR-10 : cifar10_ema_weights_jax.p
- ImageNet
$32^2$ : i32_ema_weights_jax.p - ImageNet
$32^2 \rightarrow 64^2$ : i64_ema_weights_jax.p
-
This research was supported by Google's TPU Research Cloud (TRC) Program, which provided Google Cloud TPUs for training the models.
-
Portions of our codebase were adapted from the Efficient-VDVAE, Progressive Distillation, and Guided Diffusion repositories - thanks for open-sourcing!
If you found this repository useful to your research, please consider citing our paper:
@misc{luhman2022optimizing,
title={Optimizing Hierarchical Image VAEs for Sample Quality},
author={Eric Luhman and Troy Luhman},
year={2022},
eprint={2210.10205},
archivePrefix={arXiv},
primaryClass={cs.LG}
}