Full implementation will be made available upon acceptance of the paper.
Official repository for M3DSR, a generative adversarial network for 3D medical image reconstruction.
If you find this project useful, please consider citing the paper:
@article{
...
}
import torch
from monai.inferers.inferer import SlidingWindowInferer
# Define device
device = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
# Set patch size (D, H, W)
patch_size = (16, 16, 16)
# Define inferer (Optional)
inferer = SlidingWindowInferer(roi_size=patch_size, overlap=0.125, mode='gaussian', sw_batch_size=16, sw_device=device, device=device)
# Two random volumes (batch size of 2) with 1 channel and size of 64x64x64 (D, H, W). Mean is 0 and STD is 1.
volumes = torch.randn(2, 1, 64, 64, 64).to(device)
# Avoid HTTP Error 403
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
# Load model (check pretrained_models/README.md)
model = torch.hub.load(
'joaolcguerreiro/M3DSR', 'load_model',
checkpoint_url='https://drive.google.com/uc?export=download&id=15mrIzDnomlko1PXFHsnlF12DY9s4lf2p', pretrained=True,
in_channels=1, out_channels=1, num_features=192, num_rrdb_blocks=1, num_rrdb_db_blocks=2, num_rrdb_db_conv_blocks=3, num_rrdb_db_growth_features=96, scale_factor=2, residual_learning=True
).to(device)
model.eval()
# Generated volumes. Model expects z-score normalized input volumes.
gen_volumes = model(volumes) if not inferer else inferer(inputs=volumes, network=model)
import torch
from monai.inferers.inferer import SlidingWindowInferer
# Same as Reconstruction
...
# Load model (check pretrained_models/README.md)
model = torch.hub.load(
'joaolcguerreiro/M3DSR', 'load_model',
checkpoint_url='https://drive.google.com/uc?export=download&id=15mrIzDnomlko1PXFHsnlF12DY9s4lf2p', pretrained=True,
in_channels=1, out_channels=1, num_features=192, num_rrdb_blocks=1, num_rrdb_db_blocks=2, num_rrdb_db_conv_blocks=3, num_rrdb_db_growth_features=96, scale_factor=2, residual_learning=True
).to(device)
model.forward = model.extract_features
model.eval()
# Extract features. Model expects z-score normalized input volumes.
features = model(volumes) if not inferer else inferer(inputs=volumes, network=model)
-
Clone GitHub repository:
git clone https://github.com/joaolcguerreiro/M3DSR.git cd M3DSR
-
Create environment:
conda create -n m3dsr python=3.10.11 conda activate m3dsr
-
Install Python dependencies (Python 3.10 is recommended):
pip install -r requirements.txt
-
See data/README.md for details.
-
After preparing the data, navigate to
configs/config.yaml
and adjust data paths.
-
Navigate to
configs/config.yaml
and adjust configuration parameters needed. For example configurations see configs/examples. -
Navigate to
src/scripts/
. -
Run script:
python train.py
Note: Validation is optional. Use option
--val=True
if validation is intended.
Note: Set
LR_PATCH_SIZE
inconfig.yaml
file for patch-based training (recommended). Patch-based training is not employed whenLR_PATCH_SIZE=[-1, -1, -1]
(more computationally expensive). Also, forLR_PATCH_SIZE
equal or smaller than[8, 8, 8]
thebatch_size
is required to be greater than1
because of the discriminator batch normalization.
-
See pretrained_models/README.md to download available pre-trained models or use checkpoint path from a previous experiment.
-
Navigate to
configs/config.yaml
and updatePRE_TRAINED_PATH
parameters. -
Navigate to
src/scripts/
. -
Run script:
python train.py
-
Navigate to
experiments/
and copy experiment path. -
Navigate to
src/scripts/
. -
Run script:
python train.py --resume=<experiment_path>
Note: Load pre-trained differs from resume training since in pre-training only the model weights are loaded and the training is reset. If you resume training the last epoch checkpoint weights will be prioritized over the pre-trained weights.
-
Navigate to
experiments/
and copy experiment path. -
Navigate to
src/scripts/
. -
Run script:
python test.py --exp=<experiment_path> --epoch=<epoch> [--lr-patch-size=<lr_patch_size>]
Note: Results are in
<experiment_path>/results
. Additionally, check<experiment_path>/logs/test
and<experiment_path>/figures/test
for more insights.
Note: Increase testing
PATCH_SLIDING_WINDOW_BATCH_SIZE
inconfig.yaml
file to accelerate testing time. It controls the number of concurrent sliding window patches being processed. Notice, this will also increase memory used.
Note: Set
--lr-patch-size
to override the training low-resolution patch size.
-
Get model checkpoint path.
-
Navigate to
src/scripts/
. -
Run script:
python inference.py --checkpoint=<path_to_checkpoint> --lr-dir=<path_to_lr_dir> --save-dir=<path_to_save_dir> [--batch-size=<batch_size>] [--patch-sliding-window-batch-size=<patch_sliding_window_batch_size>] [--lr-patch-size=<lr_patch_size>]
Note: Increase
--patch-sliding-window-batch-size
to accelerate inference time. It controls the number of concurrent sliding window patches being processed. Notice, this will also increase memory used.
Note: Set
--lr-patch-size
to override the training low-resolution patch size.
Everytime you train a model, an experiment folder is built (unless training is resumed, in which case the folder from the prior run is used). The configuration file for the execution is stored in the experiment folder, allowing the original parameters to be retrieved when resuming training.
See experiments/README.md for details.
This project is licensed under MIT license. See LICENSE for details.