GAN implementation in PyTorch for MNIST
A GAN is comprised of two adversarial networks, a discriminator and a generator.
The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we'll need at least one hidden layer, and these hidden layers should have one key attribute:
All hidden layers will have a Leaky ReLu activation function applied to their outputs.
We should use a leaky ReLU to allow gradients to flow backwards through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.
We'll also take the approach of using a more numerically stable loss function on the outputs. Recall that we want the discriminator to output a value 0-1 indicating whether an image is real or fake.
We will ultimately use BCEWithLogitsLoss, which combines a
sigmoid
activation function and and binary cross entropy loss in one function.
The generator network will be almost exactly the same as the discriminator network, except that we're applying a tanh activation function to our output layer.
The generator has been found to perform the best with