Skip to content
/ vdm Public

Implementation and extension of 'Variational Diffusion Models' (Kingma++21) in JAX and Equinox.

Notifications You must be signed in to change notification settings

homerjed/vdm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

vdm

Variational Diffusion Models

Implementation and extension of Variational Diffusion Models (Kingma++21) in jax and equinox.

Synopsis

A Variational Diffusion Model (VDM) is essentially an infinitely deep hierarchical model with an analytic encoding model for each of the latent variables.

This design shares many similarities with a Variational Autoencoder (VAE) but unlike the VAE, the model is fit with three loss terms: the consistency (diffusion) loss, the reconstruction loss, and the prior KL-divergence.

Here training is implemented with the continuous-time depth consistency loss as opposed to a discretised SDE in the DDPM methods.

Features

  • Conditional likelihood modelling,
  • exotic score-network architectures (more to be added),
  • multi-device training and inference.

Usage

pip install variational-diffusion-models 
python main.py

See examples.

alt text

CIFAR10

alt text

MNIST

alt text

About

Implementation and extension of 'Variational Diffusion Models' (Kingma++21) in JAX and Equinox.

Topics

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages