Skip to content

GitHubPro18/gan-mnist-generator

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

5 Commits
 
 
 
 

Repository files navigation

GAN for MNIST Digits Generation

This project implements a Generative Adversarial Network (GAN) to generate MNIST-style handwritten digits using PyTorch Lightning. The code includes a custom GAN class and a data module MNISTDataModule for managing the MNIST dataset.

Features

  • Generator: Produces realistic MNIST-style images from latent vectors.
  • Discriminator: Differentiates between real and fake images.
  • Adversarial Loss: Uses binary cross-entropy to train both models.
  • PyTorch Lightning Integration: Simplifies training loops, logging, and hardware acceleration.
  • MNIST Data Module: Manages data loading and preprocessing for training, validation, and testing.

Setup and Dependencies

Prerequisites

  • Python 3.8+
  • PyTorch 1.12+
  • PyTorch Lightning
  • Torchvision
  • Numpy
  • Matplotlib

Installation

Install dependencies using pip:

pip install torch torchvision pytorch-lightning 

Code Overview

1. MNISTDataModule

This class handles the MNIST dataset, including downloading, transforming, and creating DataLoader objects for training, validation, and testing.

  • Normalizes MNIST images to mean=0.1307 and std=0.3081.
  • Splits the training set into training (55,000 samples) and validation (5,000 samples).

2. Generator

The generator model consists of:

  • Fully connected layers to project latent vectors.
  • Transposed convolutional layers for upscaling to 28x28 images.
  • Tanh activation for output scaling between -1 and 1.

3. Discriminator

The discriminator model includes:

  • Convolutional layers for feature extraction.
  • Fully connected layers for binary classification.
  • Sigmoid activation for final outputs.

4. GAN Class

This PyTorch Lightning module encapsulates:

  • The generator and discriminator.
  • Adversarial training logic.
  • Custom optimizers for both networks.

Training steps:

  1. Train the generator to produce realistic images.
  2. Train the discriminator to distinguish between real and fake images.

5. Training Script

To train the model, use:

trainer = pl.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=10,
)
trainer.fit(model, dm)

Logging

  • Generated images during training and validation are logged using TensorBoard.

Results

  • Training logs and generated images are saved in the lightning_logs directory.
  • Example images generated by the model during training can be visualized in TensorBoard.

References

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published