Skip to content

Code for Single-Shot Domain Adaptation via Target-Aware Generative Augmentations

Notifications You must be signed in to change notification settings

kowshikthopalli/SISTA

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

22 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Single-Domain Adaptation via Target-Aware Generative Augmentations

arXiv

This is the PyTorch implementation of Single-Domain Adaptation via Target-Aware Generative Augmentations.

Dataset

CelebA-HQ dataset was utilized for all our experiments.

Requirements

We tested our code with the following package versions

pytorch 1.10.2
cudatoolkit 10.2.89
ninja 1.10.2.3

Checkpoints

The checkpoints for the styleGANs can be downloaded from here

Folder Structure

├── SISTA_DA
│   ├── celeba_dataloader.py
│   ├── celebahq_dataloader.py
│   ├── data_list.py
│   ├── image_NRC_target.py
│   ├── image_source.py
│   ├── image_target_memo.py
│   ├── image_target.py
│   ├── loss.py
│   ├── network.py
│   ├── randconv.py
│   ├── DATA/
│   └── utils_memo.py
|   └── run.sh
|    
├── GenerativeAugmentations
│   ├── data_augmentation.ipynb
│   ├── e4e
│   ├── e4e_projection.py
│   ├── model.py
│   ├── models
│   │   ├── dlibshape_predictor_68_face_landmarks.dat
│   │   ├── e4e_ffhq_encode.pt
│   │   ├── psp_ffhq_toonify.pt
│   │   ├── stylegan2-color_sketch.pt
│   │   ├── stylegan2-ffhq-config-f.pt
│   │   ├── stylegan2-pencil_sketch.pt
│   │   ├── stylegan2-sketch.pt
│   │   ├── stylegan2-toon.pt
│   │   └── stylegan2-water_color.pt
│   ├── op
│   ├── README.MD
│   ├── transformations.py
│   └── util.py
├── README.MD

Generative Augmentation

The jupyter notebook data_augmentation.ipynb guides in generating the augementated images. The notebook illustrates three major steps

  1. StyleGAN fine tuning to target domain
  2. Target domain image generation
  3. Target-Aware augmentation

Classifier Training

Source model Training

To train the source model on a desired attribute

python image_source.py --attribute 'Smiling'

Domain adaptation on synthetic target domain images

python SISTA_DA/image_NRC_target.py --variant 'interp_concat' --attribute 'Smiling' --t 1

This command adapts a source trained model to target domain A (controlled by the --t flag) for the attribute 'Smiling' using 'SISTA_DA' protocol.

Domain adaptation on images generated using sampling strategies

Similarly for using the images generated by pruning-zero or pruning-rewind please run

Pruning Zero: python SISTA_DA/image_NRC_target.py --variant '' --prune True --attribute 'Smiling' --t 1

Pruning Rewind: python SISTA_DA/image_NRC_target.py --variant 'prune_rewind' --prune True --attribute 'Smiling' --t 1

Oracle

Adaptation using the unlabled target

python SISTA_DA/image_NRC_target.py --variant 'direct_target' --attribute 'Smiling' --t 1

Baselines

To generate the results for baseline performance of MEMO with two different varients MEMO (Augmix) and MEMO (RandConv). MEMO (Augmix) python SISTA_DA/image_target_memo --augmix True --attribute 'Smiling' --t 1 and

MEMO (RandConv) python SISTA_DA/image_target_memo --augmix False --attribute 'Smiling' --t 1

Citation

If you use this code or ideas from our paper, please cite our paper:

@article{subramanyam2022SISTA,
  title={Single-SISTA_DA Domain Adaptation via Target-Aware Generative Augmentations},
  author = {Subramanyam, Rakshith and Thopalli, Kowshik and Berman, Spring and Turaga, Pavan and Thiagarajan, Jayaraman J.},
  journal={arXiv preprint arxiv.2210.16692},
  year={2022}
}

Acknowledgments

This code builds upon the following codebases: StyleGAN2 by rosalinity, e4e, StyleGAN-NADA, NRC, MEMO and RandConv. We thank the authors of the respective works for publicly sharing their code. Please cite them when appropriate.

About

Code for Single-Shot Domain Adaptation via Target-Aware Generative Augmentations

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages