Skip to content

Commit

Permalink
1542 Add RandLambdad transform (#1546)
Browse files Browse the repository at this point in the history
* [DLMED] add RandLambdad transform

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] add doc-strings

Signed-off-by: Nic Ma <nma@nvidia.com>

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] fix typo

Signed-off-by: Nic Ma <nma@nvidia.com>

* [DLMED] change to rtol=1e-05

Signed-off-by: Nic Ma <nma@nvidia.com>

* [MONAI] python code formatting

Signed-off-by: monai-bot <monai.miccai2019@gmail.com>

* fixes seeds

Signed-off-by: Wenqi Li <wenqil@nvidia.com>

Co-authored-by: monai-bot <monai.miccai2019@gmail.com>
Co-authored-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
3 people authored Feb 4, 2021
1 parent 2021f24 commit bcdee8c
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 2 deletions.
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,12 @@ Utility (Dict)
:members:
:special-members: __call__

`RandLambdad`
"""""""""""""
.. autoclass:: RandLambdad
:members:
:special-members: __call__

`LabelToMaskd`
""""""""""""""
.. autoclass:: LabelToMaskd
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@
Lambdad,
LambdaD,
LambdaDict,
RandLambdad,
RandLambdaD,
RandLambdaDict,
RepeatChanneld,
RepeatChannelD,
RepeatChannelDict,
Expand Down
29 changes: 27 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import copy
import logging
from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -64,10 +64,12 @@
"CopyItemsd",
"ConcatItemsd",
"Lambdad",
"RandLambdad",
"LabelToMaskd",
"FgBgToIndicesd",
"ConvertToMultiChannelBasedOnBratsClassesd",
"AddExtremePointsChanneld",
"TorchVisiond",
"IdentityD",
"IdentityDict",
"AsChannelFirstD",
Expand All @@ -76,6 +78,8 @@
"AsChannelLastDict",
"AddChannelD",
"AddChannelDict",
"RandLambdaD",
"RandLambdaDict",
"RepeatChannelD",
"RepeatChannelDict",
"SplitChannelD",
Expand Down Expand Up @@ -106,7 +110,6 @@
"ConvertToMultiChannelBasedOnBratsClassesDict",
"AddExtremePointsChannelD",
"AddExtremePointsChannelDict",
"TorchVisiond",
"TorchVisionD",
"TorchVisionDict",
]
Expand Down Expand Up @@ -621,6 +624,27 @@ def __call__(self, data):
return d


class RandLambdad(Lambdad, Randomizable):
"""
Randomizable version :py:class:`monai.transforms.Lambdad`, the input `func` contains random logic.
It's a randomizable transform so `CacheDataset` will not execute it and cache the results.
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
func: Lambda/function to be applied. It also can be a sequence of Callable,
each element corresponds to a key in ``keys``.
overwrite: whether to overwrite the original data in the input dictionary with lamdbda function output.
default to True. it also can be a sequence of bool, each element corresponds to a key in ``keys``.
For more details, please check :py:class:`monai.transforms.Lambdad`.
"""

def randomize(self, data: Any) -> None:
pass


class LabelToMaskd(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.LabelToMask`.
Expand Down Expand Up @@ -830,3 +854,4 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc
) = ConvertToMultiChannelBasedOnBratsClassesd
AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld
TorchVisionD = TorchVisionDict = TorchVisiond
RandLambdaD = RandLambdaDict = RandLambdad
48 changes: 48 additions & 0 deletions tests/test_rand_lambdad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from monai.transforms import Randomizable
from monai.transforms.utility.dictionary import RandLambdad


class RandTest(Randomizable):
"""
randomisable transform for testing.
"""

def randomize(self, data=None):
self._a = self.R.random()

def __call__(self, data):
self.randomize()
return data + self._a


class TestRandLambdad(unittest.TestCase):
def test_rand_lambdad_identity(self):
img = np.zeros((10, 10))
data = {"img": img, "prop": 1.0}

test_func = RandTest()
test_func.set_random_state(seed=134)
expected = {"img": test_func(data["img"]), "prop": 1.0}
test_func.set_random_state(seed=134)
ret = RandLambdad(keys=["img", "prop"], func=test_func, overwrite=[True, False])(data)
np.testing.assert_allclose(expected["img"], ret["img"])
np.testing.assert_allclose(expected["prop"], ret["prop"])


if __name__ == "__main__":
unittest.main()

0 comments on commit bcdee8c

Please sign in to comment.