diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 33d7fba26e..99d9c2b8b8 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -10,7 +10,7 @@ # limitations under the License. from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs -from .compose import Compose +from .compose import Compose, OneOf from .croppad.array import ( BorderPad, BoundingRect, diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b380f7d42a..8737abd0fa 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,7 +13,7 @@ """ import warnings -from typing import Any, Callable, Optional, Sequence, Union +from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np @@ -28,8 +28,9 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed +from monai.utils.enums import InverseKeys -__all__ = ["Compose"] +__all__ = ["Compose", "OneOf"] class Compose(Randomizable, InvertibleTransform): @@ -143,7 +144,7 @@ def flatten(self): """ new_transforms = [] for t in self.transforms: - if isinstance(t, Compose): + if isinstance(t, Compose) and not isinstance(t, OneOf): new_transforms += t.flatten().transforms else: new_transforms.append(t) @@ -168,3 +169,101 @@ def inverse(self, data): for t in reversed(invertible_transforms): data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data + + +class OneOf(Compose): + """ + ``OneOf`` provides the ability to radomly choose one transform out of a + list of callables with predfined probabilities for each. + + Args: + transforms: sequence of callables. + weights: probabilities corresponding to each callable in transforms. + Probabilities are normalized to sum to one. + + OneOf inherits from Compose and uses args map_items and unpack_items in + the same way. + """ + + def __init__( + self, + transforms: Optional[Union[Sequence[Callable], Callable]] = None, + weights: Optional[Union[Sequence[float], float]] = None, + map_items: bool = True, + unpack_items: bool = False, + ) -> None: + super().__init__(transforms, map_items, unpack_items) + if len(self.transforms) == 0: + weights = [] + elif weights is None or isinstance(weights, float): + weights = [1.0 / len(self.transforms)] * len(self.transforms) + if len(weights) != len(self.transforms): + raise AssertionError("transforms and weights should be same size if both specified as sequences.") + self.weights = ensure_tuple(self._normalize_probabilities(weights)) + + def _normalize_probabilities(self, weights): + if len(weights) == 0: + return weights + else: + weights = np.array(weights) + if np.any(weights < 0): + raise AssertionError("Probabilities must be greater than or equal to zero.") + if np.all(weights == 0): + raise AssertionError("At least one probability must be greater than zero.") + weights = weights / weights.sum() + return list(weights) + + def flatten(self): + transforms = [] + weights = [] + for t, w in zip(self.transforms, self.weights): + # if nested, probability is the current weight multiplied by the nested weights, + # and so on recursively + if isinstance(t, OneOf): + tr = t.flatten() + for t_, w_ in zip(tr.transforms, tr.weights): + transforms.append(t_) + weights.append(w_ * w) + else: + transforms.append(t) + weights.append(w) + return OneOf(transforms, weights, self.map_items, self.unpack_items) + + def __call__(self, data): + if len(self.transforms) == 0: + return data + else: + index = self.R.multinomial(1, self.weights).argmax() + _transform = self.transforms[index] + data = apply_transform(_transform, data, self.map_items, self.unpack_items) + # if the data is a mapping (dictionary), append the OneOf transform to the end + if isinstance(data, Mapping): + for key in data.keys(): + if key + InverseKeys.KEY_SUFFIX in data: + self.push_transform(data, key, extra_info={"index": index}) + return data + + def inverse(self, data): + if len(self.transforms) == 0: + return data + if not isinstance(data, Mapping): + raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") + + # loop until we get an index and then break (since they'll all be the same) + index = None + for key in data.keys(): + if key + InverseKeys.KEY_SUFFIX in data: + # get the index of the applied OneOf transform + index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"] + # and then remove the OneOf transform + self.pop_transform(data, key) + if index is None: + raise RuntimeError("No invertible transforms have been applied") + + # if applied transform is not InvertibleTransform, throw error + _transform = self.transforms[index] + if not isinstance(_transform, InvertibleTransform): + raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).") + + # apply the inverse + return _transform.inverse(data) diff --git a/tests/test_one_of.py b/tests/test_one_of.py new file mode 100644 index 0000000000..d45d0f3f61 --- /dev/null +++ b/tests/test_one_of.py @@ -0,0 +1,181 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from copy import deepcopy + +from parameterized import parameterized + +from monai.transforms import InvertibleTransform, OneOf, Transform +from monai.transforms.compose import Compose +from monai.transforms.transform import MapTransform +from monai.utils.enums import InverseKeys + + +class X(Transform): + def __call__(self, x): + return x + + +class Y(Transform): + def __call__(self, x): + return x + + +class A(Transform): + def __call__(self, x): + return x + 1 + + +class B(Transform): + def __call__(self, x): + return x + 2 + + +class C(Transform): + def __call__(self, x): + return x + 3 + + +class MapBase(MapTransform): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn, self.inv_fn = None, None + + def __call__(self, data): + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + d[key] = self.fwd_fn(d[key]) + return d + + +class NonInv(MapBase): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn = lambda x: x * 2 + + +class Inv(MapBase, InvertibleTransform): + def __call__(self, data): + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + d[key] = self.fwd_fn(d[key]) + self.push_transform(d, key) + return d + + def inverse(self, data): + d = deepcopy(dict(data)) + for key in self.key_iterator(d): + d[key] = self.inv_fn(d[key]) + self.pop_transform(d, key) + return d + + +class InvA(Inv): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn = lambda x: x + 1 + self.inv_fn = lambda x: x - 1 + + +class InvB(Inv): + def __init__(self, keys): + super().__init__(keys) + self.fwd_fn = lambda x: x + 100 + self.inv_fn = lambda x: x - 100 + + +TESTS = [ + ((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)), +] + +KEYS = ["x", "y"] +TEST_INVERSES = [ + (OneOf((InvA(KEYS), InvB(KEYS))), True), + (OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True), + (OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True), + (OneOf((NonInv(KEYS), NonInv(KEYS))), False), +] + + +class TestOneOf(unittest.TestCase): + @parameterized.expand(TESTS) + def test_normalize_weights(self, transforms, input_weights, expected_weights): + tr = OneOf(transforms, input_weights) + self.assertTupleEqual(tr.weights, expected_weights) + + def test_no_weights_arg(self): + p = OneOf((X(), Y(), X(), Y())) + expected_weights = (0.25,) * 4 + self.assertTupleEqual(p.weights, expected_weights) + + def test_len_and_flatten(self): + p1 = OneOf((X(), Y()), (1, 3)) # 0.25, 0.75 + p2 = OneOf((Y(), Y()), (2, 2)) # 0.5. 0.5 + p = OneOf((p1, p2, X()), (1, 2, 1)) # 0.25, 0.5, 0.25 + expected_order = (X, Y, Y, Y, X) + expected_weights = (0.25 * 0.25, 0.25 * 0.75, 0.5 * 0.5, 0.5 * 0.5, 0.25) + self.assertEqual(len(p), len(expected_order)) + self.assertTupleEqual(p.flatten().weights, expected_weights) + + def test_compose_flatten_does_not_affect_one_of(self): + p = Compose([A(), B(), OneOf([C(), Inv(KEYS), Compose([X(), Y()])])]) + f = p.flatten() + # in this case the flattened transform should be the same. + + def _match(a, b): + self.assertEqual(type(a), type(b)) + for a_, b_ in zip(a.transforms, b.transforms): + self.assertEqual(type(a_), type(b_)) + if isinstance(a_, (Compose, OneOf)): + _match(a_, b_) + + _match(p, f) + + @parameterized.expand(TEST_INVERSES) + def test_inverse(self, transform, should_be_ok): + data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} + fwd_data = transform(data) + if not should_be_ok: + with self.assertRaises(RuntimeError): + transform.inverse(fwd_data) + return + + for k in KEYS: + t = fwd_data[k + InverseKeys.KEY_SUFFIX][-1] + # make sure the OneOf index was stored + self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[InverseKeys.EXTRA_INFO]["index"] < len(transform)) + + # call the inverse + fwd_inv_data = transform.inverse(fwd_data) + + for k in KEYS: + # check transform was removed + self.assertTrue(len(fwd_inv_data[k + InverseKeys.KEY_SUFFIX]) < len(fwd_data[k + InverseKeys.KEY_SUFFIX])) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) + self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + + def test_one_of(self): + p = OneOf((A(), B(), C()), (1, 2, 1)) + counts = [0] * 3 + for _i in range(10000): + out = p(1.0) + counts[int(out - 2)] += 1 + self.assertAlmostEqual(counts[0] / 10000, 0.25, delta=1.0) + self.assertAlmostEqual(counts[1] / 10000, 0.50, delta=1.0) + self.assertAlmostEqual(counts[2] / 10000, 0.25, delta=1.0) + + +if __name__ == "__main__": + unittest.main()