@@ -53,6 +53,8 @@ def sliding_window_inference(
5353 progress : bool = False ,
5454 roi_weight_map : torch .Tensor | None = None ,
5555 process_fn : Callable | None = None ,
56+ buffer_steps : int | None = None ,
57+ buffer_dim : int = 0 ,
5658 * args : Any ,
5759 ** kwargs : Any ,
5860) -> torch .Tensor | tuple [torch .Tensor , ...] | dict [Any , torch .Tensor ]:
@@ -114,26 +116,23 @@ def sliding_window_inference(
114116 roi_weight_map: pre-computed (non-negative) weight map for each ROI.
115117 If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
116118 process_fn: process inference output and adjust the importance map per window
119+ buffer_steps: the number of sliding window iterations before writing the outputs to ``device``.
120+ default is None, no buffer.
121+ buffer_dim: the dimension along which the buffer are created, default is 0.
117122 args: optional args to be passed to ``predictor``.
118123 kwargs: optional keyword args to be passed to ``predictor``.
119124
120- - buffer_steps: the number of sliding window iterations before writing the outputs to ``device``.
121- default is None, no buffer.
122- - buffer_dim: the dimension along which the buffer are created, default is 0.
123-
124125 Note:
125126 - input must be channel-first and have a batch dim, supports N-D sliding window.
126127
127128 """
128- b_steps = kwargs .pop ("buffer_steps" , None )
129- b_plane = kwargs .pop ("buffer_dim" , 0 )
130- buffered = b_steps is not None and b_steps > 0
129+ buffered = buffer_steps is not None and buffer_steps > 0
131130 num_spatial_dims = len (inputs .shape ) - 2
132131 if buffered :
133- if b_plane < - num_spatial_dims + 1 or b_plane > num_spatial_dims :
134- raise ValueError (f"buffer_dim must be in [{ - num_spatial_dims + 1 } , { num_spatial_dims } ], got { b_plane } ." )
135- if b_steps <= 0 :
136- raise ValueError (f"buffer_steps must be >= 0, got { b_steps } ." )
132+ if buffer_dim < - num_spatial_dims + 1 or buffer_dim > num_spatial_dims :
133+ raise ValueError (f"buffer_dim must be in [{ - num_spatial_dims + 1 } , { num_spatial_dims } ], got { buffer_dim } ." )
134+ if buffer_steps <= 0 : # type: ignore
135+ raise ValueError (f"buffer_steps must be >= 0, got { buffer_steps } ." )
137136 if overlap < 0 or overlap >= 1 :
138137 raise ValueError (f"overlap must be >= 0 and < 1, got { overlap } ." )
139138 compute_dtype = inputs .dtype
@@ -165,25 +164,31 @@ def sliding_window_inference(
165164 slices = dense_patch_slices (image_size , roi_size , scan_interval , return_slice = False )
166165
167166 slices_np = np .asarray (slices )
168- if b_plane < 0 :
169- b_plane += num_spatial_dims
170- slices_np = slices_np [np .argsort (slices_np [:, b_plane , 0 ], kind = "mergesort" )]
167+ if buffer_dim < 0 :
168+ buffer_dim += num_spatial_dims
169+ slices_np = slices_np [np .argsort (slices_np [:, buffer_dim , 0 ], kind = "mergesort" )]
171170 slices = [tuple (slice (c [0 ], c [1 ]) for c in i ) for i in slices_np ]
172- _ , _p_id , _b_lens = np .unique (slices_np [:, b_plane , 0 ], return_counts = True , return_index = True )
173- b_se = [tuple (slices_np [i ][b_plane ]) for i in _p_id ] # buffer start & end along the b_plane
174- b_ends = np .cumsum (np . repeat ( _b_lens , batch_size )) # buffer flush boundaries
171+ _ , _p_id , _b_lens = np .unique (slices_np [:, buffer_dim , 0 ], return_counts = True , return_index = True )
172+ _b_se = [tuple (slices_np [i ][buffer_dim ]) for i in _p_id ] # buffer start & end along the buffer_dim
173+ b_ends = np .cumsum (_b_lens ). tolist () # possible buffer flush boundaries
175174
176175 num_win = len (slices ) # number of windows per image
177176 total_slices = num_win * batch_size # total number of windows
178177 windows_range : Iterable
179178 if not buffered :
180179 windows_range = range (0 , total_slices , sw_batch_size )
181180 else :
182- b_steps = min (len (b_se ), b_steps )
183- x = [0 , * b_ends ][::b_steps ]
181+ buffer_steps = min (len (_b_se ), int ( buffer_steps )) # type: ignore
182+ x = [0 , * b_ends ][::buffer_steps ]
184183 if x [- 1 ] < b_ends [- 1 ]:
185184 x .append (b_ends [- 1 ])
186- windows_range = itertools .chain (* [range (x [i ], x [i + 1 ], sw_batch_size ) for i in range (len (x ) - 1 )])
185+ windows_range , n_per_batch , b_ends = [], len (x ) - 1 , [0 ]
186+ for b in range (batch_size ):
187+ offset = b * x [- 1 ]
188+ for i in range (n_per_batch ):
189+ windows_range .append (range (offset + x [i ], offset + x [i + 1 ], sw_batch_size ))
190+ b_ends .append (offset + x [i + 1 ])
191+ windows_range = itertools .chain (* windows_range )
187192
188193 # Create window-level importance map
189194 valid_patch_size = get_valid_patch_size (image_size , roi_size )
@@ -206,8 +211,7 @@ def sliding_window_inference(
206211 output_image_list , count_map_list , sw_device_buffer , b_s , b_i = [], [], [], 0 , 0 # type: ignore
207212 # for each patch
208213 for slice_g in tqdm (windows_range ) if progress else windows_range :
209- _cur_max = b_ends [b_s + b_steps - 1 ] if buffered else total_slices
210- slice_range = range (slice_g , min (slice_g + sw_batch_size , _cur_max ))
214+ slice_range = range (slice_g , min (slice_g + sw_batch_size , b_ends [b_s + 1 ] if buffered else total_slices ))
211215 unravel_slice = [
212216 [slice (idx // num_win , idx // num_win + 1 ), slice (None )] + list (slices [idx % num_win ])
213217 for idx in slice_range
@@ -223,22 +227,21 @@ def sliding_window_inference(
223227 importance_map = importance_map_
224228
225229 if buffered :
226- # if len(seg_tuple) > 1:
227- # warnings.warn("Multiple outputs are not supported with buffer_steps")
228- c_start , c_end = b_se [b_s % len (b_se )], b_se [(b_s + b_steps - 1 ) % len (b_se )]
230+ c_start = slices_np [b_ends [b_s ] % num_win , buffer_dim , 0 ]
231+ c_end = slices_np [(b_ends [b_s + 1 ] - 1 ) % num_win , buffer_dim , 1 ]
229232 if not sw_device_buffer :
230- k = seg_tuple [0 ].shape [1 ]
233+ k = seg_tuple [0 ].shape [1 ] # len(seg_tuple) > 1 is currently ignored
231234 sp_size = list (image_size )
232- sp_size [b_plane ] = max ( c_end [ 1 ] - c_start [ 0 ], roi_size [ b_plane ])
235+ sp_size [buffer_dim ] = c_end - c_start
233236 sw_device_buffer = [torch .zeros (size = [1 , k , * sp_size ], dtype = compute_dtype , device = sw_device )]
234237 importance_map = importance_map .to (dtype = compute_dtype , device = sw_device )
235238 for p , s in zip (seg_tuple [0 ], unravel_slice ):
236- offset = s [b_plane + 2 ].start - c_start [ 0 ]
237- s [b_plane + 2 ] = slice (offset , offset + roi_size [b_plane ])
239+ offset = s [buffer_dim + 2 ].start - c_start
240+ s [buffer_dim + 2 ] = slice (offset , offset + roi_size [buffer_dim ])
238241 s [0 ] = slice (0 , 1 )
239242 sw_device_buffer [0 ][s ] += p * importance_map
240243 b_i += len (unravel_slice )
241- if b_i < b_ends [b_s + b_steps - 1 ]:
244+ if b_i < b_ends [b_s + 1 ]:
242245 continue
243246 else :
244247 sw_device_buffer = seg_tuple
@@ -269,8 +272,8 @@ def sliding_window_inference(
269272 w_t = w_t .to (sw_device )
270273 if buffered :
271274 o_slice = [slice (None )] * len (inputs .shape )
272- o_slice [b_plane + 2 ] = slice (c_start [ 0 ] , c_end [ 1 ] )
273- img_b = b_s // len ( b_se ) # image batch index
275+ o_slice [buffer_dim + 2 ] = slice (c_start , c_end )
276+ img_b = b_s // n_per_batch # image batch index
274277 o_slice [0 ] = slice (img_b , img_b + 1 )
275278 output_image_list [0 ][o_slice ] += sw_device_buffer [0 ].to (device = device )
276279 else :
@@ -280,7 +283,7 @@ def sliding_window_inference(
280283 _compute_coords (sw_batch_size , unravel_slice , z_scale , output_image_list [ss ], sw_t )
281284 sw_device_buffer = []
282285 if buffered :
283- b_s += b_steps
286+ b_s += 1
284287
285288 # account for any overlapping sections
286289 for ss in range (len (output_image_list )):
0 commit comments