Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

2681 Enhance MaskIntensity to select mask values #2726

Merged
merged 6 commits into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""

from collections.abc import Iterable
from typing import Any, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
from warnings import warn

import numpy as np
Expand All @@ -24,7 +24,7 @@
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.networks.layers import GaussianFilter, HilbertTransform, SavitzkyGolayFilter
from monai.transforms.transform import Fourier, RandomizableTransform, Transform
from monai.transforms.utils import rescale_array
from monai.transforms.utils import is_positive, rescale_array
from monai.utils import (
PT_BEFORE_1_7,
InvalidPyTorchVersionError,
Expand Down Expand Up @@ -789,19 +789,23 @@ class MaskIntensity(Transform):
"""
Mask the intensity values of input image with the specified mask data.
Mask data must have the same spatial size as the input image, and all
the intensity values of input image corresponding to `0` in the mask
data will be set to `0`, others will keep the original value.
the intensity values of input image corresponding to the selected values
in the mask data will keep the original value, others will be set to `0`.

Args:
mask_data: if `mask_data` is single channel, apply to every channel
of input image. if multiple channels, the number of channels must
match the input data. `mask_data` will be converted to `bool` values
by `mask_data > 0` before applying transform to input image.
match the input data. the intensity values of input image corresponding
to the selected values in the mask data will keep the original value,
others will be set to `0`.
select_fn: function to select valid values of the `mask_data`, default is
to select `values > 0`.

"""

def __init__(self, mask_data: Optional[np.ndarray]) -> None:
def __init__(self, mask_data: Optional[np.ndarray], select_fn: Callable = is_positive) -> None:
self.mask_data = mask_data
self.select_fn = select_fn

def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> np.ndarray:
"""
Expand All @@ -816,21 +820,18 @@ def __call__(self, img: np.ndarray, mask_data: Optional[np.ndarray] = None) -> n
- ValueError: When ``mask_data`` and ``img`` channels differ and ``mask_data`` is not single channel.

"""
if self.mask_data is None and mask_data is None:
raise ValueError("Unknown mask_data.")
mask_data_ = np.array([[1]])
if self.mask_data is not None and mask_data is None:
mask_data_ = self.mask_data > 0
if mask_data is not None:
mask_data_ = mask_data > 0
mask_data_ = np.asarray(mask_data_)
if mask_data_.shape[0] != 1 and mask_data_.shape[0] != img.shape[0]:
mask_data = self.mask_data if mask_data is None else mask_data
if mask_data is None:
raise ValueError("must provide the mask_data when initializing the transform or at runtime.")

mask_data = np.asarray(self.select_fn(mask_data))
if mask_data.shape[0] != 1 and mask_data.shape[0] != img.shape[0]:
raise ValueError(
"When mask_data is not single channel, mask_data channels must match img, "
f"got img={img.shape[0]} mask_data={mask_data_.shape[0]}."
f"got img channels={img.shape[0]} mask_data channels={mask_data.shape[0]}."
)

return np.asarray(img * mask_data_)
return np.asarray(img * mask_data)


class SavitzkyGolaySmooth(Transform):
Expand Down
15 changes: 10 additions & 5 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
"""

from collections.abc import Iterable
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand All @@ -42,6 +42,7 @@
ThresholdIntensity,
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.utils import is_positive
from monai.utils import dtype_torch_to_numpy, ensure_tuple, ensure_tuple_rep, ensure_tuple_size, fall_back_tuple

__all__ = [
Expand Down Expand Up @@ -808,11 +809,14 @@ class MaskIntensityd(MapTransform):
See also: :py:class:`monai.transforms.compose.MapTransform`
mask_data: if mask data is single channel, apply to every channel
of input image. if multiple channels, the channel number must
match input data. mask_data will be converted to `bool` values
by `mask_data > 0` before applying transform to input image.
if None, will extract the mask data from input data based on `mask_key`.
match input data. the intensity values of input image corresponding
to the selected values in the mask data will keep the original value,
others will be set to `0`. if None, will extract the mask data from
input data based on `mask_key`.
mask_key: the key to extract mask data from input dictionary, only works
when `mask_data` is None.
select_fn: function to select valid values of the `mask_data`, default is
to select `values > 0`.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand All @@ -822,10 +826,11 @@ def __init__(
keys: KeysCollection,
mask_data: Optional[np.ndarray] = None,
mask_key: Optional[str] = None,
select_fn: Callable = is_positive,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.converter = MaskIntensity(mask_data)
self.converter = MaskIntensity(mask_data=mask_data, select_fn=select_fn)
self.mask_key = mask_key if mask_data is None else None

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
Expand Down
11 changes: 10 additions & 1 deletion tests/test_mask_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,18 @@
np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]),
]

TEST_CASE_4 = [
{
"mask_data": np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),
"select_fn": lambda x: np.where((x > 3) & (x < 7), True, False),
},
np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]]),
np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]),
]


class TestMaskIntensity(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
def test_value(self, argments, image, expected_data):
result = MaskIntensity(**argments)(image)
np.testing.assert_allclose(result, expected_data)
Expand Down
12 changes: 11 additions & 1 deletion tests/test_mask_intensityd.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,19 @@
np.array([[[0, 0, 0], [0, 2, 0], [0, 0, 0]], [[0, 4, 0], [0, 5, 0], [0, 6, 0]]]),
]

TEST_CASE_5 = [
{
"keys": "img",
"mask_data": np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]),
"select_fn": lambda x: np.where((x > 3) & (x < 7), True, False),
},
{"img": np.array([[[1, 1, 1], [2, 2, 2], [3, 3, 3]], [[4, 4, 4], [5, 5, 5], [6, 6, 6]]])},
np.array([[[0, 0, 0], [2, 2, 2], [0, 0, 0]], [[0, 0, 0], [5, 5, 5], [0, 0, 0]]]),
]


class TestMaskIntensityd(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_value(self, argments, image, expected_data):
result = MaskIntensityd(**argments)(image)
np.testing.assert_allclose(result["img"], expected_data)
Expand Down