Skip to content
/ gDDIM Public

[ICLR'23 Spotlight] gDDIM: analyze and accelerate general diffusion models, isotropic and non-isotropic

License

Notifications You must be signed in to change notification settings

qsh-zh/gDDIM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

9 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

gDDIM: Generalized denoising diffusion implicit models

Qinsheng Zhang·Molei Tao·Yongxin Chen

Paper


TLDR: We unbox the accelerating secret of DDIMs based on Dirac approximation and generalize it to general diffusion models, isotropic and non-isotropic.

gDDIM dirac

Setup

The codebase is only tested in docker environment.

Docker

Reproduce results

gDDIM on CLD

Training on cifar10

cd ${gDDIM_PROJECT_FOLDER}/cld_jax
wandb login ${WANDB_KEY}
python main.py --config configs/accr_dcifar10_config.py --mode train --workdir logs/accr_dcifar_nomixed --wandb --config.seed=8
  • I have randomly try seed=1,8,123. And seed=8 (checkpoint 15) gives the best FID while the lowest FIDs from other two are slightly high (around 2.30) in CIFAR10.

Eval on cifar10

  1. Download CIFAR stats to ${gDDIM_PROJECT_FOLDER}/cld_jax/assets/stats/.

  2. We provide pretrain model checkpoint.

the checkpoint has 2.2565 FID in my machine with 50 NFE

  1. User can evaluate FID via
cd ${gDDIM_PROJECT_FOLDER}/cld_jax
python main.py --config configs/accr_dcifar10_config.py --mode check --result_folder logs/fid --ckpt ${CLD_BEST_PATH} --config.sampling.deis_order=2 --config.sampling.nfe=50

Blur diffusion model

Training on cifar10

cd ${gDDIM_PROJECT_FOLDER}/blur_jax
wandb login ${WANDB_KEY}
python main.py --config configs/ddpm_deep_cifar10_config.py --mode train --workdir logs/ddpm_deep_sigma${sigma}_seed${seed} --wandb --config.model.sigma_blur_max=${sigma} --config.seed=${seed}"

Eval on cifar10

  1. Download CIFAR stats to ${gDDIM_PROJECT_FOLDER}/blur_jax/assets/stats/.

  2. We provide pretrain model checkpoint.

  3. User can evaluate FID via

cd ${gDDIM_PROJECT_FOLDER}/blur_jax
python main.py --config configs/accr_dcifar10_config.py --mode check --result_folder logs/fid --ckpt ${CLD_BEST_PATH} --config.sampling.deis_order=2 --config.sampling.nfe=50

Reference

@misc{zhang2022gddim,
      title={gDDIM: Generalized denoising diffusion implicit models}, 
      author={Qinsheng Zhang and Molei Tao and Yongxin Chen},
      year={2022},
      eprint={2206.05564},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

Related works

@inproceedings{song2020denoising,
  title={Denoising diffusion implicit models},
  author={Song, Jiaming and Meng, Chenlin and Ermon, Stefano},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2021}
}

@inproceedings{dockhorn2022score,
    title={Score-Based Generative Modeling with Critically-Damped Langevin Diffusion},
    author={Tim Dockhorn and Arash Vahdat and Karsten Kreis},
    booktitle={International Conference on Learning Representations (ICLR)},
    year={2022}
}

@article{hoogeboom2022blurring,
  title={Blurring diffusion models},
  author={Hoogeboom, Emiel and Salimans, Tim},
  journal={arXiv preprint arXiv:2209.05557},
  year={2022}
}

Miscellaneous

The project is built upon score-sde developed by Yang Song. Additionally, the sampling code has been adopted from DEIS.

About

[ICLR'23 Spotlight] gDDIM: analyze and accelerate general diffusion models, isotropic and non-isotropic

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published