Skip to content

3525 Fix invertible issue in OneOf compose #3530

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Dec 22, 2021
Merged
9 changes: 3 additions & 6 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,9 @@ def inverse(self, data):
# and then remove the OneOf transform
self.pop_transform(data, key)
if index is None:
raise RuntimeError("No invertible transforms have been applied")
# no invertible transforms have been applied
return data

# if applied transform is not InvertibleTransform, throw error
_transform = self.transforms[index]
if not isinstance(_transform, InvertibleTransform):
raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).")

# apply the inverse
return _transform.inverse(data)
return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data
71 changes: 50 additions & 21 deletions tests/test_one_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@
import unittest
from copy import deepcopy

import numpy as np
from parameterized import parameterized

from monai.transforms import InvertibleTransform, OneOf, TraceableTransform, Transform
from monai.transforms import (
InvertibleTransform,
OneOf,
RandScaleIntensityd,
RandShiftIntensityd,
Resized,
TraceableTransform,
Transform,
)
from monai.transforms.compose import Compose
from monai.transforms.transform import MapTransform
from monai.utils.enums import TraceKeys
Expand Down Expand Up @@ -139,32 +148,52 @@ def _match(a, b):
_match(p, f)

@parameterized.expand(TEST_INVERSES)
def test_inverse(self, transform, should_be_ok):
def test_inverse(self, transform, invertible):
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
fwd_data = transform(data)
if not should_be_ok:
with self.assertRaises(RuntimeError):
transform.inverse(fwd_data)
return

for k in KEYS:
t = fwd_data[TraceableTransform.trace_key(k)][-1]
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))

if invertible:
for k in KEYS:
t = fwd_data[TraceableTransform.trace_key(k)][-1]
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))

# call the inverse
fwd_inv_data = transform.inverse(fwd_data)

for k in KEYS:
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
)
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
if invertible:
for k in KEYS:
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
)
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
else:
# if not invertible, should not change the data
self.assertDictEqual(fwd_data, fwd_inv_data)

def test_inverse_compose(self):
transform = Compose(
[
Resized(keys="img", spatial_size=[100, 100, 100]),
OneOf(
[
RandScaleIntensityd(keys="img", factors=0.5, prob=1.0),
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
]
),
]
)
transform.set_random_state(seed=0)
result = transform({"img": np.ones((1, 101, 102, 103))})

result = transform.inverse(result)
# invert to the original spatial shape
self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103))

def test_one_of(self):
p = OneOf((A(), B(), C()), (1, 2, 1))
Expand Down