diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index f00a432f38..331637ba94 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -119,6 +119,7 @@ class SlidingWindowInferer(Inferer): By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. + progress: whether to print a tqdm progress bar. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, @@ -137,6 +138,7 @@ def __init__( cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, + progress: bool = False, ) -> None: Inferer.__init__(self) self.roi_size = roi_size @@ -148,6 +150,7 @@ def __init__( self.cval = cval self.sw_device = sw_device self.device = device + self.progress = progress def __call__( self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any @@ -174,6 +177,7 @@ def __call__( self.cval, self.sw_device, self.device, + self.progress, *args, **kwargs, ) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 65a81378fa..36e4377bd6 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -15,7 +15,9 @@ import torch.nn.functional as F from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size -from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option +from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option, optional_import + +tqdm, _ = optional_import("tqdm", name="tqdm") __all__ = ["sliding_window_inference"] @@ -32,6 +34,7 @@ def sliding_window_inference( cval: float = 0.0, sw_device: Union[torch.device, str, None] = None, device: Union[torch.device, str, None] = None, + progress: bool = False, *args: Any, **kwargs: Any, ) -> torch.Tensor: @@ -74,6 +77,7 @@ def sliding_window_inference( By default the device (and accordingly the memory) of the `inputs` is used. If for example set to device=torch.device('cpu') the gpu memory consumption is less and independent of the `inputs` and `roi_size`. Output is on the `device`. + progress: whether to print a `tqdm` progress bar. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -120,7 +124,7 @@ def sliding_window_inference( # Perform predictions output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) _initialized = False - for slice_g in range(0, total_slices, sw_batch_size): + for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size): slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) unravel_slice = [ [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index 3d4ad3151b..5b6995c1ea 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -16,8 +16,11 @@ from parameterized import parameterized from monai.inferers import SlidingWindowInferer, sliding_window_inference +from monai.utils import optional_import from tests.utils import skip_if_no_cuda +_, has_tqdm = optional_import("tqdm") + TEST_CASES = [ [(2, 3, 16), (4,), 3, 0.25, "constant", torch.device("cpu:0")], # 1D small roi [(2, 3, 16, 15, 7, 9), 4, 3, 0.25, "constant", torch.device("cpu:0")], # 4D small roi @@ -145,6 +148,7 @@ def compute(self, data): cval=-1, mode="gaussian", sigma_scale=1.0, + progress=has_tqdm, ) expected = np.array( [ @@ -222,15 +226,16 @@ def compute(data, test1, test2): 0.0, device, device, + has_tqdm, t1, test2=t2, ) expected = np.ones((1, 1, 3, 3)) + 2.0 np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4) - result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1)( - inputs, compute, t1, test2=t2 - ) + result = SlidingWindowInferer( + roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm + )(inputs, compute, t1, test2=t2) np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)