Skip to content

Deblurring Masked Image Modelling is Better Recipe for Ultrasound Image Analysis

License

Notifications You must be signed in to change notification settings

MembrAI/DeblurringMIM

Repository files navigation

Deblurring Masked Image Modelling is Better Recipe for Ultrasound Image Analysis

Introduction

Our initial paper is presented in Deblurring Masked Autoencoder is Better Recipe for Ultrasound Image Recognition which has been accepted in MICCAI 2023.

Compared with the initial version which propose a novel deblurring MAE for ultrasound image recognition, this repository is an extension which has the following enhancements:

  • We extended the deblurring idea from only MAE to two MIM-based approaches (MAE and ConvMAE ).
  • We increased the number of thyroid ultrasound images for pretraining from 10,675 to 280,000.
  • We extended the downstram task from only classification to classification and segmentation.

The corresponding extended paper is still working in progress, we will release it soon.

Method

tenser

Pre-trained checkpoints

The pre-trained models including our proposed deblurring (Deblurring MAE and Deblurring ConvMAE), as well as the vanilla MAE and ConvMAE. All models are pretrained with 280,000 thyroid ultrasound images. The following table provides the pre-trained checkpoints:

MAE(Base) ConvMAE(Base)
Vanilla download download
Deblurring download download

Fine-tuning Results

We provide the fine-tuning segmentation results on publicly available TN3K dataset:

Method Segmentation Model Pretraining IoU (%)
- TRFE+ - 74.47
ConvMAE U-Net++(ConViT-B) US-Esaote-280K 74.40
Deblurring ConvMAE U-Net++(ConViT-B) US-Esaote-280K 74.96

installation

  • Clone this repo:
git clone https://github.com/MembrAI/DeblurringMAE
cd DeblurringMAE
  • Create a conda environment and activate it:
conda create -n deblurringmae python=3.7
conda activate deblurringmae
  • Install Pytorch==1.8.0 and torchvision==0.9.0 with CUDA==11.1
conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=11.1 -c pytorch -c conda-forge
  • Install timm==0.3.2
pip install timm==0.3.2

Pretraining

Data preparation

Prepareing the original dataset follow this format:

dataset_orig
  ├── train
      ├── class1
      │   ├── img1.png
      │   ├── img2.png
      │   └── ...
      ├── class2
      │   ├── img3.png
      │   └── ...
      └── ...
  ├── val
      ├── class1
      │   ├── img4.png
      │   ├── img5.png
      │   └── ...
      ├── class2
      │   ├── img6.png
      │   └── ...
      └── ...

For deblurring pretraining, you also need to apply image blurring operation on the original dataset to prepare blurred dataset:

python blurred_images.py --src_dir /path/to/dataset_orig/ --dst_dir /path/to/dataset_blurred/ \
     --method gaussian --sigma 1.1

Running Pretraining Scripts

To pretrain the deblurring MAE, run the following on 1 nodes with 8 GPUs each:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 main_pretrain.py \
--model dmae_vit_base_patch16 --output /path/to/saved/weights/ \
--data_path_blurred /path/to/dataset_blurred/ \
--data_path_orig /path/to/dataset_orig/ --batch_size 32

To pretrain the deblurring ConvMAE, run the following on 1 nodes with 8 GPUs each:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 main_pretrain.py \
--model dconvmae_convvit_base_patch16 --output /path/to/saved/weights/ \
--data_path_blurred /path/to/dataset_blurred/ \
--data_path_orig /path/to/dataset_orig/ --batch_size 32

To pretrain the vanilla MAE, run the following on 1 nodes with 8 GPUs each:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 main_pretrain.py \
--model mae_vit_base_patch16 --output /path/to/saved/weights/ \
--data_path_blurred /path/to/dataset_orig/ \
--data_path_orig /path/to/dataset_orig/ --batch_size 32

To pretrain the vanilla ConvMAE, run the following on 1 nodes with 8 GPUs each:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7,8 torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 main_pretrain.py \
--model convmae_convvit_base_patch16 --output /path/to/saved/weights/ \
--data_path_blurred /path/to/dataset_orig/ \
--data_path_orig /path/to/dataset_orig/ --batch_size 32

Fine-tuning for classification

Data preparation

Preparing the dataset for classification follow this format:

dataset
  ├── train
  │   ├── class1
  │   │   ├── img1.png
  │   │   ├── img2.png
  │   │   └── ...
  │   ├── class2
  │   │   ├── img3.png
  │   │   └── ...
  │   └── ...
  └── val
  │   ├── class1
  │   │   ├── img4.png
  │   │   ├── img5.png
  │   │   └── ...
  │   ├── class2
  │   │   ├── img6.png
  │   │   └── ...
  │   └── ...
  └── test
      ├── class1
      │   ├── img7.png
      │   ├── img8.png
      │   └── ...
      ├── class2
      │   ├── img9.png
      │   └── ...
      └── ...

Note that for fine-tuning the deblurring MIM approaches, you should also need to apply image blurring operation on the original images.

Training for classification

To finetune deblurring or vanilla MAE training for classification, run the following on single GPU:

python main_finetune.py --seed 0 \
    --data_path  /path/to/dataset/  \
    --output_dir /path/to/saved/weights/ \
    --model vit_base_patch16 --finetune ${PRETRAIN_CHKPT} \
    --blr 1e-4 --batch_size 256

You can change the model parameter to convvit_base_patch16 for deblurring or vanilla ConvMAE. Note that for deblurring models, you should use blurred images as input dataset.

Evaluation for classification

To evaluate the fine-tuned deblurring or vanilla MAE for classification, run the following on single GPU:

python main_finetune.py --batch_size 256  \
--model vit_base_patch16 \
--data_path /path/to/dataset/ --nb_classes 2 \
--output_dir  /path/to/save/results/ \
--resume ${FINETUNE_CHKPT} --eval 

Fine-tuning for segmentation

Data preparation

Preparing the segmentation dataset in this format:

dataset
  ├── images_gaussian
  │   ├── train
  │   │   ├── 0000.png
  │   │   ├── 0001.png
  │   │   └── ...
  │   ├── val
  │   │   ├── 0002.png
  │   │   ├── 0003.png
  │   │   └── ...
  │   └── test
  │       ├── 0004.png
  │       ├── 0005.png
  │       └── ...
  └── masks
      ├── train
      │   ├── 0000.png
      │   ├── 0001.png
      │   └── ...
      ├── val
      │   ├── 0002.png
      │   ├── 0003.png
      │   └── ...
      └── test
          ├── 0004.png
          ├── 0005.png
          └── ...

Note that the images_gaussian folder contains gaussian blurred images. We use 255 for foreground pixels in masks.

Training for segmentation

Download the pretrained deblurring ConvMAE model here.

Run SEG_UNET/train_smp.py

python SEG_UNET/train_smp.py --encoder_weights /path/to/pretrained/weights/ --datapath /path/to/dataset/ --output_dir /path/to/save/results/

Evaluation for segmentation

We provide the fine-tuned checkpoint here tuned on the TN3K dataset. Run

python SEG_UNET/predict.py --weight_path /path/to/seg/checkpoint/ --save_dir /path/to/save/predictions/ --datapath /path/to/dataset/

This should give

Acknowledgement

The pretraining and finetuning of our project are based on MAE and ConvMAE. The segmentation part is based on segmentation_models.pytorch. Thanks for their wonderful work.

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Citation

@article{kang2023deblurring,
  title={Deblurring Masked Autoencoder is Better Recipe for Ultrasound Image Recognition},
  author={Kang, Qingbo and Gao, Jun and Li, Kang and Lao, Qicheng},
  journal={arXiv preprint arXiv:2306.08249},
  year={2023}
}

About

Deblurring Masked Image Modelling is Better Recipe for Ultrasound Image Analysis

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published