-
Notifications
You must be signed in to change notification settings - Fork 178
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4401498
commit a893b4b
Showing
48 changed files
with
5,674 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,135 @@ | ||
# Medical-Transformer | ||
Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" | ||
# Medical-Transformer-pytorch | ||
|
||
<a href="https://arxiv.org/abs/2006.04878"> Paper (Conference) </a> | ||
|
||
Official Pytorch Code for the paper [""](https://arxiv.org/) | ||
Journal Extension: | ||
["Medical Transformer: Gated Axial-Attention for | ||
Medical Image Segmentation"](https://arxiv.org/) | ||
|
||
# About this repo: | ||
|
||
This repo hosts the code for the following networks: | ||
|
||
1) Gated Axial Attention U-net | ||
2) MedT | ||
|
||
# Introduction | ||
|
||
Majority of existing Transformer-based network architectures proposed for vision applications require large-scale | ||
datasets to train properly. However, compared to the datasets for vision | ||
applications, for medical imaging the number of data samples is relatively | ||
low, making it difficult to efficiently train transformers for medical appli- | ||
cations. To this end, we propose a Gated Axial-Attention model which | ||
extends the existing architectures by introducing an additional control | ||
mechanism in the self-attention module. Furthermore, to train the model | ||
effectively on medical images, we propose a Local-Global training strat- | ||
egy (LoGo) which further improves the performance. Specifically, we op- | ||
erate on the whole image and patches to learn global and local features, | ||
respectively. The proposed Medical Transformer (MedT) uses LoGo training strategy on Gated Axial Attention U-Net. | ||
|
||
<p align="center"> | ||
<img src="img/arch.png" width="800"/> | ||
</p> | ||
|
||
### Using the code: | ||
|
||
- Clone this repository: | ||
```bash | ||
git clone https://github.com/jeya-maria-jose/KiU-Net-pytorch | ||
cd KiU-Net-pytorch | ||
``` | ||
|
||
The code is tested in Python 3.6.10, Pytorch 1.4.0 | ||
|
||
To install all the dependecies using conda: | ||
|
||
```bash | ||
conda env create -f environment.yml | ||
conda activate medt | ||
``` | ||
|
||
### Links for downloading the public Datasets: | ||
|
||
1) GLAS Dataset - <a href="https://warwick.ac.uk/fac/sci/dcs/research/tia/glascontest/"> Link (Original) </a> | <a href = "https://drive.google.com/drive/folders/1dwhjqE0vC0KL_siGUeqMUq08KyO1bPKH?usp=sharing"> Link (Resized) </a> | ||
2) MoNuSeG Dataset - <a href="https://monuseg.grand-challenge.org/Data/"> Link (Original)</a> | ||
3) Brain Anatomy US dataset introduced in the paper will be made public soon ! | ||
|
||
# Using the Code for your dataset | ||
|
||
### Dataset Preparation | ||
|
||
Prepare the dataset in the following format for easy use of the code. The train and test folders should contain two subfolders each: img and label. Make sure the images their corresponding segmentation masks are placed under these folders and have the same name for easy correspondance. Please change the data loaders to your need if you prefer not preparing the dataset in this format. | ||
|
||
|
||
|
||
```bash | ||
Train Folder----- | ||
img---- | ||
0001.png | ||
0002.png | ||
....... | ||
label--- | ||
0001.png | ||
0002.png | ||
....... | ||
Validation Folder----- | ||
img---- | ||
0001.png | ||
0002.png | ||
....... | ||
label--- | ||
0001.png | ||
0002.png | ||
....... | ||
Test Folder----- | ||
img---- | ||
0001.png | ||
0002.png | ||
....... | ||
label--- | ||
0001.png | ||
0002.png | ||
....... | ||
|
||
``` | ||
|
||
- The ground truth images should have pixels corresponding to the labels. Example: In case of binary segmentation, the pixels in the GT should be 0 or 255. | ||
|
||
### Training Command: | ||
|
||
```bash | ||
python train.py --train_dataset "enter train directory" --val_dataset "enter validation directory" --direc 'path for results to be saved' --batch_size 4 --epoch 400 --save_freq 10 --modelname "gatedaxialunet" --learning_rate 0.001 --imgsize 128 --gray "no" | ||
``` | ||
|
||
```bash | ||
Change modelname to medt or logo to train them | ||
``` | ||
|
||
### Testing Command: | ||
|
||
```bash | ||
python test.py --loaddirec "./saved_model_path/model_name.pth" --val_dataset "test dataset directory" --direc 'path for results to be saved' --batch_size 1 --modelname "kiunet" --imgsize 128 --gray "no" | ||
``` | ||
|
||
The results including predicted segmentations maps will be placed in the results folder along with the model weights. Run the performance metrics code in MATLAB for calculating F1 Score and mIoU. | ||
|
||
### Acknowledgement: | ||
|
||
The dataloader code is inspired from <a href="https://github.com/cosmic-cortex/pytorch-UNet"> pytorch-UNet </a>. The axial attention code is developed from <a href="<a href="https://github.com/cosmic-cortex/pytorch-UNet"> pytorch-UNet"> axial-deeplab </a>. | ||
|
||
# Citation: | ||
|
||
```bash | ||
@inproceedings{valanarasu2020kiu, | ||
title={Medical Transformer: Gated Axial-Attention for Medical Image Segmentation}, | ||
author={Valanarasu, Jeya Maria Jose and Sindagi, Vishwanath A and Hacihaliloglu, Ilker and Patel, Vishal M}, | ||
booktitle={Medical Image Computing and Computer Assisted Intervention--MICCAI 2020: 23rd International Conference, Lima, Peru, October 4--8, 2020, Proceedings, Part IV 23}, | ||
pages={363--373}, | ||
year={2020}, | ||
organization={Springer} | ||
} | ||
``` | ||
|
||
``` | ||
Open an issue in case of any queries or mail me directly. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
python train.py --train_dataset "/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/train/" --val_dataset "/media/jeyamariajose/7888230b-5c10-4229-90f2-c78bdae9c5de/Data/Brain_Ultrasound/Final/resized/test/" --direc "./results/axial128_en/" --batch_size 4 --modelname "logo" --epoch 401 --save_freq 50 --learning_rate 0.0001 --imgsize 128 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
name: medt | ||
channels: | ||
- conda-forge | ||
- defaults | ||
dependencies: | ||
- _libgcc_mutex=0.1=main | ||
- argon2-cffi=20.1.0=py36h8c4c3a4_1 | ||
- attrs=20.1.0=pyh9f0ad1d_0 | ||
- backcall=0.2.0=pyh9f0ad1d_0 | ||
- backports=1.0=py_2 | ||
- backports.functools_lru_cache=1.6.1=py_0 | ||
- blas=1.0=mkl | ||
- bleach=3.1.5=pyh9f0ad1d_0 | ||
- brotlipy=0.7.0=py36h8c4c3a4_1000 | ||
- ca-certificates=2020.6.20=hecda079_0 | ||
- certifi=2020.6.20=py36h9f0ad1d_0 | ||
- cffi=1.11.5=py36_0 | ||
- chardet=3.0.4=py36h9f0ad1d_1006 | ||
- cryptography=3.1=py36h45558ae_0 | ||
- decorator=4.4.2=py_0 | ||
- defusedxml=0.6.0=py_0 | ||
- entrypoints=0.3=py36h9f0ad1d_1001 | ||
- idna=2.10=pyh9f0ad1d_0 | ||
- importlib-metadata=1.7.0=py36h9f0ad1d_0 | ||
- importlib_metadata=1.7.0=0 | ||
- intel-openmp=2020.1=217 | ||
- ipykernel=5.3.4=py36h95af2a2_0 | ||
- ipython=7.16.1=py36h95af2a2_0 | ||
- ipython_genutils=0.2.0=py_1 | ||
- ipywidgets=7.5.1=py_0 | ||
- jedi=0.17.2=py36h9f0ad1d_0 | ||
- jinja2=2.11.2=pyh9f0ad1d_0 | ||
- json5=0.9.4=pyh9f0ad1d_0 | ||
- jsonschema=3.2.0=py36h9f0ad1d_1 | ||
- jupyter_client=6.1.7=py_0 | ||
- jupyter_core=4.6.3=py36h9f0ad1d_1 | ||
- jupyterlab=2.2.6=py_0 | ||
- jupyterlab_server=1.2.0=py_0 | ||
- ld_impl_linux-64=2.33.1=h53a641e_7 | ||
- libedit=3.1.20191231=h7b6447c_0 | ||
- libffi=3.3=he6710b0_1 | ||
- libgcc-ng=9.1.0=hdf63c60_0 | ||
- libgfortran-ng=7.3.0=hdf63c60_0 | ||
- libsodium=1.0.18=h516909a_0 | ||
- libstdcxx-ng=9.1.0=hdf63c60_0 | ||
- markupsafe=1.1.1=py36h8c4c3a4_1 | ||
- mistune=0.8.4=py36h8c4c3a4_1001 | ||
- mkl=2020.1=217 | ||
- mkl-service=2.3.0=py36he904b0f_0 | ||
- mkl_fft=1.1.0=py36h23d657b_0 | ||
- mkl_random=1.1.1=py36h0573a6f_0 | ||
- nbconvert=5.6.1=py36h9f0ad1d_1 | ||
- nbformat=5.0.7=py_0 | ||
- ncurses=6.2=he6710b0_1 | ||
- notebook=6.1.3=py36h9f0ad1d_0 | ||
- numpy=1.18.5=py36ha1c710e_0 | ||
- numpy-base=1.18.5=py36hde5b4d6_0 | ||
- openssl=1.1.1g=h516909a_1 | ||
- packaging=20.4=pyh9f0ad1d_0 | ||
- pandoc=2.10.1=h516909a_0 | ||
- pandocfilters=1.4.2=py_1 | ||
- parso=0.7.1=pyh9f0ad1d_0 | ||
- pexpect=4.8.0=py36h9f0ad1d_1 | ||
- pickleshare=0.7.5=py36h9f0ad1d_1001 | ||
- pip=20.1.1=py36_1 | ||
- prometheus_client=0.8.0=pyh9f0ad1d_0 | ||
- prompt-toolkit=3.0.7=py_0 | ||
- ptyprocess=0.6.0=py_1001 | ||
- pycparser=2.20=pyh9f0ad1d_2 | ||
- pygments=2.6.1=py_0 | ||
- pyopenssl=19.1.0=py_1 | ||
- pyparsing=2.4.7=pyh9f0ad1d_0 | ||
- pyrsistent=0.16.0=py36h8c4c3a4_0 | ||
- pysocks=1.7.1=py36h9f0ad1d_1 | ||
- python=3.6.10=h7579374_2 | ||
- python-dateutil=2.8.1=py_0 | ||
- python_abi=3.6=1_cp36m | ||
- pyzmq=19.0.2=py36h9947dbf_0 | ||
- readline=8.0=h7b6447c_0 | ||
- requests=2.24.0=pyh9f0ad1d_0 | ||
- send2trash=1.5.0=py_0 | ||
- setuptools=47.3.1=py36_0 | ||
- six=1.15.0=py_0 | ||
- sqlite=3.32.3=h62c20be_0 | ||
- terminado=0.8.3=py36h9f0ad1d_1 | ||
- testpath=0.4.4=py_0 | ||
- tk=8.6.10=hbc83047_0 | ||
- tornado=6.0.4=py36h8c4c3a4_1 | ||
- traitlets=4.3.3=py36h9f0ad1d_1 | ||
- urllib3=1.25.10=py_0 | ||
- wcwidth=0.2.5=pyh9f0ad1d_1 | ||
- webencodings=0.5.1=py_1 | ||
- wheel=0.34.2=py36_0 | ||
- widgetsnbextension=3.5.1=py36h9f0ad1d_1 | ||
- xz=5.2.5=h7b6447c_0 | ||
- yaml=0.2.5=h7b6447c_0 | ||
- zeromq=4.3.2=he1b5a44_3 | ||
- zipp=3.1.0=py_0 | ||
- zlib=1.2.11=h7b6447c_3 | ||
- pip: | ||
- ci-info==0.2.0 | ||
- click==7.1.2 | ||
- cython==0.29.20 | ||
- et-xmlfile==1.0.1 | ||
- etelemetry==0.2.1 | ||
- filelock==3.0.12 | ||
- isodate==0.6.0 | ||
- jdcal==1.4.1 | ||
- joblib==0.17.0 | ||
- lxml==4.5.1 | ||
- matplotlib==3.3.2 | ||
- medpy==0.4.0 | ||
- natsort==7.0.1 | ||
- nibabel==3.1.0 | ||
- nipype==1.5.0 | ||
- openpyxl==3.0.4 | ||
- prov==1.5.3 | ||
- pydicom==2.0.0 | ||
- pydot==1.4.1 | ||
- pydotplus==2.0.2 | ||
- pynrrd==0.4.2 | ||
- rdflib==5.0.0 | ||
- scikit-learn==0.23.2 | ||
- scipy==1.5.3 | ||
- setproctitle==1.1.10 | ||
- simplejson==3.17.0 | ||
- threadpoolctl==2.1.0 | ||
- torch==1.4.0 | ||
- torch-dwconv==0.1.0 | ||
- torchvision==0.4.0 | ||
- traits==6.1.0 | ||
prefix: /home/jeyamariajose/anaconda3/envs/medt | ||
|
Oops, something went wrong.