Leveraging Segment Anything Model for Source-Free Domain Adaptation via Dual Feature Guided Auto-Prompting
This repository contains Pytorch implementation of our source-free domain adaptation (SFDA) method with Dual Feature Guided (DFG) auto-prompting approach.
Source-free domain adaptation (SFDA) for segmentation aims at adapting a model trained in the source domain to perform well in the target domain with only the source model and unlabeled target data. Inspired by the recent success of Segment Anything Model (SAM) which exhibits the generality of segmenting images of various modalities and in different domains given human-annotated prompts like bounding boxes or points, we for the first time explore the potentials of Segment Anything Model for SFDA via automatedly finding an accurate bounding box prompt. We find that the bounding boxes directly generated with existing SFDA approaches are defective due to the domain gap. To tackle this issue, we propose a novel Dual Feature Guided (DFG) auto-prompting approach to search for the box prompt. Specifically, the source model is first trained in a feature aggregation phase, which not only preliminarily adapts the source model to the target domain but also builds a feature distribution well-prepared for box prompt search. In the second phase, based on two feature distribution observations, we gradually expand the box prompt with the guidance of the target model feature and the SAM feature to handle the class-wise clustered target features and the class-wise dispersed target features, respectively. To remove the potentially enlarged false positive regions caused by the over-confident prediction of the target model, the refined pseudo-labels produced by SAM are further postprocessed based on connectivity analysis. Experiments on 3D and 2D datasets indicate that our approach yields superior performance compared to conventional methods.
Create the environment from the environment.yml
file:
conda env create -f environment.yml
conda activate dfg
- Download the BTCV dataset from MICCAI 2015 Multi-Atlas Abdomen Labeling Challenge, and the CHAOS dataset from 2019 CHAOS Challenge. Then preprocess the downloaded data referring to
./preprocess.ipynb
. You can also directly download our preprocessed datasets from here. The paths to the datasets need to be specified in the yaml files in./configs
.
The following are the steps for the CHAOS (MRI) to BTCV (CT) adaptation.
- Download the source domain model from here or specify the data path in
configs/train_source_seg.yaml
and then run
python main_trainer_source.py --config_file configs/train_source_seg.yaml
- Download the trained model after the feature aggregation phase from here or specify the source model path and data path in
configs/train_target_adapt_FA.yaml
, and then run
python main_trainer_fa.py --config_file configs/train_target_adapt_FA.yaml
- Download the MedSAM model checkpoint from here and put it under
./medsam/work_dir/MedSAM
. - Specify the model (after feature aggregation) path, data path, and refined pseudo-label paths in
configs/train_target_adapt_SAM.yaml
, and then run
python main_trainer_sam.py --config_file configs/train_target_adapt_SAM.yaml