In this project we use conditional GAN(DCGAN) models to learn to generate images using the CIFAR-10 dataset. This dataset includes 10 classes:
- airplane
- automobile
- bird
- cat
- deer
- dog
- frog
- horse
- ship
- truck
As part of our project we take into consideration efficiency and training time. This means that we are careful to implement compact models that can be trained on a single GPU for a relatively short amount of time(less than two days).
To set up our environment we use Anaconda. You can use one of two methods below.
-
First create a new environment using the following command:
conda create -n environment_name python=3.9
We use Python 3.9 in our testing but other versions might be compatible. We use PyTorch as our main machine learning library which can be installed using the following commands:
# To install the cpu only version conda install pytorch torchvision torchaudio cpuonly -c pytorch # To install the nvidia/cuda version conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
We used PyTorch 1.10 and Cuda 10.2 in our testing but other version might be compatible. The other necessary packages can be installed using
pip
from the requirements file:pip install -r requirements.txt
-
Alternatively, you can create a conda environment with the name gan(can be changed manually in the variable name inside the
.yml
file) and install all necessary packages using the definedenvironment.yml
file:conda env create --file environment.yml
This is set by default to install the gpu version of pytorch.
Both methods will also install the 'black' package for formatting code and 'mypy'* for static typing in Python. You can run both using the commands:
# you can run both commands for a file or directory like shown below
black file.py
mypy file.py
*Not fully implemented throughout the code.
To start a training instance run the src/train.py
program with the path to a valid(see configs/default.yaml
) .yaml
configuration file. If no such path is specified then the program will look for a default.yaml
file under the configs/
directory. If the configuration file is read correctly then a directory will be created to save the progress of the training using Tensorboard and .pt
files will be created to save the generator and discriminator. The first will be saved under experiments/config_name/current_time_key/
and the latter under models/config_name/current_time_key/gen_or_disc.pt
. Furthermore, we also save the configuration file used to initialize the training under the experiments/config_name/current_time_key/
, which comes in handy if small changes are made to the configuration before starting training. During training the loss values are tracked for both generator and discriminator on each epoch. The last condition to train the models properly is to create the statistics for the dataset in order to calculate the Frechet Inception distance(FID) which is explained in the next paragraph.
Testing will be carried out at specific intervals set in the configuration. During testing the test set is used to evaluate the loss for both generator and discriminator as well as the accuracy of the discriminator for real images and generated images. Some of the generated images are saved beside the real images for 'human observation' evaluation. During testing we also use Inception score(IS) and Frechet Inception distance(FID) to evaluate the generative model. These are calculated using the pytorch_gan_metrics
library(git hub), by calling the function get_inception_score_and_fid(gen_imgs, stats_path)
. Where gen_imgs
refers to generated images and stats_path
is the path of the precalculated statistics for the given dataset. To create the precalculated statistics first use src/save_cifar10_imgs.py
to create a folder of all test images of CFIAR-10 then use the command:
python -m pytorch_gan_metrics.calc_fid_stats --path dataset/cifar10_images --output dataset/cifar10_fid_stats.npz
This will create a .npz
file where the necessary data to calculate FID will be stored. The FID stats are created only using test data as this is a more fair way to compare the generator to images that have not been used during training. Using the training data would almost definitely result in a slightly better FID.
We have also implemented a separate program for testing,
test_gen.py gnerator_path.pt config_path.yaml num_samples num_images
which generates, saves and evaluates images. You need to pass as arguments the path to a generator model and its corresponding configuration file. There is also the optional argument of passing the number of samples which is set by default to 1000. To get a fair evaluation for FID and IS scores you need to use at least 10000 samples. You can also pass the number of images you want to save which will be equally divided into the 10 possible classes, the default value is set to 100.
To run an experiment you can either use one of the existing configuration under the configs/
folder or create your own. To help define a configuration you can use default.yaml
configuration as template, where all possible options for each parameter are written. The default configuration trains a general conditional GAN. The dcgan.yaml
configuration has been optimized for our task of interest, CIFAR10. The wgan.yaml
and wgan_gp.yaml
configuration both train a conditional Wasserstein GAN. The first applies clipping to the parameters of the discriminator whereas the second calculates a gradient penalty like proposed in this paper. Finally, the deep_dcgan.yaml
uses feature matching loss and architectures developed based on this work.
To run an experiment after setting up the environment, make also sure to have created the stats for the FID score. For this to work you also need to have the CIFAR10 dataset installed under dataset/cifar10_images
. To download the dataset and save it under the appropriate folder you can run python src/train.py
and stop the program after the dataset has finished downloading otherwise it will return an error during testing since the stats will not have been created yet.
# create folder with all test images under dataset/cifar10-images/
python save_cifar10_imgs.py
# create stats under dataset/cifar10_fid_stats.npz
python -m pytorch_gan_metrics.calc_fid_stats --path dataset/cifar10_images --output dataset/cifar10_fid_stats.npz
# run an experiment
python src/train.py configs/dcgan.py
# program output
Train model using configs/dcgan.yaml as configuration file.
Saving experiment under: /path_to_repo/gan_cifar/experiments/dcgan/Tue_Jan_18_12_11_26_2022
Saving experiment models under: /path_to_repo/gan_cifar/models/dcgan/Tue_Jan_18_12_11_26_2022
Files already downloaded and verified
Files already downloaded and verified
0%|██ | 25/1000 [00:00<?, ?it/s]
0%|████████████ | 55/391 [00:00<?, ?it/s]
# The progress over the epochs and the current training or testing will be shown
# If the dataset is not downloaded it will be under the dataset/ directory
# To check the experiment results use Tensorboard
tensorboard --logdir /path_to_repo/gan_cifar/experiment/dcgan/Tue_Jan_18_12_11_26_2022
# To test the generative model again use test_gen.py
python src/test_gen.py /path_to_repo/gan_cifar/models/dcgan/Tue_Jan_18_12_11_26_2022/gen.pt /path_to_repo/gan_cifar/experiments/dcgan/Tue_Jan_18_12_11_26_2022/dcgan.yaml 10000 100
Due to time and hardware constraints we have not yet tested on multiple seeds for each training/model configuration. Below are the best FID and IS scores* achieved while testing for different initializations of model, features‡ and normalization. Some of the instances have been trained with older versions of our code.
- The DCGAN(corresponds to
configs/dcgan.yaml
) uses a generator and discriminator with 4 hidden layers and hidden features specified by the number beside the name. - The WGAN(corresponds to
configd/wgan_gp.yaml
) uses a generator and discriminator with 4 hidden layers and hidden features specified by the number besides the name. We use gradient penalty as proposed here(instead of clipping, see also improvements on WGANs) and apply spectral norm as a default. - The Deeper_DCGAN(corresponds to
configs/deep_dcgan.yaml
) uses the generator and discriminator architectures used in improved techniques for training GANs(github). We also use feature matching as proposed in the paper.
Models & Training \ Metrics | FID(TEST 10k) | IS |
---|---|---|
DCGAN64† | 39.94 | 6.56±0.15 |
DCGAN128† | 37.99 | 6.59±0.15 |
DCGAN64_Batch_Dropout | 52.15 | 6.23±0.17 |
DCGAN96_Feature† | 37.85 | 7.00±0.19 |
DCGAN64_ExpSigmoid | 42.89 | 6.43±0.18 |
WGAN64_Instance† | 33.95 | 7.07±0.20 |
WGAN96_Instance† | 30.20 | 7.41±0.20 |
WGAN128_Instance† | 29.57 | 7.19±0.15 |
WGAN64_Layer | 49.21 | 5.90±0.13 |
WGAN64_Batch | 49.82 | 6.00±0.20 |
DEEPER_DCGAN16_Dropout | 45.03 | 5.87±0.15 |
DEEPER_DCGAN64_Feature† | 35.80 | 6.64±0.17 |
*For FID, lower is better. For IS, higher is better.
†Reproduced result over more then one instance.
‡The number of features is multiplied in each layer of the module.