Skip to content

Commit 1232a6b

Browse files
committed
fixes tests
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 7699634 commit 1232a6b

File tree

3 files changed

+11
-17
lines changed

3 files changed

+11
-17
lines changed

monai/data/meta_obj.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ def copy_meta_from(self, input_objs, copy_attr=True, keys=None):
126126
return self with the updated ``__dict__``.
127127
"""
128128
first_meta = input_objs if isinstance(input_objs, MetaObj) else first(input_objs, default=self)
129+
if not hasattr(first_meta, "__dict__"):
130+
return self
129131
first_meta = first_meta.__dict__
130132
keys = first_meta.keys() if keys is None else keys
131133
if not copy_attr:

monai/transforms/croppad/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@
4343
map_classes_to_indices,
4444
weighted_patch_samples,
4545
)
46+
from monai.utils import ImageMetaKey as Key
4647
from monai.utils import (
47-
Method,
4848
LazyAttr,
49+
Method,
4950
PytorchPadMode,
5051
TraceKeys,
5152
TransformBackends,
@@ -58,7 +59,6 @@
5859
fall_back_tuple,
5960
look_up_option,
6061
pytorch_after,
61-
ImageMetaKey as Key,
6262
)
6363

6464
__all__ = [

tests/test_traceable_transform.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313

1414
import unittest
1515

16+
import torch
17+
1618
from monai.transforms.inverse import TraceableTransform
1719

1820

1921
class _TraceTest(TraceableTransform):
2022
def __call__(self, data):
21-
self.push_transform(data)
23+
self.push_transform(data, "image")
2224
return data
2325

2426
def pop(self, data):
25-
self.pop_transform(data)
27+
self.pop_transform(data, "image")
2628
return data
2729

2830

@@ -34,21 +36,11 @@ def test_default(self):
3436

3537
data = {"image": "test"}
3638
data = a(data) # adds to the stack
37-
self.assertTrue(isinstance(data[expected_key], list))
38-
self.assertEqual(data[expected_key][0]["class"], "_TraceTest")
39+
self.assertEqual(data["image"], "test")
3940

41+
data = {"image": torch.tensor(1.0)}
4042
data = a(data) # adds to the stack
41-
self.assertEqual(len(data[expected_key]), 2)
42-
self.assertEqual(data[expected_key][-1]["class"], "_TraceTest")
43-
44-
with self.assertRaises(IndexError):
45-
a.pop({"test": "test"}) # no stack in the data
46-
data = a.pop(data)
47-
data = a.pop(data)
48-
self.assertEqual(data[expected_key], [])
49-
50-
with self.assertRaises(IndexError): # no more items
51-
a.pop(data)
43+
self.assertEqual(data["image"].applied_operations[0]["class"], "_TraceTest")
5244

5345

5446
if __name__ == "__main__":

0 commit comments

Comments
 (0)