Skip to content

Official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (ICML 2022), and a reimplementation of Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models (ICLR 2022)

Notifications You must be signed in to change notification settings

baofff/Extended-Analytic-DPM

Folders and files

NameName
Last commit message
Last commit date

Latest commit

92aa768 · Jun 17, 2022

History

3 Commits
Jun 1, 2022
May 18, 2022
May 18, 2022
May 18, 2022
May 18, 2022
Jun 17, 2022
May 18, 2022
May 18, 2022
May 18, 2022

Repository files navigation

Extended Analytic-DPM

  • This is the official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (Accepted in ICML 2022). It extends Analytic-DPM under the following two settings:

    • The reverse process adpots complicated covariance matrices dependent to states, instead of simple scalar variances (which motivates the SN-DPM in the paper).
    • The score-based model has some error w.r.t. the exact score function (which motivates NPR-DPM in the paper).
  • This codebase also reimplements Analytic-DPM and reproduces its most results. The pretrained DPMs used in the Analytic-DPM paper are provided here, and have already been converted to a format that can be directly used for this codebase. We also additionally applies Analytic-DPM to score-based SDE.

  • Models and FID statistics are available here to reproduce results in this paper.

Dependencies

The codebase is based on pytorch. The dependencies are listed below.

pip install pytorch>=1.9.0 torchvision ml-collections ninja tensorboard

Basic usage

The basic usage for training is

python run_train.py --pretrained_path path/to/pretrained_dpm --dataset dataset --workspace path/to/working_directory $train_hparams
  • pretrained_path is the path to a pretrained diffusion probabilistic model (DPM). Here provide all pretrained DPMs used in this work.
  • dataset represents the training dataset, one of <cifar10|celeba64|imagenet64|lsun_bedroom>.
  • workspace is the place to put training outputs, e.g., logs and middle checkpoints.
  • train_hparams specify other hyperparameters used in training. Here lists train_hparams for all models.

The basic usage for evaluation is

python run_eval.py --pretrained_path path/to/evaluated_model --dataset dataset --workspace path/to/working_directory \
    --phase phase --sample_steps sample_steps --batch_size batch_size --method method $eval_hparams
  • pretrained_path is the path to a model to evaluate. Here provide all models evaluated in this work.
  • dataset represents the dataset the model is trained on, one of <cifar10|celeba64|imagenet64|lsun_bedroom>.
  • workspace is the place to put evaluation outputs, e.g., logs, samples and bpd values.
  • phase specifies running sampling or likelihood evaluation, one of <sample4test|nll4test>.
  • sample_steps is the number of steps to run during inference, the samller this value the faster the inference.
  • batch_size is the batch size, e.g., 500.
  • method specifies the type of the model, one of:
    • pred_eps the original DPM (i.e., a noise prediction model) with discrete timesteps
    • pred_eps_eps2_pretrained the SN-DPM with discrete timesteps
    • pred_eps_epsc_pretrained the NPR-DPM with discrete timesteps
    • pred_eps_ct2dt the original (i.e., a noise prediction model) with continuous timesteps (i.e., a score-based SDE)
    • pred_eps_eps2_pretrained_ct2dt the SN-DPM with continuous timesteps
    • pred_eps_epsc_pretrained_ct2dt the NPR-DPM with continuous timesteps
  • eval_hparams specifies other optional hyperparameters used in evaluation.
  • Here lists method and eval_hparams for NPR/SN-DPM and Analytic-DPM results in this paper.

Models and FID statistics

Here is the list of NPR-DPMs and SN-DPMs trained in this work. These models only train an additional prediction head in the last layer of a pretrained diffusion probabilistic model (DPM).

NPR/SN-DPM Pretrained DPM train_hparams
CIFAR10 (LS), NPR-DPM CIFAR10 (LS) "--method pred_eps_epsc_pretrained"
CIFAR10 (LS), SN-DPM CIFAR10 (LS) "--method pred_eps_eps2_pretrained"
CIFAR10 (CS), NPR-DPM CIFAR10 (CS) "--method pred_eps_epsc_pretrained --schedule cosine_1000"
CIFAR10 (CS), SN-DPM CIFAR10 (CS) "--method pred_eps_eps2_pretrained --schedule cosine_1000"
CIFAR10 (VP SDE), NPR-DPM CIFAR10 (VP SDE) "--method pred_eps_epsc_pretrained_ct --sde vpsde"
CIFAR10 (VP SDE), SN-DPM CIFAR10 (VP SDE) "--method pred_eps_eps2_pretrained_ct --sde vpsde"
CelebA 64x64, NPR-DPM CelebA 64x64 "--method pred_eps_epsc_pretrained"
CelebA 64x64, SN-DPM CelebA 64x64 "--method pred_eps_eps2_pretrained"
ImageNet 64x64, NPR-DPM ImageNet 64x64 "--method pred_eps_epsc_pretrained --mode simple"
ImageNet 64x64, SN-DPM ImageNet 64x64 "--method pred_eps_eps2_pretrained --mode complex"
LSUN Bedroom, NPR-DPM LSUN Bedroom "--method pred_eps_epsc_pretrained --mode simple"
LSUN Bedroom, SN-DPM LSUN Bedroom "--method pred_eps_eps2_pretrained --mode complex"

Here is the list of pretrained DPMs, collected from prior works. They are converted to a format that can be directly used for this codebase.

Pretrained DPM Expected mean squared norm (ms_eps)
(Used in Analytic-DPM)
From
CIFAR10 (LS) Link Analytic-DPM
CIFAR10 (CS) Link Analytic-DPM
CIFAR10 (VP SDE) Link score-sde
CelebA 64x64 Link DDIM
ImageNet 64x64 Link Improved DDPM
LSUN Bedroom Link pytorch_diffusion

This link provides precalculated FID statistics on CIFAR10, CelebA 64x64, ImageNet 64x64 and LSUN Bedroom. They are computed following Appendix F.2 in Analytic-DPM.

Evaluation Hyperparamters for NPR/SN-DPM and Analytic-DPM

Note: Analytic-DPM needs to precalculate the expected mean squared norm of noise prediction model (ms_eps), which is provided here. Specify their path by --ms_eps_path.

  • Sampling experiments on CIFAR10 (LS) or CelebA 64x64, Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2"
SN-DDPM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2"
Analytic-DDPM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0"
SN-DDIM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0"
Analytic-DDIM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path"
  • Sampling experiments on CIFAR10 (CS), Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000"
SN-DDPM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000"
Analytic-DDPM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --schedule cosine_1000 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000"
SN-DDIM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000"
Analytic-DDIM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule cosine_1000 --ms_eps_path ms_eps_path"
  • Sampling experiments on CIFAR10 (VP SDE), Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000"
SN-DDPM pred_eps_eps2_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000"
Analytic-DDPM pred_eps_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 2 --schedule vpsde_1000 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000"
SN-DDIM pred_eps_eps2_pretrained_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000"
Analytic-DDIM pred_eps_ct2dt "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --schedule vpsde_1000 --ms_eps_path ms_eps_path"
  • Sampling experiments on ImageNet 64x64, Table 1 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode simple"
SN-DDPM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --mode complex"
Analytic-DDPM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --ms_eps_path ms_eps_path"
NPR-DDIM pred_eps_epsc_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode simple"
SN-DDIM pred_eps_eps2_pretrained "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --mode complex"
Analytic-DDIM pred_eps "--rev_var_type optimal --clip_sigma_idx 1 --clip_pixel 1 --forward_type ddim --eta 0 --ms_eps_path ms_eps_path"
  • Likelihood experiments on CIFAR10 (LS) or CelebA 64x64, Table 3 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal"
Analytic-DDPM pred_eps "--rev_var_type optimal --ms_eps_path ms_eps_path"
  • Likelihood experiments on CIFAR10 (CS), Table 3 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --schedule cosine_1000"
Analytic-DDPM pred_eps "--rev_var_type optimal --schedule cosine_1000 --ms_eps_path ms_eps_path"
  • Likelihood experiments on ImageNet 64x64, Table 3 in the paper:
method eval_hparams
NPR-DDPM pred_eps_epsc_pretrained "--rev_var_type optimal --mode simple"
Analytic-DDPM pred_eps "--rev_var_type optimal --ms_eps_path ms_eps_path"

This implementation is based on / inspired by

About

Official implementation for Estimating the Optimal Covariance with Imperfect Mean in Diffusion Probabilistic Models (ICML 2022), and a reimplementation of Analytic-DPM: an Analytic Estimate of the Optimal Reverse Variance in Diffusion Probabilistic Models (ICLR 2022)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published