Skip to content

Commit

Permalink
[DLMED] add dict version shuffle (#2918)
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <nma@nvidia.com>

Co-authored-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
Nic-Ma and wyli authored Sep 10, 2021
1 parent ef09811 commit e2965db
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 1 deletion.
6 changes: 6 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -977,6 +977,12 @@ Intensity (Dict)
:members:
:special-members: __call__

`RandCoarseShuffled`
""""""""""""""""""""
.. autoclass:: RandCoarseShuffled
:members:
:special-members: __call__

`HistogramNormalized`
"""""""""""""""""""""
.. autoclass:: HistogramNormalized
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@
RandCoarseDropoutd,
RandCoarseDropoutD,
RandCoarseDropoutDict,
RandCoarseShuffled,
RandCoarseShuffleD,
RandCoarseShuffleDict,
RandGaussianNoised,
RandGaussianNoiseD,
RandGaussianNoiseDict,
Expand Down
15 changes: 15 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,21 @@ class RandCoarseShuffle(RandCoarseTransform):
Kang, Guoliang, et al. "Patchshuffle regularization." arXiv preprint arXiv:1707.07103 (2017).
https://arxiv.org/abs/1707.07103
Args:
holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
randomly select the expected number of regions.
spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
as the minimum spatial size to randomly select size for every region.
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
max_holes: if not None, define the maximum number to randomly select the expected number of regions.
max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
if some components of the `max_spatial_size` are non-positive values, the transform will use the
corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
prob: probability of applying the transform.
"""

def _transform_holes(self, img: np.ndarray):
Expand Down
78 changes: 78 additions & 0 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
NormalizeIntensity,
RandBiasField,
RandCoarseDropout,
RandCoarseShuffle,
RandGaussianNoise,
RandKSpaceSpikeNoise,
RandRicianNoise,
Expand Down Expand Up @@ -75,6 +76,7 @@
"RandKSpaceSpikeNoised",
"RandHistogramShiftd",
"RandCoarseDropoutd",
"RandCoarseShuffled",
"HistogramNormalized",
"RandGaussianNoiseD",
"RandGaussianNoiseDict",
Expand Down Expand Up @@ -126,6 +128,8 @@
"RandRicianNoiseDict",
"RandCoarseDropoutD",
"RandCoarseDropoutDict",
"RandCoarseShuffleD",
"RandCoarseShuffleDict",
"HistogramNormalizeD",
"HistogramNormalizeDict",
]
Expand Down Expand Up @@ -1478,6 +1482,13 @@ def __init__(
prob=prob,
)

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "RandCoarseDropoutd":
self.dropper.set_random_state(seed, state)
super().set_random_state(seed, state)
return self

def randomize(self, img_size: Sequence[int]) -> None:
self.dropper.randomize(img_size=img_size)

Expand All @@ -1492,6 +1503,72 @@ def __call__(self, data):
return d


class RandCoarseShuffled(Randomizable, MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.RandCoarseShuffle`.
Expect all the data specified by `keys` have same spatial shape and will randomly dropout the same regions
for every key, if want to shuffle different regions for every key, please use this transform separately.
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
holes: number of regions to dropout, if `max_holes` is not None, use this arg as the minimum number to
randomly select the expected number of regions.
spatial_size: spatial size of the regions to dropout, if `max_spatial_size` is not None, use this arg
as the minimum spatial size to randomly select size for every region.
if some components of the `spatial_size` are non-positive values, the transform will use the
corresponding components of input img size. For example, `spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
max_holes: if not None, define the maximum number to randomly select the expected number of regions.
max_spatial_size: if not None, define the maximum spatial size to randomly select size for every region.
if some components of the `max_spatial_size` are non-positive values, the transform will use the
corresponding components of input img size. For example, `max_spatial_size=(32, -1)` will be adapted
to `(32, 64)` if the second spatial dimension size of img is `64`.
prob: probability of applying the transform.
allow_missing_keys: don't raise exception if key is missing.
"""

def __init__(
self,
keys: KeysCollection,
holes: int,
spatial_size: Union[Sequence[int], int],
max_holes: Optional[int] = None,
max_spatial_size: Optional[Union[Sequence[int], int]] = None,
prob: float = 0.1,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
self.shuffle = RandCoarseShuffle(
holes=holes,
spatial_size=spatial_size,
max_holes=max_holes,
max_spatial_size=max_spatial_size,
prob=prob,
)

def set_random_state(
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
) -> "RandCoarseShuffled":
self.shuffle.set_random_state(seed, state)
super().set_random_state(seed, state)
return self

def randomize(self, img_size: Sequence[int]) -> None:
self.shuffle.randomize(img_size=img_size)

def __call__(self, data):
d = dict(data)
# expect all the specified keys have same spatial shape
self.randomize(d[self.keys[0]].shape[1:])
if self.shuffle._do_transform:
for key in self.key_iterator(d):
d[key] = self.shuffle(img=d[key])

return d


class HistogramNormalized(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.HistogramNormalize`.
Expand Down Expand Up @@ -1562,3 +1639,4 @@ def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.nda
RandKSpaceSpikeNoiseD = RandKSpaceSpikeNoiseDict = RandKSpaceSpikeNoised
RandCoarseDropoutD = RandCoarseDropoutDict = RandCoarseDropoutd
HistogramNormalizeD = HistogramNormalizeDict = HistogramNormalized
RandCoarseShuffleD = RandCoarseShuffleDict = RandCoarseShuffled
2 changes: 1 addition & 1 deletion tests/test_rand_coarse_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

class TestRandCoarseShuffle(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_local_patch_shuffle(self, input_param, input_data, expected_val):
def test_shuffle(self, input_param, input_data, expected_val):
g = RandCoarseShuffle(**input_param)
g.set_random_state(seed=12)
result = g(**input_data)
Expand Down
56 changes: 56 additions & 0 deletions tests/test_rand_coarse_shuffled.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from parameterized import parameterized

from monai.transforms import RandCoarseShuffled

TEST_CASES = [
[
{"keys": "img", "holes": 5, "spatial_size": 1, "max_spatial_size": -1, "prob": 0.0},
{"img": np.arange(8).reshape((1, 2, 2, 2))},
np.arange(8).reshape((1, 2, 2, 2)),
],
[
{"keys": "img", "holes": 10, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0},
{"img": np.arange(27).reshape((1, 3, 3, 3))},
np.asarray(
[
[
[[13, 17, 5], [6, 16, 25], [12, 15, 22]],
[[24, 7, 3], [9, 2, 23], [0, 4, 26]],
[[19, 11, 14], [1, 20, 8], [18, 10, 21]],
]
]
),
],
[
{"keys": "img", "holes": 2, "spatial_size": 1, "max_spatial_size": -1, "prob": 1.0},
{"img": np.arange(16).reshape((2, 2, 2, 2))},
np.asarray([[[[7, 2], [1, 4]], [[5, 0], [3, 6]]], [[[8, 13], [10, 15]], [[14, 12], [11, 9]]]]),
],
]


class TestRandCoarseShuffled(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shuffle(self, input_param, input_data, expected_val):
g = RandCoarseShuffled(**input_param)
g.set_random_state(seed=12)
result = g(input_data)
np.testing.assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
unittest.main()

0 comments on commit e2965db

Please sign in to comment.