diff --git a/monai/inferers/inferer.py b/monai/inferers/inferer.py index f9c04664be..5484970d82 100644 --- a/monai/inferers/inferer.py +++ b/monai/inferers/inferer.py @@ -22,6 +22,7 @@ from monai.apps.utils import get_logger from monai.data.meta_tensor import MetaTensor +from monai.data.thread_buffer import ThreadBuffer from monai.inferers.merger import AvgMerger, Merger from monai.inferers.splitter import Splitter from monai.inferers.utils import compute_importance_map, sliding_window_inference @@ -103,6 +104,7 @@ class PatchInferer(Inferer): the output dictionary to be used for merging. Defaults to None, where all the keys are used. match_spatial_shape: whether to crop the output to match the input shape. Defaults to True. + buffer_size: number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0. merger_kwargs: arguments to be passed to `merger_cls` for instantiation. `merged_shape` is calculated automatically based on the input shape and the output patch shape unless it is passed here. @@ -117,6 +119,7 @@ def __init__( postprocessing: Callable | None = None, output_keys: Sequence | None = None, match_spatial_shape: bool = True, + buffer_size: int = 0, **merger_kwargs: Any, ) -> None: Inferer.__init__(self) @@ -157,6 +160,8 @@ def __init__( self.postprocessing = postprocessing # batch size for patches + if batch_size < 1: + raise ValueError(f"`batch_size` must be a positive number, {batch_size} is given.") self.batch_size = batch_size # model output keys @@ -165,6 +170,9 @@ def __init__( # whether to crop the output to match the input shape self.match_spatial_shape = match_spatial_shape + # buffer size for multithreaded batch sampling + self.buffer_size = buffer_size + def _batch_sampler( self, patches: Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor ) -> Iterator[tuple[torch.Tensor, Sequence, int]]: @@ -182,10 +190,16 @@ def _batch_sampler( batch_size = min(self.batch_size, total_size - i) yield patches[i : i + batch_size], patches[i : i + batch_size].meta[PatchKeys.LOCATION], batch_size # type: ignore else: + buffer: Iterable | ThreadBuffer + if self.buffer_size > 0: + # Use multi-threading to sample patches with a buffer + buffer = ThreadBuffer(patches, buffer_size=self.buffer_size, timeout=0.1) + else: + buffer = patches patch_batch: list[Any] = [None] * self.batch_size location_batch: list[Any] = [None] * self.batch_size idx_in_batch = 0 - for sample in patches: + for sample in buffer: patch_batch[idx_in_batch] = sample[0] location_batch[idx_in_batch] = sample[1] idx_in_batch += 1 diff --git a/tests/test_patch_inferer.py b/tests/test_patch_inferer.py index b0d25a98b9..bc6fc30a88 100644 --- a/tests/test_patch_inferer.py +++ b/tests/test_patch_inferer.py @@ -127,6 +127,7 @@ TENSOR_4x4, ] + # non-divisible patch_size leading to larger image (without matching spatial shape) TEST_CASE_11_PADDING = [ TENSOR_4x4, @@ -155,6 +156,23 @@ TENSOR_4x4, ] +# multi-threading +TEST_CASE_14_MULTITHREAD_BUFFER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=2), + lambda x: x, + TENSOR_4x4, +] + +# multi-threading with batch +TEST_CASE_15_MULTITHREADD_BUFFER = [ + TENSOR_4x4, + dict(splitter=SlidingWindowSplitter(patch_size=(2, 2)), merger_cls=AvgMerger, buffer_size=4, batch_size=4), + lambda x: x, + TENSOR_4x4, +] + + # list of tensor output TEST_CASE_0_LIST_TENSOR = [ TENSOR_4x4, @@ -245,6 +263,8 @@ class PatchInfererTests(unittest.TestCase): TEST_CASE_11_PADDING, TEST_CASE_12_MATCHING, TEST_CASE_13_PADDING_MATCHING, + TEST_CASE_14_MULTITHREAD_BUFFER, + TEST_CASE_15_MULTITHREADD_BUFFER, ] ) def test_patch_inferer_tensor(self, inputs, arguments, network, expected):