Authors: Mohamad Shahbazi, Zhiwu Huang, Danda P.Paudel, Ajad Chhatkuli, and Luc Van Gool
Paper: https://arxiv.org/pdf/2102.06696.pdf
Generative adversarial networks (GANs) have shown impressive results in both unconditional and conditional image generation. In recent literature, it is shown that pre-trained GANs, on a different dataset, can be transferred to improve the image generation from a small target data. The same, however, has not been well-studied in the case of conditional GANs (cGANs), which provides new opportunities for knowledge transfer compared to unconditional setup. In particular, the new classes may borrow knowledge from the related old classes, or share knowledge among themselves to improve the training. This motivates us to study the problem of efficient conditional GAN transfer with knowledge propagation across classes. To address this problem, we introduce a new GAN transfer method to explicitly propagate the knowledge from the old classes to the new classes. The key idea is to enforce the popularly used conditional batch normalization (BN) to learn the class-specific information of the new classes from that of the old classes, with implicit knowledge sharing among the new ones. This allows for an efficient knowledge propagation from the old classes to the new classes, with the BN parameters increasing linearly with the number of new classes. The extensive evaluation demonstrates the clear superiority of the proposed method over state-of-the-art competitors for efficient conditional GAN transfer tasks.
The code has been tested with Python 3.6 and Pytorch 1.0.1.
To install the dependencies:
pip install -r requirements.txt
The base directory--which can be different from the code directory--will contain the following sub-directories:
Path | Description |
---|---|
base_dir | The base directory containing your data, weights, logs, and generated samples |
├ data | Contains the training data |
├ weights | Contains the pretraining and training weights |
├ logs | Contains the training logs |
└ samples | Contains the saved images during evaluation |
Before training, you should create the sub-directories "data" and "weights". Then place the data and pretraining weights inside them accordingly. the sub-directories "logs" and "samples" will be created automatically during the training.
Data prepration is done based on the setup you are using. The code is adapted to two setups:
- ImageNet: Using BigGAN pretrained on ImageNet (Image size 128x128)
- CIFAR: Using BigGAN pretrained on CIFAR (Image size 32x32)
The multi-class data should be organized as a main folder called "ImageNet" for ImageNet setup or "cifar" for CIFAR setup. (regardless of the actual target data). The main folder should contain different sub-folders for different classes of the target data. The main folder then should be placed in "base_dir/data/"
base_dir/data/ImageNet/[class_1, ..., class_n]
base_dir/data/cifar/[class_1, ..., class_n]
Below you can find the datasets used in our experiments:
- Places356: 5 classes (Alley, Arch, Art Gallery, Auditorium, Ballroom) are selected and down-sampled to 500 images per class
- Animal Face: The class "Natural" is excluded in our experiments.
- CIFAR100
Then, to pre-calulate the inception momentums of the target dataset for FID calculation:
bash prepare_data
Note 1: The path to the main folder of the dataset should be modified in the data prepration scripts using "--data_root". Note2: For other setups other than ImageNet and CIFAR, you need to adapt the data-specific configuraion (type of the data loader, image size, root folder name, ...)
Before training, BigGAN's pretrained weights should be placed in the sub-directory "weights" of the base directory.
For BigGAN on ImageNet, you can use the pretrained weights provided by the BigGAN's authors. In this project, we have used the main checkpoint.
If you want to use other datasets (e.g. CIFAR10/100) as the pretraining dataset, you can first train the BigGAN on the desired dataset, and then, use the pretrained weights for cGANTransfer.
To launch the training on your target data:
bash train.sh
In the experimetns conducted in the paper, for the ImageNet backbone, we trained the model with the batch size of 256 using 8 V100 (16G) GPUs. For the CIFAR experiments, the model is trained with the batch size of 50 using one V100 (16G) GPU.
Some of the configuraions in scripts "train_ImageNet.sh" and "train_cifar.sh" need to be set according to your experiments. Some of the important parameters are:
Parameter | Description |
---|---|
--base_dir | The base directory containing your data, weights, logs, and generated samples |
--experiment_name | The name of the experiment you are going to run (will be generated automatically if nothing is passed) |
--batch_size | The size of the training batch |
--stage | The stage of the training ["BN", "FT"] (Details in Sec. 3.3 of this README file). |
--n_class | The number of target classes |
--n_pretrain_class | The number of pretrained classes |
--resume | If used, weights are loaded from the last checkpoint. Otherwise, pretrained weights are loaded |
--load_weights | The suffix of the checkpoint, in case of loading a specific checkpoint |
--res_l2_scale | The degrees of l2 regularization (details in section 4.3 of the paper) |
--comb_l1_scale | The degrees of l1 regularization (details in section 4.3 of the paper) |
Make sure to understand the configuration used in the scripts and their default values, by reading their descriptions in "utils.py"
The training is assumed to start from a pre-trained network, and can be done in two stages.
In the first stage (BN) , only the batch normalization (BN) parameters of the target classes are learned using transfering knowledge from pre-trained classes (see the paper for details). To train in this mode, set the value flag "--stage" to "BN" in the training script.
In addition to the first stage, an extra fine-tuning stage is also possible, in which the whole network (including the obtained BN prameters of the target data) is fine-tuned on target dataset. To enable fine-tuning, do the following steps:
- Set the value flag "--stage" to "FT" in the training script
- Use the value flag "--resume" in the training script
- Specify the suffix of the checkpoints for fine-tuning using "--load_weights" (E.g. "best1" for state_dict_best1.pth)
Note: To use the best model best on the evaluation saved during the training, pass the value "best#i" to the flag "--load_weights". #i is the number the saved best model.
The main evaluation metrics used in this project are Frechet Inception Distance (FID) and Kernel Maximum Mean Discrepancy (KMMD). Inception score (IS) is also included in the code.
FID and IS are calculated during the training using the Pytorch implementation provided by BigGAN.
For KMMD, the default implementation provided by GAN Metrics is used (Gaussian kernel with sigma=1).
KMMD calculation is deactivated by default to avoid memory errors. To activate the KMMD calculation, use the flag "--kmmd" in the script and reduce the number of generated samples used for evaluation using "--num_inception_images" (default=50000. We were able to avoid memory errors by setting it to 25000).
Note 1: Reducing the number of generated samples might make the FID calculation a bit less accurate.
Note 2: For very small datasets, FID score becomes unstable, since the inception statistics of the real data cannot be estimated accurately.
In addition to the evaluation during the training, you can evaluate your final model using the following commands:
bash sample.sh
Make sure to adjust the configuration in sample_ImageNet.sh / sample_cifar.sh according to your experiments.
Generated images from 5 classes of Places365 only by learning the BN parameters:
For any questions, suggestions, or issues with the code, please contact Mohamad at mshahbazi@vision.ee.ethz.ch.
If you find this project helpful, please consider citing us as follows:
@inproceedings{shahbazi2021cGANTransfer,
title = {Efficient Conditional GAN Transfer with Knowledge Propagation across Classes},
author = {Shahbazi, Mohamad and
Huang, Zhiwu and
Paudel, Danda Pani and
Chhatkuli, Ajad and
Van Gool, Luc},
year = {2021},
booktitle = {2021 {IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR} 2021}
}