Skip to content

Commit

Permalink
Remove batch dim and add one-hot handling (Project-MONAI#2678)
Browse files Browse the repository at this point in the history
Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>
  • Loading branch information
Sebastian Penhouet committed Aug 6, 2021
1 parent a4a6bde commit 8454770
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 50 deletions.
17 changes: 7 additions & 10 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class LabelFilter:
[7, 8, 9] [0, 0, 9]
"""

def __init__(self, applied_labels: Union[Sequence[int], int]) -> None:
def __init__(self, applied_labels: Union[Iterable[int], int]) -> None:
"""
Initialize the LabelFilter class with the labels to filter on.
Expand Down Expand Up @@ -358,8 +358,9 @@ class FillHoles(Transform):
[ ] [ ] [ ] [ ]
It is possible to define for which labels the hole filling should be applied.
The input image is assumed to be a PyTorch Tensor or numpy array
with shape [batch_size, 1, spatial_dim1[, spatial_dim2, ...]] and the values correspond to expected labels.
The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]].
If C = 1, then the values correspond to expected labels.
If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing.
Note:
Expand Down Expand Up @@ -404,20 +405,16 @@ def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
The value 0 is assumed as background label.
Args:
img: Pytorch Tensor or numpy array of shape [batch_size, num_channel, spatial_dim1[, spatial_dim2, ...]].
img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
Raises:
NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.
Returns:
Pytorch Tensor or numpy array of shape [batch_size, num_channel, spatial_dim1[, spatial_dim2, ...]].
Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
"""
if isinstance(img, np.ndarray):
channel_axis = 1
img_arr = np.squeeze(img, axis=channel_axis)
output = fill_holes(img_arr, self.applied_labels, self.connectivity)
output = np.expand_dims(output, axis=channel_axis)
return output
return fill_holes(img, self.applied_labels, self.connectivity)
elif isinstance(img, torch.Tensor):
img_arr = img.detach().cpu().numpy()
img_arr = self(img_arr)
Expand Down
27 changes: 17 additions & 10 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,43 +745,50 @@ def fill_holes(
Holes on the edge are always considered to be open (not enclosed).
Note:
The performance of this method heavily depends on the number of labels.
It is a bit faster if the list of `applied_labels` is provided.
Limiting the number of `applied_labels` results in a big decrease in processing time.
If the image is one-hot-encoded, then the `applied_labels` need to match the channel index.
Args:
img_arr: numpy array of shape [batch_size, spatial_dim1[, spatial_dim2, ...]].
img_arr: numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
applied_labels: Labels for which to fill holes. Defaults to None,
that is filling holes for all labels.
connectivity: connectivity (int, optional): Maximum number of orthogonal hops to
connectivity: Maximum number of orthogonal hops to
consider a pixel/voxel as a neighbor. Accepted values are ranging from 1 to input.ndim.
Defaults to a full connectivity of ``input.ndim``.
Returns:
numpy array of shape [batch_size, spatial_dim1[, spatial_dim2, ...]].
numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
"""
# Ignore batch dimension in structure (window for dilation steps)
channel_axis = 0
num_channels = img_arr.shape[channel_axis]
is_one_hot = num_channels > 1
spatial_dims = img_arr.ndim - 1
structure = np.zeros((3, *[3] * spatial_dims))
structure[1, ...] = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims)
structure = ndimage.generate_binary_structure(spatial_dims, connectivity or spatial_dims)

# Get labels if not provided. Exclude background label.
applied_labels = set(applied_labels or np.unique(img_arr))
applied_labels = set(applied_labels or (range(num_channels) if is_one_hot else np.unique(img_arr)))
background_label = 0
applied_labels.discard(background_label)

for label in applied_labels:
tmp = np.zeros(img_arr.shape, dtype=bool)
tmp = np.zeros(img_arr.shape[1:], dtype=bool)
ndimage.binary_dilation(
tmp,
structure=structure,
iterations=-1,
mask=img_arr != label,
mask=np.logical_not(img_arr[label]) if is_one_hot else img_arr[0] != label,
origin=0,
border_value=1,
output=tmp,
)
img_arr[np.logical_not(tmp)] = label
if is_one_hot:
img_arr[label] = np.logical_not(tmp)
else:
img_arr[0, np.logical_not(tmp)] = label

return img_arr

Expand Down
70 changes: 40 additions & 30 deletions tests/test_fill_holes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,46 +42,40 @@
[0, 1, 0],
]

grid_1 = torch.tensor([[grid_1_raw]])
grid_1 = torch.tensor([grid_1_raw])

grid_2 = torch.tensor([[grid_2_raw]])
grid_2 = torch.tensor([grid_2_raw])

grid_3 = torch.tensor([[grid_3_raw]])
grid_3 = torch.tensor([grid_3_raw])

grid_4 = torch.tensor([[grid_4_raw]])
grid_4 = torch.tensor([grid_4_raw])

grid_5 = torch.tensor(
[
[
[
[1, 1, 1],
[1, 0, 0],
[1, 1, 1],
]
[1, 1, 1],
[1, 0, 0],
[1, 1, 1],
]
]
)

grid_6 = torch.tensor(
[
[
[
[1, 1, 2, 2, 2],
[1, 0, 2, 0, 2],
[1, 1, 2, 2, 2],
]
[1, 1, 2, 2, 2],
[1, 0, 2, 0, 2],
[1, 1, 2, 2, 2],
]
]
)

grid_7 = torch.tensor(
[
[
[
[1, 1, 2, 2, 2],
[1, 0, 2, 2, 2],
[1, 1, 2, 2, 2],
]
[1, 1, 2, 2, 2],
[1, 0, 2, 2, 2],
[1, 1, 2, 2, 2],
]
]
)
Expand Down Expand Up @@ -187,29 +181,29 @@
TEST_CASE_14 = [
"3D_enclosed_full_connectivity_default_applied_labels",
{"connectivity": 3},
torch.tensor([[[grid_3_raw, grid_1_raw, grid_3_raw]]]),
torch.tensor([[[grid_3_raw, grid_3_raw, grid_3_raw]]]),
torch.tensor([[grid_3_raw, grid_1_raw, grid_3_raw]]),
torch.tensor([[grid_3_raw, grid_3_raw, grid_3_raw]]),
]

TEST_CASE_15 = [
"3D_enclosed_connectivity_1_default_applied_labels",
{"connectivity": 1},
torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]),
torch.tensor([[[grid_4_raw, grid_4_raw, grid_4_raw]]]),
torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),
torch.tensor([[grid_4_raw, grid_4_raw, grid_4_raw]]),
]

TEST_CASE_16 = [
"3D_open_full_connectivity_default_applied_labels",
{"connectivity": 3},
torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]),
torch.tensor([[[grid_4_raw, grid_2_raw, grid_4_raw]]]),
torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),
torch.tensor([[grid_4_raw, grid_2_raw, grid_4_raw]]),
]

TEST_CASE_17 = [
"3D_open_to_edge_connectivity_1_default_applied_labels",
{"connectivity": 1},
torch.tensor([[[grid_1_raw, grid_1_raw, grid_3_raw]]]),
torch.tensor([[[grid_1_raw, grid_1_raw, grid_3_raw]]]),
torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]),
torch.tensor([[grid_1_raw, grid_1_raw, grid_3_raw]]),
]

TEST_CASE_18 = [
Expand All @@ -227,10 +221,24 @@
]

TEST_CASE_20 = [
"batch_enclosed_connectivity_1_default_applied_labels",
"one-hot_enclosed_connectivity_1_default_applied_labels",
{"connectivity": 1},
torch.tensor([[grid_1_raw], [grid_2_raw]]),
torch.tensor([[grid_3_raw], [grid_4_raw]]),
torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),
torch.tensor([grid_1_raw, grid_3_raw, grid_4_raw]),
]

TEST_CASE_21 = [
"one-hot_enclosed_connectivity_1_applied_labels_2",
{"connectivity": 1, "applied_labels": [2]},
torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),
torch.tensor([grid_1_raw, grid_1_raw, grid_4_raw]),
]

TEST_CASE_22 = [
"one-hot_full_connectivity_applied_labels_2",
{"connectivity": 2},
torch.tensor([grid_1_raw, grid_1_raw, grid_2_raw]),
torch.tensor([grid_1_raw, grid_3_raw, grid_2_raw]),
]

VALID_CASES = [
Expand All @@ -255,6 +263,8 @@
TEST_CASE_18,
TEST_CASE_19,
TEST_CASE_20,
TEST_CASE_21,
TEST_CASE_22,
]

ITEST_CASE_1 = ["invalid_image_data_type", {}, [[[[1, 1, 1]]]], NotImplementedError]
Expand Down

0 comments on commit 8454770

Please sign in to comment.