diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 1c3ee288a13..c010b1895ff 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -356,6 +356,18 @@ Post-processing :members: :special-members: __call__ +`Filter` +""""""""""""""""""""""""""""""" +.. autoclass:: Filter + :members: + :special-members: __call__ + +`FillHoles` +""""""""""""""""""""""""""""""" +.. autoclass:: FillHoles + :members: + :special-members: __call__ + `LabelToContour` """""""""""""""" .. autoclass:: LabelToContour @@ -947,6 +959,18 @@ Post-processing (Dict) :members: :special-members: __call__ +`Filterd` +"""""""""""""""""""""""""""""""" +.. autoclass:: Filterd + :members: + :special-members: __call__ + +`FillHolesd` +"""""""""""""""""""""""""""""""" +.. autoclass:: FillHolesd + :members: + :special-members: __call__ + `LabelToContourd` """"""""""""""""" .. autoclass:: LabelToContourd diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 20e29d5aa96..d9b6cf08fb1 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -194,6 +194,8 @@ from .post.array import ( Activations, AsDiscrete, + FillHoles, + Filter, KeepLargestConnectedComponent, LabelToContour, MeanEnsemble, @@ -201,33 +203,39 @@ VoteEnsemble, ) from .post.dictionary import ( - Activationsd, ActivationsD, + Activationsd, ActivationsDict, - AsDiscreted, AsDiscreteD, + AsDiscreted, AsDiscreteDict, Ensembled, - Invertd, + FillHolesD, + FillHolesd, + FillHolesDict, + FilterD, + Filterd, + FilterDict, InvertD, + Invertd, InvertDict, - KeepLargestConnectedComponentd, KeepLargestConnectedComponentD, + KeepLargestConnectedComponentd, KeepLargestConnectedComponentDict, - LabelToContourd, LabelToContourD, + LabelToContourd, LabelToContourDict, - MeanEnsembled, MeanEnsembleD, + MeanEnsembled, MeanEnsembleDict, - ProbNMSd, ProbNMSD, + ProbNMSd, ProbNMSDict, - SaveClassificationd, SaveClassificationD, + SaveClassificationd, SaveClassificationDict, - VoteEnsembled, VoteEnsembleD, + VoteEnsembled, VoteEnsembleDict, ) from .spatial.array import ( diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 397b14e2e2c..fc41a106bec 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -23,17 +23,19 @@ from monai.networks import one_hot from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform -from monai.transforms.utils import get_largest_connected_component_mask +from monai.transforms.utils import get_filled_holes, get_largest_connected_component_mask from monai.utils import ensure_tuple __all__ = [ "Activations", "AsDiscrete", + "FillHoles", + "Filter", "KeepLargestConnectedComponent", "LabelToContour", "MeanEnsemble", - "VoteEnsemble", "ProbNMS", + "VoteEnsemble", ] @@ -289,6 +291,140 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: return output +class Filter: + """ + This transform filters out labels and can be used as a processing step to view only certain labels. + + The list of applied labels defines which labels will be kept. + + Note: + All labels which do not match the `applied_labels` are set to the background label (0). + + For example: + + Use Filter with applied_labels=[1, 5, 9]:: + + [1, 2, 3] [1, 0, 0] + [4, 5, 6] => [0, 5 ,0] + [7, 8, 9] [0, 0, 9] + """ + + def __init__(self, applied_labels: Union[Sequence[int], int]) -> None: + """ + Initialize the Filter class with the labels to filter on. + + Args: + applied_labels (Union[Sequence[int], int]): Label(s) to filter on. + """ + self.applied_labels = ensure_tuple(applied_labels) + + def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + """ + Filter the image on the `applied_labels`. + + Args: + img (Union[np.ndarray, torch.Tensor]): Pytorch tensor or numpy array of any shape. + + Raises: + NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. + + Returns: + Union[np.ndarray, torch.Tensor]: Pytorch tensor or numpy array of the same shape as the input. + """ + if isinstance(img, np.ndarray): + return np.where(np.isin(img, self.applied_labels), img, 0) + elif isinstance(img, torch.Tensor): + img_arr = img.detach().cpu().numpy() + img_arr = self(img_arr) + return torch.as_tensor(img_arr, device=img.device) + else: + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + + +class FillHoles(Transform): + """ + This transform fills holes in the image and can be used as a post-processing step to remove artifacts inside areas in model output. + + An enclosed hole is defined as a background pixel/voxel which is only enclosed by a single class. + The definition of enclosed can be defined with the connectivity parameter:: + + 1-connectivity 2-connectivity diagonal connection close-up + + [ ] [ ] [ ] [ ] [ ] + | \ | / | <- hop 2 + [ ]--[x]--[ ] [ ]--[x]--[ ] [x]--[ ] + | / | \ hop 1 + [ ] [ ] [ ] [ ] + + It is possible to define for which labels the hole filling should be applied. + The input image is assumed to be a PyTorch Tensor or numpy array + with shape [batch_size, 1, spatial_dim1[, spatial_dim2, ...]] and the values correspond to expected labels. + + Note: + The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label. + + For example: + + Use FillHoles with default parameters:: + + [1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3] + [1, 0, 1, 2, 0, 0, 3, 0] => [1, 1 ,1, 2, 0, 0, 3, 0] + [1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3] + + The hole in label 1 is fully enclosed and therefore filled with label 1. + The background label near label 2 and 3 is not fully enclosed and therefore not filled. + """ + + def __init__( + self, connectivity: Optional[int] = None, applied_labels: Optional[Union[Sequence[int], int]] = None + ) -> None: + """ + Initialize the connectivity and limit the labels for which holes are filled. + + Args: + connectivity (int, optional): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. Defaults to a full + connectivity of ``input.ndim``. + applied_labels (Optional[Union[Sequence[int], int]], optional): Labels for which to fill holes. Defaults to None, + that is filling holes for all labels. + """ + super().__init__() + self.connectivity = connectivity + self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None + + def __call__(self, img: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: + """ + Fill the holes in the provided image. + + Note: + The value 0 is assumed as background label. + + Args: + img (Union[np.ndarray, torch.Tensor]): Pytorch Tensor or numpy array of shape [batch_size, num_channel, spatial_dim1[, spatial_dim2, ...]]. + + Raises: + NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. + + Returns: + Union[np.ndarray, torch.Tensor]: Pytorch Tensor or numpy array of shape [batch_size, num_channel, spatial_dim1[, spatial_dim2, ...]]. + """ + if isinstance(img, np.ndarray): + channel_axis = 1 + img_arr = np.squeeze(img, axis=channel_axis) + output = get_filled_holes(img_arr, self.connectivity) + if self.applied_labels: + output = Filter(self.applied_labels)(output) + output = np.expand_dims(output, axis=channel_axis) + output = img_arr + output + return output + elif isinstance(img, torch.Tensor): + img_arr = img.detach().cpu().numpy() + img_arr = self(img_arr) + return torch.as_tensor(img_arr, device=img.device) + else: + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + + class LabelToContour(Transform): """ Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 6cba08948b0..8d2373cddbf 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -28,6 +28,8 @@ from monai.transforms.post.array import ( Activations, AsDiscrete, + FillHoles, + Filter, KeepLargestConnectedComponent, LabelToContour, MeanEnsemble, @@ -41,34 +43,40 @@ from monai.utils.enums import InverseKeys __all__ = [ - "Activationsd", - "AsDiscreted", - "KeepLargestConnectedComponentd", - "LabelToContourd", - "Ensembled", - "MeanEnsembled", - "VoteEnsembled", "ActivationsD", "ActivationsDict", + "Activationsd", "AsDiscreteD", "AsDiscreteDict", + "AsDiscreted", + "Ensembled", + "FillHolesD", + "FillHolesDict", + "FillHolesd", + "FilterD", + "FilterDict", + "Filterd", "InvertD", "InvertDict", "Invertd", "KeepLargestConnectedComponentD", "KeepLargestConnectedComponentDict", + "KeepLargestConnectedComponentd", "LabelToContourD", "LabelToContourDict", + "LabelToContourd", "MeanEnsembleD", "MeanEnsembleDict", - "VoteEnsembleD", - "VoteEnsembleDict", - "ProbNMSd", + "MeanEnsembled", "ProbNMSD", "ProbNMSDict", - "SaveClassificationd", + "ProbNMSd", "SaveClassificationD", "SaveClassificationDict", + "SaveClassificationd", + "VoteEnsembleD", + "VoteEnsembleDict", + "VoteEnsembled", ] @@ -208,6 +216,69 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class Filterd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.Filter`. + """ + + def __init__( + self, + keys: KeysCollection, + applied_labels: Union[Sequence[int], int], + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + applied_labels (Union[Sequence[int], int]): Label(s) to filter on. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.converter = Filter(applied_labels) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + +class FillHolesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.FillHoles`. + """ + + def __init__( + self, + keys: KeysCollection, + connectivity: Optional[int] = None, + applied_labels: Optional[Union[Sequence[int], int]] = None, + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + connectivity (int, optional): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. Defaults to a full + connectivity of ``input.ndim``. + applied_labels (Optional[Union[Sequence[int], int]], optional): Labels for which to fill holes. Defaults to None, + that is filling holes for all labels. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.converter = FillHoles(connectivity, applied_labels) + + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + class LabelToContourd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`. @@ -620,10 +691,12 @@ def get_saver(self): ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted +FillHolesD = FillHolesDict = FillHolesd +FilterD = FilterDict = Filterd +InvertD = InvertDict = Invertd KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd LabelToContourD = LabelToContourDict = LabelToContourd MeanEnsembleD = MeanEnsembleDict = MeanEnsembled ProbNMSD = ProbNMSDict = ProbNMSd -VoteEnsembleD = VoteEnsembleDict = VoteEnsembled -InvertD = InvertDict = Invertd SaveClassificationD = SaveClassificationDict = SaveClassificationd +VoteEnsembleD = VoteEnsembleDict = VoteEnsembled diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 2da7b688cb3..5472cb2e4a3 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -37,6 +37,7 @@ ) measure, _ = optional_import("skimage.measure", "0.14.2", min_version) +ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") @@ -66,6 +67,7 @@ "create_translate", "generate_spatial_bounding_box", "get_largest_connected_component_mask", + "get_filled_holes", "get_extreme_points", "extreme_points_to_image", "map_spatial_axes", @@ -724,6 +726,50 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option return torch.as_tensor(largest_cc, device=img.device) +def get_filled_holes(img_arr: np.ndarray, connectivity: Optional[int] = None) -> np.ndarray: + """ + Fill the holes in the provided image. + + The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label. + What is considered to be an enclosed hole is defined by the connectivity. + Holes on the edge are always considered to be open (not enclosed). + + Args: + img_arr (np.ndarray): numpy array of shape [batch_size, spatial_dim1[, spatial_dim2, ...]]. + connectivity (Optional[int], optional): connectivity (int, optional): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. Defaults to a full + connectivity of ``input.ndim``. + + Returns: + np.ndarray: numpy array of shape [batch_size, spatial_dim1[, spatial_dim2, ...]]. + """ + ndim = img_arr.ndim - 1 + if not connectivity: + connectivity = ndim + + footprint = ndimage.generate_binary_structure(ndim, connectivity) + + background_label = 0 + filled_holes = np.zeros_like(img_arr) + for i, item in enumerate(img_arr): + background_mask = item == background_label + components, num_components = ndimage.label(background_mask, structure=footprint) + + for component_label in range(1, num_components + 1): + component_mask = components == component_label + # Pad with -1 to detect edge voxels + component_neighborhood = np.pad(item, 1, constant_values=-1)[ + ndimage.binary_dilation(np.pad(component_mask, 1), structure=footprint) + ] + + neighbor_labels = np.unique(component_neighborhood) + if len(neighbor_labels) == 2 and -1 not in neighbor_labels: + neighbor_label = neighbor_labels[1] + filled_holes[i, component_mask] = neighbor_label + + return filled_holes + + def get_extreme_points( img: np.ndarray, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 ) -> List[Tuple[int, ...]]: diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py new file mode 100644 index 00000000000..131a52a1b7b --- /dev/null +++ b/tests/test_fill_holes.py @@ -0,0 +1,265 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import FillHoles +from tests.utils import allclose, clone + + +grid_1_raw = [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1], +] + +grid_2_raw = [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0], +] + +grid_3_raw = [ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], +] + +grid_4_raw = [ + [0, 1, 0], + [1, 1, 1], + [0, 1, 0], +] + +grid_1 = torch.tensor([[grid_1_raw]]) + +grid_2 = torch.tensor([[grid_2_raw]]) + +grid_3 = torch.tensor([[grid_3_raw]]) + +grid_4 = torch.tensor([[grid_4_raw]]) + +grid_5 = torch.tensor( + [ + [ + [ + [1, 1, 1], + [1, 0, 0], + [1, 1, 1], + ] + ] + ] +) + +grid_6 = torch.tensor( + [ + [ + [ + [1, 1, 2, 2, 2], + [1, 0, 2, 0, 2], + [1, 1, 2, 2, 2], + ] + ] + ] +) + +grid_7 = torch.tensor( + [ + [ + [ + [1, 1, 2, 2, 2], + [1, 0, 2, 2, 2], + [1, 1, 2, 2, 2], + ] + ] + ] +) + + +TEST_CASE_0 = [ + "enclosed_default_full_connectivity_default_applied_labels", + {}, + grid_1, + grid_3, +] + +TEST_CASE_1 = [ + "enclosed_full_connectivity_default_applied_labels", + {"connectivity": 2}, + grid_1, + grid_3, +] + +TEST_CASE_2 = [ + "enclosed_full_connectivity_applied_labels_same_single", + {"connectivity": 2, "applied_labels": 1}, + grid_1, + grid_3, +] + +TEST_CASE_3 = [ + "enclosed_full_connectivity_applied_labels_same_list", + {"connectivity": 2, "applied_labels": [1]}, + grid_1, + grid_3, +] + +TEST_CASE_4 = [ + "enclosed_full_connectivity_applied_labels_other_single", + {"connectivity": 2, "applied_labels": 2}, + grid_1, + grid_1, +] + +TEST_CASE_5 = [ + "enclosed_full_connectivity_applied_labels_other_list", + {"connectivity": 2, "applied_labels": [2]}, + grid_1, + grid_1, +] + +TEST_CASE_6 = [ + "enclosed_full_connectivity_applied_labels_same_and_other", + {"connectivity": 2, "applied_labels": [1, 2]}, + grid_1, + grid_3, +] + +TEST_CASE_7 = [ + "enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_1, + grid_3, +] + +TEST_CASE_8 = [ + "enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_2, + grid_4, +] + +TEST_CASE_9 = [ + "open_full_connectivity_default_applied_labels", + {"connectivity": 2}, + grid_2, + grid_2, +] + +TEST_CASE_10 = [ + "open_to_edge_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_5, + grid_5, +] + +TEST_CASE_11 = [ + "open_to_other_label_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_6, + grid_7, +] + +TEST_CASE_12 = [ + "open_to_other_label_connectivity_1_applied_labels_other", + {"connectivity": 1, "applied_labels": 1}, + grid_6, + grid_6, +] + +TEST_CASE_13 = [ + "numpy_enclosed_default_full_connectivity_default_applied_labels", + {}, + grid_1.cpu().numpy(), + grid_3.cpu().numpy(), +] + +TEST_CASE_14 = [ + "3D_enclosed_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[[grid_3_raw, grid_1_raw, grid_3_raw]]]), + torch.tensor([[[grid_3_raw, grid_3_raw, grid_3_raw]]]), +] + +TEST_CASE_15 = [ + "3D_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]), + torch.tensor([[[grid_4_raw, grid_4_raw, grid_4_raw]]]), +] + +TEST_CASE_16 = [ + "3D_open_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]), + torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]), +] + +TEST_CASE_17 = [ + "3D_open_to_edge_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[[grid_1_raw, grid_1_raw, grid_3_raw]]]), + torch.tensor([[[grid_1_raw, grid_1_raw, grid_3_raw]]]), +] + +VALID_CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, + TEST_CASE_13, + TEST_CASE_14, + TEST_CASE_15, + TEST_CASE_16, + TEST_CASE_17, +] + +ITEST_CASE_1 = ["invalid_image_data_type", {}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_image, expected): + converter = FillHoles(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + result = converter(clone(input_image).cuda()) + assert allclose(result, expected.cuda()) + else: + result = converter(clone(input_image)) + assert allclose(result, expected) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + with self.assertRaises(expected_error): + converter = FillHoles(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter(clone(input_image).cuda()) + else: + _ = converter(clone(input_image)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_filter.py b/tests/test_filter.py new file mode 100644 index 00000000000..58c1c57a4f4 --- /dev/null +++ b/tests/test_filter.py @@ -0,0 +1,127 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import Filter +from tests.utils import allclose, clone + +grid_1 = torch.tensor( + [ + [ + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + ] + ] +) + + +TEST_CASE_0 = [ + "filter_single_label", + {"applied_labels": 3}, + grid_1, + torch.tensor( + [ + [ + [ + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + ] + ] + ] + ), +] + + +TEST_CASE_1 = [ + "filter_single_label_list", + {"applied_labels": [3]}, + grid_1, + torch.tensor( + [ + [ + [ + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + ] + ] + ] + ), +] + +TEST_CASE_2 = [ + "filter_multi_label", + {"applied_labels": [3, 5, 8]}, + grid_1, + torch.tensor( + [ + [ + [ + [0, 0, 3], + [0, 5, 0], + [0, 8, 0], + ] + ] + ] + ), +] + +TEST_CASE_3 = [ + "filter_all", + {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, + grid_1, + grid_1, +] + +VALID_CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, +] + +ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_image, expected): + converter = Filter(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + result = converter(clone(input_image).cuda()) + assert allclose(result, expected.cuda()) + else: + result = converter(clone(input_image)) + assert allclose(result, expected) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + with self.assertRaises(expected_error): + converter = Filter(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter(clone(input_image).cuda()) + else: + _ = converter(clone(input_image)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index a8835329ba3..670dd2d2ee9 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent +from tests.utils import allclose, clone grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) @@ -322,23 +323,23 @@ class TestKeepLargestConnectedComponent(unittest.TestCase): @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, args, tensor, expected): + def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - result = converter(tensor.clone().cuda()) - assert torch.allclose(result, expected.cuda()) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + result = converter(clone(input_image).cuda()) + assert allclose(result, expected.cuda()) else: - result = converter(tensor.clone()) - assert torch.allclose(result, expected) + result = converter(clone(input_image)) + assert allclose(result, expected) @parameterized.expand(INVALID_CASES) - def test_raise_exception(self, _, args, tensor, expected_error): + def test_raise_exception(self, _, args, input_image, expected_error): with self.assertRaises(expected_error): converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - _ = converter(tensor.clone().cuda()) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter(clone(input_image).cuda()) else: - _ = converter(tensor.clone()) + _ = converter(clone(input_image).clone()) if __name__ == "__main__": diff --git a/tests/utils.py b/tests/utils.py index ce280a13f06..f8f1ff9bb3c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import datetime import functools import importlib @@ -22,7 +23,7 @@ import warnings from io import BytesIO from subprocess import PIPE, Popen -from typing import Optional +from typing import Any, Optional, Union from urllib.error import ContentTooShortError, HTTPError, URLError import numpy as np @@ -39,6 +40,42 @@ quick_test_var = "QUICKTEST" +def clone(data: Any) -> Any: + """ + Clone data independent of type. + + Args: + data (Any): This can be a Pytorch Tensor, numpy array, list, ... + + Returns: + Any: Cloned data object + """ + return copy.deepcopy(data) + + +def allclose(a: Union[np.ndarray, torch.Tensor], b: Union[np.ndarray, torch.Tensor]) -> bool: + """ + Check if all values of two data objects are close. + + Note: + This method also checks that both data objects are either Pytorch Tensors or numpy arrays. + + Args: + a (Union[np.ndarray, torch.Tensor]): Pytorch Tensor or numpy array for comparison + b (Union[np.ndarray, torch.Tensor]): Pytorch Tensor or numpy array to compare against + + Returns: + bool: If both data objects are close. + """ + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.allclose(a, b) + + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.allclose(a, b) + + return False + + def test_pretrained_networks(network, input_param, device): try: net = network(**input_param).to(device)