This repository holds several notebooks that implement GANs in JAX using the Flax Linen package. All models are trained on Colab using the MNIST dataset on TPUs, with parallelization enabled by default.
The original GAN with architecture and other tips from the GANs for representation learning paper.
Training GANs is a notoriously difficult process. Even by carefully selecting the model architecture, training can still suffer due to mode collapse. The authors of the Wasserstein GAN paper argue the biggest problem is the way that the vanilla GAN learns a distribution; by switching to minimizing the earth mover distance we can alleviate this problem.
This is the logical next step after the vanilla GAN. If we do have labels, we should utilize them somehow. The Conditional GAN, as the name implies, conditions the output of the generator on the labels in addition to the noise. The discriminator in turn receives both the generated/real images and the label for classification.
My personal favorite is the information-maximizing GAN. As the authors mention, because the info loss converges faster than the GAN loss, this addition basically comes for free. The result is a somewhat disentangled latent space where digits are easily separable. A great reference and interpretation of both the InfoGAN objective and the vanilla objective can be found here in Ferenc Huszár's blog.