Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ class MaskIntensity(Transform):

"""

backend = [TransformBackends.NUMPY]
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, mask_data: Optional[NdarrayOrTensor] = None, select_fn: Callable = is_positive) -> None:
self.mask_data = mask_data
Expand Down
66 changes: 37 additions & 29 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@
convert_to_numpy,
convert_to_tensor,
ensure_tuple,
get_equivalent_dtype,
look_up_option,
min_version,
optional_import,
)
from monai.utils.enums import TransformBackends
from monai.utils.misc import is_module_ver_at_least
from monai.utils.type_conversion import convert_to_dst_type
from monai.utils.type_conversion import convert_to_dst_type, get_equivalent_dtype

PILImageImage, has_pil = optional_import("PIL.Image", name="Image")
pil_image_fromarray, _ = optional_import("PIL.Image", name="fromarray")
Expand Down Expand Up @@ -389,11 +388,9 @@ def __call__(self, data: NdarrayOrTensor):
if applicable.

"""
if self.data_type == "tensor":
dtype_ = get_equivalent_dtype(self.dtype, torch.Tensor)
return convert_to_tensor(data, dtype=dtype_, device=self.device)
dtype_ = get_equivalent_dtype(self.dtype, np.ndarray)
return convert_to_numpy(data, dtype=dtype_)
output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray
out, *_ = convert_data_type(data, output_type=output_type, dtype=self.dtype, device=self.device)
return out


class ToNumpy(Transform):
Expand Down Expand Up @@ -880,18 +877,19 @@ class ConvertToMultiChannelBasedOnBratsClasses(Transform):
and ET (Enhancing tumor).
"""

def __call__(self, img: np.ndarray) -> np.ndarray:
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
# if img has channel dim, squeeze it
if img.ndim == 4 and img.shape[0] == 1:
img = np.squeeze(img, axis=0)
img = img.squeeze(0)

result = [np.logical_or(img == 1, img == 4)]
result = [(img == 1) | (img == 4)]
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2))
result.append((img == 1) | (img == 4) | (img == 2))
# label 4 is ET
result.append(img == 4)
return np.stack(result, axis=0)
return torch.stack(result, dim=0) if isinstance(img, torch.Tensor) else np.stack(result, axis=0)


class AddExtremePointsChannel(Randomizable, Transform):
Expand Down Expand Up @@ -966,6 +964,8 @@ class TorchVision:

"""

backend = [TransformBackends.TORCH]

def __init__(self, name: str, *args, **kwargs) -> None:
"""
Args:
Expand All @@ -978,14 +978,16 @@ def __init__(self, name: str, *args, **kwargs) -> None:
transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name)
self.trans = transform(*args, **kwargs)

def __call__(self, img: torch.Tensor):
def __call__(self, img: NdarrayOrTensor):
"""
Args:
img: PyTorch Tensor data for the TorchVision transform.

"""
img, *_ = convert_data_type(img, torch.Tensor) # type: ignore
return self.trans(img)
img_t, *_ = convert_data_type(img, torch.Tensor) # type: ignore
out = self.trans(img_t)
out, *_ = convert_to_dst_type(src=out, dst=img)
return out


class MapLabelValue:
Expand All @@ -997,6 +999,8 @@ class MapLabelValue:

"""

backend = [TransformBackends.NUMPY]

def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeLike = np.float32) -> None:
"""
Args:
Expand All @@ -1012,11 +1016,11 @@ def __init__(self, orig_labels: Sequence, target_labels: Sequence, dtype: DtypeL

self.orig_labels = orig_labels
self.target_labels = target_labels
self.dtype = dtype
self.dtype = get_equivalent_dtype(dtype, data_type=np.ndarray)

def __call__(self, img: np.ndarray):
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
img_flat = img.flatten()
def __call__(self, img: NdarrayOrTensor):
img_np, *_ = convert_data_type(img, np.ndarray)
img_flat = img_np.flatten()
try:
out_flat = np.copy(img_flat).astype(self.dtype)
except ValueError:
Expand All @@ -1028,7 +1032,9 @@ def __call__(self, img: np.ndarray):
continue
np.place(out_flat, img_flat == o, t)

return out_flat.reshape(img.shape)
out = out_flat.reshape(img_np.shape)
out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype)
return out


class IntensityStats(Transform):
Expand All @@ -1050,14 +1056,16 @@ class IntensityStats(Transform):

"""

backend = [TransformBackends.NUMPY]

def __init__(self, ops: Sequence[Union[str, Callable]], key_prefix: str, channel_wise: bool = False) -> None:
self.ops = ensure_tuple(ops)
self.key_prefix = key_prefix
self.channel_wise = channel_wise

def __call__(
self, img: np.ndarray, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, Dict]:
self, img: NdarrayOrTensor, meta_data: Optional[Dict] = None, mask: Optional[np.ndarray] = None
) -> Tuple[NdarrayOrTensor, Dict]:
"""
Compute statistics for the intensity of input image.

Expand All @@ -1068,15 +1076,15 @@ def __call__(
mask must have the same shape as input `img`.

"""
img, *_ = convert_data_type(img, np.ndarray) # type: ignore
img_np: np.ndarray
img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore
if meta_data is None:
meta_data = {}

img_: np.ndarray = img
if mask is not None:
if mask.shape != img.shape or mask.dtype != bool:
if mask.shape != img_np.shape or mask.dtype != bool:
raise TypeError("mask must be bool array with the same shape as input `img`.")
img_ = img[mask]
img_np = img_np[mask]

supported_ops = {
"mean": np.nanmean,
Expand All @@ -1095,9 +1103,9 @@ def _compute(op: Callable, data: np.ndarray):
for o in self.ops:
if isinstance(o, str):
o = look_up_option(o, supported_ops.keys())
meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_) # type: ignore
meta_data[self.key_prefix + "_" + o] = _compute(supported_ops[o], img_np) # type: ignore
elif callable(o):
meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_)
meta_data[self.key_prefix + "_custom_" + str(custom_index)] = _compute(o, img_np)
custom_index += 1
else:
raise ValueError("ops must be key string for predefined operations or callable function.")
Expand Down
22 changes: 16 additions & 6 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,11 +1192,13 @@ class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
and ET (Enhancing tumor).
"""

backend = ConvertToMultiChannelBasedOnBratsClasses.backend

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False):
super().__init__(keys, allow_missing_keys)
self.converter = ConvertToMultiChannelBasedOnBratsClasses()

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
Expand Down Expand Up @@ -1280,6 +1282,8 @@ class TorchVisiond(MapTransform):
data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.
"""

backend = TorchVision.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
Expand All @@ -1294,7 +1298,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
super().__init__(keys, allow_missing_keys)
self.trans = TorchVision(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.trans(d[key])
Expand All @@ -1317,6 +1321,8 @@ class RandTorchVisiond(Randomizable, MapTransform):

"""

backend = TorchVision.backend

def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
"""
Args:
Expand All @@ -1331,7 +1337,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
MapTransform.__init__(self, keys, allow_missing_keys)
self.trans = TorchVision(name, *args, **kwargs)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.trans(d[key])
Expand All @@ -1343,6 +1349,8 @@ class MapLabelValued(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`.
"""

backend = MapLabelValue.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1364,7 +1372,7 @@ def __init__(
super().__init__(keys, allow_missing_keys)
self.mapper = MapLabelValue(orig_labels=orig_labels, target_labels=target_labels, dtype=dtype)

def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.mapper(d[key])
Expand Down Expand Up @@ -1406,6 +1414,8 @@ class IntensityStatsd(MapTransform):

"""

backend = IntensityStats.backend

def __init__(
self,
keys: KeysCollection,
Expand All @@ -1425,7 +1435,7 @@ def __init__(
raise ValueError("meta_keys should have the same length as keys.")
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))

def __call__(self, data) -> Dict[Hashable, np.ndarray]:
def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key, mask_key, meta_key, meta_key_postfix in self.key_iterator(
d, self.mask_keys, self.meta_keys, self.meta_key_postfix
Expand All @@ -1442,7 +1452,7 @@ class ToDeviced(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`.
"""

backend = [TransformBackends.TORCH]
backend = ToDevice.backend

def __init__(
self, keys: KeysCollection, device: Union[torch.device, str], allow_missing_keys: bool = False, **kwargs
Expand Down
2 changes: 1 addition & 1 deletion monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def convert_data_type(

output_type = output_type or orig_type

dtype_ = get_equivalent_dtype(dtype or get_dtype(data), output_type)
dtype_ = get_equivalent_dtype(dtype, output_type)

if output_type is torch.Tensor:
data = convert_to_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence)
Expand Down
5 changes: 0 additions & 5 deletions tests/test_convert_data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def test_convert_data_type(self, in_image, im_out):
def test_neg_stride(self):
_ = convert_data_type(np.array((1, 2))[::-1], torch.Tensor)

def test_ill_arg(self):
with self.assertRaises(ValueError):
convert_data_type(None, torch.Tensor)
convert_data_type(None, np.ndarray)

@parameterized.expand(TESTS_LIST)
def test_convert_list(self, in_image, im_out, wrap):
output_type = type(im_out) if wrap else type(im_out[0])
Expand Down
46 changes: 29 additions & 17 deletions tests/test_convert_to_multi_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,46 @@

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms import ConvertToMultiChannelBasedOnBratsClasses
from tests.utils import TEST_NDARRAYS, assert_allclose

TEST_CASE_1 = [
np.array([[0, 1, 2], [1, 2, 4], [0, 1, 4]]),
np.array([[[0, 1, 0], [1, 0, 1], [0, 1, 1]], [[0, 1, 1], [1, 1, 1], [0, 1, 1]], [[0, 0, 0], [0, 0, 1], [0, 0, 1]]]),
]

TEST_CASE_2 = [
np.array([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]),
np.array(
TESTS = []
for p in TEST_NDARRAYS:
TESTS.extend(
[
[[[0, 1], [1, 0]], [[0, 1], [1, 1]]],
[[[0, 1], [1, 1]], [[1, 1], [1, 1]]],
[[[0, 0], [0, 0]], [[0, 1], [1, 1]]],
[
p([[0, 1, 2], [1, 2, 4], [0, 1, 4]]),
p(
[
[[0, 1, 0], [1, 0, 1], [0, 1, 1]],
[[0, 1, 1], [1, 1, 1], [0, 1, 1]],
[[0, 0, 0], [0, 0, 1], [0, 0, 1]],
]
),
],
[
p([[[[0, 1], [1, 2]], [[2, 4], [4, 4]]]]),
p(
[
[[[0, 1], [1, 0]], [[0, 1], [1, 1]]],
[[[0, 1], [1, 1]], [[1, 1], [1, 1]]],
[[[0, 0], [0, 0]], [[0, 1], [1, 1]]],
]
),
],
]
),
]
)


class TestConvertToMultiChannel(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
@parameterized.expand(TESTS)
def test_type_shape(self, data, expected_result):
result = ConvertToMultiChannelBasedOnBratsClasses()(data)
np.testing.assert_equal(result, expected_result)
self.assertEqual(f"{result.dtype}", "bool")
assert_allclose(result, expected_result)
self.assertTrue(result.dtype in (bool, torch.bool))


if __name__ == "__main__":
Expand Down
Loading