Skip to content

Commit 5135fb4

Browse files
changed compose to explicit flattening
Signed-off-by: Lukas Folle <lukas.folle@snke.com>
1 parent a8c9e24 commit 5135fb4

File tree

6 files changed

+67
-18
lines changed

6 files changed

+67
-18
lines changed

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,7 @@
546546
TorchVision,
547547
ToTensor,
548548
Transpose,
549+
FlattenSequence,
549550
)
550551
from .utility.dictionary import (
551552
AddCoordinateChannelsd,
@@ -671,6 +672,9 @@
671672
Transposed,
672673
TransposeD,
673674
TransposeDict,
675+
FlattenSequenced,
676+
FlattenSequenceD,
677+
FlattenSequenceDict,
674678
)
675679
from .utils import (
676680
Fourier,

monai/transforms/traits.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"]
17+
__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe", "ReduceTrait"]
1818

1919
from typing import Any
2020

@@ -99,3 +99,14 @@ class ThreadUnsafe:
9999
"""
100100

101101
pass
102+
103+
104+
class ReduceTrait:
105+
"""
106+
An interface to indicate that the transform has the capability to reduce multiple samples
107+
into a single sample.
108+
This interface can be extended from by people adapting transforms to the MONAI framework as well
109+
as by implementors of MONAI transforms.
110+
"""
111+
112+
pass

monai/transforms/transform.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from monai import config, transforms
2626
from monai.config import KeysCollection
2727
from monai.data.meta_tensor import MetaTensor
28-
from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe
28+
from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe, ReduceTrait
2929
from monai.utils import MAX_SEED, ensure_tuple, first
3030
from monai.utils.enums import TransformBackends
3131
from monai.utils.misc import MONAIEnvVars
@@ -142,17 +142,11 @@ def apply_transform(
142142
"""
143143
try:
144144
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
145-
if isinstance(data, (list, tuple)) and map_items_ > 0:
146-
res: list[Any] = []
147-
for item in data:
148-
res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
149-
# Only extend if we're at the leaf level (map_items_ == 1) and the transform
150-
# actually returned a list (not preserving nested structure)
151-
if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)):
152-
res.extend(res_item)
153-
else:
154-
res.append(res_item)
155-
return res
145+
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
146+
return [
147+
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
148+
for item in data
149+
]
156150
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
157151
except Exception as e:
158152
# if in debug mode, don't swallow exception so that the breakpoint

monai/transforms/utility/array.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
median_filter,
4444
)
4545
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
46-
from monai.transforms.traits import MultiSampleTrait
46+
from monai.transforms.traits import MultiSampleTrait, ReduceTrait
4747
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
4848
from monai.transforms.utils import (
4949
apply_affine_to_points,
@@ -110,6 +110,7 @@
110110
"ImageFilter",
111111
"RandImageFilter",
112112
"ApplyTransformToPoints",
113+
"FlattenSequence"
113114
]
114115

115116

@@ -1950,3 +1951,16 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
19501951
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])
19511952

19521953
return data
1954+
1955+
1956+
class FlattenSequence(Transform, ReduceTrait):
1957+
def __init__(self):
1958+
super().__init__()
1959+
1960+
def __call__(self, data):
1961+
if isinstance(data, (list, tuple)):
1962+
if len(data) == 0:
1963+
return data
1964+
if isinstance(data[0], (list, tuple)):
1965+
return [item for sublist in data for item in sublist]
1966+
return data

monai/transforms/utility/dictionary.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from monai.data.meta_tensor import MetaObj, MetaTensor
3131
from monai.data.utils import no_collation
3232
from monai.transforms.inverse import InvertibleTransform
33-
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait
33+
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait, ReduceTrait
3434
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
3535
from monai.transforms.utility.array import (
3636
AddCoordinateChannels,
@@ -64,6 +64,7 @@
6464
TorchVision,
6565
ToTensor,
6666
Transpose,
67+
FlattenSequence
6768
)
6869
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
6970
from monai.transforms.utils_pytorch_numpy_unification import concatenate
@@ -191,6 +192,9 @@
191192
"ApplyTransformToPointsd",
192193
"ApplyTransformToPointsD",
193194
"ApplyTransformToPointsDict",
195+
"FlattenSequenced",
196+
"FlattenSequenceD",
197+
"FlattenSequenceDict"
194198
]
195199

196200
DEFAULT_POST_FIX = PostFix.meta()
@@ -1906,6 +1910,23 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19061910
return d
19071911

19081912

1913+
class FlattenSequenced(MapTransform, ReduceTrait):
1914+
def __init__(
1915+
self,
1916+
keys: KeysCollection,
1917+
allow_missing_keys: bool = False,
1918+
**kwargs,
1919+
) -> None:
1920+
super().__init__(keys, allow_missing_keys)
1921+
self.flatten_sequence = FlattenSequence(**kwargs)
1922+
1923+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
1924+
d = dict(data)
1925+
for key in self.key_iterator(d):
1926+
d[key] = self.flatten_sequence(d[key])
1927+
return d
1928+
1929+
19091930
RandImageFilterD = RandImageFilterDict = RandImageFilterd
19101931
ImageFilterD = ImageFilterDict = ImageFilterd
19111932
IdentityD = IdentityDict = Identityd
@@ -1949,3 +1970,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
19491970
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
19501971
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
19511972
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
1973+
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced

tests/transforms/compose/test_compose.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,26 +285,30 @@ def test_backwards_compatible_imports(self):
285285
def test_list_extend_multi_sample_trait(self):
286286
center_crop = mt.CenterSpatialCrop([128, 128])
287287
multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1)
288+
flatten_sequence_transform = mt.FlattenSequence()
288289

289290
img = torch.zeros([1, 512, 512])
290291

291292
self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128]))
292-
single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop])
293+
single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop, flatten_sequence_transform])
293294
self.assertIsInstance(single_multi_sample_trait_result, list)
294295
self.assertEqual(len(single_multi_sample_trait_result), 1)
295296
self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))
296297

297-
double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop])
298+
double_multi_sample_trait_result = execute_compose(img, [
299+
multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop
300+
])
298301
self.assertIsInstance(double_multi_sample_trait_result, list)
299302
self.assertEqual(len(double_multi_sample_trait_result), 1)
300303
self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))
301304

302305
def test_multi_sample_trait_cardinality(self):
303306
img = torch.zeros([1, 128, 128])
304307
t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2)
308+
flatten_sequence_transform = mt.FlattenSequence()
305309

306310
# chaining should multiply counts: 2 x 2 = 4, flattened
307-
res = execute_compose(img, [t2, t2])
311+
res = execute_compose(img, [t2, t2, flatten_sequence_transform])
308312
self.assertIsInstance(res, list)
309313
self.assertEqual(len(res), 4)
310314
for r in res:

0 commit comments

Comments
 (0)