Skip to content

Hippocampus Segmentation from MRI using 3D Convolutional Neural Networks in PyTorch

License

Notifications You must be signed in to change notification settings

MLMS-CG/HippocampusSegmentationMRI

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

57 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Hippocampus Segmentation from MRI using V-Net

In this repo, hippocampus segmentation from MRI is performed using a Convolutional Neural Network (CNN) architecture based on V-Net. The dataset is publicly available from the Medical Segmentation Decathlon Challenge, and can be downloaded from here.

The PyTorch library has been used to write the model architecture and performing the training and validation. SimpleITK has been exploited to handle I/O of medical images. 3D Data Augmentation has been made by employing torchio.

A 5-folders cross validation has been performed on the training set, yielding a Mean Multi Dice Coefficient of 0.8727 +/- 0.0364, a Dice Coefficient for Anterior Hippocampus of 0.8821 +/- 0.0363 and a Dice Coefficient for Posterior Hippocampus of 0.8634 +/- 0.0415. The results are reported as "mean +/- std".

Meshes and images reported in the images folder have been obtained exploiting ITK-SNAP.

Quality Measures

Results
Model Mean Dice per case Dice per case (Anterior) Dice per case (Posterior)
3D V-Net (no data augmentation) 0.8727 +/- 0.0364 0.8821 +/- 0.0363 0.8634 +/- 0.0415
3D V-Net (with data augmentation) 0.8761 +/- 0.0374 0.8875 +/- 0.0354 0.8647 +/- 0.0455

Confusion Matrices

No Data Augmentation

Confusion Matrix Normalized Confusion Matrix
Confusion Matrix (Cross-validation) Normalized Confusion Matrix (Cross-validation)

With Data Augmentation

Confusion Matrix Normalized Confusion Matrix
Confusion Matrix (Cross-validation) Normalized Confusion Matrix (Cross-validation)

TODO List

  • Automatic Download of dataset
  • CNN Architecture Definition
  • 3D Data Loader for Nifti files
  • Definition of loss functions
  • Training loop
  • Cross-validation on Train set
  • Command Line Interface for training
  • Command Line Interface for validation
  • 3D Data Augmentation
  • Tuning of Optimal Parameters for 3D Data Augmentation
  • Validation on Test set

Usage

Use poetry install for installing this package. A complete run (dataset download, train, validation) of the package may be the following:

git clone https://github.com/Nicolik/HippocampusSegmentationMRI.git
cd HippocampusSegmentationMRI
poetry install
python run/download.py
python run/train.py 
python run/validate.py

Dataset

If you want to download the original dataset, run run/download.py. The syntax is as follows:

python run/download.py --dir=path/to/dataset/dir

Training

If you simply want to perform the training, run run/train.py. The syntax is as follows:

python run/train.py --epochs=NUM_EPOCHS --batch=BATCH_SIZE --workers=NUM_WORKERS --lr=LR

If you want to edit the configuration, you can also modify the config/config.py file. In particular, consider the class SemSegMRIConfig. If you want to play with data augmentation (built with torchio), modify the config/augm.py file.

Validation

If you want to perform the cross-validation, run run/validate.py or run/validate_torchio.py. The syntax is as follows:

python run/validate.py --dir=path/to/logs/dir --write=WRITE --verbose=VERBOSE
python run/validate_torchio.py --dir=path/to/logs/dir --verbose=VERBOSE

The former adopts a loop from scratch, whereas the latter exploits the DataLoader created upon torchio.

Output Results

Sample Images (Training set)

Ground Truth Images

Ground Truth - MRI 327 (1) Ground Truth - MRI 327 (2)
Ground Truth - MRI 327 (1) Ground Truth - MRI 327 (2)

Predicted Images

Prediction - MRI 327 (1) Prediction - MRI 327 (2)
Prediction   - MRI 327 (1) Prediction   - MRI 327 (2)

Sample Images (Test set)

Predicted Images

Prediction - MRI 283 (1) Prediction - MRI 283 (2)
Prediction   - MRI 283 (1) Prediction   - MRI 283 (2)

About

Hippocampus Segmentation from MRI using 3D Convolutional Neural Networks in PyTorch

Resources

License

Stars

Watchers

Forks

Languages

  • Python 100.0%