By Subham Sekhar Sahoo, Justin Deschenaux, Aaron Gokaslan, Guanghan Wang, Justin Chiu, Volodymyr Kuleshov
We unlock few-step generation in discrete diffusion language models via the underlying Guassian diffusion.
In this repo, we release:
- The DUO framework
- Baseline implementations [Examples]:
- Autoregressive Model.
- MDLM: Sahoo et al., "Simple and Effective Masked Diffusion Language Model", NeurIPS 2024.
- SEDD (absorb): Lou et al., "Score Entropy Based Discrete Diffusion", ICML 2024.
- D3PM (absorb) Austin et al., "Structured Denoising Diffusion Models in Discrete State-Spaces", NeurIPS 2021.
To get started, create a conda environment containing the required dependencies.
conda env create -n duo python=3.12
conda activate duo
conda install nvidia/label/cuda-12.4.0::cuda-toolkit
pip install -r requirements.txt
pip install flash_attn==2.7.4.post1
Curriculum Learning (Sec. 4.1
) and Discrete Consistency Distillation (Sec. 4.2
) require mapping Gaussian to discrete diffusion parameters via the Diffusion Transformation operator (Sec. 3
), which involves computing an integral (dependent only on the tokenizer’s vocabulary size). To avoid slowing down training, we pre-compute and cache this integral. Cached operators for bert-base-uncased
(LM1B) and gpt2
(OWT) are in integral/
. For other tokenizers, run:
python utils.py --vocab_size=N
where N
is the vocabulary size of the tokenizer.
The checkpoints for the DUO models (distilled/undistilled) trained on OpenWebText for 1M training steps are available on:
- Huggingface🤗.
- Google Drive folder as the HF checkpoints can't be finetuned.
Run mkdir watch_folder
to create a directory to store saved models and slurm logs
and then run any script in scripts/
as a slurm job:
sbatch scripts/ABC_XYZ.sh
To train DUO use the following scripts:
-
LM1B
- w/ sentencepacking (same as in D3PM)
- Training script:
scripts/train_lm1b_duo_sentencepacking.sh
- Wandb run
- Training script:
- w/o sentencepacking (same as in MDLM, SEDD)
- Training script:
scripts/train_lm1b_duo.sh
- Wandb run
- Training script:
- w/ sentencepacking (same as in D3PM)
-
OWT:
scripts/train_owt_duo.sh
.
Curriculum Learning increases memory consumption. For faster training on OWT, one may consider a two-stage approach:
Stage 1
: Curriculum Learning for500K
steps- Use
scripts/train_owt_duo.sh
with the following modifications:- Reduced batch size (
loader.batch_size=32
on an80 GB
GPU) trainer.max_steps=500000
- Reduced batch size (
- Use
Stage 2
: Finetuning the checkpoint fromstage 1
for500K
more steps- Training script:
scripts/train_owt_duo_finetune.sh
- Features a larger batch size (
loader.batch_size=64
on an80 GB
) thanstage 1
. - Wandb run: Although this run uses a
stage 1
checkpoint trained for1M
steps, the results reported in the paper correspond to the checkpoint at500K
steps.
- Training script:
Control the batch size per GPU using the argument loader.batch_size
. If loader.batch_size * num_gpus < loader.global_batch_size
, PyTorch Lightning resorts to gradient accumulation.
To distil a model using the Discrete Consisitency Distillation (Alg. 1
in the paper), use scripts/distil_owt.sh
To compute test perplexity on the validtion set of OWT use scripts/eval_owt_duo.sh
and for zero shot perplexities use scripts/zero_shot_duo.sh
.
To generate samples from a pre-trained model use one of the following command. Set
sampling.noise_removal=greedy
to use the "Greedy-tail sampler" (equivalent to nucleus sampling in AR models; seeSec. 4.2
in the paper).sampling.noise_removal=ancestral
for the standard ancestral sampling. This produces more diverse samples (higher entropy) but with worse generative perplexity.
We have realease the distilled model s-sahoo/duo-distilled
and the un-distilled model s-sahoo/duo
on Huggingface🤗. To sample from a HF model, run the following command:
python main.py \
mode=sample_eval \
loader.batch_size=2 \
loader.eval_batch_size=8 \
data=openwebtext-split \
algo=duo_base \
algo.backbone=hf_dit \
eval.checkpoint_path=s-sahoo/duo-distilled \
sampling.steps=8 \
sampling.num_sample_batches=1 \
sampling.noise_removal=greedy \
+wandb.offline=true
We’ve also released checkpoints for the distilled duo-distilled.ckpt
and the un-distilled model duo.ckpt
trained on OWT in this Google Drive folder. Download them and use the command in scripts/gen_ppl_owt_duo.sh
while specifying the paths correctly.
We release the checkpoints for the baselines: SEDD, MDLM and AR trained on OpenWebText in this Google Drive folder. Download the checkpoints: ar.ckpt
, mdlm.ckpt
, sedd.ckpt
and specify the paths appropriately in the respective shell scripts:
scripts/eval_owt_*.sh
for computing validation perplexity on OWT.scripts/gen_ppl_*.sh
for generating text samples and evaluating them.scripts/zero_shot_*.sh
for computing zero shot perplexities.scripts/train_*.sh
for training the models.
This repository was built off of MDLM's Github repository. Cite our paper using:
@inproceedings{
sahoo2025the,
title={The Diffusion Duality},
author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=9P9Y8FOSOk}
}