Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2231 adds backend for largest cc #3239

Merged
merged 3 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 33 additions & 40 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,17 +263,21 @@ class KeepLargestConnectedComponent(Transform):

"""

backend = [TransformBackends.NUMPY]

def __init__(
self, applied_labels: Union[Sequence[int], int], independent: bool = True, connectivity: Optional[int] = None
) -> None:
"""
Args:
applied_labels: Labels for applying the connected component on.
If only one channel. The pixel whose value is not in this list will remain unchanged.
If the data is in one-hot format, this is used to determine what channels to apply.
independent: consider several labels as a whole or independent, default is `True`.
Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case
you want this "independent" to be specified as False.
applied_labels: Labels for applying the connected component analysis on.
If only one channel. The pixel whose value is in this list will be analyzed.
If the data is in one-hot format, this is used to determine which channels to apply.
independent: whether to treat ``applied_labels`` as a union of foreground labels.
If ``True``, the connected component analysis will be performed on each foreground label independently
and return the intersection of the largest component.
If ``False``, the analysis will be performed on the union of foreground labels.
default is `True`.
connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used.
Expand All @@ -283,48 +287,37 @@ def __init__(
self.independent = independent
self.connectivity = connectivity

def __call__(self, img: torch.Tensor) -> torch.Tensor:
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Args:
img: shape must be (C, spatial_dim1[, spatial_dim2, ...]).

Returns:
A PyTorch Tensor with shape (C, spatial_dim1[, spatial_dim2, ...]).
An array with shape (C, spatial_dim1[, spatial_dim2, ...]).
"""
if img.shape[0] == 1:
img = torch.squeeze(img, dim=0)

if self.independent:
for i in self.applied_labels:
foreground = (img == i).type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[foreground != mask] = 0
else:
foreground = torch.zeros_like(img)
for i in self.applied_labels:
foreground += (img == i).type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[foreground != mask] = 0

output = torch.unsqueeze(img, dim=0)
else:
# one-hot data is assumed to have binary value in each channel
if self.independent:
for i in self.applied_labels:
foreground = img[i, ...].type(torch.uint8)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[i, ...][foreground != mask] = 0
else:
applied_img = img[self.applied_labels, ...].type(torch.uint8)
foreground = torch.any(applied_img, dim=0)
is_onehot = img.shape[0] > 1
if self.independent:
for i in self.applied_labels:
foreground = img[i] > 0 if is_onehot else img[0] == i
mask = get_largest_connected_component_mask(foreground, self.connectivity)
background_mask = torch.unsqueeze(foreground != mask, dim=0)
background_mask = torch.repeat_interleave(background_mask, len(self.applied_labels), dim=0)
applied_img[background_mask] = 0
img[self.applied_labels, ...] = applied_img.type(img.type())
output = img

return output
if is_onehot:
img[i][foreground != mask] = 0
else:
img[0][foreground != mask] = 0
return img
if not is_onehot: # not one-hot, union of labels
labels, *_ = convert_to_dst_type(self.applied_labels, dst=img, wrap_sequence=True)
foreground = (img[..., None] == labels).any(-1)[0]
mask = get_largest_connected_component_mask(foreground, self.connectivity)
img[0][foreground != mask] = 0
return img
# one-hot, union of labels
foreground = (img[self.applied_labels, ...] == 1).any(0)
mask = get_largest_connected_component_mask(foreground, self.connectivity)
for i in self.applied_labels:
img[i][foreground != mask] = 0
return img


class LabelFilter:
Expand Down
12 changes: 8 additions & 4 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ class KeepLargestConnectedComponentd(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.KeepLargestConnectedComponent`.
"""

backend = KeepLargestConnectedComponent.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -207,9 +209,11 @@ def __init__(
applied_labels: Labels for applying the connected component on.
If only one channel. The pixel whose value is not in this list will remain unchanged.
If the data is in one-hot format, this is the channel indices to apply transform.
independent: consider several labels as a whole or independent, default is `True`.
Example use case would be segment label 1 is liver and label 2 is liver tumor, in that case
you want this "independent" to be specified as False.
independent: whether to treat ``applied_labels`` as a union of foreground labels.
If ``True``, the connected component analysis will be performed on each foreground label independently
and return the intersection of the largest component.
If ``False``, the analysis will be performed on the union of foreground labels.
default is `True`.
connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used.
Expand All @@ -219,7 +223,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.converter = KeepLargestConnectedComponent(applied_labels, independent, connectivity)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand Down
10 changes: 5 additions & 5 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,7 @@ def generate_spatial_bounding_box(
return box_start, box_end


def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Optional[int] = None) -> torch.Tensor:
def get_largest_connected_component_mask(img: NdarrayOrTensor, connectivity: Optional[int] = None) -> NdarrayOrTensor:
"""
Gets the largest connected component mask of an image.

Expand All @@ -927,13 +927,13 @@ def get_largest_connected_component_mask(img: torch.Tensor, connectivity: Option
Accepted values are ranging from 1 to input.ndim. If ``None``, a full
connectivity of ``input.ndim`` is used.
"""
img_arr = img.detach().cpu().numpy()
largest_cc = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype)
img_arr: np.ndarray = convert_data_type(img, np.ndarray)[0] # type: ignore
largest_cc: np.ndarray = np.zeros(shape=img_arr.shape, dtype=img_arr.dtype)
img_arr = measure.label(img_arr, connectivity=connectivity)
if img_arr.max() != 0:
largest_cc[...] = img_arr == (np.argmax(np.bincount(img_arr.flat)[1:]) + 1)

return torch.as_tensor(largest_cc, device=img.device)
largest_cc = convert_to_dst_type(largest_cc, dst=img, dtype=largest_cc.dtype)[0] # type: ignore
return largest_cc


def fill_holes(
Expand Down
Loading