This repository contains the official PyTorch implementation of the following paper:
Labeled-to-Unlabeled Distribution Alignment for Partially-Supervised Multi-Organ Medical Image Segmentation
Xixi Jiang, Dong Zhang, Xiang Li, Kangyi Liu, Kwang-Ting Cheng, Xin Yang
Partially-supervised multi-organ medical image segmentation aims to develop a unified semantic segmentation model by utilizing multiple partially-labeled datasets, with each dataset providing labels for a single class of organs. However, the limited availability of labeled foreground organs and the absence of supervision to distinguish unlabeled foreground organs from the background pose a significant challenge, which leads to a distribution mismatch between labeled and unlabeled pixels. Although existing pseudo-labeling methods can be employed to learn from both labeled and unlabeled pixels, they are prone to performance degradation in this task, as they rely on the assumption that labeled and unlabeled pixels have the same distribution. In this paper, to address the problem of distribution mismatch, we propose a labeled-to-unlabeled distribution alignment (LTUDA) framework that aligns feature distributions and enhances discriminative capability. Specifically, we introduce a cross-set data augmentation strategy, which performs region-level mixing between labeled and unlabeled organs to reduce distribution discrepancy and enrich the training set. Besides, we propose a prototype-based distribution alignment method that implicitly reduces intra-class variation and increases the separation between the unlabeled foreground and background. This can be achieved by encouraging consistency between the outputs of two prototype classifiers and a linear classifier.Extensive experimental results on the AbdomenCT-1K dataset and a union of four benchmark datasets (including LiTS, MSD-Spleen, KiTS, and NIH82) demonstrate that our method outperforms the state-of-the-art partially-supervised methods by a considerable margin, and even surpasses the fully-supervised methods
- Create conda environment and activate it:
conda create -n ltuda python=3.6
conda activate ltuda
- Clone this repo:
git clone https://github.com/xjiangmed/LTUDA.git
cd LTUDA
- Install requirements:
bash env_setup.sh
- Toy dataset: partially labeled images are sampled from AbdomenCT-1K.
- Partially labeled dataset: a union of four benchmark datasets (LiTS, MSD-Spleen, KiTS and NIH82). We evaluate the performance of the multi-organ segmentation model trained on partially labeled dataset on two external datasets, BTCV and AbdomenCT-1K.
Dataset | source |
---|---|
LiTS | data |
MSD-Spleen | data |
KiTS | data |
NIH82 | data |
BTCV | data |
AbdomenCT-1K | data |
- For data preprocessing, the 3D CT volume is first sliced along the z-axis direction, and irrelevant non-abdominal regions are removed, and finally the axial slice is adjusted to a size of 256×256 pixels.
- A preprocessed toy dataset can be downloaded from the link.
- Partially labeled datasets can be processed using the same steps, we do not provide the processed images due to large data size. Here, we do not consider the segmentation of the tumor region, so the tumor in the LiTS dataset is treated as liver, and the tumor in the KiTS dataset is treated as the background. This treatment is based on the consideration that the tumors in the liver are all inside, while the tumors in the kidney have largely destroyed the integrity and continuity of the entire organ.
- Stage1: cross-set data augmentation:
python train_CDA.py --save_model_path ../checkpoint/CDA --model unet
- Stage2: prototype-based distribution alignment
python train_CDA_PDA.py --save_model_path ../checkpoint/CDA_PDA --model unet_proto --reload_path '../checkpoint/CDA/model_best.pth'
The models trained on toy dataset are available here. The models trained on Partially labeled dataset are available here.
- To test the model, please run:
python test.py --model unet_proto --reload_path '../checkpoint/CDA_PDA/ema_model_best.pth'
When testing, you can choose linear classifier, labeled prototype classifier or unlabeled prototype classifier.