Skip to content

Unofficial PyTorch Implementation of Denoising Diffusion Probabilistic Models (DDPM)

License

Notifications You must be signed in to change notification settings

tqch/ddpm-torch

Repository files navigation

banner


PyTorch Implementation of Denoising Diffusion Probabilistic Models [paper] [official repo]

Features

  • Original DDPM1 training & sampling
  • DDIM2 sampler
  • Standard evaluation metrics
    • Fréchet Inception Distance3 (FID)
    • Precision & Recall4
  • Distributed Data Parallel5 (DDP) multi-GPU training

Requirements

  • torch>=1.12.0
  • torchvision>=1.13.0
  • scipy>=1.7.3

Code usage

Toy data Real-world data 
Training Training Generation Evaluation
Expand

usage: train_toy.py [-h] [--dataset {gaussian8,gaussian25,swissroll}]      
                    [--size SIZE] [--root ROOT] [--epochs EPOCHS] [--lr LR]
                    [--beta1 BETA1] [--beta2 BETA2] [--lr-warmup LR_WARMUP]
                    [--batch-size BATCH_SIZE] [--timesteps TIMESTEPS]      
                    [--beta-schedule {quad,linear,warmup10,warmup50,jsd}]  
                    [--beta-start BETA_START] [--beta-end BETA_END]        
                    [--model-mean-type {mean,x_0,eps}]                     
                    [--model-var-type {learned,fixed-small,fixed-large}]   
                    [--loss-type {kl,mse}] [--image-dir IMAGE_DIR]         
                    [--chkpt-dir CHKPT_DIR] [--chkpt-intv CHKPT_INTV]      
                    [--eval-intv EVAL_INTV] [--seed SEED] [--resume]       
                    [--device DEVICE] [--mid-features MID_FEATURES]        
                    [--num-temporal-layers NUM_TEMPORAL_LAYERS]            
optional arguments:                                                        
  -h, --help            show this help message and exit                    
  --dataset {gaussian8,gaussian25,swissroll}                               
  --size SIZE                                                              
  --root ROOT           root directory of datasets                         
  --epochs EPOCHS       total number of training epochs                    
  --lr LR               learning rate                                      
  --beta1 BETA1         beta_1 in Adam                                     
  --beta2 BETA2         beta_2 in Adam                                     
  --lr-warmup LR_WARMUP                                                    
                        number of warming-up epochs                        
  --batch-size BATCH_SIZE                                                  
  --timesteps TIMESTEPS                                                    
                        number of diffusion steps                          
  --beta-schedule {quad,linear,warmup10,warmup50,jsd}                      
  --beta-start BETA_START                                                  
  --beta-end BETA_END                                                      
  --model-mean-type {mean,x_0,eps}
  --model-var-type {learned,fixed-small,fixed-large}
  --loss-type {kl,mse}
  --image-dir IMAGE_DIR
  --chkpt-dir CHKPT_DIR
  --chkpt-intv CHKPT_INTV
                        frequency of saving a checkpoint
  --eval-intv EVAL_INTV
  --seed SEED           random seed
  --resume              to resume training from a checkpoint
  --device DEVICE
  --mid-features MID_FEATURES
  --num-temporal-layers NUM_TEMPORAL_LAYERS
                
Expand

usage: train.py [-h] [--dataset {mnist,cifar10,celeba,celebahq}] [--root ROOT]
                [--epochs EPOCHS] [--lr LR] [--beta1 BETA1] [--beta2 BETA2]   
                [--batch-size BATCH_SIZE] [--num-accum NUM_ACCUM]
                [--block-size BLOCK_SIZE] [--timesteps TIMESTEPS]
                [--beta-schedule {quad,linear,warmup10,warmup50,jsd}]
                [--beta-start BETA_START] [--beta-end BETA_END]
                [--model-mean-type {mean,x_0,eps}]
                [--model-var-type {learned,fixed-small,fixed-large}]
                [--loss-type {kl,mse}] [--num-workers NUM_WORKERS]
                [--train-device TRAIN_DEVICE] [--eval-device EVAL_DEVICE]
                [--image-dir IMAGE_DIR] [--image-intv IMAGE_INTV]
                [--num-save-images NUM_SAVE_IMAGES] [--config-dir CONFIG_DIR]
                [--chkpt-dir CHKPT_DIR] [--chkpt-name CHKPT_NAME]
                [--chkpt-intv CHKPT_INTV] [--seed SEED] [--resume]
                [--chkpt-path CHKPT_PATH] [--eval] [--use-ema]
                [--ema-decay EMA_DECAY] [--distributed] [--rigid-launch]
                [--num-gpus NUM_GPUS] [--dry-run]
optional arguments:
  -h, --help            show this help message and exit
  --dataset {mnist,cifar10,celeba,celebahq}
  --root ROOT           root directory of datasets
  --epochs EPOCHS       total number of training epochs
  --lr LR               learning rate
  --beta1 BETA1         beta_1 in Adam
  --beta2 BETA2         beta_2 in Adam
  --batch-size BATCH_SIZE
  --num-accum NUM_ACCUM
                        number of mini-batches before an update
  --block-size BLOCK_SIZE
                        block size used for pixel shuffle
  --timesteps TIMESTEPS
                        number of diffusion steps
  --beta-schedule {quad,linear,warmup10,warmup50,jsd}
  --beta-start BETA_START
  --beta-end BETA_END
  --model-mean-type {mean,x_0,eps}
  --model-var-type {learned,fixed-small,fixed-large}
  --loss-type {kl,mse}
  --chkpt-path CHKPT_PATH
                        checkpoint path used to resume training
  --eval                whether to evaluate fid during training
  --use-ema             whether to use exponential moving average
  --ema-decay EMA_DECAY
                        decay factor of ema
  --distributed         whether to use distributed training
  --rigid-launch        whether to use torch multiprocessing spawn
  --num-gpus NUM_GPUS   number of gpus for distributed training
  --dry-run             test-run till the first model update completes
            	
Expand

usage: generate.py [-h] [--dataset {mnist,cifar10,celeba,celebahq}]
                   [--batch-size BATCH_SIZE] [--total-size TOTAL_SIZE]
                   [--config-dir CONFIG_DIR] [--chkpt-dir CHKPT_DIR]
                   [--chkpt-path CHKPT_PATH] [--save-dir SAVE_DIR]
                   [--device DEVICE] [--use-ema] [--use-ddim] [--eta ETA]
                   [--skip-schedule SKIP_SCHEDULE] [--subseq-size SUBSEQ_SIZE]
                   [--suffix SUFFIX] [--max-workers MAX_WORKERS]
                   [--num-gpus NUM_GPUS]
optional arguments:
  -h, --help            show this help message and exit
  --dataset {mnist,cifar10,celeba,celebahq}
  --batch-size BATCH_SIZE
  --total-size TOTAL_SIZE
  --config-dir CONFIG_DIR
  --chkpt-dir CHKPT_DIR
  --chkpt-path CHKPT_PATH
  --save-dir SAVE_DIR
  --device DEVICE
  --use-ema
  --use-ddim
  --eta ETA
  --skip-schedule SKIP_SCHEDULE
  --subseq-size SUBSEQ_SIZE
  --suffix SUFFIX
  --max-workers MAX_WORKERS
  --num-gpus NUM_GPUS
			
Expand

usage: eval.py [-h] [--root ROOT] [--dataset {mnist,cifar10,celeba,celebahq}]
               [--model-device MODEL_DEVICE] [--eval-device EVAL_DEVICE]
               [--eval-batch-size EVAL_BATCH_SIZE]
               [--eval-total-size EVAL_TOTAL_SIZE] [--num-workers NUM_WORKERS]
               [--nhood-size NHOOD_SIZE] [--row-batch-size ROW_BATCH_SIZE]
               [--col-batch-size COL_BATCH_SIZE] [--device DEVICE]
               [--eval-dir EVAL_DIR] [--precomputed-dir PRECOMPUTED_DIR]
               [--metrics METRICS [METRICS ...]] [--seed SEED]
               [--folder-name FOLDER_NAME]
optional arguments:
  -h, --help            show this help message and exit
  --root ROOT
  --dataset {mnist,cifar10,celeba,celebahq}
  --model-device MODEL_DEVICE
  --eval-device EVAL_DEVICE
  --eval-batch-size EVAL_BATCH_SIZE
  --eval-total-size EVAL_TOTAL_SIZE
  --num-workers NUM_WORKERS
  --nhood-size NHOOD_SIZE
  --row-batch-size ROW_BATCH_SIZE
  --col-batch-size COL_BATCH_SIZE
  --device DEVICE
  --eval-dir EVAL_DIR
  --precomputed-dir PRECOMPUTED_DIR
  --metrics METRICS [METRICS ...]
  --seed SEED
  --folder-name FOLDER_NAME
			

Examples

  • Train a 25-Gaussian toy model with single GPU (device id: 0) for a total of 100 epochs

    python train_toy.py --dataset gaussian25 --device cuda:0 --epochs 100
  • Train CIFAR-10 model with single GPU (device id: 0) for a total of 50 epochs

    python train.py --dataset cifar10 --train-device cuda:0 --epochs 50

(You can always use dry-run for testing/tuning purpose.)

  • Train a CelebA model with an effective batch size of 64 x 2 x 4 = 128 on a four-card machine (single node) using shared file-system initialization

    python train.py --dataset celeba --num-accum 2 --num-gpus 4 --distributed --rigid-launch
    • num-accum 2: accumulate gradients for 2 mini-batches
    • num-gpus: number of GPU(s) to use for training, i.e. WORLD_SIZE of the process group
    • distributed: enable multi-gpu DDP training
    • rigid-run: use shared-file system initialization and torch.multiprocessing
  • (Recommended) Train a CelebA model with an effective batch-size of 64 x 1 x 2 = 128 using only two GPUs with torchrun Elastic Launch6 (TCP initialization)

    export CUDA_VISIBLE_DEVICES=0,1&&torchrun --standalone --nproc_per_node 2 --rdzv_backend c10d train.py --dataset celeba --distributed
  • Generate 50,000 samples (128 per mini-batch) of the checkpoint located at ./chkpts/cifar10/cifar10_2040.pt in parallel using 4 GPUs and DDIM sampler. The results are stored in ./images/eval/cifar10/cifar10_2040_ddim

     python generate.py --dataset cifar10 --chkpt-path ./chkpts/cifar10/cifar10_2040.pt --use-ddim --skip-schedule quadratic --subseq-size 100 --suffix _ddim --num-gpus 4
    • use-ddim: use DDIM
    • skip-schedule quadratic: use the quadratic schedule
    • subseq-size: length of sub-sequence, i.e. DDIM timesteps
    • suffix: suffix string to the dataset name in the folder name
    • num-gpus: number of GPU(s) to use for generation
  • Evaluate FID, Precision/Recall of generated samples in ./images/eval/cifar10_2040

     python eval.py --dataset cifar10 --sample-folder ./images/eval/cifar10/cifar10_2040

Experiment results

Toy data

Dataset 8 Gaussian 25 Gaussian Swiss Roll
True gaussian8_true_thumbnail gaussian25_true_thumbnail swissroll_true_thumbnail
Generated gaussian8_true_thumbnail gaussian25_true_thumbnail swissroll_true_thumbnail

Training process (animated)

Dataset 8 Gaussian 25 Gaussian Swiss Roll
Generated gaussian8_train_thumbnail gaussian25_train_thumbnail swissroll_train_thumbnail

Real-world data

Table of evaluated metrics

Dataset FID (↓) Precision (↑) Recall (↑) Training steps Training loss Checkpoint
CIFAR-10 9.162 0.691 0.473 46.8k 0.0295 -
|__ 5.778 0.697 0.516 93.6k 0.0293 -
|__ 4.083 0.705 0.539 187.2k 0.0291 -
|__ 3.31 0.722 0.551 421.2k 0.0284 -
|__ 3.188 0.739 0.544 795.6k 0.0277 [Link]
CelebA 4.806 0.772 0.484 189.8k 0.0155 -
|__ 3.797 0.764 0.511 379.7k 0.0152 -
|__ 2.995 0.760 0.540 949.2k 0.0148 [Link]
CelebA-HQ 19.742 0.683 0.256 56.2k 0.0105 -
|__ 11.971 0.705 0.364 224.6k 0.0097 -
|__ 8.851 0.768 0.376 393.1k 0.0098 -
|__ 8.91 0.800 0.357 561.6k 0.0097 [Link]

Dataset CIFAR-10 CelebA CelebA-HQ
Generated images cifar10_gen celeba_gen_thumbnail celebahq_gen_thumbnail

Denoising process (animated)

Dataset CIFAR-10 CelebA CelebA-HQ
Generated images cifar10_denoise celeba_denoise_thumbnail celebahq_denoise_thumbnail

Related repositories

References

Footnotes

  1. Ho, Jonathan, Ajay Jain, and Pieter Abbeel. "Denoising diffusion probabilistic models." Advances in Neural Information Processing Systems 33 (2020): 6840-6851.

  2. Song, Jiaming, Chenlin Meng, and Stefano Ermon. "Denoising Diffusion Implicit Models." International Conference on Learning Representations. 2020.

  3. Heusel, Martin, et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium." Advances in neural information processing systems 30 (2017).

  4. Kynkäänniemi, Tuomas, et al. "Improved precision and recall metric for assessing generative models." Advances in Neural Information Processing Systems 32 (2019).

  5. DistributedDataParallel - PyTorch 1.12 Documentation, https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html.

  6. Torchrun (Elastic Launch) - PyTorch 1.12 Documentation*, https://pytorch.org/docs/stable/elastic/run.html.