Skip to content

Snuffy: Efficient Whole Slide Image Classifier For Efficient and Performant Diagnosis in Pathology Whole Slide Images


Notifications You must be signed in to change notification settings


Folders and files

Last commit message
Last commit date

Latest commit



4 Commits

Repository files navigation

Snuffy: Efficient Whole Slide Image Classifier

Static Badge PWC PWC

Hossein Jafarinia, Alireza Alipanah, Danial Hamdi, Saeed Razavi, Nahal Mirzaie, Mohammad Hossein Rohban

[arXiv] [Project Page] [Demo] [BibTex]

PyTorch implementation for the Multiple Instance Learning framework described in the paper Snuffy: Efficient Whole Slide Image Classifier (ECCV 2024, accepted).

Snuffy is a novel MIL-pooling method based on sparse transformers, designed to address the computational challenges in Whole Slide Image (WSI) classification for digital pathology. Our approach mitigates performance loss with limited pre-training and enables continual few-shot pre-training as a competitive option.

Key features:

  • Tailored sparsity pattern for pathology
  • Theoretically proven universal approximator with tight probabilistic sharp bounds
  • Superior WSI and patch-level accuracies on CAMELYON16 and TCGA Lung cancer datasets


This repository provides a complete, runnable implementation of the Snuffy framework, including code for the FROC metric, which is unique among WSI classification frameworks to the best of our knowledge.

  1. Slide Patching: WSIs are divided into manageable patches.
  2. Self-Supervised Learning: An SSL method is trained on the patches to create an embedder.
  3. Feature Extraction: The embedder computes features (embeddings) for each slide.
  4. MIL Training: The Snuffy MIL framework is applied to the computed features.

Each step in this pipeline can be executed independently, with intermediate results available for download to facilitate continued processing.

Table of Contents
  1. Requirements
  2. Dataset Download
  3. Train/Val/Test Split
  4. Slide Preparation: Patching and N-Shot Dataset Creation
  5. Training the Embedder
  6. Feature Extraction
  7. MIL Training
  8. Visualization
  9. Acknowledgement
  10. Citation


System Requirements

  • Operating System: Ubuntu 20.04 LTS (or compatible Linux distribution)
  • Python Version: 3.8 or later
  • GPU: Recommended for faster processing (CUDA-compatible)


  • Disk Space: Ensure you have sufficient disk space for dataset downloads and processing, especially if you intend to work with raw slides rather than pre-computed embeddings. Raw slide data can be very large.
  • Hardware: The MIL training code can run on both GPU and CPU. For optimal performance, a GPU is strongly recommended.

Downloading and Preparing Datasets

  1. Amazon CLI: To download the CAMELYON16 dataset's raw whole-slide images, you'll need the AWS CLI. Install it by:
curl "" -o ""
  1. GDC Client (For downloading the TCGA dataset): This is automatically downloaded and installed when you use the script.

  2. OpenSlide is necessary if you intend to patch the slides yourself using the or scripts. Install OpenSlide with:

# Update package list and install OpenSlide
apt-get update
apt-get install openslide-tools

Running Snuffy

  1. The ASAP package is required for calculating the FROC metric. Install ASAP and its multiresolutionimageinterface Python package as follows:
# Download and install ASAP
apt-get install -f "./ASAP-2.1-py38-Ubuntu2004.deb"
  1. Required Python packages can be installed with:
# Install Python packages from requirements.txt
pip install -r requirements.txt

Note: The requirements.txt file includes specific package versions used and verified in our experiments. However, newer versions available in your environment may also be compatible.

Additional Components

  1. MAE with Adapter: Refer to the MAE repository for installation instructions.

    Important: If using PyTorch versions 1.8+ , follow the instructions in the MAE repository to fix compatibility issue with the timm module. Alternatively, run the following script to fix the issue.

    chmod +x

    Note that we've also included a modified version of timm, to support adapter functionality.

Download Data


  1. List and Download Dataset: Run the following commands to list and download the CAMELYON16 dataset:

    aws s3 ls --no-sign-request s3://camelyon-dataset/CAMELYON16/ --recursive
    aws s3 cp --no-sign-request s3://camelyon-dataset/CAMELYON16/ raw_data/camelyon16 --recursive
  2. Directory Structure: After downloading, your raw_data/camelyon16 directory should look like this:

    -- camelyon16
        |-- annotations
        |-- background_tissue
        |-- checksums.md5
        |-- evaluation
        |-- images
        |-- license.txt
        |-- masks
        `-- pathology-tissue-background-segmentation.json
  3. Organize Files:
    Use the provided script to copy the necessary files into the datasets/camelyon16 directory. If space is limited, modify the script to move files instead of copying them.

  4. Final Directory Structure:

    |-- annotations
    |   |-- test_001.xml
    |   |-- tumor_001.xml
    |   |-- ...
    |-- masks
    |   |-- normal_001_mask.tif
    |   |-- test_001_mask.tif
    |   |-- tumor_001_mask.tif
    |   |-- ...
    |-- 0_normal
    |   |-- normal_004.tif
    |   |-- test_018.tif
    |   |-- ...
    |-- 1_tumor
    |   |-- test_046.tif
    |   |-- tumor_075.tif
    |   |-- ...
    |-- reference.csv

TCGA Lung Cancer

To download the TCGA Lung Cancer dataset, run the following script. This will download the slides listed in the LUAD manifest and LUSC manifest to the datasets/tcga/{luad, lusc} directory. Each slide will be stored in its own directory, named according to its ID in the manifest.

chmod +x

MIL datasets

Download the MIL datasets (sourced from the DSMIL project) and unzip them into the datasets/ directory.

unzip -d datasets/

Slide Preparation: Patching


This script processes TIFF slides located in datasets/camelyon16/{0_normal, 1_tumor}/. For each slide, it creates a directory at datasets/camelyon16/single/{0_normal, 1_tumor}/{slide_name}, saving the extracted patches as JPEG images.


TCGA Lung Cancer

This script processes SVS slides in datasets/tcga/{lusc, luad}/ and saves the extracted patches in datasets/tcga/single/{lusc, luad}/{slide_name} as JPEG images.


For both scripts, please refer to their arguments for detailed information on the script's arguments and their functionalities.

Train/Val/Test Split and N-Shot Dataset Creation


To split the CAMELYON16 dataset:

cd datasets/camelyon16

This script reorganizes the directory structure from:

datasets/camelyon16/single/{0_normal, 1_tumor}


datasets/camelyon16/single/fold1/{train, validation, test}/{0_normal, 1_tumor}

The official CAMELYON16 test set is used for testing, while the remaining data is randomly split into training and validation sets with an 80/20 ratio. You can adjust the fold number directly in the script.

To reverse the CAMELYON16 split:

cd datasets/camelyon16

The processed and shuffled datasets are saved with filenames that reflect the dataset name, fold count, and split ratio.

TCGA Lung Cancer

K-Fold Cross Validation Split

The script creates K-Fold cross-validation splits for the TCGA data, ensuring that a single patient's slides are not divided across multiple splits. It uses the patients.csv reference file and stores the fold information in datasets/tcga/folds/fold_{i}.csv.

To run the K-Fold split:

cd datasets/tcga

Selecting a Fold

After generating folds, use the script to organize the directories according to a selected fold:


This script reorganizes the directory structure from:

datasets/tcga/single/{0_luad, 1_lusc}


datasets/tcga/single/fold{i}/{train, validation, test}/{0_luad, 1_lusc}

De-selecting a Fold

To reverse the TCGA split and restore the original directory structure:

cd datasets/tcga

MIL Datasets

The script loads and processes MIL datasets downloaded in the previous step (Musk1, Musk2, Elephant) into a format compatible with Snuffy. It then performs cross-validation, ensuring each fold contains both negative and positive bags.

cd datasets/mil_dataset
# python --dataset [Musk1, Musk2, Elephant] --num_folds [10] --train_valid_ratio [0.2]
python --dataset Musk1

N-Shot Patch Dataset Creation


To create a 50-Shot patch dataset (a dataset containing at most n patches of each WSI):

cd datasets/camelyon16
python --shots=50

This will create a new folder named single/fold1_50shot based on the dataset in single/fold1. In this new folder, each slide will have at most 50 patches (or all patches if the original number is less than 50).


cd datasets/tcga
python --shots 5

Training the Embedder

Method Instructions Embedder Weights Embeddings
SimCLR (From Scratch) Refer to DSMIL Weights Embeddings
DINO (From Scratch) Refer to DINO (And use a ViT-S/16) Weights Embeddings
DINO (with Adapter) Refer to DINO with Adapter Section Weights Embeddings
MAE (with Adapter) Refer to MAE with Adapter Section Weights Embeddings

DINO with Adapter

Download DINO ImageNet-1K Pretrained ViT-S8 full wights:


Continue pretraining with DINO Adapter:

python dino_adapter/ \
  --adapter_ffn_scalar=10 \
  --arch=vit_small \
  --batch_size_per_gpu=16 \
  --clip_grad=3 \
  --data_path_train=datasets/camelyon16/single/fold1_50shot/train \
  --data_path_valid=datasets/camelyon16/single/fold1_50shot/validation \
  --epochs=100 \
  --ffn_num=32 \
  --freeze_last_layer=0 \
  --full_checkpoint=dino_deitsmall8_pretrain_full_checkpoint.pth \
  --lr__warmup_epochs__minlr="[0.0005, 10, 1e-06]" \
  --momentum_teacher=0.9995 \
  --norm_last_layer=False \
  --output_dir=out \
  --patch_size=8 \
  --random_head=1 \
  --teacher_temp__warmup_teacher_temp_epochs="[0.04, 0]" \
  --warmup_teacher_temp=0.04 \
  --weight_decay__weight_decay_end="[0.04, 0.4]"

MAE with Adapter

Download MAE ImageNet-1K Pretrained ViT-S8 full wights:


Continue pretraining with MAE Adapter:

torchrun \
--accum_iter=1 \
--adapter_ffn_scalar=1 \
--blr__min_lr__warmup_epochs="[0.001, 0, 40]" \
--data_path=datasets/camelyon16/single/fold1_200shot \
--epochs=400 \
--full_checkpoint=mae_pretrain_vit_base_full.pth \
--norm_pix_loss=0 \
--train_linears__linears_from_scratch="[1, 1]"

Feature Extraction

The script extracts features (embeddings) from a dataset using a specified embedder model. It processes the dataset and saves the cleaned embedder weights, feature vectors, and corresponding labels.

Input Dataset Structure

The dataset is expected to follow this directory structure:

└── {dataset_name}/
    ├── single/
    │   └── {fold}/
    │       ├── train/
    │       ├── validation/
    │       └── test/
    └── tile_label.csv
  • {dataset_name}: The name of your dataset.
  • {fold}: The specific fold of data (e.g., fold1, fold2, ...).
  • train/, validation/, test/: Directories containing the patches for training, validation, and testing, respectively.
  • tile_label.csv: CSV file containing the labels for the patches, if available, created by deepzoom_tiler.

Output Directory Structure

The script saves the outputs in the following directory structure:

└── {embedder}_{version_name}/
    └── {dataset_name}/
        ├── embedder.pth
        ├── {train, test, validation}/
        │   └── {0_normal, 1_tumor}.csv
        │   ├── {0_normal, 1_tumor}/
        │   │   └── {slide_name}.csv
        └── {dataset_name}.csv
  • {embedder}: The name of the embedder model used (e.g., SimCLR).
  • {version_name}: The version name of the embedder model.
  • {dataset_name}: The name of the dataset.
  • embedder.pth: The cleaned embedder weights.
  • {slide_name}.csv: CSV file containing features [feature_0, ..., feature_511, position, label] for each slide. Each row corresponds to a patch from the slide.
  • {split}/{class_name}.csv: CSV file containing [bag_path, bag_label] for each class in each split ( train/validation/test).
  • {dataset_name}.csv: CSV file containing [bag_path, bag_label] for the whole dataset.

Usage on CAMELYON16

SimCLR from scratch

python \
  --backbone=resnet18 \
  --norm_layer=instance \
  --weights=embedders/dsmil_simclr.pth \
  --embedder=SimCLR \

DINO from scratch

python \
  --embedder=DINO \
  --num_classes=2048 \
  --backbone=vit_small \
  --weights=embedders/dino_scratch.pth \

DINO with Adapter

python \
  --embedder=DINO \
  --num_classes=2048 \
  --backbone=vit_small \
  --patch_size=8 \
  --weights=embedders/dino_adapter.pth \
  --ffn_num=32 \
  --adapter_ffn_scalar=10 \
  --version_name=dino_adapter \
  --use_adapter \
  --transform 1

MAE with Adapter

python \
  --embedder=MAE \
  --num_classes=512 \
  --backbone=mae_vit_base_patch16 \
  --weights=embedders/mae_adapter.pth \
  --ffn_num=64 \
  --adapter_ffn_scalar=1 \
  --version_name=mae_adapter \
  --use_adapter \
  --transform 1

Usage on TCGA Lung

SimCLR from scratch

python \
  --backbone=resnet18 \
  --dataset=tcga \
  --norm_layer=instance \
  --weights=embedders/dsmil_simclr_tcga.pth \
  --embedder=SimCLR \

MIL Training

Example Run for CAMELYON16

DINO from scratch

python \ 
  --activation=relu \
  --arch=snuffy \
  --betas="[0.9, 0.999]" \
  --big_lambda=900 \
  --dataset=camelyon16 \
  --embedding=DINO_dino_scratch \
  --encoder_dropout=0.1 \
  --feats_size=384 \
  --l2normed_embeddings=1 \
  --lr=0.02 \
  --num_epochs=200 \
  --num_heads=4 \
 --optimizer=adamw \
 --random_patch_share=0.7777777777777778 \
 --scheduler=cosine \
 --single_weight__lr_multiplier=1 \
 --soft_average=0 \
 --weight_decay=0.05 \
 --weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"

DINO with Adapter

python \
  --activation=relu \
  --arch=snuffy \
  --betas="[0.9, 0.999]" \
  --big_lambda=500 \
  --dataset=camelyon16 \
  --embedding=DINO_dino_adapter \
  --encoder_dropout=0.1 \
  --feats_size=384 \
  --l2normed_embeddings=1 \
  --lr=0.02 \
  --num_epochs=200 \
  --num_heads=4 \
  --optimizer=adamw \
  --random_patch_share=0.5 \
  --scheduler=cosine \
  --single_weight__lr_multiplier=1 \
  --soft_average=1 \
  --weight_decay=0.05 \
  --weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"

MAE with Adapter

python \
  --activation=relu \
  --arch=snuffy \
  --betas="[0.9, 0.999]" \
  --big_lambda=500 \
  --dataset=camelyon16 \
  --embedding=MAE_mae_adapter \
  --encoder_dropout=0 \
  --feats_size=768 \
  --l2normed_embeddings=0 \
  --lr=0.02 \
  --num_epochs=200 \
  --num_heads=4 \
  --optimizer=adamw \
  --random_patch_share=0.5 \
  --scheduler=cosine \
  --single_weight__lr_multiplier=1 \
  --soft_average=1 \
  --weight_decay=0.05 \
  --weight_init__weight_init_i__weight_init_b="['trunc_normal', 'xavier_uniform', 'trunc_normal']"

--feats_size should match the size of features you got in Feature Extraction. --random_patch_share * --big_lambda shows the number of random patches and the rest are top patches.

For TCGA use --arch=snuffy_multiclass.

Example Run for MIL Datasets

python \
  --arch=snuffy \
  --dataset=musk1 \
  --num_heads=2 \
  --cv_num_folds 10 \
  --cv_valid_ratio 0.2 \
  --cv_current_fold 1


  1. Feature Size is automatically set based on the dataset ('musk1' and 'musk2': 166, 'elephant': 230). No manual adjustment needed.
  2. MultiHeadAttention: Ensure the feature size is divisible by the number of heads.
  3. Cross-Validation: Use to generate a shuffle file ({dataset_file_name}_{num_folds}folds_{valid_ratio}split.pkl, e.g. musk1_10folds_0.2split.pkl). Match args.cv_num_folds and args.cv_valid_ratio in this script to read the file correctly. Set the desired fold to train using args.cv_current_fold.


In the figure below, the black line outlines the tumor area. The model's attention is represented by a color overlay, where red indicates the highest attention and blue indicates the lowest. As shown, the model effectively highlights the tumor regions.

To create heatmaps similar to the one shown above, run the following command:

python \
  --batch_size=512 \
  --num_workers=24 \
  --embedder_weights=embedders/clean/camelyon16/SimCLR/embedder.pth \
  --aggregator_weights=aggregators/snuffy_simclr_dsmil.pth \
  --thres_tumor=0.75959325 \
  --num_heads=2 \
  --encoder_dropout=0.2 \
  --k=900 \
  --random_patch_share=0.7777777777777778 \
  --activation=gelu \

The script requires the following inputs:

  • --embedder_weights: Path to the embedder weights file
  • --aggregator_weights: Path to the aggregator weights file
  • Ground truth masks located in datasets/camelyon16/masks/
  • Raw TIFF slides located in datasets/camelyon16/1_tumor/
  • Name and label of slides located in datasets/camelyon16/reference.csv

For each slide, the script generates the following outputs:

  • Heatmaps saved in roi_output/{slide_name}/cmaps/, where:
    • jet_slide.png is the raw slide.
    • jet.png is the slide with the attention map overlay and the ground truth tumor region outlined in black.

By default, the script processes 3 slides from the CAMELYON16 test set, but you can customize the slides to process by modifying the script. Additionally, reducing the DPI setting can speed up processing.

You can download the aggregator used for creating the figure above from here.


This codebase is built upon the work of DSMIL, DINO, and MAE. We extend our gratitude to the authors for their valuable contributions.


If you find our work helpful for your research, please consider giving a star to this repository and citing the following BibTeX entry.

      title={Snuffy: Efficient Whole Slide Image Classifier}, 
      author={Hossein Jafarinia and Alireza Alipanah and Danial Hamdi and Saeed Razavi and Nahal Mirzaie and Mohammad Hossein Rohban},


Snuffy: Efficient Whole Slide Image Classifier For Efficient and Performant Diagnosis in Pathology Whole Slide Images








No releases published


No packages published