diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index b7462c11dd..7f64c394c5 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -384,14 +384,17 @@ class CropForeground(Transform): """ def __init__( - self, select_fn: Callable = lambda x: x > 0, channel_indices: Optional[IndexSelection] = None, margin: int = 0 + self, + select_fn: Callable = lambda x: x > 0, + channel_indices: Optional[IndexSelection] = None, + margin: Union[Sequence[int], int] = 0, ) -> None: """ Args: select_fn: function to select expected foreground, default is to select values > 0. channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. - margin: add margin to all dims of the bounding box. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ self.select_fn = select_fn self.channel_indices = ensure_tuple(channel_indices) if channel_indices is not None else None diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index fdd42dc87e..ec69c0801b 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -336,7 +336,7 @@ def __init__( select_fn: function to select expected foreground, default is to select values > 0. channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. - margin: add margin to all dims of the bounding box. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ super().__init__(keys) self.source_key = source_key diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index c5131dea1e..d7867c2bc6 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -17,7 +17,7 @@ import torch from monai.config import IndexSelection -from monai.utils import ensure_tuple, ensure_tuple_size, fall_back_tuple, min_version, optional_import +from monai.utils import ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) @@ -455,7 +455,7 @@ def generate_spatial_bounding_box( img: np.ndarray, select_fn: Callable = lambda x: x > 0, channel_indices: Optional[IndexSelection] = None, - margin: int = 0, + margin: Union[Sequence[int], int] = 0, ) -> Tuple[List[int], List[int]]: """ generate the spatial bounding box of foreground in the image with start-end positions. @@ -467,19 +467,19 @@ def generate_spatial_bounding_box( select_fn: function to select expected foreground, default is to select values > 0. channel_indices: if defined, select foreground only on the specified channels of image. if None, select foreground on the whole image. - margin: add margin to all dims of the bounding box. + margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims. """ - assert isinstance(margin, int), "margin must be int type." data = img[[*(ensure_tuple(channel_indices))]] if channel_indices is not None else img data = np.any(select_fn(data), axis=0) nonzero_idx = np.nonzero(data) + margin = ensure_tuple_rep(margin, data.ndim) box_start = list() box_end = list() for i in range(data.ndim): assert len(nonzero_idx[i]) > 0, f"did not find nonzero index at spatial dim {i}" - box_start.append(max(0, np.min(nonzero_idx[i]) - margin)) - box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin + 1)) + box_start.append(max(0, np.min(nonzero_idx[i]) - margin[i])) + box_end.append(min(data.shape[i], np.max(nonzero_idx[i]) + margin[i] + 1)) return box_start, box_end diff --git a/tests/test_crop_foreground.py b/tests/test_crop_foreground.py index 74da6da5a6..88d99932c1 100644 --- a/tests/test_crop_foreground.py +++ b/tests/test_crop_foreground.py @@ -40,9 +40,15 @@ np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), ] +TEST_CASE_5 = [ + {"select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), +] + class TestCropForeground(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, argments, image, expected_data): result = CropForeground(**argments)(image) np.testing.assert_allclose(result, expected_data) diff --git a/tests/test_crop_foregroundd.py b/tests/test_crop_foregroundd.py index f7c0c2176b..256f8a937f 100644 --- a/tests/test_crop_foregroundd.py +++ b/tests/test_crop_foregroundd.py @@ -49,9 +49,15 @@ np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0]]]), ] +TEST_CASE_5 = [ + {"keys": ["img"], "source_key": "img", "select_fn": lambda x: x > 0, "channel_indices": None, "margin": [2, 1]}, + {"img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]])}, + np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), +] + class TestCropForegroundd(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, argments, image, expected_data): result = CropForegroundd(**argments)(image) np.testing.assert_allclose(result["img"], expected_data) diff --git a/tests/test_generate_spatial_bounding_box.py b/tests/test_generate_spatial_bounding_box.py index e9935d6c8f..338b6fe5d4 100644 --- a/tests/test_generate_spatial_bounding_box.py +++ b/tests/test_generate_spatial_bounding_box.py @@ -56,9 +56,19 @@ ([0, 0], [4, 5]), ] +TEST_CASE_5 = [ + { + "img": np.array([[[0, 0, 0, 0, 0], [0, 1, 2, 1, 0], [0, 2, 3, 2, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]), + "select_fn": lambda x: x > 0, + "channel_indices": None, + "margin": [2, 1], + }, + ([0, 0], [5, 5]), +] + class TestGenerateSpatialBoundingBox(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) def test_value(self, input_data, expected_box): result = generate_spatial_bounding_box(**input_data) self.assertTupleEqual(result, expected_box)