diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6a25c62c49..8a880ff151 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -356,6 +356,18 @@ Post-processing :members: :special-members: __call__ +`LabelFilter` +""""""""""""" +.. autoclass:: LabelFilter + :members: + :special-members: __call__ + +`FillHoles` +""""""""""" +.. autoclass:: FillHoles + :members: + :special-members: __call__ + `LabelToContour` """""""""""""""" .. autoclass:: LabelToContour @@ -955,6 +967,18 @@ Post-processing (Dict) :members: :special-members: __call__ +`LabelFilterd` +"""""""""""""" +.. autoclass:: LabelFilterd + :members: + :special-members: __call__ + +`FillHolesd` +"""""""""""" +.. autoclass:: FillHolesd + :members: + :special-members: __call__ + `LabelToContourd` """"""""""""""""" .. autoclass:: LabelToContourd diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index cf9198dbf5..7f2873cc85 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -194,40 +194,48 @@ from .post.array import ( Activations, AsDiscrete, + FillHoles, KeepLargestConnectedComponent, + LabelFilter, LabelToContour, MeanEnsemble, ProbNMS, VoteEnsemble, ) from .post.dictionary import ( - Activationsd, ActivationsD, + Activationsd, ActivationsDict, - AsDiscreted, AsDiscreteD, + AsDiscreted, AsDiscreteDict, Ensembled, - Invertd, + FillHolesD, + FillHolesd, + FillHolesDict, InvertD, + Invertd, InvertDict, - KeepLargestConnectedComponentd, KeepLargestConnectedComponentD, + KeepLargestConnectedComponentd, KeepLargestConnectedComponentDict, - LabelToContourd, + LabelFilterD, + LabelFilterd, + LabelFilterDict, LabelToContourD, + LabelToContourd, LabelToContourDict, - MeanEnsembled, MeanEnsembleD, + MeanEnsembled, MeanEnsembleDict, - ProbNMSd, ProbNMSD, + ProbNMSd, ProbNMSDict, - SaveClassificationd, SaveClassificationD, + SaveClassificationd, SaveClassificationDict, - VoteEnsembled, VoteEnsembleD, + VoteEnsembled, VoteEnsembleDict, ) from .spatial.array import ( diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 397b14e2e2..a33fce785e 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -14,26 +14,29 @@ """ import warnings -from typing import Callable, Optional, Sequence, Union +from typing import Callable, Iterable, Optional, Sequence, Union import numpy as np import torch import torch.nn.functional as F +from monai.config import NdarrayTensor from monai.networks import one_hot from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform -from monai.transforms.utils import get_largest_connected_component_mask +from monai.transforms.utils import fill_holes, get_largest_connected_component_mask from monai.utils import ensure_tuple __all__ = [ "Activations", "AsDiscrete", + "FillHoles", "KeepLargestConnectedComponent", + "LabelFilter", "LabelToContour", "MeanEnsemble", - "VoteEnsemble", "ProbNMS", + "VoteEnsemble", ] @@ -289,6 +292,137 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor: return output +class LabelFilter: + """ + This transform filters out labels and can be used as a processing step to view only certain labels. + + The list of applied labels defines which labels will be kept. + + Note: + All labels which do not match the `applied_labels` are set to the background label (0). + + For example: + + Use LabelFilter with applied_labels=[1, 5, 9]:: + + [1, 2, 3] [1, 0, 0] + [4, 5, 6] => [0, 5 ,0] + [7, 8, 9] [0, 0, 9] + """ + + def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: + """ + Initialize the LabelFilter class with the labels to filter on. + + Args: + applied_labels: Label(s) to filter on. + """ + self.applied_labels = ensure_tuple(applied_labels) + + def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + """ + Filter the image on the `applied_labels`. + + Args: + img: Pytorch tensor or numpy array of any shape. + + Raises: + NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. + + Returns: + Pytorch tensor or numpy array of the same shape as the input. + """ + if isinstance(img, np.ndarray): + return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0)) + elif isinstance(img, torch.Tensor): + img_arr = img.detach().cpu().numpy() + img_arr = self(img_arr) + return torch.as_tensor(img_arr, device=img.device) + else: + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + + +class FillHoles(Transform): + r""" + This transform fills holes in the image and can be used to remove artifacts inside segments. + + An enclosed hole is defined as a background pixel/voxel which is only enclosed by a single class. + The definition of enclosed can be defined with the connectivity parameter:: + + 1-connectivity 2-connectivity diagonal connection close-up + + [ ] [ ] [ ] [ ] [ ] + | \ | / | <- hop 2 + [ ]--[x]--[ ] [ ]--[x]--[ ] [x]--[ ] + | / | \ hop 1 + [ ] [ ] [ ] [ ] + + It is possible to define for which labels the hole filling should be applied. + The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]]. + If C = 1, then the values correspond to expected labels. + If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing. + + Note: + + The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label. + + The performance of this method heavily depends on the number of labels. + It is a bit faster if the list of `applied_labels` is provided. + Limiting the number of `applied_labels` results in a big decrease in processing time. + + For example: + + Use FillHoles with default parameters:: + + [1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3] + [1, 0, 1, 2, 0, 0, 3, 0] => [1, 1 ,1, 2, 0, 0, 3, 0] + [1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3] + + The hole in label 1 is fully enclosed and therefore filled with label 1. + The background label near label 2 and 3 is not fully enclosed and therefore not filled. + """ + + def __init__( + self, applied_labels: Optional[Union[Iterable[int], int]] = None, connectivity: Optional[int] = None + ) -> None: + """ + Initialize the connectivity and limit the labels for which holes are filled. + + Args: + applied_labels: Labels for which to fill holes. Defaults to None, that is filling holes for all labels. + connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. Defaults to a full connectivity of ``input.ndim``. + """ + super().__init__() + self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None + self.connectivity = connectivity + + def __call__(self, img: NdarrayTensor) -> NdarrayTensor: + """ + Fill the holes in the provided image. + + Note: + The value 0 is assumed as background label. + + Args: + img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. + + Raises: + NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. + + Returns: + Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. + """ + if isinstance(img, np.ndarray): + return fill_holes(img, self.applied_labels, self.connectivity) + elif isinstance(img, torch.Tensor): + img_arr = img.detach().cpu().numpy() + img_arr = self(img_arr) + return torch.as_tensor(img_arr, device=img.device) + else: + raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.") + + class LabelToContour(Transform): """ Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 6cba08948b..0d9be131fc 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -17,18 +17,20 @@ import warnings from copy import deepcopy -from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Hashable, Iterable, List, Mapping, Optional, Sequence, Union import numpy as np import torch -from monai.config import KeysCollection +from monai.config import KeysCollection, NdarrayTensor from monai.data.csv_saver import CSVSaver from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( Activations, AsDiscrete, + FillHoles, KeepLargestConnectedComponent, + LabelFilter, LabelToContour, MeanEnsemble, ProbNMS, @@ -41,34 +43,40 @@ from monai.utils.enums import InverseKeys __all__ = [ - "Activationsd", - "AsDiscreted", - "KeepLargestConnectedComponentd", - "LabelToContourd", - "Ensembled", - "MeanEnsembled", - "VoteEnsembled", "ActivationsD", "ActivationsDict", + "Activationsd", "AsDiscreteD", "AsDiscreteDict", + "AsDiscreted", + "Ensembled", + "FillHolesD", + "FillHolesDict", + "FillHolesd", "InvertD", "InvertDict", "Invertd", "KeepLargestConnectedComponentD", "KeepLargestConnectedComponentDict", + "KeepLargestConnectedComponentd", + "LabelFilterD", + "LabelFilterDict", + "LabelFilterd", "LabelToContourD", "LabelToContourDict", + "LabelToContourd", "MeanEnsembleD", "MeanEnsembleDict", - "VoteEnsembleD", - "VoteEnsembleDict", - "ProbNMSd", + "MeanEnsembled", "ProbNMSD", "ProbNMSDict", - "SaveClassificationd", + "ProbNMSd", "SaveClassificationD", "SaveClassificationDict", + "SaveClassificationd", + "VoteEnsembleD", + "VoteEnsembleDict", + "VoteEnsembled", ] @@ -208,6 +216,70 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc return d +class LabelFilterd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.LabelFilter`. + """ + + def __init__( + self, + keys: KeysCollection, + applied_labels: Union[Sequence[int], int], + allow_missing_keys: bool = False, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + applied_labels: Label(s) to filter on. + allow_missing_keys: don't raise exception if key is missing. + + """ + super().__init__(keys, allow_missing_keys) + self.converter = LabelFilter(applied_labels) + + def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + +class FillHolesd(MapTransform): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.FillHoles`. + """ + + def __init__( + self, + keys: KeysCollection, + applied_labels: Optional[Union[Iterable[int], int]] = None, + connectivity: Optional[int] = None, + allow_missing_keys: bool = False, + ) -> None: + """ + Initialize the connectivity and limit the labels for which holes are filled. + + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + applied_labels (Optional[Union[Iterable[int], int]], optional): Labels for which to fill holes. Defaults to None, + that is filling holes for all labels. + connectivity (int, optional): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. + Accepted values are ranging from 1 to input.ndim. Defaults to a full + connectivity of ``input.ndim``. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.converter = FillHoles(applied_labels=applied_labels, connectivity=connectivity) + + def __call__(self, data: Mapping[Hashable, NdarrayTensor]) -> Dict[Hashable, NdarrayTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.converter(d[key]) + return d + + class LabelToContourd(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`. @@ -620,10 +692,12 @@ def get_saver(self): ActivationsD = ActivationsDict = Activationsd AsDiscreteD = AsDiscreteDict = AsDiscreted +FillHolesD = FillHolesDict = FillHolesd +InvertD = InvertDict = Invertd KeepLargestConnectedComponentD = KeepLargestConnectedComponentDict = KeepLargestConnectedComponentd +LabelFilterD = LabelFilterDict = LabelFilterd LabelToContourD = LabelToContourDict = LabelToContourd MeanEnsembleD = MeanEnsembleDict = MeanEnsembled ProbNMSD = ProbNMSDict = ProbNMSd -VoteEnsembleD = VoteEnsembleDict = VoteEnsembled -InvertD = InvertDict = Invertd SaveClassificationD = SaveClassificationDict = SaveClassificationd +VoteEnsembleD = VoteEnsembleDict = VoteEnsembled diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 6dd2d2539f..366e2d245e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -14,7 +14,7 @@ import re import warnings from contextlib import contextmanager -from typing import Callable, List, Optional, Sequence, Tuple, Union +from typing import Callable, Iterable, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -37,43 +37,45 @@ ) measure, _ = optional_import("skimage.measure", "0.14.2", min_version) +ndimage, _ = optional_import("scipy.ndimage") cp, has_cp = optional_import("cupy") cp_ndarray, _ = optional_import("cupy", name="ndarray") __all__ = [ - "rand_choice", - "img_bounds", - "in_bounds", - "is_empty", - "is_positive", - "zero_margins", - "rescale_array", - "rescale_instance_array", - "rescale_array_int_max", - "copypaste_arrays", + "allow_missing_keys_mode", "compute_divisible_spatial_size", - "resize_center", - "map_binary_to_indices", - "map_classes_to_indices", - "weighted_patch_samples", - "generate_pos_neg_label_crop_centers", - "generate_label_classes_crop_centers", - "create_grid", + "convert_inverse_interp_mode", + "convert_to_numpy", + "convert_to_tensor", + "copypaste_arrays", "create_control_grid", + "create_grid", "create_rotate", - "create_shear", "create_scale", + "create_shear", "create_translate", + "extreme_points_to_image", + "fill_holes", + "generate_label_classes_crop_centers", + "generate_pos_neg_label_crop_centers", "generate_spatial_bounding_box", - "get_largest_connected_component_mask", "get_extreme_points", - "extreme_points_to_image", + "get_largest_connected_component_mask", + "img_bounds", + "in_bounds", + "is_empty", + "is_positive", + "map_binary_to_indices", + "map_classes_to_indices", "map_spatial_axes", - "allow_missing_keys_mode", - "convert_inverse_interp_mode", - "convert_to_tensor", - "convert_to_numpy", + "rand_choice", + "rescale_array", + "rescale_array_int_max", + "rescale_instance_array", + "resize_center", "tensor_to_numpy", + "weighted_patch_samples", + "zero_margins", ] @@ -732,6 +734,65 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option return torch.as_tensor(largest_cc, device=img.device) +def fill_holes( + img_arr: np.ndarray, applied_labels: Optional[Iterable[int]] = None, connectivity: Optional[int] = None +) -> np.ndarray: + """ + Fill the holes in the provided image. + + The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label. + What is considered to be an enclosed hole is defined by the connectivity. + Holes on the edge are always considered to be open (not enclosed). + + Note: + + The performance of this method heavily depends on the number of labels. + It is a bit faster if the list of `applied_labels` is provided. + Limiting the number of `applied_labels` results in a big decrease in processing time. + + If the image is one-hot-encoded, then the `applied_labels` need to match the channel index. + + Args: + img_arr: numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. + applied_labels: Labels for which to fill holes. Defaults to None, + that is filling holes for all labels. + connectivity: Maximum number of orthogonal hops to + consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. + Defaults to a full connectivity of ``input.ndim``. + + Returns: + numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. + """ + channel_axis = 0 + num_channels = img_arr.shape[channel_axis] + is_one_hot = num_channels > 1 + spatial_dims = img_arr.ndim - 1 + structure = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims) + + # Get labels if not provided. Exclude background label. + applied_labels = set(applied_labels or (range(num_channels) if is_one_hot else np.unique(img_arr))) + background_label = 0 + applied_labels.discard(background_label) + + for label in applied_labels: + tmp = np.zeros(img_arr.shape[1:], dtype=bool) + ndimage.binary_dilation( + tmp, + structure=structure, + iterations=-1, + mask=np.logical_not(img_arr[label]) if is_one_hot else img_arr[0] != label, + origin=0, + border_value=1, + output=tmp, + ) + if is_one_hot: + img_arr[label] = np.logical_not(tmp) + else: + img_arr[0, np.logical_not(tmp)] = label + + return img_arr + + def get_extreme_points( img: np.ndarray, rand_state: Optional[np.random.RandomState] = None, background: int = 0, pert: float = 0.0 ) -> List[Tuple[int, ...]]: diff --git a/tests/min_tests.py b/tests/min_tests.py index 1f53569cd9..afe88f7433 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -31,59 +31,81 @@ def run_testsuit(): "test_arraydataset", "test_cachedataset", "test_cachedataset_parallel", + "test_cachedataset_persistent_workers", + "test_cachentransdataset", + "test_csv_dataset", + "test_csv_iterable_dataset", "test_dataset", + "test_dataset_summary", + "test_deepgrow_dataset", + "test_deepgrow_interaction", + "test_deepgrow_transforms", "test_detect_envelope", "test_efficientnet", - "test_iterable_dataset", "test_ensemble_evaluator", + "test_ensure_channel_first", + "test_ensure_channel_firstd", + "test_fill_holes", "test_handler_checkpoint_loader", "test_handler_checkpoint_saver", "test_handler_classification_saver", - "test_handler_lr_scheduler", + "test_handler_classification_saver_dist", "test_handler_confusion_matrix", "test_handler_confusion_matrix_dist", - "test_handler_hausdorff_distance", + "test_handler_decollate_batch", + "test_handler_early_stop", "test_handler_garbage_collector", + "test_handler_hausdorff_distance", + "test_handler_lr_scheduler", "test_handler_mean_dice", + "test_handler_metrics_saver", + "test_handler_metrics_saver_dist", + "test_handler_parameter_scheduler", + "test_handler_post_processing", "test_handler_prob_map_producer", "test_handler_regression_metrics", "test_handler_regression_metrics_dist", "test_handler_rocauc", "test_handler_rocauc_dist", - "test_handler_parameter_scheduler", "test_handler_segmentation_saver", "test_handler_smartcache", "test_handler_stats", "test_handler_surface_distance", "test_handler_tb_image", "test_handler_tb_stats", + "test_handler_transform_inverter", "test_handler_validation", "test_hausdorff_distance", "test_header_correct", "test_hilbert_transform", + "test_image_dataset", "test_img2tensorboard", "test_integration_segmentation_3d", "test_integration_sliding_window", "test_integration_unet_2d", "test_integration_workflows", "test_integration_workflows_gan", + "test_invertd", + "test_iterable_dataset", "test_keep_largest_connected_component", "test_keep_largest_connected_componentd", + "test_label_filter", "test_lltm", "test_lmdbdataset", "test_load_image", "test_load_imaged", "test_load_spacing_orientation", "test_mednistdataset", - "test_image_dataset", + "test_mlp", "test_nifti_header_revise", "test_nifti_rw", "test_nifti_saver", + "test_occlusion_sensitivity", "test_orientation", "test_orientationd", "test_parallel_execution", + "test_patchembedding", "test_persistentdataset", - "test_cachentransdataset", "test_pil_reader", "test_plot_2d_or_3d_image", "test_png_rw", @@ -92,50 +114,30 @@ def run_testsuit(): "test_rand_rotated", "test_rand_zoom", "test_rand_zoomd", + "test_randtorchvisiond", "test_resize", "test_resized", "test_rotate", "test_rotated", + "test_save_image", + "test_save_imaged", + "test_selfattention", + "test_senet", "test_smartcachedataset", "test_spacing", "test_spacingd", - "test_senet", "test_surface_distance", - "test_zoom", - "test_zoom_affine", - "test_zoomd", - "test_occlusion_sensitivity", + "test_testtimeaugmentation", "test_torchvision", "test_torchvisiond", - "test_randtorchvisiond", - "test_handler_metrics_saver", - "test_handler_metrics_saver_dist", - "test_handler_classification_saver_dist", - "test_dataset_summary", - "test_deepgrow_transforms", - "test_deepgrow_interaction", - "test_deepgrow_dataset", - "test_save_image", - "test_save_imaged", - "test_ensure_channel_first", - "test_ensure_channel_firstd", - "test_handler_early_stop", - "test_handler_transform_inverter", - "test_testtimeaugmentation", - "test_cachedataset_persistent_workers", - "test_invertd", - "test_handler_post_processing", - "test_write_metrics_reports", - "test_csv_dataset", - "test_csv_iterable_dataset", - "test_mlp", - "test_patchembedding", - "test_selfattention", "test_transformerblock", "test_unetr", "test_unetr_block", "test_vit", - "test_handler_decollate_batch", + "test_write_metrics_reports", + "test_zoom", + "test_zoom_affine", + "test_zoomd", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py new file mode 100644 index 0000000000..294bbd8c87 --- /dev/null +++ b/tests/test_fill_holes.py @@ -0,0 +1,297 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import FillHoles +from tests.utils import allclose, clone + +grid_1_raw = [ + [1, 1, 1], + [1, 0, 1], + [1, 1, 1], +] + +grid_2_raw = [ + [0, 1, 0], + [1, 0, 1], + [0, 1, 0], +] + +grid_3_raw = [ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1], +] + +grid_4_raw = [ + [0, 1, 0], + [1, 1, 1], + [0, 1, 0], +] + +grid_1 = torch.tensor([grid_1_raw]) + +grid_2 = torch.tensor([grid_2_raw]) + +grid_3 = torch.tensor([grid_3_raw]) + +grid_4 = torch.tensor([grid_4_raw]) + +grid_5 = torch.tensor( + [ + [ + [1, 1, 1], + [1, 0, 0], + [1, 1, 1], + ] + ] +) + +grid_6 = torch.tensor( + [ + [ + [1, 1, 2, 2, 2], + [1, 0, 2, 0, 2], + [1, 1, 2, 2, 2], + ] + ] +) + +grid_7 = torch.tensor( + [ + [ + [1, 1, 2, 2, 2], + [1, 0, 2, 2, 2], + [1, 1, 2, 2, 2], + ] + ] +) + +TEST_CASE_0 = [ + "enclosed_default_full_connectivity_default_applied_labels", + {}, + grid_1, + grid_3, +] + +TEST_CASE_1 = [ + "enclosed_full_connectivity_default_applied_labels", + {"connectivity": 2}, + grid_1, + grid_3, +] + +TEST_CASE_2 = [ + "enclosed_full_connectivity_applied_labels_same_single", + {"connectivity": 2, "applied_labels": 1}, + grid_1, + grid_3, +] + +TEST_CASE_3 = [ + "enclosed_full_connectivity_applied_labels_same_list", + {"connectivity": 2, "applied_labels": [1]}, + grid_1, + grid_3, +] + +TEST_CASE_4 = [ + "enclosed_full_connectivity_applied_labels_other_single", + {"connectivity": 2, "applied_labels": 2}, + grid_1, + grid_1, +] + +TEST_CASE_5 = [ + "enclosed_full_connectivity_applied_labels_other_list", + {"connectivity": 2, "applied_labels": [2]}, + grid_1, + grid_1, +] + +TEST_CASE_6 = [ + "enclosed_full_connectivity_applied_labels_same_and_other", + {"connectivity": 2, "applied_labels": [1, 2]}, + grid_1, + grid_3, +] + +TEST_CASE_7 = [ + "enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_1, + grid_3, +] + +TEST_CASE_8 = [ + "enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_2, + grid_4, +] + +TEST_CASE_9 = [ + "open_full_connectivity_default_applied_labels", + {"connectivity": 2}, + grid_2, + grid_2, +] + +TEST_CASE_10 = [ + "open_to_edge_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_5, + grid_5, +] + +TEST_CASE_11 = [ + "open_to_other_label_connectivity_1_default_applied_labels", + {"connectivity": 1}, + grid_6, + grid_7, +] + +TEST_CASE_12 = [ + "open_to_other_label_connectivity_1_applied_labels_other", + {"connectivity": 1, "applied_labels": 1}, + grid_6, + grid_6, +] + +TEST_CASE_13 = [ + "numpy_enclosed_default_full_connectivity_default_applied_labels", + {}, + grid_1.cpu().numpy(), + grid_3.cpu().numpy(), +] + +TEST_CASE_14 = [ + "3D_enclosed_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]), +] + +TEST_CASE_15 = [ + "3D_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]), +] + +TEST_CASE_16 = [ + "3D_open_full_connectivity_default_applied_labels", + {"connectivity": 3}, + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), +] + +TEST_CASE_17 = [ + "3D_open_to_edge_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), +] + +TEST_CASE_18 = [ + "enclosed_full_connectivity_applied_labels_with_background", + {"connectivity": 2, "applied_labels": [0, 1]}, + grid_1, + grid_3, +] + +TEST_CASE_19 = [ + "enclosed_full_connectivity_applied_labels_only_background", + {"connectivity": 2, "applied_labels": [0]}, + grid_1, + grid_1, +] + +TEST_CASE_20 = [ + "one-hot_enclosed_connectivity_1_default_applied_labels", + {"connectivity": 1}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]), +] + +TEST_CASE_21 = [ + "one-hot_enclosed_connectivity_1_applied_labels_2", + {"connectivity": 1, "applied_labels": [2]}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]), +] + +TEST_CASE_22 = [ + "one-hot_full_connectivity_applied_labels_2", + {"connectivity": 2}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]), +] + +VALID_CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, + TEST_CASE_4, + TEST_CASE_5, + TEST_CASE_6, + TEST_CASE_7, + TEST_CASE_8, + TEST_CASE_9, + TEST_CASE_10, + TEST_CASE_11, + TEST_CASE_12, + TEST_CASE_13, + TEST_CASE_14, + TEST_CASE_15, + TEST_CASE_16, + TEST_CASE_17, + TEST_CASE_18, + TEST_CASE_19, + TEST_CASE_20, + TEST_CASE_21, + TEST_CASE_22, +] + +ITEST_CASE_1 = ["invalid_image_data_type", {}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestFillHoles(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_image, expected): + converter = FillHoles(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + result = converter(clone(input_image).cuda()) + assert allclose(result, expected.cuda()) + else: + result = converter(clone(input_image)) + assert allclose(result, expected) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + with self.assertRaises(expected_error): + converter = FillHoles(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter(clone(input_image).cuda()) + else: + _ = converter(clone(input_image)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index a8835329ba..670dd2d2ee 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -15,6 +15,7 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent +from tests.utils import allclose, clone grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) @@ -322,23 +323,23 @@ class TestKeepLargestConnectedComponent(unittest.TestCase): @parameterized.expand(VALID_CASES) - def test_correct_results(self, _, args, tensor, expected): + def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - result = converter(tensor.clone().cuda()) - assert torch.allclose(result, expected.cuda()) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + result = converter(clone(input_image).cuda()) + assert allclose(result, expected.cuda()) else: - result = converter(tensor.clone()) - assert torch.allclose(result, expected) + result = converter(clone(input_image)) + assert allclose(result, expected) @parameterized.expand(INVALID_CASES) - def test_raise_exception(self, _, args, tensor, expected_error): + def test_raise_exception(self, _, args, input_image, expected_error): with self.assertRaises(expected_error): converter = KeepLargestConnectedComponent(**args) - if torch.cuda.is_available(): - _ = converter(tensor.clone().cuda()) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter(clone(input_image).cuda()) else: - _ = converter(tensor.clone()) + _ = converter(clone(input_image).clone()) if __name__ == "__main__": diff --git a/tests/test_label_filter.py b/tests/test_label_filter.py new file mode 100644 index 0000000000..9165fddc40 --- /dev/null +++ b/tests/test_label_filter.py @@ -0,0 +1,127 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import torch +from parameterized import parameterized + +from monai.transforms import LabelFilter +from tests.utils import allclose, clone + +grid_1 = torch.tensor( + [ + [ + [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + ] + ] +) + + +TEST_CASE_0 = [ + "filter_single_label", + {"applied_labels": 3}, + grid_1, + torch.tensor( + [ + [ + [ + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + ] + ] + ] + ), +] + + +TEST_CASE_1 = [ + "filter_single_label_list", + {"applied_labels": [3]}, + grid_1, + torch.tensor( + [ + [ + [ + [0, 0, 3], + [0, 0, 0], + [0, 0, 0], + ] + ] + ] + ), +] + +TEST_CASE_2 = [ + "filter_multi_label", + {"applied_labels": [3, 5, 8]}, + grid_1, + torch.tensor( + [ + [ + [ + [0, 0, 3], + [0, 5, 0], + [0, 8, 0], + ] + ] + ] + ), +] + +TEST_CASE_3 = [ + "filter_all", + {"applied_labels": [1, 2, 3, 4, 5, 6, 7, 8, 9]}, + grid_1, + grid_1, +] + +VALID_CASES = [ + TEST_CASE_0, + TEST_CASE_1, + TEST_CASE_2, + TEST_CASE_3, +] + +ITEST_CASE_1 = ["invalid_image_data_type", {"applied_labels": 1}, [[[[1, 1, 1]]]], NotImplementedError] + +INVALID_CASES = [ITEST_CASE_1] + + +class TestLabelFilter(unittest.TestCase): + @parameterized.expand(VALID_CASES) + def test_correct_results(self, _, args, input_image, expected): + converter = LabelFilter(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + result = converter(clone(input_image).cuda()) + assert allclose(result, expected.cuda()) + else: + result = converter(clone(input_image)) + assert allclose(result, expected) + + @parameterized.expand(INVALID_CASES) + def test_raise_exception(self, _, args, input_image, expected_error): + with self.assertRaises(expected_error): + converter = LabelFilter(**args) + if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): + _ = converter(clone(input_image).cuda()) + else: + _ = converter(clone(input_image)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index ce280a13f0..c3f604f12e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import datetime import functools import importlib @@ -29,6 +30,7 @@ import torch import torch.distributed as dist +from monai.config import NdarrayTensor from monai.config.deviceconfig import USE_COMPILED from monai.data import create_test_image_2d, create_test_image_3d from monai.utils import ensure_tuple, optional_import, set_determinism @@ -39,6 +41,42 @@ quick_test_var = "QUICKTEST" +def clone(data: NdarrayTensor) -> NdarrayTensor: + """ + Clone data independent of type. + + Args: + data (NdarrayTensor): This can be a Pytorch Tensor or numpy array. + + Returns: + Any: Cloned data object + """ + return copy.deepcopy(data) + + +def allclose(a: NdarrayTensor, b: NdarrayTensor) -> bool: + """ + Check if all values of two data objects are close. + + Note: + This method also checks that both data objects are either Pytorch Tensors or numpy arrays. + + Args: + a (NdarrayTensor): Pytorch Tensor or numpy array for comparison + b (NdarrayTensor): Pytorch Tensor or numpy array to compare against + + Returns: + bool: If both data objects are close. + """ + if isinstance(a, np.ndarray) and isinstance(b, np.ndarray): + return np.allclose(a, b) + + if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor): + return torch.allclose(a, b) + + return False + + def test_pretrained_networks(network, input_param, device): try: net = network(**input_param).to(device)