TSD-SR: One-Step Diffusion with Target Score Distillation for Real-World Image Super-Resolution [PyTorch]
- [2025.03] Training code is released.
- [2025.01] Release the TSD-SR, including the inference codes and pretrained models.
- [2024.12] This repo is created.
π€ If TSD-SR is helpful to your projects, please help star this repo. Thanks! π€
# git clone this repository
git clone https://github.com/Microtreei/TSD-SR.git
cd TSD-SR
# create an environment
conda create -n tsdsr python=3.9
conda activate tsdsr
pip install -r requirements.txt
- Download the pretrained SD3 models from HuggingFace.
- Download the TSD-SR lora weights and prompt embeddings from GoogleDrive or OneDrive.
You can put the models weights into checkpoint/tsdsr
.
You can put the prompt embbedings into dataset/default
.
You can put the testing images in the imgs/test
.
python test/test_tsdsr.py \
--pretrained_model_name_or_path /path/to/your/sd3 \
-i imgs/test \
-o outputs/test \
--lora_dir checkpoint/tsdsr \
--embedding_dir dataset/default
- Download StableSR testsets (DrealSRVal_crop128, RealSRVal_crop128, DIV2K_V2_val) from GoogleDrive or OneDrive. We sincerely thank the authors of StableSR for their well-curated test dataset.
- Unzip them into
imgs/StableSR_testsets/
, the data folder should be like this:
βββ imgs
βββ StableSR_
8000
testsets
βββ DIV2K_V2_val
β βββ test_LR
β βββ test_HR
βββ DrealSRVal_crop128
β βββ test_LR
β βββ test_HR
βββ RealSRVal_crop128
βββ test_LR
βββ test_HR
- Download the TSD-SR lora weights
checkpoint/tsdsr-mse
from GoogleDrive or OneDrive. We employ this model for evaluation and set the--align_method
toadain
.
Use DRealSRVal_crop128 as an example.
python test/test_tsdsr.py \
--pretrained_model_name_or_path /path/to/your/sd3 \
-i imgs/StableSR_testsets/DrealSRVal_crop128/test_LR \
-o outputs/DrealSR \
--lora_dir checkpoint/tsdsr-mse \
--embedding_dir dataset/default \
--align_method adain
python test/test_metrics.py \
--inp_imgs outputs/DrealSR \
--gt_imgs imgs/StableSR_testsets/DrealSRVal_crop128/test_HR \
--log logs/metrics
- Generate degraded images: We employ the same degradation pipeline as SeeSR. More details can be found at here. Thanks for this awesome work. In addition, you can put the prompt texts in
your_training_datasets/lr_bicubic
. - Prompt texts: You may use either tag-based prompts (generated via DAPE, download from here) or natural language descriptions (produced by the Large Language and Vision Assistant, LLaVA) as high-quality (HQ) image prompts. Thanks for these awesome works. In addition, you can put the prompt texts in
your_training_datasets/prompt_txt
.
- Modify the dataset path in
data/data.py
anddata/process.py
, and ensure the directory structure:
your_training_datasets/ # Example: FLICKR2K/
βββ gt
βββ 0000001.png # GT images, (3, 512, 512)
βββ ...
βββ lr_bicubic
βββ 0000001.png # Bicubic LR images, (3, 512, 512)
βββ ...
βββ prompt_txt
βββ 0000001.txt # prompts for teacher model and lora model
βββ ...
- Modify the SD3 path in
data/process.py
and run it to generate the training data:
python data/process.py
This step is designed to reduce GPU memory overhead during training. Since SD3βs tokenizer and text encoder require significant parameter storage, we pre-compute their outputs and load them directly during training. Similarly, we also pre-process the HR (High-Resolution) latent space tensors to optimize memory efficiency. The final data folder will be like this:
your_training_datasets/ # Example: FLICKR2K/
βββ gt
βββ 0000001.png # GT images, (3, 512, 512)
βββ ...
βββ lr_bicubic
βββ 0000001.png # Bicubic LR images, (3, 512, 512)
βββ ...
βββ prompt_txt
βββ 0000001.txt # prompts for teacher model and lora model
βββ ...
βββ prompt_embeds
βββ 0000001.pt # SD3 prompt embedding tensors, (333, 4096)
βββ ...
βββ pool_embeds
βββ 0000001.pt # SD3 pooled embedding tensors, (2048,)
βββ ...
βββ latent_hr
βββ 0000001.pt # SD3 latent space tensors, (16, 64, 64)
βββ ...
- Download the teacher models and put them in
checkpoint/teacher/
. - Download the null prompts embeddings and put them in
dataset/null/
.
We derive the teacher LoRA weights through fine-tuning of the pretrained SD3 model with HR data, optimizing via Diffusion loss. This is designed to enhance the teacher model's sensitivity to HQ data. Use use_teacher_lora
to enable LoRA weights; otherwise, the T2I SD3 model will be used as the teacher model by default.
The null prompts embeddings are used to compute cfg when training.
export MODEL_NAME="/path/to/your/sd3_model";
export TEACHER_MODEL_NAME="checkpoint/teacher/";
export CHECKPOINT_PATH="checkpoint/tsdsr";
export HF_ENDPOINT="https://hf-mirror.com";
export OUTPUT_DIR="checkpoint/tsdsr-save/";
export OUTPUT_LOG="logs/tsdsr.log";
export LOG_NAME="tsdsr-train";
nohup accelerate launch --config_file config/config.yaml --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --main_process_port 57079 --mixed_precision="fp16" train/train.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--teacher_lora_path=$TEACHER_MODEL_NAME \
--train_batch_size=2 --rank=64 --rank_vae=64 --rank_lora=64 \
--num_train_epochs=200 --checkpointing_steps=5000 --validation_steps=500 --max_train_steps=200000 \
--learning_rate=5e-06 --learning_rate_reg=1e-06 --lr_scheduler="cosine_with_restarts" --lr_warmup_steps=3000 \
--seed=43 --use_default_prompt --use_teacher_lora --use_random_bias \
--output_dir=$OUTPUT_DIR \
--report_to="wandb" --log_code --log_name=$LOG_NAME \
--gradient_accumulation_steps=1 \
--resume_from_checkpoint="latest" \
--guidance_scale=7.5 > $OUTPUT_LOG 2>&1 & \
Quantitative comparison with the state-of-the-art one-step methods across both synthetic and real-world benchmarks (click to expand).
Quantitative comparison with the state-of-the-art multi-step methods across both synthetic and real-world benchmarks (click to expand).
This project is released under the Apache 2.0 license.
@article{dong2024tsd,
title={TSD-SR: One-Step Diffusion with Target Score Distillation for Real-World Image Super-Resolution},
author={Dong, Linwei and Fan, Qingnan and Guo, Yihong and Wang, Zhonghao and Zhang, Qi and Chen, Jinwei and Luo, Yawei and Zou, Changqing},
journal={arXiv preprint arXiv:2411.18263},
year={2024}
}