Official implementation of the paper Generalizing teacher networks for effective knowledge distillation across student architectures (BMVC'24)
Authors: Kuluhan Binici, Weiming Wu, Tulika Mitra
[Paper]
Run the Makefile to generate the config.py
file:
make config
Edit config.py
to set the correct paths for your data and model folders:
DATA_ROOT = # PATH/TO/DATA/FOLDER
MODEL_ROOT = # PATH/TO/MODELS/FOLDER
You can download pre-trained teacher model checkpoints from this Google Drive link.
To train any network on any dataset, use the following command:
python train-model-no-KD.py --model $NETWORK_NAME --dataset $DATASET_NAME
Alternatively, you can use the provided bash script:
bash scripts/no-KD.sh
Trained models will be saved in the checkpoints/
directory.
KD-aware teacher training can be done by running the KD-aware-teacher-training.py
script.
Use the --student supernet
argument:
python KD-aware-teacher-training.py --student supernet
Or simply use the pre-made bash script:
bash scripts/gtn.sh
Use the --student isolated_normal
argument:
python KD-aware-teacher-training.py --student isolated_normal
Or run the provided bash script:
bash scripts/sftn-teacher.sh
The resulting teacher model checkpoints will be saved in the checkpoints/
directory.
To distill student models using pre-trained teacher models, run the distill-student.py
script with the --kdtrain $KD_METHOD
argument, where $KD_METHOD
is the name of the distillation method. Options include:
DKD
SFTN
SCKD
vanilla
Example:
python distill-student.py --kdtrain DKD
Alternatively, you can use the provided bash scripts located in the scripts/
directory:
scripts/DKD.sh
scripts/SFTN-kd.sh
scripts/SCKD.sh
scripts/vanilla-kd.sh
The resulting student model checkpoints will be saved in the checkpoints/
directory.