Retrieval-Augmented Diffusion Models
Andreas Blattmann*,
Robin Rombach*,
Kaan Oktay,
Jonas Müller,
Björn Ommer
* equal contribution
- Code release
A suitable conda environment named rdm
can be created
and activated with:
conda env create -f environment.yaml
conda activate rdm
A general list of all available checkpoints is available in via our model zoo. If you use any of these models in your work, we are always happy to receive a citation.
Example inference notebooks can be found under scripts/demo_rdm.ipynb
and scripts/demo_rarm.ipynb
.
To be able to run a RDM/RARM conditioned on a text-prompt and additionally images retrieved from this prompt, you will also need to download the corresponding retrieval database. We provide two distinct databases extracted from the Openimages- and ImageNet datasets. Interchanging the databases results in different capabilities of the model as visualized below, although the learned weights are the same in both cases.
Download the retrieval-databases which contain the retrieval-datasets (Openimages (~18GB) and ImageNet (~1GB)) compressed into CLIP image embeddings:
bash scripts/download_databases.sh
Since CLIP offers a shared image/text feature space, and RDMs learn to cover a neighborhood of a given example during training, we can directly take a CLIP text embedding of a given prompt and condition on it. Run this mode via
python scripts/rdm_sample.py --caption "a happy bear reading a newspaper, oil on canvas" --gpu 0
or sample the model unconditionally
python scripts/rdm_sample.py --gpu 0
RARMs can be used in a similar manner to RDMs with
python scripts/rarm_sample.py --caption "a happy bear reading a newspaper, oil on canvas" --gpu 0
and sample the model unconditionally
python scripts/rarm_sample.py --gpu 0
The code will try to download (through Academic
Torrents) and prepare ImageNet the first time it
is used. However, since ImageNet is quite large, this requires a lot of disk
space and time. If you already have ImageNet on your disk, you can speed things
up by putting the data into
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
(which defaults to
~/.cache/autoencoders/data/ILSVRC2012_{split}/data/
), where {split}
is one
of train
/validation
. It should have the following structure:
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
├── n01440764
│ ├── n01440764_10026.JPEG
│ ├── n01440764_10027.JPEG
│ ├── ...
├── n01443537
│ ├── n01443537_10007.JPEG
│ ├── n01443537_10014.JPEG
│ ├── ...
├── ...
If you haven't extracted the data, you can also place
ILSVRC2012_img_train.tar
/ILSVRC2012_img_val.tar
(or symlinks to them) into
${XDG_CACHE}/autoencoders/data/ILSVRC2012_train/
/
${XDG_CACHE}/autoencoders/data/ILSVRC2012_validation/
, which will then be
extracted into above structure without downloading it again. Note that this
will only happen if neither a folder
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/data/
nor a file
${XDG_CACHE}/autoencoders/data/ILSVRC2012_{split}/.ready
exist. Remove them
if you want to force running the dataset preparation again.
Create a symlink data/ffhq
pointing to the images1024x1024
folder obtained
from the FFHQ repository.
- Download a database of your choice from above.
- Precompute nearest neighbors for a given query dataset:
- Create new config for QueryDataset under
configs/query_datasets
(see template for creation) - Start nn extraction with
python scripts/search_neighbors.py -rc <path to retrieval_config> -qc <path to query config> -s <query dataset split> -bs <batch size> -w <n_workers>
, e.e.python scripts/search_neighbors.py --rconfig configs/dataset_builder/openimages.yaml --qc configs/query_datasets/imagenet.yaml -s validation -n
- Create new config for QueryDataset under
Logs and checkpoints for trained models are saved to logs/<START_DATE_AND_TIME>_<config_spec>
.
For training autoencoders see the latent diffusion repository.
In configs/rdm/
and configs/rarm/
we provide configs for training RDMs on the FFHQ and ImageNet datasets and RARMs on ImageNet subsets.
Training can be started after having downloaded the appropriate files by running
CUDA_VISIBLE_DEVICES=<GPU_ID> python main.py --base configs/<{rdm,rarm}>/<config_spec>.yaml -t --gpus 0, --scale_lr false
Source Dataset | Size [GB] | Link |
---|---|---|
OpenImages | 18 | https://ommer-lab.com/files/rdm/database/OpenImages/ |
ImageNet | 1.2 | https://ommer-lab.com/files/rdm/database/ImageNet/1281200x512-part_1.npz |
For pretrained autoencoders see the latent diffusion repository.
Train-Datset | Train-Database | FID (Validation) | Precision | Recall | Link | Filesize [GB] |
---|---|---|---|---|---|---|
ImageNet | ImageNet | 5.32 | 0.74 | 0.51 | https://ommer-lab.com/files/rdm/models/rdm/imagenet_in-db/model.ckpt | 6.2 |
ImageNet | OpenImages | 12.28 | 0.69 | 0.55 | https://ommer-lab.com/files/rdm/models/rdm/imagenet/model.ckpt | 6.2 |
FFHQ | OpenImages | 1.92* | 0.93* | 0.35* | https://ommer-lab.com/files/rdm/models/rdm/ffhq/model.ckpt | 6.2 |
*: Evaluated using CLIP as feature extractor instead of Inception
Train-Datset | Train-Database | FID (Validation) | Precision | Recall | Link | Filesize [GB] |
---|---|---|---|---|---|---|
ImageNet-Dogs | OpenImages | 45.27 | 0.64 | 0.55 | https://ommer-lab.com/files/rdm/models/rarm/imagenet/dogs/model.ckpt | 2.9 |
ImageNet-Mammals | OpenImages | 49.92 | 0.56 | 0.58 | https://ommer-lab.com/files/rdm/models/rarm/imagenet/mammals/model.ckpt | 2.9 |
ImageNet-Animals | OpenImages | 49.03 | 0.55 | 0.58 | https://ommer-lab.com/files/rdm/models/rarm/imagenet/animals/model.ckpt | 2.9 |
All models listed above can jointly be downloaded and extracted via
bash scripts/download_models.sh
The models can then be found in models/{rdm,rarm}/<model_spec>
.
- Our codebase for the diffusion models builds heavily on OpenAI's ADM codebase and https://github.com/lucidrains/denoising-diffusion-pytorch. Thanks for open-sourcing!
@inproceedings{blattmann2022retrieval,
title = {Retrieval-Augmented Diffusion Models},
author = {Blattmann, Andreas and Rombach, Robin and Oktay, Kaan and M{\"u}ller, Jonas and Ommer, Bj{\"o}rn},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
doi = {10.48550/ARXIV.2204.11824},
url = {https://arxiv.org/abs/2204.11824},
}