This repository contains the official implementation of paper A Reparameterized Discrete Diffusion Model for Text Generation.
The codebase is implemented with FairSeq. To install the dependencies, run (recommended in a virtual environment) the following commands:
pip install -r requirements.txt
# install our package of discrete diffusion models
pip install -e discrete_diffusion
# install our fork of fairseq
cd fairseq
python3 setup.py build develop
cd ..
Note The environment is tested with Python 3.8.10, PyTorch 1.10.0/1.12.0, and CUDA 11.3. Also note our fork of fairseq modifies several files in the original codebase; using more recent versions of fairseq might lead to unexpected dependency conflicts.
We implement discrete diffusion models in a self-contained library discrete_diffusion
for general use. The library provides implementations of various typical discrete diffusion models, consisting of
(Vanilla/Reparameterized) multinomial diffusion
: diffusion processes that injectuniform
noise to the token sequence. The implementation of vanilla multinomial diffusion closely follows the codebase of the original paper;(Vanilla/Reparameterized) absorbing diffusion
: diffusion processes where tokens within the sequence could get absorbed to themasking
state, as described in the D3PM paper.
click to check the implementation details as well as their arguments 👇
These diffusion models share the same set of interfaces allowing for external uses. In particular, they are defined as subclasses of DiscreteDiffusion
class, taking the following form:
class DiscreteDiffusion(nn.Module):
"""
The parent class for discrete denoising diffusion probabilistic models.
It supports the following methods:
- q_sample()
Sample x_t ~ q(x_t | x_0) to construct noisy Transformer inputs.
- compute_losses()
Compute the loss L_t = KL(q||p) at t-th time step.
- sample_step()
Sample x_t ~ p(x_{t-1} | x_t, x_0) at t-th time step.
"""
def __init__(self, num_timesteps):
super().__init__()
self.num_timesteps = num_timesteps
def q_sample(self, x_0, t, **kwargs):
"""
Sample from q(x_t | x_0), which is used as the model inputs.
Args:
x_0: token ids with shape [B, N]
t: current time step, tensor with shape [B]
Returns:
return a dict of relevant outputs including x_t.
"""
def compute_losses(self, inputs, **kwargs):
"""
Compute the loss objective KL(q||p) to train our generative process.
Args:
inputs: a dict that contains input types specific to different diffusion processes, containing
- x_t: token ids with shape [B, N]
- t: scalar timesteps, with shape [B]
Returns:
possibly return a dict of relevant outputs, including the loss used for training.
"""
def sample_step(self, decoder_out, denoising_fn, **kwargs):
"""
Given a time step t, start from x_t and sample x_{t-k} from q(x_{t-k} | x_t).
Args:
decoder_out: a namedtuple that contains decoding info, including
- x_t: token ids with shape [B, N]
- t: scalar timesteps
- max_steps: the maximum number of decoding steps
- ...
denoising_fn: a function that takes in x_t and t and returns model logits
kwargs: other arguments that are used to control decoding.
Returns:
return a new decoder_out namedtuple.
"""
A DiscreteDiffusion
model can be instantiated by configuring the following:
- Basic attributes, including
-
--num-diffusion-timesteps <int>
specifies the whole number of diffusion time steps (default: 50) -
--diffusion-type <str>
specifies the diffusion model type (choices:{absorbing, multinomial, reparam-absorbing, reparam-multinomial}
) -
--noise-scheduler-type <str>
specifies the noise schedule only in vanilla/reparam multinomial diffusion (typical choices:{linear, cosine}
; default:cosine
)
-
- Important arguments specific to the forward sampling routine in
q_sample()
, including-
--q-sample-mode <str>
specifies the sampling strategy (choices:{default, coupled, multi-step, multi-sample}
; default:default
). We provide various choices for sampling from$q(x_t|x_0)$ to prepare corrupted token sequences for denoising, including-
default
: a single sample is drawn as$x_t \sim q(x_t|x_0)$ , identical to previous practices; -
multi-step
: sample two i.i.d. time steps$s, t$ and draw$x_s \sim q(x_s|x_0)$ and$x_t \sim q(x_t|x_0)$ , respectively. We then optimize the average$\frac{1}{2}(\mathcal{L}_s + \mathcal{L}_t)$ for variance reduction; -
multi-sample
: sample two i.i.d. samples$x_t \sim q(x_t|x_0)$ and$x_t^{'} \sim q(x_t|x_0)$ at the same step, and compute the loss averaged over these two samples; -
coupled
: also known as conditioned training, which is detailed in Appendix F of the paper. This starts with sampling two i.i.d. time steps$s, t$ (assume$s < t$ ). We draw$x_t \sim q(x_t|x_0)$ as usual, but draw$x_s$ from a distribution conditioned on$x_t$ as$x_s \sim q(x_s|x_t, x_0)$ . We then compute the average$\frac{1}{2}(\mathcal{L}_s + \mathcal{L}_t)$ as the objective. This strategy can simulate the backward transition process and help stabilize training. During preliminary experiments, we found thecoupled
sampling mode brings significant improvements for both vanilla multinomial/absorbing diffusion, but the gain is not consistently substantial in reparameterized variants.
-
-
--not-diffusing-special-sym
indicates whether to include special symbols during the diffusion process (default: False)
-
- Important arguments specific to the loss objective calculation in
compute_losses()
, including-
--reweighting-type <str>
specifies the reweighting scheme in our reparameterized family (choices:{linear, reciprocal, none}
; default:linear
) -
--label-smoothing <float>
specifies the rate of label smoothing (default: 0.1)
-
- Important arguments specific to the decoding routine in
sample_step()
, including-
--argmax-decoding
indicates whether to use argmax decoding for the denoised Transformer output$\tilde{x}_0$ (default: False) -
--temperature <float>
specifies the temperature$\tau$ for sampling$\tilde{x}_0 \sim \operatorname{Categorical}(f(x_t;\theta)/\tau)$ if the argmax decoding scheme is not used. (default: 1.0) -
--decoding-strategy <str>
specifies the use of vanilla (default
) / reparameterized (reparam-<options>
; see the details)decoding strategy (choices:{default, reparam-<options>}
; default:default
) -
--load-ema-weights
indicates whether to load the EMA model weights for generation (default: False) -
--iter-decode-max-iter <int>
specifies the maximum number of timesteps for decoding (default: 10) -
--iter-decode-with-beam <int>
specifies the beam size for decoding multiple sequences with different lengths in parallel (default: 1) -
--iter-decode-force-max-iter
indicates the iterative decoding must run the specified number of iterations and do not exit. Recommended to set this flag to True.
-
See here for a more comprehensive list of arguments.
By passing --decoding-strategy default
, the vanilla sampling scheme (specific to each discrete diffusion process) is used.
A more advanced decoding approach can be invoked by passing --decoding-strategy reparam-<conditioning-of-v>-<topk_mode>-<schedule>
. This approach is based on the proposed reparameterization in our paper and allows for more effective decoding procedures. The options specify the decoding algorithm via
-
<conditioning-of-v>
:uncond
orcond
(defaultuncond
): whether to generate the routing variable$v_t$ in a conditional or unconditional manner; -
<topk_mode>
:stochastic<float>
ordeterministic
(defaultdeterministic
): whether to use stochastic or deterministic top-$k$ selection. The float value instochastic<float>
specifies the degree of randomness in the stochastic top-$k$ selection; -
<schedule>
:linear
orcosine
(defaultcosine
): the schedule for$k$ during our denoising procedure, which is used to control the number of top-$k$ tokens to be denoised for the next decoding step.
See the implementation for more details about the options.
Please see the scripts below for details.
Note
- Note that all tasks considered in this work operate on the original data and do not adopt Knowledge Distillation (KD).
We follow the standard pre-processing in fairseq/examples to prepare the binarized data:
# fetch and preprocess the data to BPE codes
cd examples/translation/
bash prepare-iwslt14.sh
cd ../..
# binarize the data
TEXT=examples/translation/iwslt14.tokenized.de-en
fairseq-preprocess --joined-dictionary --source-lang de --target-lang en \
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
--destdir data-bin/iwslt14.tokenized.de-en \
--workers 20
We use the data released in fairseq/examples to prepare the dataset:
wget http://dl.fbaipublicfiles.com/nat/original_dataset.zip
unzip original_dataset.zip
TEXT=wmt14_ende
fairseq-preprocess --joined-dictionary \
--source-lang en --target-lang de \
--trainpref $TEXT/train.en-de --validpref $TEXT/valid.en-de --testpref $TEXT/test.en-de \
--destdir data-bin/wmt14_ende --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
For this dataset, we use the raw data wmt16.tar.gz as pre-processed in this repository.
tar xzvf wmt16.tar.gz
TEXT=wmt16/en-ro
# move train/ dev/ test/ bpe codes into the $TEXT folder
mv $TEXT/train/corpus.bpe.en $TEXT/train.bpe.en
mv $TEXT/train/corpus.bpe.ro $TEXT/train.bpe.ro
mv $TEXT/dev/dev.bpe.en $TEXT/dev.bpe.en
mv $TEXT/dev/dev.bpe.ro $TEXT/dev.bpe.ro
mv $TEXT/test/test.bpe.en $TEXT/test.bpe.en
mv $TEXT/test/test.bpe.ro $TEXT/test.bpe.ro
# binarize the data
fairseq-preprocess --joined-dictionary \
--source-lang en --target-lang ro \
--trainpref $TEXT/train.bpe --validpref $TEXT/dev.bpe --testpref $TEXT/test.bpe \
--destdir data-bin/wmt16_enro --thresholdtgt 0 --thresholdsrc 0 \
--workers 20
We first get into the fairseq
folder and then run the following commands to train the models.
######## training scripts for IWSLT'14 , WMT'14, and WMT'16
# first cd to fairseq
# we use 1 GPU for IWSLT'14, 4 GPUs for WMT'14 and 2 GPUs for WMT'16 datasets respectively.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_train.sh -m absorbing -d <iwslt/wmt14/wmt16> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=1 bash experiments/mt_train.sh -m multinomial -d <iwslt/wmt14/wmt16> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=2 bash experiments/mt_train.sh -m reparam-absorbing -d <iwslt/wmt14/wmt16> -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=3 bash experiments/mt_train.sh -m reparam-multinomial -d <iwslt/wmt14/wmt16> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
Note
-s <str>
is used to specify the name of the experiment.- We could pass custom arguments that might be specific to training by appending them after
-e True
.
The evaluation pipeline is handled by experiments/mt_generate.sh
. The script will generate the translation results and evaluate the BLEU score.
########### IWLS'14, WMT'14, and WMT'16 datasets
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/mt_generate.sh -a false -c <checkpoint_path> -d <iwslt/wmt14/wmt16>
Arguments:
-a
: whether to average multiple checkpoints-c
: indicates the location of the checkpoint. If-a false
(not to average checkpoints), pass the checkpoint path; if-a true
, pass the directory that stores multiple checkpoints at different training steps for averaging.-d
: the dataset name
We also provide the checkpoints of our trained models.
Dataset | Model | Checkpoint link |
---|---|---|
IWSLT'14 | Multinomial | link |
IWSLT'14 | Absorbing | link |
IWSLT'14 | Reparam-multinomial | link |
IWSLT'14 | Reparam-absorbing | link |
WMT'14 | Multinomial | link |
WMT'14 | Absorbing | link |
WMT'14 | Reparam-multinomial | link |
WMT'14 | Reparam-absorbing | link |
WMT'16 | Multinomial | link |
WMT'16 | Absorbing | link |
WMT'16 | Reparam-multinomial | link |
WMT'16 | Reparam-absorbing | link |
We follow the experimental setup in DiffuSeq for question generation and paraphrasing tasks .
The raw data of these two tasks can be fetched from the original DiffuSeq repository. We then binarize the data via the provided script.
# put the raw data in the directory ``diffuseq_data/QG``
# Preprocess the question generation dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QG
# put the raw data in the directory ``diffuseq_data/QQP``
# Preprocess the paraphrasing dataset
bash diffusion_mt/scripts/preprocess_diffuseq_datasets.sh QQP
# QQP or QG datasets
# first cd to fairseq
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m absorbing -d <qqp/qg> -s default -e True --store-ema --label-smoothing 0.1
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m multinomial -d <qqp/qg> -s default -e True --not-diffusing-special-sym --store-ema --label-smoothing 0.0
CUDA_VISIBLE_DEVICES=0,1 bash experiments/diffuseq_train.sh -m reparam-multinomial -d <qqp/qg> -s default -e True --not-diffusing-special-sym --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
CUDA_VISIBLE_DEVICES=2,3 bash experiments/diffuseq_train.sh -m reparam-absorbing -d <qqp/qg> -s default -e True --q-sample-mode coupled --store-ema --label-smoothing 0.1 --reweighting-type linear
We closely follow the generation & evaluation protocols as in DiffuSeq to ensure a head-to-head comparison. The whole pipeline is re-implemented in fairseq/diffusion_mt/scripts/decode_diffuseq.py
and fairseq/diffusion_mt/scripts/eval_diffuseq.py
respectively to be compatible with Fairseq. Run the following commands:
# we recommend putting each checkpoint into a separate folder
# since the script will put the decoded results into a file under the same folder of each checkpoint.
CUDA_VISIBLE_DEVICES=0 bash experiments/diffuseq_generate.sh -a false -b true -c <checkpoint_path> -d <qqp/qg>
Arguments:
-a
: whether to average multiple checkpoints-b
: whether to use multiple samples for MBR decoding-c
: indicates the location of the checkpoint. If-a false
(not to average checkpoints), pass the checkpoint path; if-a true
, pass the directory that stores multiple checkpoints at different training steps for averaging.-d
: the dataset name
We also provide the checkpoints of our trained models.
Dataset | Model | Checkpoint link |
---|---|---|
QG | Multinomial | link |
QG | Absorbing | link |
QG | Reparam-multinomial | link |
QG | Reparam-absorbing | link |
QQP | Multinomial | link |
QQP | Absorbing | link |
QQP | Reparam-multinomial | link |
QQP | Reparam-absorbing | link |
@article{zheng2023rdm,
title={A Reparameterized Discrete Diffusion Model for Text Generation},
author={Zheng, Lin and Yuan, Jianbo and Yu, Lei and Kong, Lingpeng},
journal={arXiv preprint arXiv:2302.05737},
year={2023}
}