Skip to content

[IJCAI'18] Unsupervised Cross-Modality Domain Adaptation of ConvNets for Biomedical Image Segmentations with Adversarial Loss (code&data)

License

Notifications You must be signed in to change notification settings

AArchLichKing/Medical-Cross-Modality-Domain-Adaptation

 
 

Repository files navigation

Medical Cross-Modality Domain Adaptation (Med-CMDA)

Here are implementations for paper:

PnP-AdaNet: Plug-and-Play Adversarial Domain Adaptation Network with a Benchmark at Cross-modality Cardiac Segmentation. (https://arxiv.org/abs/1812.07907) (long version)

Unsupervised Cross-Modality Domain Adaptation of ConvNets for Biomedical Image Segmentations with Adversarial Loss, IJCAI, pp. 691-697, 2018. (https://arxiv.org/abs/1804.10916) (short version)

Introduction

Deep convolutional networks have demonstrated the state-of-the-art performance on various medical image computing tasks. However, the generalization capability of deep models on test data with different distributions remain as a major challenge. In this project, we tackle an interesting problem setting of unsupervised domain adaptation between CT and MRI, by proposing a plug-and-play adversarial domain adaptation network to align feature spaces of both domains presenting significant domain shift.

Usage

0. Packages

nibabel==2.1.0
nilearn==0.3.1
numpy==1.13.3
tensorflow-gpu==1.4.0 
python 2.7

(Note: other tf versions not tested, and please notify us if it also works :0)

1. Data preprocessing

The original data of cardiac 20 CT and 20 MR images come from MMWHS Challenge, with the original data release license also applies to this project.

The pre-processed and augmented training data repository can be downloaded here, in the form of tfrecord for direct load. The testing CT data can be downloaded here, in the form of .nii with heart region cropped.
The same data is also used for our SIFA paper.

Briefly, the images were 1. cropped centering at the heart region, with four cardiac substructures selected for segmentation considering mutual visibility in 2D view; 2. for each 3D cropped image top 2/% of its intensity histogram was cut off for alleviating artifacts; 3. each 3D image was then normalized to zero-mean, unit standard diviation; 4. 2D coronal slices were sampled with data augmentation.

To adapt a segmenter from MR to CT, use:
ct_train_tfs: training slices from 14 cases, 600 slices each, 8400 slices in total.
ct_val_tfs: validation slices from 2 cases, 600 slices each. 1200 slices in total.
mr_train_tfs: training slices from 16 cases, 600 slices each, 9600 slices in total.
mr_val:tfs: validation slices from 4 cases, 600 slices each, 2400 slices in total.
Since we are doing MR to CT adaptation, we don't need a real MR testing set

For the ease of training, after data augmentation, training samples are expected to be written into tfrecord with the following format:

feature = {
            # image size, dimensions of 3 consecutive slices
            'dsize_dim0': tf.FixedLenFeature([], tf.int64), # 256
            'dsize_dim1': tf.FixedLenFeature([], tf.int64), # 256
            'dsize_dim2': tf.FixedLenFeature([], tf.int64), # 3
            # label size, dimension of the middle slice
            'lsize_dim0': tf.FixedLenFeature([], tf.int64), # 256
            'lsize_dim1': tf.FixedLenFeature([], tf.int64), # 256
            'lsize_dim2': tf.FixedLenFeature([], tf.int64), # 1
            # image slices of size [256, 256, 3]
            'data_vol': tf.FixedLenFeature([], tf.string),
            # label slice of size [256, 256, 1]
            'label_vol': tf.FixedLenFeature([], tf.string)}

2. Training base segmentation network

Run train_segmenter.py, where training configurations are specified.

This calls source_segmenter.py, where network structure and training function are defined.

3. Training adversarial domain adaptation

3.1 Warming-up the discriminator

To obtain a good initial estimation of Wasserstein distances between feature maps of two domains, we first pre-train the feature domain discriminator. In order to do this, run

python train_gan.py --phase pre-train

3.2 Training adversarial domain adaptation

After warming the discriminator up, we can then jointly train the feature domain discriminator and the domain adaptation module (generator). To do this, run

python train_gan.py --phase train-gan

The experiment configurations can be found in train_gan.py. It calls adversarial.py, where network structures and training functions are defined.

4. Evaluation

# TODO: combine the testing code with training code and switch them with an additional argument

Note: We are still actively updating the repo ...

5. Citations

If you make use of the code, please cite the paper in resulting publications.

@inproceedings{dou2018unsupervised,
  title={Unsupervised cross-modality domain adaptation of convnets for biomedical image segmentations with adversarial loss},
  author={Dou, Qi and Ouyang, Cheng and Chen, Cheng and Chen, Hao and Heng, Pheng-Ann},
  booktitle={Proceedings of the 27th International Joint Conference on Artificial Intelligence (IJCAI)},
  pages={691--697},
  year={2018}
}

or

@article{dou2018pnp,
  title={PnP-AdaNet: Plug-and-play adversarial domain adaptation network with a benchmark at cross-modality cardiac segmentation},
  author={Dou, Qi and Ouyang, Cheng and Chen, Cheng and Chen, Hao and Glocker, Ben and Zhuang, Xiahai and Heng, Pheng-Ann},
  journal={arXiv preprint arXiv:1812.07907},
  year={2018}
}

6. Acknowledgements

Special thanks to Ryan Neph for the PyMedImage package, which was used for debugging in the original project.

Contact

General questions, please email qi.dou@imperial.ac.uk (Qi Dou) and c.ouyang@imperial.ac.uk (Cheng Ouyang).
Questions on data license, please contact qi.dou@imperial.ac.uk (Qi Dou) and zxh@fudan.edu.cn (Xiahai Zhuang).

About

[IJCAI'18] Unsupervised Cross-Modality Domain Adaptation of ConvNets for Biomedical Image Segmentations with Adversarial Loss (code&data)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%