8000 GitHub - s-sahoo/duo: [ICML 2025] The Diffusion Duality
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

s-sahoo/duo

Repository files navigation

By Subham Sekhar Sahoo, Justin Deschenaux, Aaron Gokaslan, Guanghan Wang, Justin Chiu, Volodymyr Kuleshov

Open In Colab deploy arXiv deploy

We unlock few-step generation in discrete diffusion language models via the underlying Guassian diffusion.

In this repo, we release:

  • The DUO framework
    1. Curriculum learning strategy to speed up training. [Example]
    2. Discrete Consistency Distillation pipeline. [Example]
    3. Greedy-tail sampler. [Example]
  • Baseline implementations [Examples]:
    1. Autoregressive Model.
    2. MDLM: Sahoo et al., "Simple and Effective Masked Diffusion Language Model", NeurIPS 2024.
    3. SEDD (absorb): Lou et al., "Score Entropy Based Discrete Diffusion", ICML 2024.
    4. D3PM (absorb) Austin et al., "Structured Denoising Diffusion Models in Discrete State-Spaces", NeurIPS 2021.

Getting Started

To get started, create a conda environment containing the required dependencies.

conda env create -n duo python=3.12
conda activate duo
conda install nvidia/label/cuda-12.4.0::cuda-toolkit
pip install -r requirements.txt
pip install flash_attn==2.7.4.post1

🏮 Integral Cache [Important]

Curriculum Learning (Sec. 4.1) and Discrete Consistency Distillation (Sec. 4.2) require mapping Gaussian to discrete diffusion parameters via the Diffusion Transformation operator (Sec. 3), which involves computing an integral (dependent only on the tokenizer’s vocabulary size). To avoid slowing down training, we pre-compute and cache this integral. Cached operators for bert-base-uncased (LM1B) and gpt2 (OWT) are in integral/. For other tokenizers, run:

python utils.py --vocab_size=N

where N is the vocabulary size of the tokenizer.

Checkpoints

The checkpoints for the DUO models (distilled/undistilled) trained on OpenWebText for 1M training steps are available on:

Slurm scripts

Run mkdir watch_folder to create a directory to store saved models and slurm logs and then run any script in scripts/ as a slurm job:

sbatch scripts/ABC_XYZ.sh

Training

To train DUO use the following scripts:

Curriculum Learning increases memory consumption. For faster training on OWT, one may consider a two-stage approach:

  • Stage 1: Curriculum Learning for 500K steps
    • Use scripts/train_owt_duo.sh with the following modifications:
      • Reduced batch size (loader.batch_size=32 on an 80 GB GPU)
      • trainer.max_steps=500000
  • Stage 2: Finetuning the checkpoint from stage 1 for 500K more steps
    • Training script: scripts/train_owt_duo_finetune.sh
    • Features a larger batch size (loader.batch_size=64 on an 80 GB) than stage 1.
    • Wandb run: Although this run uses a stage 1 checkpoint trained for 1M steps, the results reported in the paper correspond to the checkpoint at 500K steps.

Control the batch size per GPU using the argument loader.batch_size. If loader.batch_size * num_gpus < loader.global_batch_size, PyTorch Lightning resorts to gradient accumulation.

Distillation

To distil a model using the Discrete Consisitency Distillation (Alg. 1 in the paper), use scripts/distil_owt.sh

Sampling & Eval

To compute test perplexity on the validtion set of OWT use scripts/eval_owt_duo.sh and for zero shot perplexities use scripts/zero_shot_duo.sh.

To generate samples from a pre-trained model use one of the following command. Set

  • sampling.noise_removal=greedy to use the "Greedy-tail sampler" (equivalent to nucleus sampling in AR models; see Sec. 4.2 in the paper).
  • sampling.noise_removal=ancestral for the standard ancestral sampling. This produces more diverse samples (higher entropy) but with worse generative perplexity.

We have realease the distilled model s-sahoo/duo-distilled and the un-distilled model s-sahoo/duo on Huggingface🤗. To sample from a HF model, run the following command:

python main.py \
  mode=sample_eval \
  loader.batch_size=2 \
  loader.eval_batch_size=8 \
  data=openwebtext-split \
  algo=duo_base \
  algo.backbone=hf_dit \
  eval.checkpoint_path=s-sahoo/duo-distilled \
  sampling.steps=8 \
  sampling.num_sample_batches=1 \
  sampling.noise_removal=greedy \
  +wandb.offline=true 

We’ve also released checkpoints for the distilled duo-distilled.ckpt and the un-distilled model duo.ckpt trained on OWT in this Google Drive folder. Download them and use the command in scripts/gen_ppl_owt_duo.sh while specifying the paths correctly.

Baselines

We release the checkpoints for the baselines: SEDD, MDLM and AR trained on OpenWebText in this Google Drive folder. Download the checkpoints: ar.ckpt, mdlm.ckpt, sedd.ckpt and specify the paths appropriately in the respective shell scripts:

Acknowledgements & Citation

This repository was built off of MDLM's Github repository. Cite our paper using:

@inproceedings{
sahoo2025the,
title={The Diffusion Duality},
author={Subham Sekhar Sahoo and Justin Deschenaux and Aaron Gokaslan and Guanghan Wang and Justin T Chiu and Volodymyr Kuleshov},
booktitle={Forty-second International Conference on Machine Learning},
year={2025},
url={https://openreview.net/forum?id=9P9Y8FOSOk}
}

About

[ICML 2025] The Diffusion Duality

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published
0