This repository contains code to train a VQ-GAN model on the CelebA dataset, and then train a Transformer (GPT-like model) on the learned discrete latent space for image generation. It supports DDP training.
The project is inspired by the "Taming Transformers for High-Resolution Image Synthesis" paper, but streamlined to work with a single dataset (CelebA) for simplicity.
Trains a Vector Quantized Generative Adversarial Network (VQ-GAN) on CelebA.
Includes perceptual loss (LPIPS), adversarial loss, and commitment loss.
After training, you get a discrete codebook and a decoder to reconstruct images.
Trains a Transformer to model sequences of VQ-GAN codebook indices.
At inference time, the Transformer generates latent codes, and the VQ-GAN decoder turns them into full images.
git clone <your-repo-url>
cd <repo-name>
pip install -r requirements.txt
python src/train_vqgan.py data_path=<path-to-celeba-dataset> log_dir=<where-to-store-logs>
I trained this model for ~9 hours (52 epochs) using 2x Nvidia 4090 GPUs.
torchrun --standalone --nnodes --nproc_per_node=2 src/train_vqgan.py data_path=<path-to-celeba-dataset> log_dir=<where-to-store-logs> batch_size=24 num_workers=8 lr=4.2e-5 disc_start=630 ddp=True
python src/train_transformer.py data_path=<path-to-celeba-dataset> log_dir=<where-to-store-logs> vqgan_weights=<path-to-vqgan-checkpoint-from-step-1>
I trained this model for ~9 hours (47 epochs) using 2x Nvidia 4090 GPUs.
torchrun --standalone --nnodes --nproc_per_node=2 src/train_transformer.py data_path=<path-to-celeba-dataset> log_dir=<where-to-store-logs> vqgan_weights=<path-to-vqgan-checkpoint-from-step-1> batch_size=64 num_workers=8 lr=1.4e-4 ddp=True
After both models are trained:
python scripts/generate_images.py
This script samples new images by generating discrete codes with the Transformer and decoding them with the VQ-GAN decoder. See configs/generate_images.yaml
for configuring the sampling process.
This project is simplified for research/educational purposes and focuses only on the CelebA dataset.