Accompanying code for the approach presented in the revised version of Prompt Tuning for Parameter-efficient Medical Image Segmentation.
The code of the initial submission is still available here.
Based on this code, we
- introduce a deeply prompt-able encoder-decoder architecture (prompt-able UNETR, PUNETR) that can incorporate additional class-dependent prompt tokens to achieve dense binary and multi-class segmentation,
- contribute architectural components comprising prompt-able shifted window (PSWin) blocks, a heterogeneous bias score generation within the attention scheme, and a weighted similarity aggregation to enable token-dependent class predictions throughout the network,
- propose a flexible contrastive pre-training scheme designed to pre-train the whole encoder-decoder structure by relying on a dense self-supervision. Soft assignments to online generated prototypes are provided to establish the learning of an anatomical embedding space while circumventing a hard separation of samples for the contrastive attraction and repulsion,
- show that ”prompting” of the pre-trained and frozen model architecture by non-frozen (learned) prompt tokens is sufficient for the adaptation to a segmentation downstream task on medical imaging data,
- leverage our dense soft assignement-based self-supervision scheme alongside the concurrent application of a prompt-dependent segmentation supervision in the pre-training phase, further reducing the performance gap between fully fine-tuned and efficiently adapted variants.
The published code contains
-
the prompt-able UNETR (PUNETR) architecture and underlying PSWin blocks (see Figure 1)
-
the proposed dense self-supervision scheme based on contrastive prototype assignments (see Figure 2)
-
the training routines, including using various prompt-dependent predictions in a single batch
-
the ability to process 3D imaging data (tested FOVs are included in the config file)
-
This code is provided as is. It builds upon the PyTorch Lightning framework. Where possible MONAI functionality has been used.*
See the data pre-processing and data gathering on how to prepare data for e.g. TCIA .
python3 ./src/main.py --gpus 1 --batch_size 8 --architecture wip --dataset tcia_btcv --dir_images /path/to/my/data --dir_masks /path/to/my/labels
Valid configuration variants are included in the config file which is used for the phase 1 shell script.
For the loss configuration use
- self for self-supervision,
- seg for segmentation (semi-)supervision,
- seg_self for joint supervision,
- and _noninstructed for non-prompt-based architecture variants.
Have a look at the flags of the main module for more details.
For ease of use, the default parameters of the published code are set to 24 tokens (throughout the network) without the final high-res prompt-able block.
python3 ./src/main.py --gpus 1 --batch_size 8 --architecture wip --dataset tcia_btcv --dir_images /path/to/my/data --dir_masks /path/to/my/labels --ckpt /path/to/my/ckpt --no_overwrite --cold_start --downstream --adaptation_variant prompting --selective_freezing --label_indices_base 1 --label_indices_downstream_active 2 --max_epochs 100
Valid configuration variants are included in the config file which is used for the phase 2 shell script.
New classes can be provided via class index, e.g. --label_indices_downstream_active 2
Have a look at the flags of the main module for more details.
python3 ./src/main.py --gpus 1 --mode test --architecture wip --dataset tcia_btcv --dir_images /path/to/my/data --dir_masks /path/to/my/labels --ckpt /path/to/my/ckpt --no_overwrite --cold_start
Figure 1: Schematic illustration of the prompt-able UNETR (PUNETR). The network consists of an SWinViT encoder a SWinUNETR decoder. A depth of 5 levels is chosen with 48, 96, 192, 384, 768 hidden channels
Figure 2: Input image volumes
Figure 3: Exemplary slice of the TCIA/BTCV dataset, with annotated classes shown in shades of blue, b-e) augmented student views with masked regions or strong contrast adjustments, f-i) respective teacher views with overlays of the cosine similarity of the predicted teacher embedding
Figure 4: 2D visualization of cosine similarities between predicted teacher embeddings