This repository contains code for reproducing experiments in the paper Simple Guidance Mechanisms for Discrete Diffusion Models
We also share trained models on HuggingFace 🤗 and support intergration with these models. See the "Using HuggingFace Models" section below.
main.py
: Routines for training (language models and classifiers)noise_schedule.py
: Noise schedulesdiffusion.py
: Forward/reverse diffusion- Absorbing state / uniform noise diffusion
- AR
dataloader.py
: Dataloaders- For Discretized CIFAR10 and the Species10 datasets we use custom dataset classes defined in
custom_datasets/
- For Discretized CIFAR10 and the Species10 datasets we use custom dataset classes defined in
utils.py
: LR scheduler, logging,fsspec
handlingmodels/
: Denoising network architectures.configs/
: Config files for datasets/denoising networks/noise schedules/LR schedulesscripts/
: Shell scripts for training/evaluationguidance_eval/
: Guidance evaluation scripts
In diffusion.py
,
we define baseline and proposed decoding mechanisms for guidance.
These decoding schemes can be controlled via the hydra config with the guidance
field.
For example, to use the proposed D-CFG guidance mechanism,
set guidance=cfg
in the config file and optionally set the guidance.gamma
parameter to control the strength of the guidance signal.
The implemented decoding methods are as follows:
- AR (Baseline):
- Diffusion:
- Standard decoding (i.e., no guidance); set
guidance=null
- Classifier-free guidance (D-CFG); set
guidance=cfg
- Classifier-based guidance (D-CBG); set
guidance=cbg
- Classifier-based (baseline) method of NOS; set
guidance=nos
- Standard decoding (i.e., no guidance); set
The three modeling parameterizations we explore in this work are:
- Autoregressive (AR) Models
- Masked Diffusion Language Models (MDLM)
- Uniform Diffusion Language Models (UDLM)
The config
files can be used
to specify which of these parameterizations to use.
Below we detail which config parameters correspond to which model.
AR
diffusion="absorbing_state" # AR models can be thought of as a special case of abosrbing state diffusion models
parameterization="ar"
T=0 # N/A for AR models, this is a placeholder
time_conditioning=False # AR models are not conditioned on time
zero_recon_loss=False # N/A for this model
MDLM
diffusion="absorbing_state"
parameterization="subs" # See MDLM paper for details: https://arxiv.org/abs/2406.07524
T=0 # Indicates continuous-time, e.g. T --> infinity
time_conditioning=False # MDLM not conditioned on time
zero_recon_loss=False # N/A for this model
UDLM
diffusion="uniform"
parameterization="d3pm" # Indicates that we explicitly compute KL on posteriors
T=0 # Indicates continuous-time, e.g. T --> infinity
time_conditioning=True # UDLM is conditioned on time
zero_recon_loss=True # In continuous time, recon loss evaluates to zero
To get started, create a conda environment containing the required dependencies.
conda env create -f requirements.yaml
conda activate discdiff
Create the following directories to store saved models and slurm logs:
mkdir outputs
mkdir watch_folder
We rely on wandb
integration
to log experiments and eval curves.
Below, we describe the steps required for reproducing the experiments in the paper.
Throughout, the main entry point for running experiments is the main.py
script.
We also provide sample slurm
scripts for launching pre-training and evaluation experiments in the scrips/
directory.
To reproduce the language modeling results, please refer to the following shell scripts in the scripts/
directory:
- Species10:
train_ten_species_guidance.sh
- QM9:
train_qm9_no-guidance.sh
- CIFAR10:
train_cifar10_unet_guidance.sh
- text8:
train_text8.sh
- Amazon Polarity:
train_amazon_polarity.sh
- LM1B:
train_lm1b.sh
Each script contains a comment detailing the usage.
For example, to train either an AR,
MDLM, or UDLM model on the text8
dataset, use the following command:
cd scripts/
MODEL=<ar|mdlm|udlm>
sbatch \
--export=ALL,MODEL=${MODEL} \
--job-name=train_text8_${MODEL} \
train_text8.sh
For classifier-free guidance we require training models
that can condition on the class label
to model conditional distributions,
and we randomly mask out the signal,
replacing it with a dummy value of num_claseses + 1
, to simulate an unconditional model.
Refer to the shell scripts with the _guidance
suffix
to train these models for CIFAR10,
QM9, and Species10 datasets.
For QM9, we have two experiments,
one where we condition on the drug-likeness
(qed
)
of the molecules and another
where we condition on the ring counts (ring_count
).
For classifier-based guidance, we need to train a classifier on the noisy latent samples. Refer to the following shell scripts to train these classifiers:
- FUDGE (AR guidance):
train_qm9_fudge_classifier.sh
- D-CBG (diffusion guidance):
train_qm9_classifier.sh
An alternative classifier-based guidance mechanism to D-CBG is that of PPLM
(which was adapted for diffusion models in NOS).
To train these classifiers,
refer to the following shell script:
train_qm9_pplm_classifier.sh
(for both PPLM and NOS classifiers).
To evaluate guidance mechanisms, we load trained models (and classifiers, if applicable) and generate some number of samples for which we compute "quality" metrics (e.g., validity/novelty in the QM9 experiments) and control label satisfaction (e.g., mean value of novel generated molecules for the property of interest in the QM9 experiments).
The scripts for these evaluations can be found in the guidance_eval/
directory.
To run these evaluations, please refer to the following shell scripts:
- QM9:
eval_qm9_guidance.sh
- Species10:
eval_ten_species_guidance.sh
- For this dataset, we also evaluate the accuracy of a HyenaDNA classifier on correctly classifying generated sequences.
This model can be trained using
train_ten_species_eval_classifier.sh
.- To see how this trained evaluation classifier performs on the validation set of the original data use this notebook
eval_hyenadna_classifier.ipynb
.
- To see how this trained evaluation classifier performs on the validation set of the original data use this notebook
- For this dataset, we also evaluate the accuracy of a HyenaDNA classifier on correctly classifying generated sequences.
This model can be trained using
In the paper, we performed an extensive hyperparameter sweep for our proposed guidance mechanisms and for baselines. The shell scripts can be used to reproduce these experiments, e.g., for the D-CFG experiments on QM9:
export MODEL=<ar|mdlm|udlm>
export PROP=<qed|ring_count>
export GUIDANCE=cfg
for GAMMA in $(seq 1 5); do
sbatch \
--export=ALL,MODEL=${MODEL},PROP=${PROP},GUIDANCE=${GUIDANCE},GAMMA=${GAMMA} \
--job-name=eval_qm9_${GUIDANCE}_${PROP}_${MODEL}_GAMMA-${GAMMA} \
eval_qm9_guidance.sh
done
Once each evaluation run is complete,
a .csv
file
containing the results is saved in the run directory of the trained generative model.
We provide pre-trained models on HuggingFace 🤗:
- UDLM trained on LM1B: kuleshov-group/udlm-lm1b
- UDLM trained on QM9: kuleshov-group/udlm-qm9
- Note: this model was trained without guidance and can be used with classifier-free guidance.
Please see the README pages for these models on HuggingFace or our paper for more details about the training of these models.
To use these models, you can load them using the HuggingFace API, e.g.,
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("kuleshov-group/udlm-lm1b")
To use these models in our repository, set the following config
parameters:
backbone="hf_dit"
model="hf"
model.pretrained_model_name_or_path="kuleshov-group/udlm-lm1b" # or "kuleshov-group/udlm-qm9"
This repository was built off of MDLM, which in used SEDD. Our code implementation of D-CBG is adapted from Nisonoff et al.'s repo.
@article{
schiff2024discreteguidance,
title={Simple Guidance Mechanisms for Discrete Diffusion Models},
author={Schiff, Yair and Sahoo, Subham Sekhar and Phung, Hao and Wang, Guanghan and Boshar, Sam and Dalla-torre, Hugo and de Almeida, Bernardo P and Rush, Alexander and Pierrot, Thomas and Kuleshov, Volodymyr},
journal={arXiv preprint arXiv:2412.10193},
year={2024}
}