Skip to content

Latest commit

 

History

History
115 lines (92 loc) · 4.46 KB

README.MD

File metadata and controls

115 lines (92 loc) · 4.46 KB

Single-Domain Adaptation via Target-Aware Generative Augmentations

arXiv

This is the PyTorch implementation of Single-Domain Adaptation via Target-Aware Generative Augmentations.

Dataset

CelebA-HQ dataset was utilized for all our experiments.

Requirements

We tested our code with the following package versions

pytorch 1.10.2
cudatoolkit 10.2.89
ninja 1.10.2.3

Checkpoints

The checkpoints for the styleGANs can be downloaded from here

Folder Structure

├── SISTA_DA
│   ├── celeba_dataloader.py
│   ├── celebahq_dataloader.py
│   ├── data_list.py
│   ├── image_NRC_target.py
│   ├── image_source.py
│   ├── image_target_memo.py
│   ├── image_target.py
│   ├── loss.py
│   ├── network.py
│   ├── randconv.py
│   ├── DATA/
│   └── utils_memo.py
|   └── run.sh
|    
├── GenerativeAugmentations
│   ├── data_augmentation.ipynb
│   ├── e4e
│   ├── e4e_projection.py
│   ├── model.py
│   ├── models
│   │   ├── dlibshape_predictor_68_face_landmarks.dat
│   │   ├── e4e_ffhq_encode.pt
│   │   ├── psp_ffhq_toonify.pt
│   │   ├── stylegan2-color_sketch.pt
│   │   ├── stylegan2-ffhq-config-f.pt
│   │   ├── stylegan2-pencil_sketch.pt
│   │   ├── stylegan2-sketch.pt
│   │   ├── stylegan2-toon.pt
│   │   └── stylegan2-water_color.pt
│   ├── op
│   ├── README.MD
│   ├── transformations.py
│   └── util.py
├── README.MD

Generative Augmentation

The jupyter notebook data_augmentation.ipynb guides in generating the augementated images. The notebook illustrates three major steps

  1. StyleGAN fine tuning to target domain
  2. Target domain image generation
  3. Target-Aware augmentation

Classifier Training

Source model Training

To train the source model on a desired attribute

python image_source.py --attribute 'Smiling'

Domain adaptation on synthetic target domain images

python SISTA_DA/image_NRC_target.py --variant 'interp_concat' --attribute 'Smiling' --t 1

This command adapts a source trained model to target domain A (controlled by the --t flag) for the attribute 'Smiling' using 'SISTA_DA' protocol.

Domain adaptation on images generated using sampling strategies

Similarly for using the images generated by pruning-zero or pruning-rewind please run

Pruning Zero: python SISTA_DA/image_NRC_target.py --variant '' --prune True --attribute 'Smiling' --t 1

Pruning Rewind: python SISTA_DA/image_NRC_target.py --variant 'prune_rewind' --prune True --attribute 'Smiling' --t 1

Oracle

Adaptation using the unlabled target

python SISTA_DA/image_NRC_target.py --variant 'direct_target' --attribute 'Smiling' --t 1

Baselines

To generate the results for baseline performance of MEMO with two different varients MEMO (Augmix) and MEMO (RandConv). MEMO (Augmix) python SISTA_DA/image_target_memo --augmix True --attribute 'Smiling' --t 1 and

MEMO (RandConv) python SISTA_DA/image_target_memo --augmix False --attribute 'Smiling' --t 1

Citation

If you use this code or ideas from our paper, please cite our paper:

@article{subramanyam2022SISTA,
  title={Single-SISTA_DA Domain Adaptation via Target-Aware Generative Augmentations},
  author = {Subramanyam, Rakshith and Thopalli, Kowshik and Berman, Spring and Turaga, Pavan and Thiagarajan, Jayaraman J.},
  journal={arXiv preprint arxiv.2210.16692},
  year={2022}
}

Acknowledgments

This code builds upon the following codebases: StyleGAN2 by rosalinity, e4e, StyleGAN-NADA, NRC, MEMO and RandConv. We thank the authors of the respective works for publicly sharing their code. Please cite them when appropriate.