This project implements a Generative Adversarial Network (GAN) to generate MNIST-style handwritten digits using PyTorch Lightning. The code includes a custom GAN
class and a data module MNISTDataModule
for managing the MNIST dataset.
- Generator: Produces realistic MNIST-style images from latent vectors.
- Discriminator: Differentiates between real and fake images.
- Adversarial Loss: Uses binary cross-entropy to train both models.
- PyTorch Lightning Integration: Simplifies training loops, logging, and hardware acceleration.
- MNIST Data Module: Manages data loading and preprocessing for training, validation, and testing.
- Python 3.8+
- PyTorch 1.12+
- PyTorch Lightning
- Torchvision
- Numpy
- Matplotlib
Install dependencies using pip:
pip install torch torchvision pytorch-lightning
This class handles the MNIST dataset, including downloading, transforming, and creating DataLoader objects for training, validation, and testing.
- Normalizes MNIST images to mean=0.1307 and std=0.3081.
- Splits the training set into training (55,000 samples) and validation (5,000 samples).
The generator model consists of:
- Fully connected layers to project latent vectors.
- Transposed convolutional layers for upscaling to 28x28 images.
- Tanh activation for output scaling between -1 and 1.
The discriminator model includes:
- Convolutional layers for feature extraction.
- Fully connected layers for binary classification.
- Sigmoid activation for final outputs.
This PyTorch Lightning module encapsulates:
- The generator and discriminator.
- Adversarial training logic.
- Custom optimizers for both networks.
Training steps:
- Train the generator to produce realistic images.
- Train the discriminator to distinguish between real and fake images.
To train the model, use:
trainer = pl.Trainer(
accelerator="auto",
devices=1,
max_epochs=10,
)
trainer.fit(model, dm)
- Generated images during training and validation are logged using TensorBoard.
- Training logs and generated images are saved in the
lightning_logs
directory. - Example images generated by the model during training can be visualized in TensorBoard.