Skip to content

Commit

Permalink
nonblocking copy
Browse files Browse the repository at this point in the history
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli committed Apr 1, 2023
1 parent bd42c16 commit 581181f
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,20 @@ def sliding_window_inference(
total_slices = num_win * batch_size # total number of windows
windows_range: Iterable
if not buffered:
non_blocking = False
windows_range = range(0, total_slices, sw_batch_size)
else:
slices, n_per_batch, b_slices, windows_range = _create_buffered_slices(
slices, batch_size, sw_batch_size, buffer_dim, buffer_steps
)
non_blocking = buffered and overlap[buffer_dim] == 0 and torch.device(sw_device).type == "cuda"
for idx, x in enumerate(b_slices):
if x[1] == 0 or idx == 0:
_ss = -1
if x[1] <= _ss:
non_blocking = False
break
_ss = x[2]

# Create window-level importance map
valid_patch_size = get_valid_patch_size(image_size, roi_size)
Expand Down Expand Up @@ -259,7 +268,10 @@ def sliding_window_inference(
o_slice[buffer_dim + 2] = slice(c_start, c_end)
img_b = b_s // n_per_batch # image batch index
o_slice[0] = slice(img_b, img_b + 1)
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
if non_blocking:
output_image_list[0][o_slice].copy_(sw_device_buffer[0], non_blocking=non_blocking)
else:
output_image_list[0][o_slice] += sw_device_buffer[0].to(device=device)
else:
sw_device_buffer[ss] *= w_t
sw_device_buffer[ss] = sw_device_buffer[ss].to(device)
Expand All @@ -268,6 +280,9 @@ def sliding_window_inference(
if buffered:
b_s += 1

if non_blocking:
torch.cuda.current_stream().synchronize()

# account for any overlapping sections
for ss in range(len(output_image_list)):
output_image_list[ss] /= count_map_list.pop(0)
Expand Down

0 comments on commit 581181f

Please sign in to comment.