Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1542 Add RandLambdad transform #1546

Merged
merged 11 commits into from
Feb 4, 2021
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,12 @@ Utility (Dict)
:members:
:special-members: __call__

`RandLambdad`
"""""""""""""
.. autoclass:: RandLambdad
:members:
:special-members: __call__

`LabelToMaskd`
""""""""""""""
.. autoclass:: LabelToMaskd
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@
Lambdad,
LambdaD,
LambdaDict,
RandLambdad,
RandLambdaD,
RandLambdaDict,
RepeatChanneld,
RepeatChannelD,
RepeatChannelDict,
Expand Down
29 changes: 27 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import copy
import logging
from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -64,10 +64,12 @@
"CopyItemsd",
"ConcatItemsd",
"Lambdad",
"RandLambdad",
"LabelToMaskd",
"FgBgToIndicesd",
"ConvertToMultiChannelBasedOnBratsClassesd",
"AddExtremePointsChanneld",
"TorchVisiond",
"IdentityD",
"IdentityDict",
"AsChannelFirstD",
Expand All @@ -76,6 +78,8 @@
"AsChannelLastDict",
"AddChannelD",
"AddChannelDict",
"RandLambdaD",
"RandLambdaDict",
"RepeatChannelD",
"RepeatChannelDict",
"SplitChannelD",
Expand Down Expand Up @@ -106,7 +110,6 @@
"ConvertToMultiChannelBasedOnBratsClassesDict",
"AddExtremePointsChannelD",
"AddExtremePointsChannelDict",
"TorchVisiond",
"TorchVisionD",
"TorchVisionDict",
]
Expand Down Expand Up @@ -621,6 +624,27 @@ def __call__(self, data):
return d


class RandLambdad(Lambdad, Randomizable):
"""
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
func: Lambda/function to be applied. It also can be a sequence of Callable,
each element corresponds to a key in ``keys``.
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
For more details, please check :py:class:`monai.transforms.Lambdad`.
"""

def randomize(self, data: Any) -> None:
pass


class LabelToMaskd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.LabelToMask`.
Expand Down Expand Up @@ -830,3 +854,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
) = ConvertToMultiChannelBasedOnBratsClassesd
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
RandLambdaD = RandLambdaDict = RandLambdad
48 changes: 48 additions & 0 deletions tests/test_rand_lambdad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from monai.transforms import Randomizable
from monai.transforms.utility.dictionary import RandLambdad


class RandTest(Randomizable):
"""
randomisable transform for testing.
"""

def randomize(self, data=None):
self._a = self.R.random()

def __call__(self, data):
self.randomize()
return data + self._a


class TestRandLambdad(unittest.TestCase):
def test_rand_lambdad_identity(self):
img = np.zeros((10, 10))
data = {"img": img, "prop": 1.0}

test_func = RandTest()
test_func.set_random_state(seed=134)
expected = {"img": test_func(data["img"]), "prop": 1.0}
test_func.set_random_state(seed=134)
ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data)
np.testing.assert_allclose(expected["img"], ret["img"])
np.testing.assert_allclose(expected["prop"], ret["prop"])


if __name__ == "__main__":
unittest.main()