Skip to content

kuleshov-group/discrete-diffusion-guidance

Repository files navigation

Simple Guidance Mechanisms for Discrete Diffusion Models

arXiv deploy deploy

graphical abstract

This repository contains code for reproducing experiments in the paper Simple Guidance Mechanisms for Discrete Diffusion Models

We also share trained models on HuggingFace 🤗 and support intergration with these models. See the "Using HuggingFace Models" section below.

Code Organization

  1. main.py: Routines for training (language models and classifiers)
  2. noise_schedule.py: Noise schedules
  3. diffusion.py: Forward/reverse diffusion
    • Absorbing state / uniform noise diffusion
    • AR
  4. dataloader.py: Dataloaders
    • For Discretized CIFAR10 and the Species10 datasets we use custom dataset classes defined in custom_datasets/
  5. utils.py: LR scheduler, logging, fsspec handling
  6. models/: Denoising network architectures.
  7. configs/: Config files for datasets/denoising networks/noise schedules/LR schedules
  8. scripts/: Shell scripts for training/evaluation
  9. guidance_eval/: Guidance evaluation scripts

Implemented Decoding Mechanisms

In diffusion.py, we define baseline and proposed decoding mechanisms for guidance. These decoding schemes can be controlled via the hydra config with the guidance field. For example, to use the proposed D-CFG guidance mechanism, set guidance=cfg in the config file and optionally set the guidance.gamma parameter to control the strength of the guidance signal.

The implemented decoding methods are as follows:

  • AR (Baseline):
    • Standard decoding (i.e., no-guidance); set guidance=null
    • Classifier-free guidance (D-CFG); set guidance=cfg
    • Classifier-based guidance using FUDGE (set guidance=fudge) and using PPLM (set guidance=pplm)
  • Diffusion:
    • Standard decoding (i.e., no guidance); set guidance=null
    • Classifier-free guidance (D-CFG); set guidance=cfg
    • Classifier-based guidance (D-CBG); set guidance=cbg
    • Classifier-based (baseline) method of NOS; set guidance=nos

Implemented Generative Models

The three modeling parameterizations we explore in this work are:

  1. Autoregressive (AR) Models
  2. Masked Diffusion Language Models (MDLM)
  3. Uniform Diffusion Language Models (UDLM)

The config files can be used to specify which of these parameterizations to use. Below we detail which config parameters correspond to which model.

AR

diffusion="absorbing_state"  # AR models can be thought of as a special case of abosrbing state diffusion models
parameterization="ar"
T=0  # N/A for AR models, this is a placeholder
time_conditioning=False  # AR models are not conditioned on time
zero_recon_loss=False  # N/A for this model

MDLM

diffusion="absorbing_state"
parameterization="subs"  # See MDLM paper for details: https://arxiv.org/abs/2406.07524
T=0  # Indicates continuous-time, e.g. T --> infinity
time_conditioning=False  # MDLM not conditioned on time
zero_recon_loss=False  # N/A for this model

UDLM

diffusion="uniform"
parameterization="d3pm"  # Indicates that we explicitly compute KL on posteriors
T=0  # Indicates continuous-time, e.g. T --> infinity
time_conditioning=True  # UDLM is conditioned on time
zero_recon_loss=True  # In continuous time, recon loss evaluates to zero

Getting started in this repository

To get started, create a conda environment containing the required dependencies.

conda env create -f requirements.yaml
conda activate discdiff

Create the following directories to store saved models and slurm logs:

mkdir outputs
mkdir watch_folder

We rely on wandb integration to log experiments and eval curves.

Reproducing Experiments

Below, we describe the steps required for reproducing the experiments in the paper. Throughout, the main entry point for running experiments is the main.py script. We also provide sample slurm scripts for launching pre-training and evaluation experiments in the scrips/ directory.

Language Modeling Experiments

To reproduce the language modeling results, please refer to the following shell scripts in the scripts/ directory:

Each script contains a comment detailing the usage. For example, to train either an AR, MDLM, or UDLM model on the text8 dataset, use the following command:

cd scripts/
MODEL=<ar|mdlm|udlm>
sbatch \
  --export=ALL,MODEL=${MODEL} \
  --job-name=train_text8_${MODEL} \
  train_text8.sh

Guidance Training

Classifier-Free

For classifier-free guidance we require training models that can condition on the class label to model conditional distributions, and we randomly mask out the signal, replacing it with a dummy value of num_claseses + 1, to simulate an unconditional model. Refer to the shell scripts with the _guidance suffix to train these models for CIFAR10, QM9, and Species10 datasets. For QM9, we have two experiments, one where we condition on the drug-likeness (qed) of the molecules and another where we condition on the ring counts (ring_count).

Classifier-Based

For classifier-based guidance, we need to train a classifier on the noisy latent samples. Refer to the following shell scripts to train these classifiers:

PPLM / NOS baselines

An alternative classifier-based guidance mechanism to D-CBG is that of PPLM (which was adapted for diffusion models in NOS). To train these classifiers, refer to the following shell script: train_qm9_pplm_classifier.sh (for both PPLM and NOS classifiers).

Guidance Evaluation

To evaluate guidance mechanisms, we load trained models (and classifiers, if applicable) and generate some number of samples for which we compute "quality" metrics (e.g., validity/novelty in the QM9 experiments) and control label satisfaction (e.g., mean value of novel generated molecules for the property of interest in the QM9 experiments).

The scripts for these evaluations can be found in the guidance_eval/ directory. To run these evaluations, please refer to the following shell scripts:

In the paper, we performed an extensive hyperparameter sweep for our proposed guidance mechanisms and for baselines. The shell scripts can be used to reproduce these experiments, e.g., for the D-CFG experiments on QM9:

export MODEL=<ar|mdlm|udlm>
export PROP=<qed|ring_count>
export GUIDANCE=cfg
for GAMMA in $(seq 1 5); do
    sbatch \
      --export=ALL,MODEL=${MODEL},PROP=${PROP},GUIDANCE=${GUIDANCE},GAMMA=${GAMMA} \
      --job-name=eval_qm9_${GUIDANCE}_${PROP}_${MODEL}_GAMMA-${GAMMA} \
      eval_qm9_guidance.sh
done

Once each evaluation run is complete, a .csv file containing the results is saved in the run directory of the trained generative model.

Using HuggingFace Models

We provide pre-trained models on HuggingFace 🤗:

Please see the README pages for these models on HuggingFace or our paper for more details about the training of these models.

To use these models, you can load them using the HuggingFace API, e.g.,

from transformers import AutoModelForMaskedLM

model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/udlm-lm1b")

To use these models in our repository, set the following config parameters:

backbone="hf_dit"
model="hf"
model.pretrained_model_name_or_path="kuleshov-group/udlm-lm1b"  # or "kuleshov-group/udlm-qm9"

Acknowledgements

This repository was built off of MDLM, which in used SEDD. Our code implementation of D-CBG is adapted from Nisonoff et al.'s repo.

Citation

@article{
    schiff2024discreteguidance,
    title={Simple Guidance Mechanisms for Discrete Diffusion Models},
    author={Schiff, Yair and Sahoo, Subham Sekhar and Phung, Hao and Wang, Guanghan and Boshar, Sam and Dalla-torre, Hugo and de Almeida, Bernardo P and Rush, Alexander and Pierrot, Thomas and Kuleshov, Volodymyr},
    journal={arXiv preprint arXiv:2412.10193},
    year={2024}
}

About

Simple Guidance Mechanisms for Discrete Diffusion Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published