This repository contains the implementation of TreeDiffusion from the Hierarchical Clustering for Conditional Diffusion in Image Generation paper by Jorge da Silva Goncalves, Laura Manduchi, Moritz Vandenhirtz, and Julia E. Vogt.
TreeDiffusion is a deep generative model that incorporates hierarchical clustering into the framework of Diffusion Models (Sohl-Dickstein et al., 2015; Ho et al., 2020; Song et al., 2021). It enables cluster-guided diffusion in unsupervised settings, as opposed to classifier-guided diffusion for labeled data, as introduced by \cite{dhariwal_diffusion_2021}. In our framework, the TreeVAE by Manduchi et. al. (2023) serves as the clustering model, encoding hierarchical clusters within its latent tree structure, where the leaves represent the clusters. A second-stage diffusion model conditioned on the TreeVAE leaves, utilizes these leaf representations to generate improved cluster-conditional samples. This is achieved using an adapted version of the DiffuseVAE framework by Pandey et. al. (2022). The result is a model that not only improves image quality but also ensures that the generated samples are representative of their respective clusters, addressing the limitations of previous VAE-based methods and advancing the state of clustering-based generative modeling.
The following figure illustrates the architecture and workflow of TreeDiffusion.
conda env create --name envname --file=treevae.yml
conda activate envname
Currently, the code supports the following datasets:
- MNIST (
"mnist"
) - FashionMNIST (
"fmnist"
) - CIFAR-10 (
"cifar10"
) - CelebA (
"celeba"
) - CUBICC (
"cubicc"
)
To train and evaluate the TreeVAE model, you can use the main.py
script. Follow the steps below to configure and run your training session.
The recommended approach is to modify the appropriate .yml
file in the configs
folder to set up your configurations. Once you've updated the configuration file, run the following command for the desired dataset:
python main.py --config_name "cifar10"
Given a trained and saved TreeVAE model on a given dataset, you can use the following command to generate the 10,000 reconstructions of the testset (mode = "vae_recons"
) or create 10,000 newly generated images (mode = "vae_samples"
) for each leaf in the tree. In the following command, "/20240307-195731_9e95e"
denotes the folder in the "models/experiments/{dataset}"
directory for the trained TreeVAE instance. However, the specific folder name will be different for each trained model instance.
python vae_generations.py --config_name "cifar10" --seed 1 --mode "vae_recons" --model_name "/20240307-195731_9e95e"
Given a trained and saved TreeVAE model, you can train the conditional second-stage DDPM for the TreeDiffusion model using the train_ddpm.py
script.
The recommended approach is to modify the appropriate .yml
file in the configs
folder to set up your configurations. In particular, make sure to update the paths, such as the directory to the folder of the pre-trained TreeVAE model on which the DDPM is conditioned (vae_chkpt_path
), the results directory (results_dir
). Once you've updated the configuration file, run the following command for the desired dataset:
python train_ddpm.py --config_name "cifar10"
To retrieve the reconstructions or samples from the diffusion model, further adjust the appropriate .yml
file in the configs
script with the corresponding paths to the trained DDPM model (chkpt_path
). You can use the following command to generate the 10,000 reconstructions of the testset for the most probable leaf (eval_mode = "recons"
) or for all leaves (eval_mode = "recons_all_leaves"
). Furhtermore, you can create 10,000 newly generated images for the most probable leaf (eval_mode = "sample"
) or for all leaves (eval_mode = "sample_all_leaves"
).
python test_ddpm.py --config_name $dataset --seed $seed --eval_mode "sample"
Below, we present some key results achieved using the models.
We compare the reconstruction quality of test set images between the TreeVAE and TreeDiffusion models. The TreeDiffusion model generates images of higher quality and with a distribution closer to the original data distribution.
The TreeVAE model can generate leaf-specific images, where each leaf represents a cluster. The image below showcases randomly generated images from a TreeDiffusion model trained on CIFAR-10.
To evaluate the quality of the generated images, we trained a classifier on the original dataset and used it to classify the newly generated images from our TreeDiffusion, analyzing each cluster separately. Ideally, most images from a cluster should be classified into one or very few classes from the original dataset, indicating "pure" or "unambiguous" generations. The normalized histograms below the leaf-specific generated images show the distribution of predicted classes for these new images.
For the classifier, we utilize a ResNet-50 model (He et al., 2016) trained on each dataset. The pre-trained classifiers are included in this repo under the "classifier_pretraining"
directory. If you want to retrain the models yourself, you can run the following command:
python classifier_pretraining/clf_training.py --data_name "cifar10"
The following image compares the TreeVAE with TreeDiffusion. The picture shows the image generations from each leaf of the TreeVAE and the cluster-conditional TreeDiffusion, all trained on CUBICC. Each row displays the generated images from all leaves of the specified model, starting with the same sample from the root. The corresponding leaf probabilities are shown at the top of the image and are by design the same for all models by design. The results show that the cluster-conditional TreeDiffusion model produces higher quality and more diverse images, better adapting to each cluster.