Skip to content

Commit 0fbbd78

Browse files
committed
update one_of/random_order tests
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent aa0389e commit 0fbbd78

File tree

2 files changed

+10
-31
lines changed

2 files changed

+10
-31
lines changed

tests/test_one_of.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
RandShiftIntensityd,
2828
Resize,
2929
Resized,
30-
TraceableTransform,
3130
Transform,
3231
)
3332
from monai.transforms.compose import Compose
@@ -113,9 +112,9 @@ def __init__(self, keys):
113112
KEYS = ["x", "y"]
114113
TEST_INVERSES = [
115114
(OneOf((InvA(KEYS), InvB(KEYS))), True, True),
116-
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, True),
117-
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, True),
118-
(OneOf((NonInv(KEYS), NonInv(KEYS))), False, True),
115+
(OneOf((OneOf((InvA(KEYS), InvB(KEYS))), OneOf((InvB(KEYS), InvA(KEYS))))), True, False),
116+
(OneOf((Compose((InvA(KEYS), InvB(KEYS))), Compose((InvB(KEYS), InvA(KEYS))))), True, False),
117+
(OneOf((NonInv(KEYS), NonInv(KEYS))), False, False),
119118
]
120119

121120

@@ -161,11 +160,7 @@ def test_inverse(self, transform, invertible, use_metatensor):
161160

162161
if invertible:
163162
for k in KEYS:
164-
t = (
165-
fwd_data[TraceableTransform.trace_key(k)][-1]
166-
if not use_metatensor
167-
else fwd_data[k].applied_operations[-1]
168-
)
163+
t = fwd_data[k].applied_operations[-1]
169164
# make sure the OneOf index was stored
170165
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
171166
# make sure index exists and is in bounds
@@ -176,12 +171,6 @@ def test_inverse(self, transform, invertible, use_metatensor):
176171

177172
if invertible:
178173
for k in KEYS:
179-
# check transform was removed
180-
if not use_metatensor:
181-
self.assertTrue(
182-
len(fwd_inv_data[TraceableTransform.trace_key(k)])
183-
< len(fwd_data[TraceableTransform.trace_key(k)])
184-
)
185174
# check data is same as original (and different from forward)
186175
self.assertEqual(fwd_inv_data[k], data[k])
187176
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])

tests/test_random_order.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from parameterized import parameterized
1717

1818
from monai.data import MetaTensor
19-
from monai.transforms import RandomOrder, TraceableTransform
19+
from monai.transforms import RandomOrder
2020
from monai.transforms.compose import Compose
2121
from monai.utils import set_determinism
2222
from monai.utils.enums import TraceKeys
@@ -41,10 +41,10 @@ def __init__(self, keys):
4141
KEYS = ["x", "y"]
4242
TEST_INVERSES = [
4343
(RandomOrder((InvC(KEYS), InvD(KEYS))), True, True),
44-
(Compose((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, True),
45-
(RandomOrder((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, True),
46-
(RandomOrder((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, True),
47-
(RandomOrder((NonInv(KEYS), NonInv(KEYS))), False, True),
44+
(Compose((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, False),
45+
(RandomOrder((RandomOrder((InvC(KEYS), InvD(KEYS))), RandomOrder((InvD(KEYS), InvC(KEYS))))), True, False),
46+
(RandomOrder((Compose((InvC(KEYS), InvD(KEYS))), Compose((InvD(KEYS), InvC(KEYS))))), True, False),
47+
(RandomOrder((NonInv(KEYS), NonInv(KEYS))), False, False),
4848
]
4949

5050

@@ -77,11 +77,7 @@ def test_inverse(self, transform, invertible, use_metatensor):
7777

7878
if invertible:
7979
for k in KEYS:
80-
t = (
81-
fwd_data1[TraceableTransform.trace_key(k)][-1]
82-
if not use_metatensor
83-
else fwd_data1[k].applied_operations[-1]
84-
)
80+
t = fwd_data1[k].applied_operations[-1]
8581
# make sure the RandomOrder applied_order was stored
8682
self.assertEqual(t[TraceKeys.CLASS_NAME], RandomOrder.__name__)
8783

@@ -94,12 +90,6 @@ def test_inverse(self, transform, invertible, use_metatensor):
9490
for i, _fwd_inv_data in enumerate(fwd_inv_data):
9591
if invertible:
9692
for k in KEYS:
97-
# check transform was removed
98-
if not use_metatensor:
99-
self.assertTrue(
100-
len(_fwd_inv_data[TraceableTransform.trace_key(k)])
101-
< len(fwd_data[i][TraceableTransform.trace_key(k)])
102-
)
10393
# check data is same as original (and different from forward)
10494
self.assertEqual(_fwd_inv_data[k], data[k])
10595
self.assertNotEqual(_fwd_inv_data[k], fwd_data[i][k])

0 commit comments

Comments
 (0)