Authors: Santosh Sanjeev, Nuren Zhaksylyk, Ibrahim Almakky, Anees Ur Rehman Hashmi, Mohammad Areeb Qazi, Mohammad Yaqub
Abstract: The scarcity of well-annotated medical datasets requires leveraging transfer learning from broader datasets like ImageNet or pre-trained models like CLIP. Model soups averages multiple fine-tuned models aiming to improve performance on In-Domain (ID) tasks and enhance robustness against Out-of-Distribution (OOD) datasets. However, applying these methods to the medical imaging domain faces challenges and results in suboptimal performance. This is primarily due to differences in error surface characteristics that stem from data complexities such as heterogeneity, domain shift, class imbalance, and distributional shifts between training and testing phases. To address this issue, we propose a hierarchical merging approach that involves local and global aggregation of models at various levels based on models' hyperparameter configurations. Furthermore, to alleviate the need for training a large number of models in the hyperparameter search, we introduce a computationally efficient method using a cyclical learning rate scheduler to produce multiple models for aggregation in the weight space. Our method demonstrates significant improvements over the model souping approach across multiple datasets (around 6% gain in HAM10000 and CheXpert datasets) while maintaining low computational costs for model generation and selection. Moreover, we achieve better results on OOD datasets than model soups.
Welcome to the repository for "FissionFusion: Fast Geometric Generation and Hierarchical Souping for Medical Image Analysis". This paper introduces the limitations of model soups and introduces an innovative approach towards generation and merging of models.
-
Clone the repository:
git clone https://github.com/BioMedIA-MBZUAI/Fission-Fusion.git
-
Create a conda environment:
conda create --name fissionfusion python=3.8 conda activate fissionfusion
-
Install PyTorch and other dependencies:
pip install -r requirements.txt
Details about the datasets currently supported, where to download the datasets and the data directory structure is available here.
-
a) To run the grid search experiments, we need to run the linear probing first as a warmup to get the linear-probed model (θlp). Please change the DATASETS paths, and the implementation section in the corresponding config file as per the dataset and model.
python train.py --config './configs/lp.yaml'
b) To run the finetuning stage (which returns 48 models for all the hyperparameter settings)
python finetune.py --config './configs/full_finetuning.yaml'
-
a) For the fast geometric generation experiments, we first get the models for different learning rates fixing the seed = 1 and augmentation = Heavy. We get 6 models for each learning rate including the initial model.
python finetune.py --config './configs/pre_fgg_finetuning.yaml'
b) To further generate the models in the second stage, we use a cyclic learning rate scheduler and save models for every cycle.
python fgg.py --config './configs/fgg.yaml'
-
a) To run uniform and greedy soups on the Grid Search (GS) generated models, execute
python gs_model_souping_test.py --config './configs/gs_model_souping_test.yaml'
b) To run uniform and greedy soups on the FGG generated models, execute
python fgg_model_souping_test.py --config './configs/fgg_model_souping_test.yaml'
-
To run the hierarchical souping approach, execute
python hierarchical_souping.py --config './configs/hierarchical_souping.yaml'
If you use Fission-Fusion or our repository in your research, please cite our paper *FissionFusion: Fast Geometric Generation and Hierarchical Souping for Medical Image Analysis *:
@article{sanjeev2024fissionfusion,
title={FissionFusion: Fast Geometric Generation and Hierarchical Souping for Medical Image Analysis},
author={Santosh Sanjeev and Nuren Zhaksylyk and Ibrahim Almakky and Anees Ur Rehman Hashmi and Mohammad Areeb Qazi and Mohammad Yaqub},
year={2024},
eprint={2403.13341},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
Our work is inspired from
For any inquiries or questions, please create an issue on this repository or contact Santosh Sanjeev at santosh.sanjeev@mbzuai.ac.ae.