Skip to content

Commit

Permalink
[WIP] OneOf Transform (#2551)
Browse files Browse the repository at this point in the history
* Added OneOf class

Signed-off-by: Lyndon Boone <lyndonboone8@gmail.com>

* Clean up OneOf constructor

Signed-off-by: Lyndon Boone <lyndonboone8@gmail.com>

* add flatten, len and unit test

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* Added unit tests and inverse method

Signed-off-by: Lyndon Boone <lyndonboone8@gmail.com>

* rename test

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* flatten tests

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

* add inverse

Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com>

Co-authored-by: Richard Brown <33289025+rijobro@users.noreply.github.com>
  • Loading branch information
lyndonboone and rijobro authored Aug 12, 2021
1 parent 8726dd5 commit a6cf9b6
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 4 deletions.
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from .adaptors import FunctionSignature, adaptor, apply_alias, to_kwargs
from .compose import Compose
from .compose import Compose, OneOf
from .croppad.array import (
BorderPad,
BoundingRect,
Expand Down
105 changes: 102 additions & 3 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""

import warnings
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Callable, Mapping, Optional, Sequence, Union

import numpy as np

Expand All @@ -28,8 +28,9 @@
apply_transform,
)
from monai.utils import MAX_SEED, ensure_tuple, get_seed
from monai.utils.enums import InverseKeys

__all__ = ["Compose"]
__all__ = ["Compose", "OneOf"]


class Compose(Randomizable, InvertibleTransform):
Expand Down Expand Up @@ -143,7 +144,7 @@ def flatten(self):
"""
new_transforms = []
for t in self.transforms:
if isinstance(t, Compose):
if isinstance(t, Compose) and not isinstance(t, OneOf):
new_transforms += t.flatten().transforms
else:
new_transforms.append(t)
Expand All @@ -168,3 +169,101 @@ def inverse(self, data):
for t in reversed(invertible_transforms):
data = apply_transform(t.inverse, data, self.map_items, self.unpack_items)
return data


class OneOf(Compose):
"""
``OneOf`` provides the ability to radomly choose one transform out of a
list of callables with predfined probabilities for each.
Args:
transforms: sequence of callables.
weights: probabilities corresponding to each callable in transforms.
Probabilities are normalized to sum to one.
OneOf inherits from Compose and uses args map_items and unpack_items in
the same way.
"""

def __init__(
self,
transforms: Optional[Union[Sequence[Callable], Callable]] = None,
weights: Optional[Union[Sequence[float], float]] = None,
map_items: bool = True,
unpack_items: bool = False,
) -> None:
super().__init__(transforms, map_items, unpack_items)
if len(self.transforms) == 0:
weights = []
elif weights is None or isinstance(weights, float):
weights = [1.0 / len(self.transforms)] * len(self.transforms)
if len(weights) != len(self.transforms):
raise AssertionError("transforms and weights should be same size if both specified as sequences.")
self.weights = ensure_tuple(self._normalize_probabilities(weights))

def _normalize_probabilities(self, weights):
if len(weights) == 0:
return weights
else:
weights = np.array(weights)
if np.any(weights < 0):
raise AssertionError("Probabilities must be greater than or equal to zero.")
if np.all(weights == 0):
raise AssertionError("At least one probability must be greater than zero.")
weights = weights / weights.sum()
return list(weights)

def flatten(self):
transforms = []
weights = []
for t, w in zip(self.transforms, self.weights):
# if nested, probability is the current weight multiplied by the nested weights,
# and so on recursively
if isinstance(t, OneOf):
tr = t.flatten()
for t_, w_ in zip(tr.transforms, tr.weights):
transforms.append(t_)
weights.append(w_ * w)
else:
transforms.append(t)
weights.append(w)
return OneOf(transforms, weights, self.map_items, self.unpack_items)

def __call__(self, data):
if len(self.transforms) == 0:
return data
else:
index = self.R.multinomial(1, self.weights).argmax()
_transform = self.transforms[index]
data = apply_transform(_transform, data, self.map_items, self.unpack_items)
# if the data is a mapping (dictionary), append the OneOf transform to the end
if isinstance(data, Mapping):
for key in data.keys():
if key + InverseKeys.KEY_SUFFIX in data:
self.push_transform(data, key, extra_info={"index": index})
return data

def inverse(self, data):
if len(self.transforms) == 0:
return data
if not isinstance(data, Mapping):
raise RuntimeError("Inverse only implemented for Mapping (dictionary) data")

# loop until we get an index and then break (since they'll all be the same)
index = None
for key in data.keys():
if key + InverseKeys.KEY_SUFFIX in data:
# get the index of the applied OneOf transform
index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"]
# and then remove the OneOf transform
self.pop_transform(data, key)
if index is None:
raise RuntimeError("No invertible transforms have been applied")

# if applied transform is not InvertibleTransform, throw error
_transform = self.transforms[index]
if not isinstance(_transform, InvertibleTransform):
raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).")

# apply the inverse
return _transform.inverse(data)
181 changes: 181 additions & 0 deletions tests/test_one_of.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
from copy import deepcopy

from parameterized import parameterized

from monai.transforms import InvertibleTransform, OneOf, Transform
from monai.transforms.compose import Compose
from monai.transforms.transform import MapTransform
from monai.utils.enums import InverseKeys


class X(Transform):
def __call__(self, x):
return x


class Y(Transform):
def __call__(self, x):
return x


class A(Transform):
def __call__(self, x):
return x + 1


class B(Transform):
def __call__(self, x):
return x + 2


class C(Transform):
def __call__(self, x):
return x + 3


class MapBase(MapTransform):
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn, self.inv_fn = None, None

def __call__(self, data):
d = deepcopy(dict(data))
for key in self.key_iterator(d):
d[key] = self.fwd_fn(d[key])
return d


class NonInv(MapBase):
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x * 2


class Inv(MapBase, InvertibleTransform):
def __call__(self, data):
d = deepcopy(dict(data))
for key in self.key_iterator(d):
d[key] = self.fwd_fn(d[key])
self.push_transform(d, key)
return d

def inverse(self, data):
d = deepcopy(dict(data))
for key in self.key_iterator(d):
d[key] = self.inv_fn(d[key])
self.pop_transform(d, key)
return d


class InvA(Inv):
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x + 1
self.inv_fn = lambda x: x - 1


class InvB(Inv):
def __init__(self, keys):
super().__init__(keys)
self.fwd_fn = lambda x: x + 100
self.inv_fn = lambda x: x - 100


TESTS = [
((X(), Y(), X()), (1, 2, 1), (0.25, 0.5, 0.25)),
]

KEYS = ["x", "y"]
TEST_INVERSES = [
(OneOf((InvA(KEYS), InvB(KEYS))), True),
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True),
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True),
(OneOf((NonInv(KEYS), NonInv(KEYS))), False),
]


class TestOneOf(unittest.TestCase):
@parameterized.expand(TESTS)
def test_normalize_weights(self, transforms, input_weights, expected_weights):
tr = OneOf(transforms, input_weights)
self.assertTupleEqual(tr.weights, expected_weights)

def test_no_weights_arg(self):
p = OneOf((X(), Y(), X(), Y()))
expected_weights = (0.25,) * 4
self.assertTupleEqual(p.weights, expected_weights)

def test_len_and_flatten(self):
p1 = OneOf((X(), Y()), (1, 3)) # 0.25, 0.75
p2 = OneOf((Y(), Y()), (2, 2)) # 0.5. 0.5
p = OneOf((p1, p2, X()), (1, 2, 1)) # 0.25, 0.5, 0.25
expected_order = (X, Y, Y, Y, X)
expected_weights = (0.25 * 0.25, 0.25 * 0.75, 0.5 * 0.5, 0.5 * 0.5, 0.25)
self.assertEqual(len(p), len(expected_order))
self.assertTupleEqual(p.flatten().weights, expected_weights)

def test_compose_flatten_does_not_affect_one_of(self):
p = Compose([A(), B(), OneOf([C(), Inv(KEYS), Compose([X(), Y()])])])
f = p.flatten()
# in this case the flattened transform should be the same.

def _match(a, b):
self.assertEqual(type(a), type(b))
for a_, b_ in zip(a.transforms, b.transforms):
self.assertEqual(type(a_), type(b_))
if isinstance(a_, (Compose, OneOf)):
_match(a_, b_)

_match(p, f)

@parameterized.expand(TEST_INVERSES)
def test_inverse(self, transform, should_be_ok):
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
fwd_data = transform(data)
if not should_be_ok:
with self.assertRaises(RuntimeError):
transform.inverse(fwd_data)
return

for k in KEYS:
t = fwd_data[k + InverseKeys.KEY_SUFFIX][-1]
# make sure the OneOf index was stored
self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[InverseKeys.EXTRA_INFO]["index"] < len(transform))

# call the inverse
fwd_inv_data = transform.inverse(fwd_data)

for k in KEYS:
# check transform was removed
self.assertTrue(len(fwd_inv_data[k + InverseKeys.KEY_SUFFIX]) < len(fwd_data[k + InverseKeys.KEY_SUFFIX]))
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])

def test_one_of(self):
p = OneOf((A(), B(), C()), (1, 2, 1))
counts = [0] * 3
for _i in range(10000):
out = p(1.0)
counts[int(out - 2)] += 1
self.assertAlmostEqual(counts[0] / 10000, 0.25, delta=1.0)
self.assertAlmostEqual(counts[1] / 10000, 0.50, delta=1.0)
self.assertAlmostEqual(counts[2] / 10000, 0.25, delta=1.0)


if __name__ == "__main__":
unittest.main()

0 comments on commit a6cf9b6

Please sign in to comment.