Implementation and extension of Variational Diffusion Models (Kingma++21) in jax
and equinox
.
A Variational Diffusion Model (VDM) is essentially an infinitely deep hierarchical model with an analytic encoding model for each of the latent variables.
This design shares many similarities with a Variational Autoencoder (VAE) but unlike the VAE, the model is fit with three loss terms: the consistency (diffusion) loss, the reconstruction loss, and the prior KL-divergence.
Here training is implemented with the continuous-time depth consistency loss as opposed to a discretised SDE in the DDPM methods.
- Conditional likelihood modelling,
- exotic score-network architectures (more to be added),
- multi-device training and inference.
pip install variational-diffusion-models
python main.py
See examples.