diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 550f5d67ad..a33fce785e 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -310,7 +310,7 @@ class LabelFilter: [7, 8, 9] [0, 0, 9] """ - def __init__(self, applied_labels: Union[Sequence[int], int]) -> None: + def __init__(self, applied_labels: Union[Iterable[int], int]) -> None: """ Initialize the LabelFilter class with the labels to filter on. @@ -358,8 +358,9 @@ class FillHoles(Transform): [ ] [ ] [ ] [ ] It is possible to define for which labels the hole filling should be applied. - The input image is assumed to be a PyTorch Tensor or numpy array - with shape [batch_size, 1, spatial_dim1[, spatial_dim2, ...]] and the values correspond to expected labels. + The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]]. + If C = 1, then the values correspond to expected labels. + If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing. Note: @@ -404,20 +405,16 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor: The value 0 is assumed as background label. Args: - img: Pytorch Tensor or numpy array of shape [batch_size, num_channel, spatial_dim1[, spatial_dim2, ...]]. + img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. Raises: NotImplementedError: The provided image was not a Pytorch Tensor or numpy array. Returns: - Pytorch Tensor or numpy array of shape [batch_size, num_channel, spatial_dim1[, spatial_dim2, ...]]. + Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ if isinstance(img, np.ndarray): - channel_axis = 1 - img_arr = np.squeeze(img, axis=channel_axis) - output = fill_holes(img_arr, self.applied_labels, self.connectivity) - output = np.expand_dims(output, axis=channel_axis) - return output + return fill_holes(img, self.applied_labels, self.connectivity) elif isinstance(img, torch.Tensor): img_arr = img.detach().cpu().numpy() img_arr = self(img_arr) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index bc067b897d..366e2d245e 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -745,43 +745,50 @@ def fill_holes( Holes on the edge are always considered to be open (not enclosed). Note: + The performance of this method heavily depends on the number of labels. It is a bit faster if the list of `applied_labels` is provided. Limiting the number of `applied_labels` results in a big decrease in processing time. + If the image is one-hot-encoded, then the `applied_labels` need to match the channel index. + Args: - img_arr: numpy array of shape [batch_size, spatial_dim1[, spatial_dim2, ...]]. + img_arr: numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. applied_labels: Labels for which to fill holes. Defaults to None, that is filling holes for all labels. - connectivity: connectivity (int, optional): Maximum number of orthogonal hops to + connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim. Defaults to a full connectivity of ``input.ndim``. Returns: - numpy array of shape [batch_size, spatial_dim1[, spatial_dim2, ...]]. + numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]]. """ - # Ignore batch dimension in structure (window for dilation steps) + channel_axis = 0 + num_channels = img_arr.shape[channel_axis] + is_one_hot = num_channels > 1 spatial_dims = img_arr.ndim - 1 - structure = np.zeros((3, *[3] * spatial_dims)) - structure[1, ...] = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims) + structure = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims) # Get labels if not provided. Exclude background label. - applied_labels = set(applied_labels or np.unique(img_arr)) + applied_labels = set(applied_labels or (range(num_channels) if is_one_hot else np.unique(img_arr))) background_label = 0 applied_labels.discard(background_label) for label in applied_labels: - tmp = np.zeros(img_arr.shape, dtype=bool) + tmp = np.zeros(img_arr.shape[1:], dtype=bool) ndimage.binary_dilation( tmp, structure=structure, iterations=-1, - mask=img_arr != label, + mask=np.logical_not(img_arr[label]) if is_one_hot else img_arr[0] != label, origin=0, border_value=1, output=tmp, ) - img_arr[np.logical_not(tmp)] = label + if is_one_hot: + img_arr[label] = np.logical_not(tmp) + else: + img_arr[0, np.logical_not(tmp)] = label return img_arr diff --git a/tests/test_fill_holes.py b/tests/test_fill_holes.py index 2c239d1e29..0813d53b35 100644 --- a/tests/test_fill_holes.py +++ b/tests/test_fill_holes.py @@ -42,22 +42,20 @@ [0, 1, 0], ] -grid_1 = torch.tensor([[grid_1_raw]]) +grid_1 = torch.tensor([grid_1_raw]) -grid_2 = torch.tensor([[grid_2_raw]]) +grid_2 = torch.tensor([grid_2_raw]) -grid_3 = torch.tensor([[grid_3_raw]]) +grid_3 = torch.tensor([grid_3_raw]) -grid_4 = torch.tensor([[grid_4_raw]]) +grid_4 = torch.tensor([grid_4_raw]) grid_5 = torch.tensor( [ [ - [ - [1, 1, 1], - [1, 0, 0], - [1, 1, 1], - ] + [1, 1, 1], + [1, 0, 0], + [1, 1, 1], ] ] ) @@ -65,11 +63,9 @@ grid_6 = torch.tensor( [ [ - [ - [1, 1, 2, 2, 2], - [1, 0, 2, 0, 2], - [1, 1, 2, 2, 2], - ] + [1, 1, 2, 2, 2], + [1, 0, 2, 0, 2], + [1, 1, 2, 2, 2], ] ] ) @@ -77,11 +73,9 @@ grid_7 = torch.tensor( [ [ - [ - [1, 1, 2, 2, 2], - [1, 0, 2, 2, 2], - [1, 1, 2, 2, 2], - ] + [1, 1, 2, 2, 2], + [1, 0, 2, 2, 2], + [1, 1, 2, 2, 2], ] ] ) @@ -187,29 +181,29 @@ TEST_CASE_14 = [ "3D_enclosed_full_connectivity_default_applied_labels", {"connectivity": 3}, - torch.tensor([[[grid_3_raw, grid_1_raw, grid_3_raw]]]), - torch.tensor([[[grid_3_raw, grid_3_raw, grid_3_raw]]]), + torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]), ] TEST_CASE_15 = [ "3D_enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, - torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]), - torch.tensor([[[grid_4_raw, grid_4_raw, grid_4_raw]]]), + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]), ] TEST_CASE_16 = [ "3D_open_full_connectivity_default_applied_labels", {"connectivity": 3}, - torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]), - torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]), + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), + torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]), ] TEST_CASE_17 = [ "3D_open_to_edge_connectivity_1_default_applied_labels", {"connectivity": 1}, - torch.tensor([[[grid_1_raw, grid_1_raw, grid_3_raw]]]), - torch.tensor([[[grid_1_raw, grid_1_raw, grid_3_raw]]]), + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), + torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]), ] TEST_CASE_18 = [ @@ -227,10 +221,24 @@ ] TEST_CASE_20 = [ - "batch_enclosed_connectivity_1_default_applied_labels", + "one-hot_enclosed_connectivity_1_default_applied_labels", {"connectivity": 1}, - torch.tensor([[grid_1_raw], [grid_2_raw]]), - torch.tensor([[grid_3_raw], [grid_4_raw]]), + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]), +] + +TEST_CASE_21 = [ + "one-hot_enclosed_connectivity_1_applied_labels_2", + {"connectivity": 1, "applied_labels": [2]}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]), +] + +TEST_CASE_22 = [ + "one-hot_full_connectivity_applied_labels_2", + {"connectivity": 2}, + torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]), + torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]), ] VALID_CASES = [ @@ -255,6 +263,8 @@ TEST_CASE_18, TEST_CASE_19, TEST_CASE_20, + TEST_CASE_21, + TEST_CASE_22, ] ITEST_CASE_1 = ["invalid_image_data_type", {}, [[[[1, 1, 1]]]], NotImplementedError]