8000 GitHub - alsdudrla10/ARD: [CVPR 2025 Oral] PyTorch re-implementation for Autoregressive Distillation of Diffusion Transformers (ARD).
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

[CVPR 2025 Oral] PyTorch re-implementation for Autoregressive Distillation of Diffusion Transformers (ARD).

License

Notifications You must be signed in to change notification settings

alsdudrla10/ARD

Repository files navigation

[CVPR 2025 Oral] Autoregressive Distillation of Diffusion Transformers
This repository provides a re-implementation of the original work (ARD), reconstructed from the author's recollection.

Yeongmin Kim, Sotiris Anagnostidis, Yuming Du, Edgar Schoenfeld, Jonas Kohler, Markos Georgopoulos, Albert Pumarola, Ali Thabet, Artsiom Sanakoyeu

arXiv 

Overview

We propose AutoRegressive Distillation (ARD), a method that leverages the historical trajectory of diffusion ODEs to mitigate exposure bias and improve efficiency in distillation. ARD achieves strong performance on ImageNet and text-to-image synthesis with significantly fewer steps and minimal computational overhead. Teaser image

Dependencies

The requirements for this code are the same as DiT.

Save ODE trajectories

Make sure to save a sufficient number of ODE trajectories using sample_trajectory.py, and ensure they match the dataloader used in the subsequent training procedure (see the --data-path argument in the training script).

Training

This procedure is performed using 8 A100 GPUs for 2 days.

torchrun --nnodes=1 --nproc_per_node=8 train_ARD.py --model DiT-XL/2 --global-batch-size=64 --stack=6

Fine-tuning with GAN loss

This procedure is performed using 8 A100 GPUs over a few hours. Since this procedure also uses real data, it needs to properly set the path to the actual ImageNet dataset using the --real-data-path argument.

torchrun --nnodes=1 --nproc_per_node=8 train_ARD_gan.py --model DiT-XL/2 --global-batch-size=48 --stack=6 --ckpt_path={$PATH}/checkpoints/0300000.pt"

Generation

torchrun --nnodes=1 --nproc_per_node=1 sample_ARD.py --stack=6 --ckpt_path={$PATH}/checkpoints/0300000.pt"

Performance

We follow the evaluation protocol of ADM. Teaser image

Model Steps Latency FID Inception Score Precision Recall
DiT/XL-2 25 493.5 2.89 230.22 0.797 0.572
Step Distillation (N=0) 4 64.80 10.25 181.58 0.704 0.474
ARD (N=6) 4 66.34 4.32 209.03 0.770 0.574
+ GAN loss finetuning 4 66.34 1.84 235.84 0.797 0.615

Teaser image

Citation

If you find the code useful for your research, please consider citing

@inproceedings{kim2025autoregressive,
      title={Autoregressive Distillation of Diffusion Transformers}, 
      author={Kim, Yeongmin and Anagnostidis, Sotiris and Du, Yuming and Schönfeld, Edgar and Kohler, Jonas and Georgopoulos, Markos and Pumarola, Albert and Thabet, Ali and Sanakoyeu, Artsiom},
      booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 
      year={2025},
}

About

[CVPR 2025 Oral] PyTorch re-implementation for Autoregressive Distillation of Diffusion Transformers (ARD).

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

0