Official code release for Consistency Regularization for VAEs, NeurIPS 2021.
If there are any questions, please email: samarth.sinha@mail.utoronto.ca. Github issues are not checked often, and may be missed.
git clone https://github.com/sinhasam/CRVAE.git
cd CRVAE
pip3 install -e .
Basic usage of the CR-VAE API, that can be added to your favorite VAE variant and training:
from CRVAE import CRVAE
... data loading
crvae = CRVAE(gamma=self.gamma, beta_1=self.beta_1, beta_2=self.beta_2)
loss, log = crvae.calculate_loss(model, images, augmented_images)
loss.backward()
... optimizer step
To use base hyperparameters, simply use
from CRVAE import CRVAE
... data loading
loss, logs = CRVAE().calculate_loss(model, images, augmented_images)
loss.backward()
... optimizer step
There are two simple VAE architechtures implemented but can be easily extended.
To use the architectures:
from CRVAE.models import CNNVAE, MLPVAE
cnn_model = CNNVAE(in_channels=3)
mlp_model = MLPVAE(latent_dim=32)
There are few image augmenatation policies implemented which can be used as:
from CRVAE.augmentations import get_augmentation
simple_augmentation = get_augmentation('simple')
large_augmentation_normalize = get_augmentation('large', normalize=True)
large_color_jitter_augmentation = get_augmentation('large_jitter', normalize=True)
# vertical flip might not be suitable for all datasets since it assumes data symmetry
large_vertical_flip = get_augmentation('large_vertical_flip', normalize=True)
...