Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sw inferer progress #3899

Merged
merged 1 commit into from
Mar 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monai/inferers/inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class SlidingWindowInferer(Inferer):
By default the device (and accordingly the memory) of the `inputs` is used. If for example
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
`inputs` and `roi_size`. Output is on the `device`.
progress: whether to print a tqdm progress bar.

Note:
``sw_batch_size`` denotes the max number of windows per network inference iteration,
Expand All @@ -137,6 +138,7 @@ def __init__(
cval: float = 0.0,
sw_device: Union[torch.device, str, None] = None,
device: Union[torch.device, str, None] = None,
progress: bool = False,
) -> None:
Inferer.__init__(self)
self.roi_size = roi_size
Expand All @@ -148,6 +150,7 @@ def __init__(
self.cval = cval
self.sw_device = sw_device
self.device = device
self.progress = progress

def __call__(
self, inputs: torch.Tensor, network: Callable[..., torch.Tensor], *args: Any, **kwargs: Any
Expand All @@ -174,6 +177,7 @@ def __call__(
self.cval,
self.sw_device,
self.device,
self.progress,
*args,
**kwargs,
)
Expand Down
8 changes: 6 additions & 2 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch.nn.functional as F

from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size
from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option
from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option, optional_import

tqdm, _ = optional_import("tqdm", name="tqdm")

__all__ = ["sliding_window_inference"]

Expand All @@ -32,6 +34,7 @@ def sliding_window_inference(
cval: float = 0.0,
sw_device: Union[torch.device, str, None] = None,
device: Union[torch.device, str, None] = None,
progress: bool = False,
*args: Any,
**kwargs: Any,
) -> torch.Tensor:
Expand Down Expand Up @@ -74,6 +77,7 @@ def sliding_window_inference(
By default the device (and accordingly the memory) of the `inputs` is used. If for example
set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
`inputs` and `roi_size`. Output is on the `device`.
progress: whether to print a `tqdm` progress bar.
args: optional args to be passed to ``predictor``.
kwargs: optional keyword args to be passed to ``predictor``.

Expand Down Expand Up @@ -120,7 +124,7 @@ def sliding_window_inference(
# Perform predictions
output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)
_initialized = False
for slice_g in range(0, total_slices, sw_batch_size):
for slice_g in tqdm(range(0, total_slices, sw_batch_size)) if progress else range(0, total_slices, sw_batch_size):
slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices))
unravel_slice = [
[slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win])
Expand Down
11 changes: 8 additions & 3 deletions tests/test_sliding_window_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
from parameterized import parameterized

from monai.inferers import SlidingWindowInferer, sliding_window_inference
from monai.utils import optional_import
from tests.utils import skip_if_no_cuda

_, has_tqdm = optional_import("tqdm")

TEST_CASES = [
[(2, 3, 16), (4,), 3, 0.25, "constant", torch.device("cpu:0")], # 1D small roi
[(2, 3, 16, 15, 7, 9), 4, 3, 0.25, "constant", torch.device("cpu:0")], # 4D small roi
Expand Down Expand Up @@ -145,6 +148,7 @@ def compute(self, data):
cval=-1,
mode="gaussian",
sigma_scale=1.0,
progress=has_tqdm,
)
expected = np.array(
[
Expand Down Expand Up @@ -222,15 +226,16 @@ def compute(data, test1, test2):
0.0,
device,
device,
has_tqdm,
t1,
test2=t2,
)
expected = np.ones((1, 1, 3, 3)) + 2.0
np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)

result = SlidingWindowInferer(roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1)(
inputs, compute, t1, test2=t2
)
result = SlidingWindowInferer(
roi_shape, sw_batch_size, overlap=0.5, mode="constant", cval=-1, progress=has_tqdm
)(inputs, compute, t1, test2=t2)
np.testing.assert_allclose(result.cpu().numpy(), expected, rtol=1e-4)


Expand Down