Skip to content

Commit

Permalink
[2678] Add transform to fill holes and to filter (#2692)
Browse files Browse the repository at this point in the history
* Add transform to fill holes and to filter (#2678)

Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>

* Change name of label filter class (#2678)

Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>

* Change fill holes to growing logic (#2678)

Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>

* Fix missing entry in min_tests (#2678)

Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>

* Fix review comments (#2678)

Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>

* Remove batch dim and add one-hot handling (#2678)

Signed-off-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

Co-authored-by: Sebastian Penhouet <sebastian.penhouet@airamed.de>
  • Loading branch information
Spenhouet and Sebastian Penhouet authored Aug 7, 2021
1 parent 62425d7 commit 945e21c
Show file tree
Hide file tree
Showing 10 changed files with 864 additions and 98 deletions.
24 changes: 24 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,18 @@ Post-processing
:members:
:special-members: __call__

`LabelFilter`
"""""""""""""
.. autoclass:: LabelFilter
:members:
:special-members: __call__

`FillHoles`
"""""""""""
.. autoclass:: FillHoles
:members:
:special-members: __call__

`LabelToContour`
""""""""""""""""
.. autoclass:: LabelToContour
Expand Down Expand Up @@ -955,6 +967,18 @@ Post-processing (Dict)
:members:
:special-members: __call__

`LabelFilterd`
""""""""""""""
.. autoclass:: LabelFilterd
:members:
:special-members: __call__

`FillHolesd`
""""""""""""
.. autoclass:: FillHolesd
:members:
:special-members: __call__

`LabelToContourd`
"""""""""""""""""
.. autoclass:: LabelToContourd
Expand Down
26 changes: 17 additions & 9 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,40 +194,48 @@
from .post.array import (
Activations,
AsDiscrete,
FillHoles,
KeepLargestConnectedComponent,
LabelFilter,
LabelToContour,
MeanEnsemble,
ProbNMS,
VoteEnsemble,
)
from .post.dictionary import (
Activationsd,
ActivationsD,
Activationsd,
ActivationsDict,
AsDiscreted,
AsDiscreteD,
AsDiscreted,
AsDiscreteDict,
Ensembled,
Invertd,
FillHolesD,
FillHolesd,
FillHolesDict,
InvertD,
Invertd,
InvertDict,
KeepLargestConnectedComponentd,
KeepLargestConnectedComponentD,
KeepLargestConnectedComponentd,
KeepLargestConnectedComponentDict,
LabelToContourd,
LabelFilterD,
LabelFilterd,
LabelFilterDict,
LabelToContourD,
LabelToContourd,
LabelToContourDict,
MeanEnsembled,
MeanEnsembleD,
MeanEnsembled,
MeanEnsembleDict,
ProbNMSd,
ProbNMSD,
ProbNMSd,
ProbNMSDict,
SaveClassificationd,
SaveClassificationD,
SaveClassificationd,
SaveClassificationDict,
VoteEnsembled,
VoteEnsembleD,
VoteEnsembled,
VoteEnsembleDict,
)
from .spatial.array import (
Expand Down
140 changes: 137 additions & 3 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,29 @@
"""

import warnings
from typing import Callable, Optional, Sequence, Union
from typing import Callable, Iterable, Optional, Sequence, Union

import numpy as np
import torch
import torch.nn.functional as F

from monai.config import NdarrayTensor
from monai.networks import one_hot
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import Transform
from monai.transforms.utils import get_largest_connected_component_mask
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
from monai.utils import ensure_tuple

__all__ = [
"Activations",
"AsDiscrete",
"FillHoles",
"KeepLargestConnectedComponent",
"LabelFilter",
"LabelToContour",
"MeanEnsemble",
"VoteEnsemble",
"ProbNMS",
"VoteEnsemble",
]


Expand Down Expand Up @@ -289,6 +292,137 @@ def __call__(self, img: torch.Tensor) -> torch.Tensor:
return output


class LabelFilter:
"""
This transform filters out labels and can be used as a processing step to view only certain labels.
The list of applied labels defines which labels will be kept.
Note:
All labels which do not match the `applied_labels` are set to the background label (0).
For example:
Use LabelFilter with applied_labels=[1, 5, 9]::
[1, 2, 3] [1, 0, 0]
[4, 5, 6] => [0, 5 ,0]
[7, 8, 9] [0, 0, 9]
"""

def __init__(self, applied_labels: Union[Iterable[int], int]) -> None:
"""
Initialize the LabelFilter class with the labels to filter on.
Args:
applied_labels: Label(s) to filter on.
"""
self.applied_labels = ensure_tuple(applied_labels)

def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
"""
Filter the image on the `applied_labels`.
Args:
img: Pytorch tensor or numpy array of any shape.
Raises:
NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.
Returns:
Pytorch tensor or numpy array of the same shape as the input.
"""
if isinstance(img, np.ndarray):
return np.asarray(np.where(np.isin(img, self.applied_labels), img, 0))
elif isinstance(img, torch.Tensor):
img_arr = img.detach().cpu().numpy()
img_arr = self(img_arr)
return torch.as_tensor(img_arr, device=img.device)
else:
raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.")


class FillHoles(Transform):
r"""
This transform fills holes in the image and can be used to remove artifacts inside segments.
An enclosed hole is defined as a background pixel/voxel which is only enclosed by a single class.
The definition of enclosed can be defined with the connectivity parameter::
1-connectivity 2-connectivity diagonal connection close-up
[ ] [ ] [ ] [ ] [ ]
| \ | / | <- hop 2
[ ]--[x]--[ ] [ ]--[x]--[ ] [x]--[ ]
| / | \ hop 1
[ ] [ ] [ ] [ ]
It is possible to define for which labels the hole filling should be applied.
The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]].
If C = 1, then the values correspond to expected labels.
If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing.
Note:
The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label.
The performance of this method heavily depends on the number of labels.
It is a bit faster if the list of `applied_labels` is provided.
Limiting the number of `applied_labels` results in a big decrease in processing time.
For example:
Use FillHoles with default parameters::
[1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3]
[1, 0, 1, 2, 0, 0, 3, 0] => [1, 1 ,1, 2, 0, 0, 3, 0]
[1, 1, 1, 2, 2, 2, 3, 3] [1, 1, 1, 2, 2, 2, 3, 3]
The hole in label 1 is fully enclosed and therefore filled with label 1.
The background label near label 2 and 3 is not fully enclosed and therefore not filled.
"""

def __init__(
self, applied_labels: Optional[Union[Iterable[int], int]] = None, connectivity: Optional[int] = None
) -> None:
"""
Initialize the connectivity and limit the labels for which holes are filled.
Args:
applied_labels: Labels for which to fill holes. Defaults to None, that is filling holes for all labels.
connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Accepted values are ranging from 1 to input.ndim. Defaults to a full connectivity of ``input.ndim``.
"""
super().__init__()
self.applied_labels = ensure_tuple(applied_labels) if applied_labels else None
self.connectivity = connectivity

def __call__(self, img: NdarrayTensor) -> NdarrayTensor:
"""
Fill the holes in the provided image.
Note:
The value 0 is assumed as background label.
Args:
img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
Raises:
NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.
Returns:
Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
"""
if isinstance(img, np.ndarray):
return fill_holes(img, self.applied_labels, self.connectivity)
elif isinstance(img, torch.Tensor):
img_arr = img.detach().cpu().numpy()
img_arr = self(img_arr)
return torch.as_tensor(img_arr, device=img.device)
else:
raise NotImplementedError(f"{self.__class__} can not handle data of type {type(img)}.")


class LabelToContour(Transform):
"""
Return the contour of binary input images that only compose of 0 and 1, with Laplace kernel
Expand Down
Loading

0 comments on commit 945e21c

Please sign in to comment.