Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e0cda55
added list extend to MultiSampleTrait
lukas-folle-snkeos Aug 8, 2025
1ad24af
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
35658f2
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
eeb7e12
fixed type errors
lukas-folle-snkeos Aug 8, 2025
c011103
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
6bb6110
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
e7a9185
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
a5d2261
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
b0dd089
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
7560a37
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
77c138d
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
7df8cb9
avoided breaking map_item functionality
lukas-folle-snkeos Aug 8, 2025
be46018
fixed wrong type annotation
lukas-folle-snkeos Aug 8, 2025
3aa1288
Merge branch 'dev' into dev
ericspod Aug 11, 2025
2d58774
added test for many multisample transforms; refactored code
lukas-folle-snkeos Sep 16, 2025
ee74761
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Sep 16, 2025
2c18f36
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Sep 16, 2025
7ae8a26
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Sep 16, 2025
1d04028
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Sep 16, 2025
9377b63
added slight cleanup and additional test
lukas-folle-snkeos Sep 16, 2025
a56c0c3
Merge branch 'dev' into dev
lukas-folle-snkeos Sep 16, 2025
a8c9e24
Merge branch 'dev' into dev
lukas-folle-snkeos Oct 10, 2025
5135fb4
changed compose to explicit flattening
lukas-folle-snkeos Oct 10, 2025
fee6cd3
added documentation
lukas-folle-snkeos Oct 10, 2025
416584d
fixed doc build; fixed isort
lukas-folle-snkeos Oct 10, 2025
a8f3fe9
added type hints and fixed potential bug
lukas-folle-snkeos Oct 10, 2025
c707a2c
formatted
lukas-folle-snkeos Oct 10, 2025
2644f47
ignored mypy error
lukas-folle-snkeos Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ Generic Interfaces
.. autoclass:: MultiSampleTrait
:members:

`ReduceTrait`
^^^^^^^^^^^^^^^^^^
.. autoclass:: ReduceTrait
:members:

`Randomizable`
^^^^^^^^^^^^^^
.. autoclass:: Randomizable
Expand Down Expand Up @@ -1252,6 +1257,12 @@ Utility
:members:
:special-members: __call__

`FlattenSequence`
""""""""""""""""""""""""
.. autoclass:: FlattenSequence
:members:
:special-members: __call__

Dictionary Transforms
---------------------

Expand Down Expand Up @@ -2337,6 +2348,12 @@ Utility (Dict)
:members:
:special-members: __call__

`FlattenSequenced`
"""""""""""""""""""""""""
.. autoclass:: FlattenSequenced
:members:
:special-members: __call__


MetaTensor
^^^^^^^^^^
Expand Down
6 changes: 5 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@
ZoomDict,
)
from .spatial.functional import spatial_resample
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe
from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform
from .utility.array import (
AddCoordinateChannels,
Expand All @@ -521,6 +521,7 @@
EnsureChannelFirst,
EnsureType,
FgBgToIndices,
FlattenSequence,
Identity,
ImageFilter,
IntensityStats,
Expand Down Expand Up @@ -593,6 +594,9 @@
FgBgToIndicesd,
FgBgToIndicesD,
FgBgToIndicesDict,
FlattenSequenced,
FlattenSequenceD,
FlattenSequenceDict,
FlattenSubKeysd,
FlattenSubKeysD,
FlattenSubKeysDict,
Expand Down
13 changes: 12 additions & 1 deletion monai/transforms/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe"]
__all__ = ["LazyTrait", "InvertibleTrait", "RandomizableTrait", "MultiSampleTrait", "ThreadUnsafe", "ReduceTrait"]

from typing import Any

Expand Down Expand Up @@ -99,3 +99,14 @@ class ThreadUnsafe:
"""

pass


class ReduceTrait:
"""
An interface to indicate that the transform has the capability to reduce multiple samples
into a single sample.
This interface can be extended from by people adapting transforms to the MONAI framework as well
as by implementors of MONAI transforms.
"""

pass
7 changes: 3 additions & 4 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai import config, transforms
from monai.config import KeysCollection
from monai.data.meta_tensor import MetaTensor
from monai.transforms.traits import LazyTrait, RandomizableTrait, ThreadUnsafe
from monai.transforms.traits import LazyTrait, RandomizableTrait, ReduceTrait, ThreadUnsafe
from monai.utils import MAX_SEED, ensure_tuple, first
from monai.utils.enums import TransformBackends
from monai.utils.misc import MONAIEnvVars
Expand Down Expand Up @@ -142,7 +142,7 @@ def apply_transform(
"""
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0:
if isinstance(data, (list, tuple)) and map_items_ > 0 and not isinstance(transform, ReduceTrait):
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
Expand Down Expand Up @@ -482,8 +482,7 @@ def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable
yield (key,) + tuple(_ex_iters) if extra_iterables else key
elif not self.allow_missing_keys:
raise KeyError(
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
" and allow_missing_keys==False."
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False."
)

def first_key(self, data: dict[Hashable, Any]):
Expand Down
39 changes: 38 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
median_filter,
)
from monai.transforms.inverse import InvertibleTransform, TraceableTransform
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.traits import MultiSampleTrait, ReduceTrait
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
from monai.transforms.utils import (
apply_affine_to_points,
Expand Down Expand Up @@ -110,6 +110,7 @@
"ImageFilter",
"RandImageFilter",
"ApplyTransformToPoints",
"FlattenSequence",
]


Expand Down Expand Up @@ -1950,3 +1951,39 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])

return data


class FlattenSequence(Transform, ReduceTrait):
"""
Flatten a nested sequence (list or tuple) by one level.
If the input is a sequence of sequences, it will flatten them into a single sequence.
Non-nested sequences and other data types are returned unchanged.

For example:

.. code-block:: python

flatten = FlattenSequence()
data = [[1, 2], [3, 4], [5, 6]]
print(flatten(data))
[1, 2, 3, 4, 5, 6]

"""

def __init__(self):
super().__init__()

def __call__(self, data: list | tuple | Any) -> list | tuple | Any:
"""
Flatten a nested sequence by one level.
Args:
data: Input data, can be a nested sequence.
Returns:
Flattened list if input is a nested sequence, otherwise returns data unchanged.
"""
if isinstance(data, (list, tuple)):
if len(data) == 0:
return data
if all(isinstance(item, (list, tuple)) for item in data):
return [item for sublist in data for item in sublist]
return data
29 changes: 28 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from monai.data.meta_tensor import MetaObj, MetaTensor
from monai.data.utils import no_collation
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait
from monai.transforms.traits import MultiSampleTrait, RandomizableTrait, ReduceTrait
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
from monai.transforms.utility.array import (
AddCoordinateChannels,
Expand All @@ -45,6 +45,7 @@
EnsureChannelFirst,
EnsureType,
FgBgToIndices,
FlattenSequence,
Identity,
ImageFilter,
IntensityStats,
Expand Down Expand Up @@ -191,6 +192,9 @@
"ApplyTransformToPointsd",
"ApplyTransformToPointsD",
"ApplyTransformToPointsDict",
"FlattenSequenced",
"FlattenSequenceD",
"FlattenSequenceDict",
]

DEFAULT_POST_FIX = PostFix.meta()
Expand Down Expand Up @@ -1906,6 +1910,28 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
return d


class FlattenSequenced(MapTransform, ReduceTrait):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.FlattenSequence`.

Args:
keys: keys of the corresponding items to be transformed.
See also: monai.transforms.MapTransform
allow_missing_keys:
Don't raise exception if key is missing.
"""

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False, **kwargs) -> None:
super().__init__(keys, allow_missing_keys)
self.flatten_sequence = FlattenSequence(**kwargs)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.flatten_sequence(d[key]) # type: ignore[assignment]
return d


RandImageFilterD = RandImageFilterDict = RandImageFilterd
ImageFilterD = ImageFilterDict = ImageFilterd
IdentityD = IdentityDict = Identityd
Expand Down Expand Up @@ -1949,3 +1975,4 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch
AddCoordinateChannelsD = AddCoordinateChannelsDict = AddCoordinateChannelsd
FlattenSubKeysD = FlattenSubKeysDict = FlattenSubKeysd
ApplyTransformToPointsD = ApplyTransformToPointsDict = ApplyTransformToPointsd
FlattenSequenceD = FlattenSequenceDict = FlattenSequenced
34 changes: 34 additions & 0 deletions tests/transforms/compose/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,40 @@ def test_flatten_and_len(self):
def test_backwards_compatible_imports(self):
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401

def test_list_extend_multi_sample_trait(self):
center_crop = mt.CenterSpatialCrop([128, 128])
multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1)
flatten_sequence_transform = mt.FlattenSequence()

img = torch.zeros([1, 512, 512])

self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128]))
single_multi_sample_trait_result = execute_compose(
img, [multi_sample_transform, center_crop, flatten_sequence_transform]
)
self.assertIsInstance(single_multi_sample_trait_result, list)
self.assertEqual(len(single_multi_sample_trait_result), 1)
self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))

double_multi_sample_trait_result = execute_compose(
img, [multi_sample_transform, multi_sample_transform, flatten_sequence_transform, center_crop]
)
self.assertIsInstance(double_multi_sample_trait_result, list)
self.assertEqual(len(double_multi_sample_trait_result), 1)
self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))

def test_multi_sample_trait_cardinality(self):
img = torch.zeros([1, 128, 128])
t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2)
flatten_sequence_transform = mt.FlattenSequence()

# chaining should multiply counts: 2 x 2 = 4, flattened
res = execute_compose(img, [t2, t2, flatten_sequence_transform])
self.assertIsInstance(res, list)
self.assertEqual(len(res), 4)
for r in res:
self.assertEqual(r.shape, torch.Size([1, 32, 32]))


TEST_COMPOSE_EXECUTE_TEST_CASES = [
[None, tuple()],
Expand Down
Loading