Skip to content

Latest commit

 

History

History
25 lines (17 loc) · 874 Bytes

README.md

File metadata and controls

25 lines (17 loc) · 874 Bytes

MiniVAE

A Variational Auto-Encoder (VAE) implemented in Jax.

Quick tour

The main files are

Details

The main point of MiniVAE was to learn about VAEs. The resulting model is therefore quite simple, i.e. it is simply a stack of [de-]convolutional layers.

The posterior is approximated with a Gaussian and penalised with the KL-divergence, the reconstruction loss is simply the MSE.

During training, you can add a flag to log to a CSV file and then use the plot.py script to show the loss curve. For training and inference, both minivar/training.py and minivae/inference.py are executable.