Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dataset and display #284

Open
wants to merge 10 commits into
base: dev
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,5 @@ temp/
# VSCode
.vscode/
*.zip

models/lower_pelvic_reg/eval/
85 changes: 85 additions & 0 deletions models/lower_pelvic_reg/configs/inference.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
---
imports:
- $import matplotlib.pyplot as plt
dataset_dir: "/Users/yiwenli/data/multiorgan_final"
bundle_root: "./"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
output_dir: "$@bundle_root + '/eval'"
ckpt: "$@bundle_root + '/lower_pelvic_reg_cpu_nonparallel-2.pth'"
cross_subjects: false # whether the input images are from the same subject

dataset:
_target_: "scripts.dataset.RegDataset"
train: false
dataset_dir: "@dataset_dir"
pixdim: [0.75, 0.75, 2.5]
spatial_size: [256, 256, 40]
rotate_range: $np.pi / 36
translate_range: [20, 20, 4]
scale_range: [0.15, 0.15, 0.15]

data_loader:
_target_: "torch.utils.data.DataLoader"
dataset: "@dataset"
batch_size: 1
num_workers: 0

# display first pair of data
first_pair: $@dataset[0]
display:
- $plt.subplot(2,2,1)
- $plt.gca().set_title("moving image")
- $plt.gca().axis('off')
- $plt.imshow(np.transpose(@first_pair[0]["image"][0, ..., @first_pair[0]["image"].shape[-1]//2]))
- $plt.subplot(2,2,2)
- $plt.gca().set_title("fixed image")
- $plt.gca().axis('off')
- $plt.imshow(np.transpose(@first_pair[1]["image"][0, ..., @first_pair[0]["image"].shape[-1]//2]))
- $plt.subplot(2,2,3)
- $plt.gca().set_title("moving label")
- $plt.gca().axis('off')
- $plt.imshow(np.transpose(@first_pair[0]["label"][0, ..., @first_pair[0]["image"].shape[-1]//2]))
- $plt.subplot(2,2,4)
- $plt.gca().set_title("fixed label")
- $plt.gca().axis('off')
- $plt.imshow(np.transpose(@first_pair[1]["label"][0, ..., @first_pair[0]["image"].shape[-1]//2]))
- $plt.show()

network:
_target_: LocalNet
spatial_dims: 3
in_channels: 2
out_channels: 3
num_channel_initial: 32
extract_levels: [0, 1, 2, 3]
out_kernel_initializer: "zeros"

handlers:
- _target_: CheckpointLoader
load_path: "@ckpt"
load_dict: {model: "@network"}

inferer:
_target_: "scripts.inferer.RegistrationInferer"

evaluator:
_target_: "scripts.evaluator.RegistrationEvaluator"
device: "@device"
val_data_loader: "@data_loader"
network: "@network"
epoch_length: $len(@dataset) // @data_loader#batch_size
inferer: "@inferer"
val_handlers: "@handlers"
postprocessing:
_target_: Compose
transforms:
- _target_: "scripts.visualise.SaveRegd"
keys: ["moving_image", "moving_label", "fixed_image", "fixed_label", "warped_image", "warped_label"]
pixdim: [ 0.75, 0.75, 2.5 ]
spatial_size: [ 256, 256, 40 ]
output_dir: "@output_dir"

eval:
- $monai.utils.set_determinism(seed=123)
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
- $@evaluator.run()
21 changes: 21 additions & 0 deletions models/lower_pelvic_reg/configs/logging.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[loggers]
keys=root

[handlers]
keys=consoleHandler

[formatters]
keys=fullFormatter

[logger_root]
level=INFO
handlers=consoleHandler

[handler_consoleHandler]
class=StreamHandler
level=INFO
formatter=fullFormatter
args=(sys.stdout,)

[formatter_fullFormatter]
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
64 changes: 64 additions & 0 deletions models/lower_pelvic_reg/configs/metadata.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
{
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
"version": "0.0.3",
"changelog": {
"0.0.3": "update to use monai 1.1.0",
"0.0.2": "update to use rc1",
"0.0.1": "Initial version"
},
"monai_version": "1.1.0",
"pytorch_version": "1.13.0",
"numpy_version": "1.22.2",
"optional_packages_version": {
"pytorch-ignite": "0.4.8"
},
"task": "Spatial transformer for hand image registration from the MedNIST dataset",
"description": "This is an example of a ResNet and spatial transformer for hand xray image registration",
"authors": "MONAI team",
"copyright": "Copyright (c) MONAI Consortium",
"intended_use": "This is an example of image registration using MONAI, suitable for demonstration purposes only.",
"data_type": "jpeg",
"network_data_format": {
"inputs": {
"image": {
"type": "image",
"format": "magnitude",
"num_channels": 2,
"spatial_shape": [
64,
64
],
"dtype": "float32",
"value_range": [
0,
1
],
"is_patch_data": false,
"channel_def": {
"0": "moving image",
"1": "fixed image"
}
}
},
"outputs": {
"pred": {
"type": "image",
"format": "magnitude",
"num_channels": 1,
"spatial_shape": [
64,
64
],
"dtype": "float32",
"value_range": [
0,
1
],
"is_patch_data": false,
"channel_def": {
"0": "image"
}
}
}
}
}
217 changes: 217 additions & 0 deletions models/lower_pelvic_reg/configs/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
---
imports:
- $import glob
- $import matplotlib.pyplot as plt

# workflow parameters
bundle_root: "./"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
ckpt_dir: "$@bundle_root + '/models'" # folder to save new checkpoints
ckpt: "" # path to load an existing checkpoint
val_interval: 1 # every epoch
max_epochs: 300
cross_subjects: false # whether the input images are from the same subject

# construct the moving and fixed datasets
dataset_dir: "../MedNIST/Hand"
datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[:7000]" # training with 7000 images
val_datalist: "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))[7000:8500]" # validation with 1500 images

image_load:
- _target_: LoadImage
image_only: True
ensure_channel_first: True

- _target_: ScaleIntensityRange
a_min: 0.0
a_max: 255.0
b_min: 0.0
b_max: 1.0

- _target_: EnsureType
device: "@device"

image_aug:
- _target_: RandAffine
spatial_size: [64, 64]
translate_range: 5
scale_range: [-0.15, 0.15]
prob: 1.0
rotate_range: $np.pi / 8
mode: bilinear
padding_mode: border
cache_grid: True
device: "@device"

- _target_: RandGridDistortion
prob: 0.2
num_cells: 8
device: "@device"
distort_limit: 0.1

preprocessing:
_target_: Compose
transforms: "$@image_load + @image_aug"

cache_datasets:
- _target_: ShuffleBuffer
data:
_target_: CacheDataset
data: "@datalist"
transform: $@preprocessing.set_random_state(123)
hash_as_key: true
runtime_cache: threads
epochs: "@max_epochs"
seed: "$int(3) if @cross_subjects else int(2)"
- _target_: ShuffleBuffer
data:
_target_: CacheDataset
data: "@datalist"
transform: $@preprocessing.set_random_state(234)
hash_as_key: true
runtime_cache: threads
epochs: "@max_epochs"
seed: 2

zip_dataset:
_target_: IterableDataset
data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), label=t[1]), zip(*@cache_datasets))"

data_loader:
_target_: ThreadDataLoader
dataset: "@zip_dataset"
batch_size: 64
num_workers: 0


# components for debugging
first_pair: $monai.utils.misc.first(@data_loader)
display:
- $monai.utils.set_determinism(seed=123)
- $print(@first_pair.keys(), @first_pair['image'].meta['filename_or_obj'])
- "$print(@trainer#loss_function(@first_pair['image'][:, 0:1], @first_pair['image'][:, 1:2]))" # print loss
- $plt.subplot(1,2,1)
- $plt.imshow(@first_pair['image'][0, 0], cmap="gray")
- $plt.subplot(1,2,2)
- $plt.imshow(@first_pair['image'][0, 1], cmap="gray")
- $plt.show()


# network definition
net:
_target_: scripts.net.RegResNet
image_size: [64, 64]
spatial_dims: 2
mode: "bilinear"
padding_mode: "border"

optimizer:
_target_: torch.optim.Adam
params: $@net.parameters()
lr: 0.00001

# create a validation evaluator
val:
cache_datasets:
- _target_: ShuffleBuffer
data:
_target_: CacheDataset
data: "@val_datalist"
transform: $@preprocessing.set_random_state(123)
hash_as_key: true
runtime_cache: threads
epochs: -1 # infinite
seed: "$int(3) if @cross_subjects else int(2)"
- _target_: ShuffleBuffer
data:
_target_: CacheDataset
data: "@val_datalist"
transform: $@preprocessing.set_random_state(234)
hash_as_key: true
runtime_cache: threads
epochs: -1 # infinite
seed: 2

zip_dataset:
_target_: IterableDataset
data: "$map(lambda t: dict(image=monai.transforms.concatenate(t), label=t[1]), zip(*@val#cache_datasets))"

data_loader:
_target_: ThreadDataLoader
dataset: "@val#zip_dataset"
batch_size: 64
num_workers: 0

evaluator:
_target_: SupervisedEvaluator
device: "@device"
val_data_loader: "@val#data_loader"
network: "@net"
epoch_length: $len(@val_datalist) // @val#data_loader#batch_size
inferer: "$monai.inferers.SimpleInferer()"
metric_cmp_fn: "$lambda x, y: x < y"
key_val_metric:
val_mse:
_target_: MeanSquaredError
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
additional_metrics: {"mutual info loss": "@loss_metric#metric_handler"}
val_handlers:
- _target_: StatsHandler
iteration_log: false
- _target_: CheckpointSaver
save_dir: "@ckpt_dir"
save_dict: {model: "@net"}
save_key_metric: true
key_metric_negative_sign: true
# key_metric_filename: "model.pt"

# training handlers
handlers:
- _target_: StatsHandler
tag_name: "train_loss"
output_transform: "$monai.handlers.from_engine(['loss'], first=True)"
- _target_: ValidationHandler
validator: "@val#evaluator"
epoch_level: true
interval: "@val_interval"

loss_metric:
metric_handler:
_target_: IgniteMetric
output_transform: "$monai.handlers.from_engine(['pred', 'label'])"
metric_fn:
_target_: LossMetric
loss_fn: "@mutual_info_loss"
get_not_nans: true

ckpt_loader:
- _target_: CheckpointLoader
load_path: "@ckpt"
load_dict: {model: "@net"}

lncc_loss:
_target_: LocalNormalizedCrossCorrelationLoss
spatial_dims: 2
kernel_size: 5
kernel_type: rectangular
reduction: mean

mutual_info_loss:
_target_: GlobalMutualInformationLoss

# create the primary trainer
trainer:
_target_: SupervisedTrainer
device: "@device"
train_data_loader: "@data_loader"
network: "@net"
max_epochs: "@max_epochs"
epoch_length: $len(@datalist) // @data_loader#batch_size
loss_function: "@lncc_loss"
optimizer: "@optimizer"
train_handlers: "$@handlers + @ckpt_loader if @ckpt else @handlers"

training:
- $monai.utils.set_determinism(seed=23)
- "$setattr(torch.backends.cudnn, 'benchmark', True)"
- $@trainer.run()
Loading