This is the official repository for DeSAM: Decoupled Segment Anything Model for Generalizable Medical Image Segmentation.
DeSAM: Decoupled Segment Anything Model for Generalizable Medical Image Segmentation
Yifan Gao, Wei Xia, Dingdu Hu, Wenkui Wang and Xin Gao
MICCAI 2024
Deep learning-based medical image segmentation models often suffer from domain shift, where the models trained on a source domain do not generalize well to other unseen domains. As a prompt-driven foundation model with powerful generalization capabilities, the Segment Anything Model (SAM) shows potential for improving the cross-domain robustness of medical image segmentation. However, SAM performs significantly worse in automatic segmentation scenarios than when manually prompted, hindering its direct application to domain generalization. Upon further investigation, we discovered that the degradation in performance was related to the coupling effect of inevitable poor prompts and mask generation. To address the coupling effect, we propose the Decoupled SAM (DeSAM). DeSAM modifies SAM’s mask decoder by introducing two new modules: a prompt-relevant IoU module (PRIM) and a prompt-decoupled mask module (PDMM). PRIM predicts the IoU score and generates mask embeddings, while PDMM extracts multi-scale features from the intermediate layers of the image encoder and fuses them with the mask embeddings from PRIM to generate the final segmentation mask. This decoupled design allows DeSAM to leverage the pre-trained weights while minimizing the performance degradation caused by poor prompts. We conducted experiments on publicly available cross-site prostate and cross-modality abdominal image segmentation datasets. The results show that our DeSAM leads to a substantial performance improvement over previous state-of-the-art domain generalization methods.
- Create a virtual environment
conda create -n desam python=3.10 -y
and activate itconda activate desam
- Install Pytorch
- git clone
https://github.com/yifangao112/DeSAM.git
- Enter the DeSAM folder
cd DeSAM
and runpip install -r requirements.txt
Our files are organized as follows, similar to nnU-Net:
- work_dir
- raw_data
- checkpoint
- image_embeddings
- results_folder
-
Download the cross-site prostate dataset Google Drive, unzip it and put files under the
work_dir/raw_data
dir. The data also host on Baidu Netdisk, password: dsam. The original pre-processing data was downloaded from MaxStyle, many thanks! -
Download SAM ViT-H checkpoint and place it at
work_dir/checkpoint/sam_vit_h_4b8939.pth
. -
Precompute image embeddings (~90G, Make sure your work_dir is on SSD):
python precompute_embeddings.py --work_dir your_work_dir
python desam_train_wholebox.py --work_dir your_work_dir --center=1 --pred_embedding=True --mixprecision=True
python desam_train_gridpoints.py --work_dir your_work_dir --center=1 --pred_embedding=True --mixprecision=True
This repository is based on MedSAM. We thank Jun Ma for making the source code of MedSAM publicly available. Part of codes are reused from the nnU-Net.
If this code is helpful for your study, please cite our paper:
@inproceedings{gao2024desam,
title={Desam: Decoupled segment anything model for generalizable medical image segmentation},
author={Gao, Yifan and Xia, Wei and Hu, Dingdu and Wang, Wenkui and Gao, Xin},
booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
pages={509--519},
year={2024},
organization={Springer}
}