Skip to content

Commit 7593660

Browse files
committed
more tests
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 011854f commit 7593660

File tree

3 files changed

+47
-33
lines changed

3 files changed

+47
-33
lines changed

monai/inferers/inferer.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,9 @@ class SlidingWindowInferer(Inferer):
366366
cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
367367
when input image volume is larger than this threshold (in pixels/voxels).
368368
Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
369+
buffer_steps: the number of sliding window iterations before writing the outputs to ``device``.
370+
default is None, no buffer.
371+
buffer_dim: the dimension along which the buffer are created, default is 0.
369372
370373
Note:
371374
``sw_batch_size`` denotes the max number of windows per network inference iteration,
@@ -387,6 +390,8 @@ def __init__(
387390
progress: bool = False,
388391
cache_roi_weight_map: bool = False,
389392
cpu_thresh: int | None = None,
393+
buffer_steps: int | None = None,
394+
buffer_dim: int = 0,
390395
) -> None:
391396
super().__init__()
392397
self.roi_size = roi_size
@@ -400,6 +405,8 @@ def __init__(
400405
self.device = device
401406
self.progress = progress
402407
self.cpu_thresh = cpu_thresh
408+
self.buffer_steps = buffer_steps
409+
self.buffer_dim = buffer_dim
403410

404411
# compute_importance_map takes long time when computing on cpu. We thus
405412
# compute it once if it's static and then save it for future usage
@@ -456,6 +463,8 @@ def __call__(
456463
self.progress,
457464
self.roi_weight_map,
458465
None,
466+
self.buffer_steps,
467+
self.buffer_dim,
459468
*args,
460469
**kwargs,
461470
)

monai/inferers/utils.py

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def sliding_window_inference(
5353
progress: bool = False,
5454
roi_weight_map: torch.Tensor | None = None,
5555
process_fn: Callable | None = None,
56+
buffer_steps: int | None = None,
57+
buffer_dim: int = 0,
5658
*args: Any,
5759
**kwargs: Any,
5860
) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]:
@@ -114,26 +116,23 @@ def sliding_window_inference(
114116
roi_weight_map: pre-computed (non-negative) weight map for each ROI.
115117
If not given, and ``mode`` is not `constant`, this map will be computed on the fly.
116118
process_fn: process inference output and adjust the importance map per window
119+
buffer_steps: the number of sliding window iterations before writing the outputs to ``device``.
120+
default is None, no buffer.
121+
buffer_dim: the dimension along which the buffer are created, default is 0.
117122
args: optional args to be passed to ``predictor``.
118123
kwargs: optional keyword args to be passed to ``predictor``.
119124
120-
- buffer_steps: the number of sliding window iterations before writing the outputs to ``device``.
121-
default is None, no buffer.
122-
- buffer_dim: the dimension along which the buffer are created, default is 0.
123-
124125
Note:
125126
- input must be channel-first and have a batch dim, supports N-D sliding window.
126127
127128
"""
128-
b_steps = kwargs.pop("buffer_steps", None)
129-
b_plane = kwargs.pop("buffer_dim", 0)
130-
buffered = b_steps is not None and b_steps > 0
129+
buffered = buffer_steps is not None and buffer_steps > 0
131130
num_spatial_dims = len(inputs.shape) - 2
132131
if buffered:
133-
if b_plane < -num_spatial_dims + 1 or b_plane > num_spatial_dims:
134-
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {b_plane}.")
135-
if b_steps <= 0:
136-
raise ValueError(f"buffer_steps must be >= 0, got {b_steps}.")
132+
if buffer_dim < -num_spatial_dims + 1 or buffer_dim > num_spatial_dims:
133+
raise ValueError(f"buffer_dim must be in [{-num_spatial_dims + 1}, {num_spatial_dims}], got {buffer_dim}.")
134+
if buffer_steps <= 0: # type: ignore
135+
raise ValueError(f"buffer_steps must be >= 0, got {buffer_steps}.")
137136
if overlap < 0 or overlap >= 1:
138137
raise ValueError(f"overlap must be >= 0 and < 1, got {overlap}.")
139138
compute_dtype = inputs.dtype
@@ -165,25 +164,31 @@ def sliding_window_inference(
165164
slices = dense_patch_slices(image_size, roi_size, scan_interval, return_slice=False)
166165

167166
slices_np = np.asarray(slices)
168-
if b_plane < 0:
169-
b_plane += num_spatial_dims
170-
slices_np = slices_np[np.argsort(slices_np[:, b_plane, 0], kind="mergesort")]
167+
if buffer_dim < 0:
168+
buffer_dim += num_spatial_dims
169+
slices_np = slices_np[np.argsort(slices_np[:, buffer_dim, 0], kind="mergesort")]
171170
slices = [tuple(slice(c[0], c[1]) for c in i) for i in slices_np]
172-
_, _p_id, _b_lens = np.unique(slices_np[:, b_plane, 0], return_counts=True, return_index=True)
173-
b_se = [tuple(slices_np[i][b_plane]) for i in _p_id] # buffer start & end along the b_plane
174-
b_ends = np.cumsum(np.repeat(_b_lens, batch_size)) # buffer flush boundaries
171+
_, _p_id, _b_lens = np.unique(slices_np[:, buffer_dim, 0], return_counts=True, return_index=True)
172+
_b_se = [tuple(slices_np[i][buffer_dim]) for i in _p_id] # buffer start & end along the buffer_dim
173+
b_ends = np.cumsum(_b_lens).tolist() # possible buffer flush boundaries
175174

176175
num_win = len(slices) # number of windows per image
177176
total_slices = num_win * batch_size # total number of windows
178177
windows_range: Iterable
179178
if not buffered:
180179
windows_range = range(0, total_slices, sw_batch_size)
181180
else:
182-
b_steps = min(len(b_se), b_steps)
183-
x = [0, *b_ends][::b_steps]
181+
buffer_steps = min(len(_b_se), int(buffer_steps)) # type: ignore
182+
x = [0, *b_ends][::buffer_steps]
184183
if x[-1] < b_ends[-1]:
185184
x.append(b_ends[-1])
186-
windows_range = itertools.chain(*[range(x[i], x[i + 1], sw_batch_size) for i in range(len(x) - 1)])
185+
windows_range, n_per_batch, b_ends = [], len(x) - 1, [0]
186+
for b in range(batch_size):
187+
offset = b * x[-1]
188+
for i in range(n_per_batch):
189+
windows_range.append(range(offset + x[i], offset + x[i + 1], sw_batch_size))
190+
b_ends.append(offset + x[i + 1])
191+
windows_range = itertools.chain(*windows_range)
187192

188193
# Create window-level importance map
189194
valid_patch_size = get_valid_patch_size(image_size, roi_size)
@@ -206,8 +211,7 @@ def sliding_window_inference(
206211
output_image_list, count_map_list, sw_device_buffer, b_s, b_i = [], [], [], 0, 0 # type: ignore
207212
# for each patch
208213
for slice_g in tqdm(windows_range) if progress else windows_range:
209-
_cur_max = b_ends[b_s + b_steps - 1] if buffered else total_slices
210-
slice_range = range(slice_g, min(slice_g + sw_batch_size, _cur_max))
214+
slice_range = range(slice_g, min(slice_g + sw_batch_size, b_ends[b_s + 1] if buffered else total_slices))
211215
unravel_slice = [
212216
[slice(idx // num_win, idx // num_win + 1), slice(None)] + list(slices[idx % num_win])
213217
for idx in slice_range
@@ -223,22 +227,21 @@ def sliding_window_inference(
223227
importance_map = importance_map_
224228

225229
if buffered:
226-
# if len(seg_tuple) > 1:
227-
# warnings.warn("Multiple outputs are not supported with buffer_steps")
228-
c_start, c_end = b_se[b_s % len(b_se)], b_se[(b_s + b_steps - 1) % len(b_se)]
230+
c_start = slices_np[b_ends[b_s] % num_win, buffer_dim, 0]
231+
c_end = slices_np[(b_ends[b_s + 1] - 1) % num_win, buffer_dim, 1]
229232
if not sw_device_buffer:
230-
k = seg_tuple[0].shape[1]
233+
k = seg_tuple[0].shape[1] # len(seg_tuple) > 1 is currently ignored
231234
sp_size = list(image_size)
232-
sp_size[b_plane] = max(c_end[1] - c_start[0], roi_size[b_plane])
235+
sp_size[buffer_dim] = c_end - c_start
233236
sw_device_buffer = [torch.zeros(size=[1, k, *sp_size], dtype=compute_dtype, device=sw_device)]
234237
importance_map = importance_map.to(dtype=compute_dtype, device=sw_device)
235238
for p, s in zip(seg_tuple[0], unravel_slice):
236-
offset = s[b_plane + 2].start - c_start[0]
237-
s[b_plane + 2] = slice(offset, offset + roi_size[b_plane])
239+
offset = s[buffer_dim + 2].start - c_start
240+
s[buffer_dim + 2] = slice(offset, offset + roi_size[buffer_dim])
238241
s[0] = slice(0, 1)
239242
sw_device_buffer[0][s] += p * importance_map
240243
b_i += len(unravel_slice)
241-
if b_i < b_ends[b_s + b_steps - 1]:
244+
if b_i < b_ends[b_s + 1]:
242245
continue
243246
else:
244247
sw_device_buffer = seg_tuple
@@ -269,8 +272,8 @@ def sliding_window_inference(
269272
w_t = w_t.to(sw_device)
270273
if buffered:
271274
o_slice = [slice(None)] * len(inputs.shape)
272-
o_slice[b_plane + 2] = slice(c_start[0], c_end[1])
273-
img_b = b_s // len(b_se) # image batch index
275+
o_slice[buffer_dim + 2] = slice(c_start, c_end)
276+
img_b = b_s // n_per_batch # image batch index
274277
o_slice[0] = slice(img_b, img_b + 1)
275278
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
276279
else:
@@ -280,7 +283,7 @@ def sliding_window_inference(
280283
_compute_coords(sw_batch_size, unravel_slice, z_scale, output_image_list[ss], sw_t)
281284
sw_device_buffer = []
282285
if buffered:
283-
b_s += b_steps
286+
b_s += 1
284287

285288
# account for any overlapping sections
286289
for ss in range(len(output_image_list)):

tests/test_sliding_window_inference.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def compute(data, test1, test2):
244244
has_tqdm,
245245
None,
246246
None,
247+
None,
248+
0,
247249
t1,
248250
test2=t2,
249251
)

0 commit comments

Comments
 (0)