Skip to content

Commit

Permalink
Add track_meta option for Lambda and derived transforms (#6385)
Browse files Browse the repository at this point in the history
Fixes #6379 

### Description
`track_meta` flag added to `Lambda` and derived transforms to allow type
conversion to `np.ndarray` and `torch.Tensor` based on user-defined
function

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Suraj Pai <bspai@bwh.harvard.edu>
  • Loading branch information
surajpaib authored Apr 18, 2023
1 parent e18097d commit 30aa410
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 6 deletions.
23 changes: 18 additions & 5 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,8 @@ class Lambda(InvertibleTransform):
Args:
func: Lambda/function to be applied.
inv_func: Lambda/function of inverse operation, default to `lambda x: x`.
track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)
as opposed to MONAI's enhanced objects. By default, this is `True`.
Raises:
TypeError: When ``func`` is not an ``Optional[Callable]``.
Expand All @@ -838,11 +840,14 @@ class Lambda(InvertibleTransform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, func: Callable | None = None, inv_func: Callable = no_collation) -> None:
def __init__(
self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True
) -> None:
if func is not None and not callable(func):
raise TypeError(f"func must be None or callable but is {type(func).__name__}.")
self.func = func
self.inv_func = inv_func
self.track_meta = track_meta

def __call__(self, img: NdarrayOrTensor, func: Callable | None = None):
"""
Expand All @@ -860,7 +865,7 @@ def __call__(self, img: NdarrayOrTensor, func: Callable | None = None):
raise TypeError(f"func must be None or callable but is {type(fn).__name__}.")
out = fn(img)
# convert to MetaTensor if necessary
if isinstance(out, (np.ndarray, torch.Tensor)) and not isinstance(out, MetaTensor) and get_track_meta():
if isinstance(out, (np.ndarray, torch.Tensor)) and not isinstance(out, MetaTensor) and self.track_meta:
out = MetaTensor(out)
if isinstance(out, MetaTensor):
self.push_transform(out)
Expand All @@ -881,21 +886,29 @@ class RandLambda(Lambda, RandomizableTransform):
func: Lambda/function to be applied.
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
inv_func: Lambda/function of inverse operation, default to `lambda x: x`.
track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)
as opposed to MONAI's enhanced objects. By default, this is `True`.
For more details, please check :py:class:`monai.transforms.Lambda`.
"""

backend = Lambda.backend

def __init__(self, func: Callable | None = None, prob: float = 1.0, inv_func: Callable = no_collation) -> None:
Lambda.__init__(self=self, func=func, inv_func=inv_func)
def __init__(
self,
func: Callable | None = None,
prob: float = 1.0,
inv_func: Callable = no_collation,
track_meta: bool = True,
) -> None:
Lambda.__init__(self=self, func=func, inv_func=inv_func, track_meta=track_meta)
RandomizableTransform.__init__(self=self, prob=prob)

def __call__(self, img: NdarrayOrTensor, func: Callable | None = None):
self.randomize(img)
out = deepcopy(super().__call__(img, func) if self._do_transform else img)
# convert to MetaTensor if necessary
if not isinstance(out, MetaTensor) and get_track_meta():
if not isinstance(out, MetaTensor) and self.track_meta:
out = MetaTensor(out)
if isinstance(out, MetaTensor):
lambda_info = self.pop_transform(out) if self._do_transform else {}
Expand Down
9 changes: 8 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,8 @@ class Lambdad(MapTransform, InvertibleTransform):
each element corresponds to a key in ``keys``.
inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)
as opposed to MONAI's enhanced objects. By default, this is `True`.
overwrite: whether to overwrite the original data in the input dictionary with lambda function output. it
can be bool or str, when setting to str, it will create a new key for the output and keep the value of
key intact. default to True. it also can be a sequence of bool or str, each element corresponds to a key
Expand All @@ -1106,14 +1108,15 @@ def __init__(
keys: KeysCollection,
func: Sequence[Callable] | Callable,
inv_func: Sequence[Callable] | Callable = no_collation,
track_meta: bool = True,
overwrite: Sequence[bool] | bool | Sequence[str] | str = True,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.func = ensure_tuple_rep(func, len(self.keys))
self.inv_func = ensure_tuple_rep(inv_func, len(self.keys))
self.overwrite = ensure_tuple_rep(overwrite, len(self.keys))
self._lambd = Lambda()
self._lambd = Lambda(track_meta=track_meta)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d = dict(data)
Expand Down Expand Up @@ -1146,6 +1149,8 @@ class RandLambdad(Lambdad, RandomizableTransform):
each element corresponds to a key in ``keys``.
inv_func: Lambda/function of inverse operation if want to invert transforms, default to `lambda x: x`.
It also can be a sequence of Callable, each element corresponds to a key in ``keys``.
track_meta: If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)
as opposed to MONAI's enhanced objects. By default, this is `True`.
overwrite: whether to overwrite the original data in the input dictionary with lambda function output.
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
Expand All @@ -1165,6 +1170,7 @@ def __init__(
keys: KeysCollection,
func: Sequence[Callable] | Callable,
inv_func: Sequence[Callable] | Callable = no_collation,
track_meta: bool = True,
overwrite: Sequence[bool] | bool = True,
prob: float = 1.0,
allow_missing_keys: bool = False,
Expand All @@ -1174,6 +1180,7 @@ def __init__(
keys=keys,
func=func,
inv_func=inv_func,
track_meta=track_meta,
overwrite=overwrite,
allow_missing_keys=allow_missing_keys,
)
Expand Down
22 changes: 22 additions & 0 deletions tests/test_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

import unittest

from numpy import ndarray
from torch import Tensor

from monai.data.meta_tensor import MetaTensor
from monai.transforms.utility.array import Lambda
from monai.utils.type_conversion import convert_to_numpy, convert_to_tensor
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


Expand Down Expand Up @@ -44,6 +48,24 @@ def slice_func(x):
out = lambd.inverse(out)
self.assertEqual(len(out.applied_operations), 0)

def test_lambda_track_meta_false(self):
for p in TEST_NDARRAYS:
img = p(self.imt)

def to_numpy(x):
return convert_to_numpy(x)

lambd = Lambda(func=to_numpy, track_meta=False)
out = lambd(img)
self.assertIsInstance(out, ndarray)

def to_tensor(x):
return convert_to_tensor(x)

lambd = Lambda(func=to_tensor, track_meta=False)
out = lambd(img)
self.assertIsInstance(out, Tensor)


if __name__ == "__main__":
unittest.main()
25 changes: 25 additions & 0 deletions tests/test_lambdad.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@

import unittest

from numpy import ndarray
from torch import Tensor

from monai.data.meta_tensor import MetaTensor
from monai.transforms.utility.dictionary import Lambdad
from monai.utils.type_conversion import convert_to_numpy, convert_to_tensor
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose


Expand Down Expand Up @@ -53,6 +57,27 @@ def slice_func(x):
self.assertIsInstance(inv_img, MetaTensor)
self.assertEqual(len(inv_img.applied_operations), 0)

def test_lambdad_track_meta_false(self):
for p in TEST_NDARRAYS:
img = p(self.imt)
data = {"img": img}

def to_numpy(x):
return convert_to_numpy(x)

lambd = Lambdad(keys=data.keys(), func=to_numpy, track_meta=False)
out = lambd(data)
out_img = out["img"]
self.assertIsInstance(out_img, ndarray)

def to_tensor(x):
return convert_to_tensor(x)

lambd = Lambdad(keys=data.keys(), func=to_tensor, track_meta=False)
out = lambd(data)
out_img = out["img"]
self.assertIsInstance(out_img, Tensor)


if __name__ == "__main__":
unittest.main()

0 comments on commit 30aa410

Please sign in to comment.