From 8866706e35da6e0a251f7bc5be013f306ec1f462 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Tue, 18 Apr 2023 05:15:29 -0400 Subject: [PATCH] Add track_meta option for Lambda and derived transforms Signed-off-by: Suraj Pai --- monai/transforms/utility/array.py | 23 ++++++++++++++++++----- monai/transforms/utility/dictionary.py | 9 ++++++++- tests/test_lambda.py | 22 ++++++++++++++++++++++ tests/test_lambdad.py | 25 +++++++++++++++++++++++++ 4 files changed, 73 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index e982f3ced1..c6ac8abc48 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -830,6 +830,8 @@ class Lambda(InvertibleTransform): Args: func: Lambda/function to be applied. inv_func: Lambda/function of inverse operation, default to `lambda x: x`. + track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`) + as opposed to MONAI's enhanced objects. By default, this is `True`. Raises: TypeError: When ``func`` is not an ``Optional[Callable]``. @@ -838,11 +840,14 @@ class Lambda(InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, func: Callable | None = None, inv_func: Callable = no_collation) -> None: + def __init__( + self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True + ) -> None: if func is not None and not callable(func): raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func self.inv_func = inv_func + self.track_meta = track_meta def __call__(self, img: NdarrayOrTensor, func: Callable | None = None): """ @@ -860,7 +865,7 @@ def __call__(self, img: NdarrayOrTensor, func: Callable | None = None): raise TypeError(f"func must be None or callable but is {type(fn).__name__}.") out = fn(img) # convert to MetaTensor if necessary - if isinstance(out, (np.ndarray, torch.Tensor)) and not isinstance(out, MetaTensor) and get_track_meta(): + if isinstance(out, (np.ndarray, torch.Tensor)) and not isinstance(out, MetaTensor) and self.track_meta: out = MetaTensor(out) if isinstance(out, MetaTensor): self.push_transform(out) @@ -881,21 +886,29 @@ class RandLambda(Lambda, RandomizableTransform): func: Lambda/function to be applied. prob: probability of executing the random function, default to 1.0, with 100% probability to execute. inv_func: Lambda/function of inverse operation, default to `lambda x: x`. + track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`) + as opposed to MONAI's enhanced objects. By default, this is `True`. For more details, please check :py:class:`monai.transforms.Lambda`. """ backend = Lambda.backend - def __init__(self, func: Callable | None = None, prob: float = 1.0, inv_func: Callable = no_collation) -> None: - Lambda.__init__(self=self, func=func, inv_func=inv_func) + def __init__( + self, + func: Callable | None = None, + prob: float = 1.0, + inv_func: Callable = no_collation, + track_meta: bool = True, + ) -> None: + Lambda.__init__(self=self, func=func, inv_func=inv_func, track_meta=track_meta) RandomizableTransform.__init__(self=self, prob=prob) def __call__(self, img: NdarrayOrTensor, func: Callable | None = None): self.randomize(img) out = deepcopy(super().__call__(img, func) if self._do_transform else img) # convert to MetaTensor if necessary - if not isinstance(out, MetaTensor) and get_track_meta(): + if not isinstance(out, MetaTensor) and self.track_meta: out = MetaTensor(out) if isinstance(out, MetaTensor): lambda_info = self.pop_transform(out) if self._do_transform else {} diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 40b1527443..e64268cac6 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1088,6 +1088,8 @@ class Lambdad(MapTransform, InvertibleTransform): each element corresponds to a key in ``keys``. inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`) + as opposed to MONAI's enhanced objects. By default, this is `True`. overwrite: whether to overwrite the original data in the input dictionary with lambda function output. it can be bool or str, when setting to str, it will create a new key for the output and keep the value of key intact. default to True. it also can be a sequence of bool or str, each element corresponds to a key @@ -1106,6 +1108,7 @@ def __init__( keys: KeysCollection, func: Sequence[Callable] | Callable, inv_func: Sequence[Callable] | Callable = no_collation, + track_meta: bool = True, overwrite: Sequence[bool] | bool | Sequence[str] | str = True, allow_missing_keys: bool = False, ) -> None: @@ -1113,7 +1116,7 @@ def __init__( self.func = ensure_tuple_rep(func, len(self.keys)) self.inv_func = ensure_tuple_rep(inv_func, len(self.keys)) self.overwrite = ensure_tuple_rep(overwrite, len(self.keys)) - self._lambd = Lambda() + self._lambd = Lambda(track_meta=track_meta) def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]: d = dict(data) @@ -1146,6 +1149,8 @@ class RandLambdad(Lambdad, RandomizableTransform): each element corresponds to a key in ``keys``. inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`. It also can be a sequence of Callable, each element corresponds to a key in ``keys``. + track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`) + as opposed to MONAI's enhanced objects. By default, this is `True`. overwrite: whether to overwrite the original data in the input dictionary with lambda function output. default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``. prob: probability of executing the random function, default to 1.0, with 100% probability to execute. @@ -1165,6 +1170,7 @@ def __init__( keys: KeysCollection, func: Sequence[Callable] | Callable, inv_func: Sequence[Callable] | Callable = no_collation, + track_meta: bool = True, overwrite: Sequence[bool] | bool = True, prob: float = 1.0, allow_missing_keys: bool = False, @@ -1174,6 +1180,7 @@ def __init__( keys=keys, func=func, inv_func=inv_func, + track_meta=track_meta, overwrite=overwrite, allow_missing_keys=allow_missing_keys, ) diff --git a/tests/test_lambda.py b/tests/test_lambda.py index 91678c0b81..e2276d671c 100644 --- a/tests/test_lambda.py +++ b/tests/test_lambda.py @@ -13,8 +13,12 @@ import unittest +from numpy import ndarray +from torch import Tensor + from monai.data.meta_tensor import MetaTensor from monai.transforms.utility.array import Lambda +from monai.utils.type_conversion import convert_to_numpy, convert_to_tensor from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -44,6 +48,24 @@ def slice_func(x): out = lambd.inverse(out) self.assertEqual(len(out.applied_operations), 0) + def test_lambda_track_meta_false(self): + for p in TEST_NDARRAYS: + img = p(self.imt) + + def to_numpy(x): + return convert_to_numpy(x) + + lambd = Lambda(func=to_numpy, track_meta=False) + out = lambd(img) + self.assertIsInstance(out, ndarray) + + def to_tensor(x): + return convert_to_tensor(x) + + lambd = Lambda(func=to_tensor, track_meta=False) + out = lambd(img) + self.assertIsInstance(out, Tensor) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_lambdad.py b/tests/test_lambdad.py index 55df819fa9..02e4423b74 100644 --- a/tests/test_lambdad.py +++ b/tests/test_lambdad.py @@ -13,8 +13,12 @@ import unittest +from numpy import ndarray +from torch import Tensor + from monai.data.meta_tensor import MetaTensor from monai.transforms.utility.dictionary import Lambdad +from monai.utils.type_conversion import convert_to_numpy, convert_to_tensor from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose @@ -53,6 +57,27 @@ def slice_func(x): self.assertIsInstance(inv_img, MetaTensor) self.assertEqual(len(inv_img.applied_operations), 0) + def test_lambdad_track_meta_false(self): + for p in TEST_NDARRAYS: + img = p(self.imt) + data = {"img": img} + + def to_numpy(x): + return convert_to_numpy(x) + + lambd = Lambdad(keys=data.keys(), func=to_numpy, track_meta=False) + out = lambd(data) + out_img = out["img"] + self.assertIsInstance(out_img, ndarray) + + def to_tensor(x): + return convert_to_tensor(x) + + lambd = Lambdad(keys=data.keys(), func=to_tensor, track_meta=False) + out = lambd(data) + out_img = out["img"] + self.assertIsInstance(out_img, Tensor) + if __name__ == "__main__": unittest.main()