Skip to content

Commit 95d9fa0

Browse files
michal2409shakandrew
authored andcommitted
[nnUNet/TF2] Fix sliding window evaluation for non cubic volumes
1 parent f6d0682 commit 95d9fa0

File tree

3 files changed

+5
-17
lines changed

3 files changed

+5
-17
lines changed

TensorFlow2/Segmentation/nnUNet/main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def main(args):
3737
assert pValue.contents.value == 128
3838

3939
hvd_init()
40-
set_seed(args.seed)
40+
if args.seed is not None:
41+
set_seed(args.seed)
4142
set_tf_flags(args)
4243
data = DataModule(args)
4344
data.setup()

TensorFlow2/Segmentation/nnUNet/models/sliding_window.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,8 @@ def batch_window_slices(slices, image_batch_size, batch_size):
2727
)
2828
return batched_window_slices
2929

30-
@tf.function
31-
def gaussian_kernel_tf_v2(roi_size, sigma):
32-
"""
33-
adapted from: https://gist.github.com/blzq
34-
"""
35-
kernel_size = roi_size[0]
36-
sigma = sigma * kernel_size
37-
gauss = tf.range(start = 0, limit = kernel_size, dtype = tf.float32) - (kernel_size - 1.0) / 2.0
38-
xx, yy, zz = tf.meshgrid(gauss, gauss, gauss)
39-
kernel = tf.exp(-(xx ** 2 + yy ** 2 + zz ** 2) / (2.0 * sigma ** 2))
40-
kernel = tf.math.pow(kernel, 1/len(roi_size))
41-
kernel = kernel / tf.reduce_max(kernel)
42-
return kernel
4330

31+
@tf.function
4432
def gaussian_kernel(roi_size, sigma):
4533
gauss = signal.windows.gaussian(roi_size[0], std=sigma * roi_size[0])
4634
for s in roi_size[1:]:
@@ -57,7 +45,7 @@ def get_importance_kernel(roi_size, blend_mode, sigma):
5745
if blend_mode == "constant":
5846
return tf.ones(roi_size, dtype=tf.float32)
5947
elif blend_mode == "gaussian":
60-
return gaussian_kernel_tf_v2(roi_size, sigma=sigma)
48+
return gaussian_kernel(roi_size, sigma)
6149
else:
6250
raise ValueError(f'Invalid blend mode: {blend_mode}. Use either "constant" or "gaussian".')
6351

@@ -133,7 +121,6 @@ def sliding_window_inference(
133121
image_size = list(input_padded.shape[1:-1])
134122

135123
importance_kernel = get_importance_kernel(roi_size, blend_mode, sigma=sigma)
136-
137124
output_shape = (batch_size,) + tuple(image_size) + (n_class,)
138125
importance_map = tf.tile(
139126
tf.reshape(importance_kernel, shape=[1, *roi_size, 1]),

TensorFlow2/Segmentation/nnUNet/runtime/args.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_main_args():
6262
p.arg("--data", type=Path, default=Path("/data"), help="Path to data directory")
6363
p.arg("--task", type=str, default="01", help="Task number, MSD uses numbers 01-10")
6464
p.arg("--dim", type=int, choices=[2, 3], default=3, help="UNet dimension")
65-
p.arg("--seed", type=non_negative_int, help="Random seed")
65+
p.arg("--seed", type=non_negative_int, default=None, help="Random seed")
6666
p.flag("--benchmark", help="Run model benchmarking")
6767
p.boolean_flag("--tta", default=False, help="Enable test time augmentation")
6868
p.boolean_flag("--save-preds", "--save_preds", default=False, help="Save predictions")

0 commit comments

Comments
 (0)