Skip to content

Official implementation of the paper "Improving Sample Quality of Diffusion Models Using Self-Attention Guidance" (ICCV 2023)

License

Notifications You must be signed in to change notification settings

cvlab-kaist/Self-Attention-Guidance

 
 

Repository files navigation

Self-Attention Diffusion Guidance (ICCV`23)

image This is the implementation of the paper Improving Sample Quality of Diffusion Models Using Self-Attention Guidance by Hong et al. To gain insight from our exploration of the self-attention maps of diffusion models and for detailed explanations, please see our Paper and Project Page.

This repository is based on openai/guided-diffusion, and we modified feature extraction code from yandex-research/ddpm-segmentation to get the self-attention maps. The major implementation of our method is in ./guided_diffusion/gaussian_diffusion.py and ./guided_diffusion/unet.py.

All you need is to setup the environment, download existing models, and sample from them using our implementation. Neither further training nor a dataset is needed to apply self-attention guidance!

Updates

2023-08-14: This repository supports DDIM sampling with SAG.

2023-02-19: The Gradio Demo🤗 of SAG for Stable Diffusion is now available

2023-02-16: The Stable Diffusion pipeline of SAG is now available at huggingface/diffusers 🤗🧨

2023-02-01: The demo for Stable Diffusion is now available in Colab.

Environment

  • Python 3.8, PyTorch 1.11.0
  • 8 x NVIDIA RTX 3090 (set backend="gloo" in ./guided_diffusion/dist_util.py if P2P access is not available)
git clone https://github.com/KU-CVLAB/Self-Attention-Guidance
conda create -n sag python=3.8 anaconda
conda activate sag
conda install mpi4py
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install blobfile

Downloading Pretrained Diffusion Models (and Classifiers for CG)

Pretrained weights for ImageNet and LSUN can be downloaded from the repository. Download and place them in the ./models/ directory.

Sampling from Pretrained Diffusion Models

You can sample from pretrained diffusion models with self-attention guidance by changing SAG_FLAGS in the following commands. Note that sampling with --guide_scale 1.0 means sampling without self-attention guidance. Below are the 4 examples.

  • ImageNet 128x128 model (--classifier_guidance False deactivates classifier guidance):
SAMPLE_FLAGS="--batch_size 64 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --diffusion_steps 1000 --image_size 128 --learn_sigma True --noise_schedule linear --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.1 --guide_start 250 --sel_attn_block output --sel_attn_depth 8 --blur_sigma 3 --classifier_guidance True"
mpiexec -n $NUM_GPUS python classifier_sample.py $SAG_FLAGS $MODEL_FLAGS --classifier_scale 0.5 --classifier_path models/128x128_classifier.pt --model_path models/128x128_diffusion.pt $SAMPLE_FLAGS
  • ImageNet 256x256 model (--class_cond True for conditional models):
SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.5 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False"
mpiexec -n $NUM_GPUS python classifier_sample.py $SAG_FLAGS $MODEL_FLAGS --classifier_scale 0.0 --classifier_path models/256x256_classifier.pt --model_path models/256x256_diffusion_uncond.pt $SAMPLE_FLAGS
  • LSUN Cat model (respaced to 250 steps):
SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.05 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False"
mpiexec -n $NUM_GPUS python image_sample.py $SAG_FLAGS $MODEL_FLAGS --model_path models/lsun_cat.pt $SAMPLE_FLAGS
  • LSUN Horse model (respaced to 250 steps):
SAMPLE_FLAGS="--batch_size 16 --num_samples 10000 --timestep_respacing 250"
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond False --diffusion_steps 1000 --dropout 0.1 --image_size 256 --learn_sigma True --noise_schedule linear --num_channels 256 --num_head_channels 64 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
SAG_FLAGS="--guide_scale 1.01 --guide_start 250 --sel_attn_block output --sel_attn_depth 2 --blur_sigma 9 --classifier_guidance False"
mpiexec -n $NUM_GPUS python image_sample.py $SAG_FLAGS $MODEL_FLAGS --model_path models/lsun_horse.pt $SAMPLE_FLAGS
  • ImageNet 128x128 model (DDIM 25 steps):
MODEL_FLAGS="--attention_resolutions 32,16,8 --class_cond True --image_size 128 --learn_sigma True --num_channels 256 --num_heads 4 --num_res_blocks 2 --resblock_updown True --use_fp16 True --use_scale_shift_norm True"
CLASSIFIER_FLAGS="--image_size 128 --classifier_attention_resolutions 32,16,8 --classifier_depth 2 --classifier_width 128 --classifier_pool attention --classifier_resblock_updown True --classifier_use_scale_shift_norm True --classifier_scale 1.0 --classifier_use_fp16 True"
SAMPLE_FLAGS="--batch_size 8 --num_samples 8 --timestep_respacing ddim25 --use_ddim True"
SAG_FLAGS="--guide_scale 1.1 --guide_start 25 --sel_attn_block output --sel_attn_depth 8 --blur_sigma 3 --classifier_guidance True"
mpiexec -n $NUM_GPUS python classifier_sample.py \
    --model_path models/128x128_diffusion.pt \
    --classifier_path models/128x128_classifier.pt \
    $MODEL_FLAGS $CLASSIFIER_FLAGS $SAMPLE_FLAGS $SAG_FLAGS

Results

Compatibility of self-attention guidance (SAG) and classifier guidance (CG) on ImageNet 128x128 model:

SAG CG FID sFID Precision Recall
5.91 5.09 0.70 0.65
V 2.97 5.09 0.78 0.59
V 5.11 4.09 0.72 0.65
V V 2.58 4.35 0.79 0.59

Results on pretrained models:

Model # of steps Self-attention guidance scale FID sFID IS Precision Recall
ImageNet 256×256 (Uncond.) 250 0.0 (baseline)
0.5
0.8
26.21
20.31
20.08
6.35
5.09
5.77
39.70
45.30
45.56
0.61
0.66
0.68
0.63
0.61
0.59
ImageNet 256×256 (Cond.) 250 0.0 (baseline)
0.2
10.94
9.41
6.02
5.28
100.98
104.79
0.69
0.70
0.63
0.62
LSUN Cat 256×256 250 0.0 (baseline)
0.05
7.03
6.87
8.24
8.21
-
-
0.60
0.60
0.53
0.50
LSUN Horse 256×256 250 0.0 (baseline)
0.01
3.45
3.43
7.55
7.51
-
-
0.68
0.68
0.56
0.55

Cite as

@article{hong2022improving,
  title={Improving Sample Quality of Diffusion Models Using Self-Attention Guidance},
  author={Hong, Susung and Lee, Gyuseong and Jang, Wooseok and Kim, Seungryong},
  journal={arXiv preprint arXiv:2210.00939},
  year={2022}
}

About

Official implementation of the paper "Improving Sample Quality of Diffusion Models Using Self-Attention Guidance" (ICCV 2023)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Jupyter Notebook 84.7%
  • Python 15.3%