From bd8797594046eb978460dc83b21682b3cbddf58d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 29 Jul 2023 18:01:02 +0100 Subject: [PATCH 1/3] update Signed-off-by: Wenqi Li --- monai/apps/pathology/inferers/inferer.py | 1 + monai/inferers/inferer.py | 5 +++++ monai/inferers/utils.py | 8 +++++++- tests/test_sliding_window_hovernet_inference.py | 1 + tests/test_sliding_window_inference.py | 1 + 5 files changed, 15 insertions(+), 1 deletion(-) diff --git a/monai/apps/pathology/inferers/inferer.py b/monai/apps/pathology/inferers/inferer.py index 7a60c23aa2..71259ca7df 100644 --- a/monai/apps/pathology/inferers/inferer.py +++ b/monai/apps/pathology/inferers/inferer.py @@ -178,6 +178,7 @@ def __call__( self.process_output, self.buffer_steps, self.buffer_dim, + False, *args, **kwargs, ) diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index 5484970d82..bf8c27e5c3 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -426,6 +426,8 @@ class SlidingWindowInferer(Inferer): (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. buffer_dim: the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. + with_coord: whether to pass the window coordinates to ``network``. Defaults to False. + If True, the ``network``'s 2nd input argument should accept the window coordinates. Note: ``sw_batch_size`` denotes the max number of windows per network inference iteration, @@ -449,6 +451,7 @@ def __init__( cpu_thresh: int | None = None, buffer_steps: int | None = None, buffer_dim: int = -1, + with_coord: bool = False, ) -> None: super().__init__() self.roi_size = roi_size @@ -464,6 +467,7 @@ def __init__( self.cpu_thresh = cpu_thresh self.buffer_steps = buffer_steps self.buffer_dim = buffer_dim + self.with_coord = with_coord # compute_importance_map takes long time when computing on cpu. We thus # compute it once if it's static and then save it for future usage @@ -525,6 +529,7 @@ def __call__( None, buffer_steps, buffer_dim, + self.with_coord, *args, **kwargs, ) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 92f267e8a2..a080284e7c 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -57,6 +57,7 @@ def sliding_window_inference( process_fn: Callable | None = None, buffer_steps: int | None = None, buffer_dim: int = -1, + with_coord: bool = False, *args: Any, **kwargs: Any, ) -> torch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]: @@ -125,6 +126,8 @@ def sliding_window_inference( (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency. buffer_dim: the spatial dimension along which the buffers are created. 0 indicates the first spatial dimension. Default is -1, the last spatial dimension. + with_coord: whether to pass the window coordinates to ``predictor``. Default is False. + If True, the signature of ``predictor`` should be ``predictor(patch_data, patch_coord, ...)``. args: optional args to be passed to ``predictor``. kwargs: optional keyword args to be passed to ``predictor``. @@ -220,7 +223,10 @@ def sliding_window_inference( win_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) else: win_data = inputs[unravel_slice[0]].to(sw_device) - seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch + if with_coord: + seg_prob_out = predictor(win_data, unravel_slice, *args, **kwargs) # batched patch + else: + seg_prob_out = predictor(win_data, *args, **kwargs) # batched patch # convert seg_prob_out to tuple seg_tuple, this does not allocate new memory. dict_keys, seg_tuple = _flatten_struct(seg_prob_out) diff --git a/tests/test_sliding_window_hovernet_inference.py b/tests/test_sliding_window_hovernet_inference.py index b17e8525ec..276bd1e372 100644 --- a/tests/test_sliding_window_hovernet_inference.py +++ b/tests/test_sliding_window_hovernet_inference.py @@ -237,6 +237,7 @@ def compute(data, test1, test2): None, None, 0, + False, t1, test2=t2, ) diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index f9d49361a6..8f0c074403 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -294,6 +294,7 @@ def compute(data, test1, test2): None, None, 0, + False, t1, test2=t2, ) From a1893f613590a078fd5d77318fef07e68641dda6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sat, 5 Aug 2023 17:30:01 +0100 Subject: [PATCH 2/3] temp test Signed-off-by: Wenqi Li --- runtests.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/runtests.sh b/runtests.sh index 01f58088ec..1f0baf21ca 100755 --- a/runtests.sh +++ b/runtests.sh @@ -578,6 +578,8 @@ then then install_deps fi + which ruff + whereis -b ruff ruff --version if [ $doRuffFix = true ] From 0d2e89a6836e04363dba7bf7923d9070039b72ee Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 21 Sep 2023 07:11:45 +0100 Subject: [PATCH 3/3] remove runtests change Signed-off-by: Wenqi Li --- runtests.sh | 2 -- 1 file changed, 2 deletions(-) diff --git a/runtests.sh b/runtests.sh index 1f0baf21ca..01f58088ec 100755 --- a/runtests.sh +++ b/runtests.sh @@ -578,8 +578,6 @@ then then install_deps fi - which ruff - whereis -b ruff ruff --version if [ $doRuffFix = true ]