All experiments are run on a machine with
- 1 NVIDIA RTX A6000 GPU
- Python (Python 3.10) and Pyotrch (torch==2.0.1)
-
Install Anaconda
-
Clone this reposity and cd into the directory:
git clone https://github.com/Valeyards/ProgPath.git
cd ProgPath
- Create a new environment and install dependencies:
conda create -n progpath python=3.10 -y --no-default-packages
conda activate progpath
pip install --upgrade pip
pip install -r requirements.txt
The ProgPath model can be accessed from here
Preprocess the slides following CLAM, including foreground tissue segmentation and stitching.
- Download the pretrained Virchow2 model weights, put it to ./weights/ and load the model
import timm
import torch
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from timm.layers import SwiGLUPacked
from PIL import Image
# need to specify MLP layer and activation function for proper init
model = timm.create_model("hf-hub:paige-ai/Virchow2", pretrained=True, mlp_layer=SwiGLUPacked, act_layer=torch.nn.SiLU)
model = model.eval()
transforms = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
- Use Virchow2 to extract image embeddings
image = Image.open("/path/to/your/image.png")
image = transforms(image).unsqueeze(0) # size: 1 x 3 x 224 x 224
output = model(image) # size: 1 x 261 x 1280
class_token = output[:, 0] # size: 1 x 1280
patch_tokens = output[:, 5:] # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those
# concatenate class token and average pool of patch tokens
embedding = torch.cat([class_token, patch_tokens.mean(1)], dim=-1) # size: 1 x 2560
Please refer to demo.ipynb
for a demonstration.
- Load the ProgPath model
from utils.utils import read_yaml
from utils.model_factory import load_model
import torch
cfg = read_yaml('projects/configs/cfg_progpath.yaml')
model = load_model(cfg)
model.load_state_dict(torch.load('weights/progpath.pt'), strict=True)
- Load image and clinical data
import pandas as pd
import random
import torch
from datasets.SurvivalBagDataset import process_clinical
def encode_cancer_type(cancer_type):
cancer_types = ['blca', 'brca', 'cesc', 'crc', 'gbm', 'hnsc', 'rcc', 'lgg', 'lihc', 'luad', 'lusc', 'paad', 'skcm', 'stad', 'ucec']
encoding = [0] * len(cancer_types)
if cancer_type in cancer_types:
encoding[cancer_types.index(cancer_type)] = 1
return encoding
random.seed(1)
patient_csv = pd.read_csv('csvs/sr_crc.csv')
patient_data = random.sample(list(patient_csv['patient_id']), 1)
patient_data = patient_csv[patient_csv['patient_id'] == patient_data[0]]
patch_features = torch.load(f'features/{patient_data["filename"].values[0]}')
clinical_feature = process_clinical(patient_data, columns=['age', 'sex', 'stage'])['processed_df'].drop(columns=['filename']).values
clinical_feature = torch.from_numpy(clinical_feature).float()
cancer_type = 'crc'
cancer_encoding = torch.tensor(encode_cancer_type(cancer_type)).float()
- Predict patient risk
model.eval()
model.to('cuda')
patch_features = patch_features.to('cuda')
res = model(h=patch_features, clinical_feature=clinical_feature.to('cuda'), cancer_encoding=cancer_encoding.to('cuda'))
risk = res['bag_logits'][0][1:]
print('patient id:', patient_data['patient_id'].values[0], 'risk:', risk.item())
To reproduce the results in our paper, we provide a reproducible result on SR-CRC dataset.
Please refer to demo.ipynb
for a demonstration.
- First download our processed SR-CRC frozen features here
- Put the extracted features to ./features/
- Run the following command:
python3 eval.py --config_path projects/configs/cfg_progpath.yaml
The C-index and log-rank p-value will be printed to the screen.
sr_crc cindex_now: 0.7765380443204711 pvalue: 2.496161383249256e-10
The computed risk scores for this cohort and the corresponding Kaplan-Meier curve will be stored at exp_progpath/evaluation/sr_crc/
The project was built on many amazing repositories: Virchow, CLAM, and PORPOISE. We thank the authors and developers for their contributions.
Please open new threads or address questions to yuanw@stu.scu.edu.cn or xiyue.wang.scu@gmail.com
ProgPath is made available under the GPLv3 License and is available for non-commercial academic purposes.