Skip to content
/ MiniVAE Public

A minimal JAX implementation of a Variational Auto-Encoder.

Notifications You must be signed in to change notification settings

nlsfnr/MiniVAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MiniVAE

A Variational Auto-Encoder (VAE) implemented in Jax.

Quick tour

The main files are

Details

The main point of MiniVAE was to learn about VAEs. The resulting model is therefore quite simple, i.e. it is simply a stack of [de-]convolutional layers.

The posterior is approximated with a Gaussian and penalised with the KL-divergence, the reconstruction loss is simply the MSE.

During training, you can add a flag to log to a CSV file and then use the plot.py script to show the loss curve. For training and inference, both minivar/training.py and minivae/inference.py are executable.

About

A minimal JAX implementation of a Variational Auto-Encoder.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published