From 51028f1b5b91c312d3a8a2c4f8430fb6e4d553d6 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 2 Nov 2021 14:01:02 +0000 Subject: [PATCH 1/2] adds backend for largest cc Signed-off-by: Wenqi Li --- monai/transforms/post/array.py | 73 +-- monai/transforms/post/dictionary.py | 12 +- monai/transforms/utils.py | 10 +- .../test_keep_largest_connected_component.py | 568 ++++++++--------- .../test_keep_largest_connected_componentd.py | 571 +++++++++--------- 5 files changed, 618 insertions(+), 616 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 7d97bedd3b..5251ce880b 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -263,17 +263,21 @@ class KeepLargestConnectedComponent(Transform): """ + backend = [TransformBackends.NUMPY] + def __init__( self, applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None ) -> None: """ Args: - applied_labels: Labels for applying the connected component on. - If only one channel. The pixel whose value is not in this list will remain unchanged. - If the data is in one-hot format, this is used to determine what channels to apply. - independent: consider several labels as a whole or independent, default is `True`. - Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case - you want this "independent" to be specified as False. + applied_labels: Labels for applying the connected component analysis on. + If only one channel. The pixel whose value is in this list will be analyzed. + If the data is in one-hot format, this is used to determine which channels to apply. + independent: whether to treat ``applied_labels`` as a union of foreground labels. + If ``True``, the connected component analysis will be performed on each foreground label independently + and return the intersection of the largest component. + If ``False``, the analysis will be performed on the union of foreground labels. + default is `True`. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. @@ -283,48 +287,37 @@ def __init__( self.independent = independent self.connectivity = connectivity - def __call__(self, img: torch.Tensor) -> torch.Tensor: + def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Args: img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Returns: - A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]). + An array with shape (C, spatial_dim1[, spatial_dim2, ...]). """ - if img.shape[0] == 1: - img = torch.squeeze(img, dim=0) - - if self.independent: - for i in self.applied_labels: - foreground = (img == i).type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 - else: - foreground = torch.zeros_like(img) - for i in self.applied_labels: - foreground += (img == i).type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[foreground != mask] = 0 - output = torch.unsqueeze(img, dim=0) - else: - # one-hot data is assumed to have binary value in each channel - if self.independent: - for i in self.applied_labels: - foreground = img[i, ...].type(torch.uint8) - mask = get_largest_connected_component_mask(foreground, self.connectivity) - img[i, ...][foreground != mask] = 0 - else: - applied_img = img[self.applied_labels, ...].type(torch.uint8) - foreground = torch.any(applied_img, dim=0) + is_onehot = img.shape[0] > 1 + if self.independent: + for i in self.applied_labels: + foreground = img[i] > 0 if is_onehot else img[0] == i mask = get_largest_connected_component_mask(foreground, self.connectivity) - background_mask = torch.unsqueeze(foreground != mask, dim=0) - background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0) - applied_img[background_mask] = 0 - img[self.applied_labels, ...] = applied_img.type(img.type()) - output = img - - return output + if is_onehot: + img[i][foreground != mask] = 0 + else: + img[0][foreground != mask] = 0 + return img + if not is_onehot: # not one-hot, union of labels + labels, *_ = convert_to_dst_type(self.applied_labels, dst=img, wrap_sequence=True) + foreground = (img[..., None] == labels).any(-1)[0] + mask = get_largest_connected_component_mask(foreground, self.connectivity) + img[0][foreground != mask] = 0 + return img + # one-hot, union of labels + foreground = img[self.applied_labels, ...].any(0) + mask = get_largest_connected_component_mask(foreground, self.connectivity) + for i in self.applied_labels: + img[i][foreground != mask] = 0 + return img class LabelFilter: diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 486483577f..19a7bc9359 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -192,6 +192,8 @@ class KeepLargestConnectedComponentd(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.KeepLargestConnectedComponent`. """ + backend = KeepLargestConnectedComponent.backend + def __init__( self, keys: KeysCollection, @@ -207,9 +209,11 @@ def __init__( applied_labels: Labels for applying the connected component on. If only one channel. The pixel whose value is not in this list will remain unchanged. If the data is in one-hot format, this is the channel indices to apply transform. - independent: consider several labels as a whole or independent, default is `True`. - Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case - you want this "independent" to be specified as False. + independent: whether to treat ``applied_labels`` as a union of foreground labels. + If ``True``, the connected component analysis will be performed on each foreground label independently + and return the intersection of the largest component. + If ``False``, the analysis will be performed on the union of foreground labels. + default is `True`. connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. @@ -219,7 +223,7 @@ def __init__( super().__init__(keys, allow_missing_keys) self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity) - def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): d[key] = self.converter(d[key]) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1d3204b7a0..98341c9219 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -917,7 +917,7 @@ def generate_spatial_bounding_box( return box_start, box_end -def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor: +def get_largest_connected_component_mask(img: NdarrayOrTensor, connectivity: Optional[int] = None) -> NdarrayOrTensor: """ Gets the largest connected component mask of an image. @@ -927,13 +927,13 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option Accepted values are ranging from 1 to input.ndim. If ``None``, a full connectivity of ``input.ndim`` is used. """ - img_arr = img.detach().cpu().numpy() - largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) + img_arr: np.ndarray = convert_data_type(img, np.ndarray)[0] # type: ignore + largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype) img_arr = measure.label(img_arr, connectivity=connectivity) if img_arr.max() != 0: largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1) - - return torch.as_tensor(largest_cc, device=img.device) + largest_cc = convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0] # type: ignore + return largest_cc def fill_holes( diff --git a/tests/test_keep_largest_connected_component.py b/tests/test_keep_largest_connected_component.py index 50caa0bb31..5a7fc80e2e 100644 --- a/tests/test_keep_largest_connected_component.py +++ b/tests/test_keep_largest_connected_component.py @@ -15,331 +15,333 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponent -from tests.utils import assert_allclose, clone +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]) -grid_2 = torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]) -grid_3 = torch.tensor( +grid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]] +grid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]] +grid_3 = [ [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], - ] -) -grid_4 = torch.tensor( + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - ] -) - - -TEST_CASE_1 = [ - "value_1", - {"independent": False, "applied_labels": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], ] - -TEST_CASE_2 = [ - "value_2", - {"independent": False, "applied_labels": [2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), +grid_4 = [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], ] -TEST_CASE_3 = [ - "independent_value_1_2", - {"independent": True, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), -] -TEST_CASE_4 = [ - "dependent_value_1_2", - {"independent": False, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] +TESTS = [] +for p in TEST_NDARRAYS: + TESTS.append( + [ + "value_1", + {"independent": False, "applied_labels": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_5 = [ - "value_1", - {"independent": True, "applied_labels": [1]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + "value_2", + {"independent": False, "applied_labels": [2]}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_6 = [ - "independent_value_1_2", - {"independent": True, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + TESTS.append( + [ + "independent_value_1_2", + {"independent": True, "applied_labels": [1, 2]}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - "dependent_value_1_2", - {"independent": False, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), -] + TESTS.append( + [ + "dependent_value_1_2", + {"independent": False, "applied_labels": [1, 2]}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_8 = [ - "value_1_connect_1", - {"independent": False, "applied_labels": [1], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), -] + TESTS.append( + [ + "value_1", + {"independent": True, "applied_labels": [1]}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_9 = [ - "independent_value_1_2_connect_1", - {"independent": True, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + TESTS.append( + [ + "independent_value_1_2", + {"independent": True, "applied_labels": [1, 2]}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_10 = [ - "dependent_value_1_2_connect_1", - {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + TESTS.append( + [ + "dependent_value_1_2", + {"independent": False, "applied_labels": [1, 2]}, + p(grid_2), + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), + ] + ) -TEST_CASE_11 = [ - "onehot_independent_batch_2_apply_label_1_connect_1", - {"independent": True, "applied_labels": [1], "connectivity": 1}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + "value_1_connect_1", + {"independent": False, "applied_labels": [1], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), ] - ), -] + ) -TEST_CASE_12 = [ - "onehot_independent_batch_2_apply_label_1_connect_2", - {"independent": True, "applied_labels": [1], "connectivity": 2}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + "independent_value_1_2_connect_1", + {"independent": True, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] - ), -] + ) -TEST_CASE_13 = [ - "onehot_independent_batch_2_apply_label_1_2_connect_2", - {"independent": True, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], + "dependent_value_1_2_connect_1", + {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_1), + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] - ), -] + ) -TEST_CASE_14 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_2", - {"independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_4, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + "onehot_independent_batch_2_apply_label_1_connect_1", + {"independent": True, "applied_labels": [1], "connectivity": 1}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), ] - ), -] + ) -TEST_CASE_15 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_1", - {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_4, - torch.tensor( + TESTS.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + "onehot_independent_batch_2_apply_label_1_connect_2", + {"independent": True, "applied_labels": [1], "connectivity": 2}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), ] - ), -] + ) -VALID_CASES = [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, - TEST_CASE_12, - TEST_CASE_13, - TEST_CASE_14, - TEST_CASE_15, -] + TESTS.append( + [ + "onehot_independent_batch_2_apply_label_1_2_connect_2", + {"independent": True, "applied_labels": [1, 2], "connectivity": 2}, + p(grid_3), + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_1 = ["no_applied_labels_for_single_channel", {"independent": False}, grid_1, TypeError] + TESTS.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_2", + {"independent": False, "applied_labels": [1, 2], "connectivity": 2}, + p(grid_4), + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_2 = ["no_applied_labels_for_multi_channel", {"independent": False}, grid_3, TypeError] + TESTS.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_1", + {"independent": False, "applied_labels": [1, 2], "connectivity": 1}, + p(grid_4), + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -INVALID_CASES = [ITEST_CASE_1, ITEST_CASE_2] +INVALID_CASES = [] +for p in TEST_NDARRAYS: + INVALID_CASES.append(["no_applied_labels_for_single_channel", {"independent": False}, p(grid_1), TypeError]) + INVALID_CASES.append(["no_applied_labels_for_multi_channel", {"independent": False}, p(grid_3), TypeError]) class TestKeepLargestConnectedComponent(unittest.TestCase): - @parameterized.expand(VALID_CASES) + @parameterized.expand(TESTS) def test_correct_results(self, _, args, input_image, expected): converter = KeepLargestConnectedComponent(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - result = converter(clone(input_image).cuda()) - - else: - result = converter(clone(input_image)) - assert_allclose(result, expected) + result = converter(input_image) + assert_allclose(result, expected, type_test=False) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_image, expected_error): with self.assertRaises(expected_error): converter = KeepLargestConnectedComponent(**args) - if isinstance(input_image, torch.Tensor) and torch.cuda.is_available(): - _ = converter(clone(input_image).cuda()) - else: - _ = converter(clone(input_image).clone()) + _ = converter(input_image) if __name__ == "__main__": diff --git a/tests/test_keep_largest_connected_componentd.py b/tests/test_keep_largest_connected_componentd.py index 96a8154b65..097787bd3f 100644 --- a/tests/test_keep_largest_connected_componentd.py +++ b/tests/test_keep_largest_connected_componentd.py @@ -15,332 +15,335 @@ from parameterized import parameterized from monai.transforms import KeepLargestConnectedComponentd +from tests.utils import TEST_NDARRAYS, assert_allclose -grid_1 = {"img": torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]])} -grid_2 = {"img": torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]])} -grid_3 = { - "img": torch.tensor( +grid_1 = [[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]] +grid_2 = [[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [1, 0, 1, 1, 2], [1, 0, 1, 2, 2], [0, 0, 0, 0, 1]]] +grid_3 = [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], +] +grid_4 = [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [1.0, 0.0, 1.0, 1.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], +] + +VALID_CASES = [] +for p in TEST_NDARRAYS: + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + "value_1", + {"keys": ["img"], "independent": False, "applied_labels": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), ] ) -} -grid_4 = { - "img": torch.tensor( + + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [1.0, 0.0, 1.0, 1.0, 0.0], - [1.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + "value_2", + {"keys": ["img"], "independent": False, "applied_labels": [2]}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), ] ) -} - -TEST_CASE_1 = [ - "value_1", - {"keys": ["img"], "independent": False, "applied_labels": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] - -TEST_CASE_2 = [ - "value_2", - {"keys": ["img"], "independent": False, "applied_labels": [2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), -] - -TEST_CASE_3 = [ - "independent_value_1_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), -] - -TEST_CASE_4 = [ - "dependent_value_1_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), -] - -TEST_CASE_5 = [ - "value_1", - {"keys": ["img"], "independent": True, "applied_labels": [1]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] -TEST_CASE_6 = [ - "independent_value_1_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 1, 0], [2, 2, 0, 0, 0]]]), + ] + ) -TEST_CASE_7 = [ - "dependent_value_1_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, - grid_2, - torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), -] + VALID_CASES.append( + [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 1, 0], [2, 2, 0, 0, 2]]]), + ] + ) -TEST_CASE_8 = [ - "value_1_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), -] + VALID_CASES.append( + [ + "value_1", + {"keys": ["img"], "independent": True, "applied_labels": [1]}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_9 = [ - "independent_value_1_2_connect_1", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "independent_value_1_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2]}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 0]]]), + ] + ) -TEST_CASE_10 = [ - "dependent_value_1_2_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_1, - torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), -] + VALID_CASES.append( + [ + "dependent_value_1_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2]}, + {"img": p(grid_2)}, + torch.tensor([[[0, 0, 0, 0, 1], [0, 0, 1, 1, 1], [0, 0, 1, 1, 2], [0, 0, 1, 2, 2], [0, 0, 0, 0, 1]]]), + ] + ) -TEST_CASE_11 = [ - "onehot_independent_batch_2_apply_label_1_connect_1", - {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 1}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + "value_1_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 2]]]), ] - ), -] + ) -TEST_CASE_12 = [ - "onehot_independent_batch_2_apply_label_1_connect_2", - {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 2}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 1.0], - ], + "independent_value_1_2_connect_1", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [0, 2, 1, 0, 0], [0, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] - ), -] + ) -TEST_CASE_13 = [ - "onehot_independent_batch_2_apply_label_1_2_connect_2", - {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 2}, - grid_3, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 0.0, 1.0, 1.0], - [1.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - ], + "dependent_value_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_1)}, + torch.tensor([[[0, 0, 1, 0, 0], [0, 2, 1, 1, 1], [1, 2, 1, 0, 0], [1, 2, 0, 0, 0], [2, 2, 0, 0, 0]]]), ] - ), -] + ) -TEST_CASE_14 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_2", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2}, - grid_4, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + "onehot_independent_batch_2_apply_label_1_connect_1", + {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 1}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), ] - ), -] + ) -TEST_CASE_15 = [ - "onehot_dependent_batch_2_apply_label_1_2_connect_1", - {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, - grid_4, - torch.tensor( + VALID_CASES.append( [ - [ - [1.0, 1.0, 1.0, 1.0, 0.0], - [1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 0.0], - [1.0, 1.0, 1.0, 1.0, 0.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 1.0], - [0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - ], - [ - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 1.0], - [0.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0], - ], + "onehot_independent_batch_2_apply_label_1_connect_2", + {"keys": ["img"], "independent": True, "applied_labels": [1], "connectivity": 2}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 1.0], + ], + ] + ), ] - ), -] + ) -VALID_CASES = [ - TEST_CASE_1, - TEST_CASE_2, - TEST_CASE_3, - TEST_CASE_4, - TEST_CASE_5, - TEST_CASE_6, - TEST_CASE_7, - TEST_CASE_8, - TEST_CASE_9, - TEST_CASE_10, - TEST_CASE_11, - TEST_CASE_12, - TEST_CASE_13, - TEST_CASE_14, - TEST_CASE_15, -] + VALID_CASES.append( + [ + "onehot_independent_batch_2_apply_label_1_2_connect_2", + {"keys": ["img"], "independent": True, "applied_labels": [1, 2], "connectivity": 2}, + {"img": p(grid_3)}, + torch.tensor( + [ + [ + [1.0, 1.0, 0.0, 1.0, 1.0], + [1.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 1.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_1 = ["no_applied_labels_for_single_channel", {"keys": ["img"], "independent": False}, grid_1, TypeError] + VALID_CASES.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_2", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 2}, + {"img": p(grid_4)}, + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -ITEST_CASE_2 = ["no_applied_labels_for_multi_channel", {"keys": ["img"], "independent": False}, grid_3, TypeError] + VALID_CASES.append( + [ + "onehot_dependent_batch_2_apply_label_1_2_connect_1", + {"keys": ["img"], "independent": False, "applied_labels": [1, 2], "connectivity": 1}, + {"img": p(grid_4)}, + torch.tensor( + [ + [ + [1.0, 1.0, 1.0, 1.0, 0.0], + [1.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0, 0.0], + [1.0, 1.0, 1.0, 1.0, 0.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0, 0.0], + [0.0, 0.0, 1.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + ], + [ + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0, 1.0], + [0.0, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + ] + ), + ] + ) -INVALID_CASES = [ITEST_CASE_1, ITEST_CASE_2] +INVALID_CASES = [] +for p in TEST_NDARRAYS: + INVALID_CASES.append( + ["no_applied_labels_for_single_channel", {"keys": ["img"], "independent": False}, {"img": p(grid_1)}, TypeError] + ) + INVALID_CASES.append( + ["no_applied_labels_for_multi_channel", {"keys": ["img"], "independent": False}, {"img": p(grid_3)}, TypeError] + ) class TestKeepLargestConnectedComponentd(unittest.TestCase): @parameterized.expand(VALID_CASES) def test_correct_results(self, _, args, input_dict, expected): converter = KeepLargestConnectedComponentd(**args) - if torch.cuda.is_available(): - input_dict["img"] = input_dict["img"].cuda() - result = converter(input_dict) - torch.allclose(result["img"], expected.cuda()) - else: - result = converter(input_dict) - torch.allclose(result["img"], expected) + result = converter(input_dict) + assert_allclose(result["img"], expected, type_test=False) @parameterized.expand(INVALID_CASES) def test_raise_exception(self, _, args, input_dict, expected_error): with self.assertRaises(expected_error): converter = KeepLargestConnectedComponentd(**args) - if torch.cuda.is_available(): - input_dict["img"] = input_dict["img"].cuda() _ = converter(input_dict) From 26261bf6d1fe8d1f6f58a4b8f5175d0cf339d2da Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 2 Nov 2021 15:05:49 +0000 Subject: [PATCH 2/2] compatibility Signed-off-by: Wenqi Li --- monai/transforms/post/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 5251ce880b..9498087476 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -313,7 +313,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img[0][foreground != mask] = 0 return img # one-hot, union of labels - foreground = img[self.applied_labels, ...].any(0) + foreground = (img[self.applied_labels, ...] == 1).any(0) mask = get_largest_connected_component_mask(foreground, self.connectivity) for i in self.applied_labels: img[i][foreground != mask] = 0