Continual Learning for Abdominal Multi-Organ and Tumor Segmentation
Yixiao Zhang, Xinyi Li, Huimiao Chen, Alan Yuille, Yaoyao Liu, and Zongwei Zhou
Johns Hopkins University
MICCAI 2023 (early accept)
paper | code
git clone https://github.com/MrGiovanni/ContinualLearning
See installation instructions to create an environment and obtain requirements.
1.1 Prepare your image and label files in a customer path, then create a txt file in the dataset/dataset_list folder. See txt files under dataset/dataset_list folder as examples.
1.2 Put the class name embedding file word_embedding_38class.pth under ./pretrained_weights/ folder.
Use the train.py file for training models. An example script is
python train.py
--phase train
--data_root_path ./data
--train_data_txt_path ./dataset/dataset_list/btcv_train.txt
--val_data_txt_path ./dataset/dataset_list/btcv_val.txt
--organ_list 1 2 3 4 5 6
--max_epoch 101
--warmup_epoch 15
--batch_size 2
--num_samples 1
--lr 1e-4
--model swinunetr
--trans_encoding word_embedding
--word_embedding ./pretrained_weights/word_embedding_38class.pth
--out_nonlinear softmax
--out_channels 38
--log_name your_log_folder_name
Switch the argument --model
for different models: swinunetr
for SwinUNETR, swinunetr_partial
for the proposed model with organ-specific segmentation heads (this model should be used with --out_nonlinear sigmoid
).
Use the test.py file for testing models. An example script is
python test.py
--log_name your_log_folder_name
--resume your_checkpoint_path
--data_root_path ./data
--test_data_txt_path ./dataset/dataset_list/btcv_test.txt
--organ_list 1 2 3 4 5 6
--model swinunetr
--out_nonlinear softmax
--out_channels 38
Use predict.py to make organ segmentation predictions on the test images or images for continual learning
python predict.py \
--log_name your_log_folder_name \
--resume your_checkpoint_path \
--data_root_path ./data \
--predict_data_txt_path ./dataset/dataset_list/<step2 train files>.txt \
--organ_list 1 2 3 4 5 6 \
--model swinunetr \
--out_nonlinear softmax \
--out_channels 38
You would need to combine the predicted segmentation on some organs togethor with ground truth segmentation on other organs to generate pseudo labels for continual learning. Please take scripts in prepare_pseudo_dataset.ipynb as examples.
For the proposed model, in continual learning stages, the same train.py is used to train the model but with different settings.
python train.py \
--phase continue \
--data_root_path ./data \
--continue_data_txt_path ./dataset/dataset_list/<data list with generated pseudo label>.txt \
--organ_list 1 2 3 4 5 6 7 8 9 10 11 12 \
--num_workers 8 \
--max_epoch 101 \
--warmup_epoch 15 \
--store_num 10 \
--batch_size 2 \
--num_samples 1 \
--lr 1e-4 \
--model swinunetr_partial \
--out_nonlinear sigmoid \
--out_channels 38 \
--pretrain checkpoint_of_previous_learning_stage \
--trans_encoding word_embedding \
--word_embedding ./pretrained_weights/word_embedding_38class.pth \
--log_name your_log_folder_name
The differences are that you are now using pseudo labels listed in the continue_data_txt_path
file, training on both pretrained and continual organ targets, and use a checkpoint from the previous learning stage.
This work was supported by the Lustgarten Foundation for Pancreatic Cancer Research and partially by the Patrick J. McGovern Foundation Award. We appreciate the effort of the MONAI Team to provide open-source code for the community.