@@ -27,20 +27,8 @@ def batch_window_slices(slices, image_batch_size, batch_size):
27
27
)
28
28
return batched_window_slices
29
29
30
- @tf .function
31
- def gaussian_kernel_tf_v2 (roi_size , sigma ):
32
- """
33
- adapted from: https://gist.github.com/blzq
34
- """
35
- kernel_size = roi_size [0 ]
36
- sigma = sigma * kernel_size
37
- gauss = tf .range (start = 0 , limit = kernel_size , dtype = tf .float32 ) - (kernel_size - 1.0 ) / 2.0
38
- xx , yy , zz = tf .meshgrid (gauss , gauss , gauss )
39
- kernel = tf .exp (- (xx ** 2 + yy ** 2 + zz ** 2 ) / (2.0 * sigma ** 2 ))
40
- kernel = tf .math .pow (kernel , 1 / len (roi_size ))
41
- kernel = kernel / tf .reduce_max (kernel )
42
- return kernel
43
30
31
+ @tf .function
44
32
def gaussian_kernel (roi_size , sigma ):
45
33
gauss = signal .windows .gaussian (roi_size [0 ], std = sigma * roi_size [0 ])
46
34
for s in roi_size [1 :]:
@@ -57,7 +45,7 @@ def get_importance_kernel(roi_size, blend_mode, sigma):
57
45
if blend_mode == "constant" :
58
46
return tf .ones (roi_size , dtype = tf .float32 )
59
47
elif blend_mode == "gaussian" :
60
- return gaussian_kernel_tf_v2 (roi_size , sigma = sigma )
48
+ return gaussian_kernel (roi_size , sigma )
61
49
else :
62
50
raise ValueError (f'Invalid blend mode: { blend_mode } . Use either "constant" or "gaussian".' )
63
51
@@ -133,7 +121,6 @@ def sliding_window_inference(
133
121
image_size = list (input_padded .shape [1 :- 1 ])
134
122
135
123
importance_kernel = get_importance_kernel (roi_size , blend_mode , sigma = sigma )
136
-
137
124
output_shape = (batch_size ,) + tuple (image_size ) + (n_class ,)
138
125
importance_map = tf .tile (
139
126
tf .reshape (importance_kernel , shape = [1 , * roi_size , 1 ]),
0 commit comments