PyTorch Implementation of Denoising Diffusion Probabilistic Models [paper] [official repo]
- Original DDPM1 training & sampling
- DDIM2 sampler
- Standard evaluation metrics
- Distributed Data Parallel5 (DDP) multi-GPU training
- torch>=1.12.0
- torchvision>=1.13.0
- scipy>=1.7.3
Toy data | Real-world data | ||
---|---|---|---|
Training | Training | Generation | Evaluation |
Expand
|
Expand
|
Expand
|
Expand
|
Examples
-
Train a 25-Gaussian toy model with single GPU (device id: 0) for a total of 100 epochs
python train_toy.py --dataset gaussian25 --device cuda:0 --epochs 100
-
Train CIFAR-10 model with single GPU (device id: 0) for a total of 50 epochs
python train.py --dataset cifar10 --train-device cuda:0 --epochs 50
(You can always use dry-run
for testing/tuning purpose.)
-
Train a CelebA model with an effective batch size of 64 x 2 x 4 = 128 on a four-card machine (single node) using shared file-system initialization
python train.py --dataset celeba --num-accum 2 --num-gpus 4 --distributed --rigid-launch
num-accum 2
: accumulate gradients for 2 mini-batchesnum-gpus
: number of GPU(s) to use for training, i.e.WORLD_SIZE
of the process groupdistributed
: enable multi-gpu DDP trainingrigid-run
: use shared-file system initialization andtorch.multiprocessing
-
(Recommended) Train a CelebA model with an effective batch-size of 64 x 1 x 2 = 128 using only two GPUs with
torchrun
Elastic Launch6 (TCP initialization)export CUDA_VISIBLE_DEVICES=0,1&&torchrun --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --distributed
-
Generate 50,000 samples (128 per mini-batch) of the checkpoint located at
./chkpts/cifar10/cifar10_2040.pt
in parallel using 4 GPUs and DDIM sampler. The results are stored in./images/eval/cifar10/cifar10_2040_ddim
python generate.py --dataset cifar10 --chkpt-path ./chkpts/cifar10/cifar10_2040.pt --use-ddim --skip-schedule quadratic --subseq-size 100 --suffix _ddim --num-gpus 4
use-ddim
: use DDIMskip-schedule quadratic
: use the quadratic schedulesubseq-size
: length of sub-sequence, i.e. DDIM timestepssuffix
: suffix string to the dataset name in the folder namenum-gpus
: number of GPU(s) to use for generation
-
Evaluate FID, Precision/Recall of generated samples in
./images/eval/cifar10_2040
python eval.py --dataset cifar10 --sample-folder ./images/eval/cifar10/cifar10_2040
Dataset | 8 Gaussian | 25 Gaussian | Swiss Roll |
---|---|---|---|
True | |||
Generated |
Table of evaluated metrics
Dataset | FID (↓) | Precision (↑) | Recall (↑) | Training steps | Training loss | Checkpoint |
---|---|---|---|---|---|---|
CIFAR-10 | 9.162 | 0.691 | 0.473 | 46.8k | 0.0295 | - |
|__ | 5.778 | 0.697 | 0.516 | 93.6k | 0.0293 | - |
|__ | 4.083 | 0.705 | 0.539 | 187.2k | 0.0291 | - |
|__ | 3.31 | 0.722 | 0.551 | 421.2k | 0.0284 | - |
|__ | 3.188 | 0.739 | 0.544 | 795.6k | 0.0277 | [Link] |
CelebA | 4.806 | 0.772 | 0.484 | 189.8k | 0.0155 | - |
|__ | 3.797 | 0.764 | 0.511 | 379.7k | 0.0152 | - |
|__ | 2.995 | 0.760 | 0.540 | 949.2k | 0.0148 | [Link] |
CelebA-HQ | 19.742 | 0.683 | 0.256 | 56.2k | 0.0105 | - |
|__ | 11.971 | 0.705 | 0.364 | 224.6k | 0.0097 | - |
|__ | 8.851 | 0.768 | 0.376 | 393.1k | 0.0098 | - |
|__ | 8.91 | 0.800 | 0.357 | 561.6k | 0.0097 | [Link] |
Dataset | CIFAR-10 | CelebA | CelebA-HQ |
---|---|---|---|
Generated images |
- Simple Web App empowered by Streamlit: [tqch/diffusion-webapp]
- Classifier-Free Guidance: [tqch/v-diffusion-torch]
Footnotes
-
Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851. ↩
-
Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." International Conference on Learning Representations. 2020. ↩
-
Heusel, Martin, et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." Advances in neural information processing systems 30 (2017). ↩
-
Kynkäänniemi, Tuomas, et al. "Improved precision and recall metric for assessing generative models." Advances in Neural Information Processing Systems 32 (2019). ↩
-
DistributedDataParallel - PyTorch 1.12 Documentation, https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html. ↩
-
Torchrun (Elastic Launch) - PyTorch 1.12 Documentation*, https://pytorch.org/docs/stable/elastic/run.html. ↩