2121import numpy as np
2222
2323import monai
24+ import monai .transforms as mt
25+ from monai .apps .utils import get_logger
2426from monai .transforms .inverse import InvertibleTransform
2527
2628# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
2729from monai .transforms .transform import ( # noqa: F401
30+ LazyTransform ,
2831 MapTransform ,
2932 Randomizable ,
3033 RandomizableTransform ,
3134 Transform ,
3235 apply_transform ,
3336)
34- from monai .utils import MAX_SEED , ensure_tuple , get_seed
35- from monai .utils .enums import TraceKeys
37+ from monai .utils import MAX_SEED , TraceKeys , ensure_tuple , get_seed
38+ from monai .utils .misc import to_tuple_of_dictionaries
3639
37- __all__ = ["Compose" , "OneOf" , "RandomOrder" ]
40+ logger = get_logger (__name__ )
41+
42+ __all__ = ["Compose" , "OneOf" , "RandomOrder" , "evaluate_with_overrides" ]
43+
44+
45+ def evaluate_with_overrides (
46+ data ,
47+ upcoming ,
48+ lazy_evaluation : bool | None = False ,
49+ overrides : dict | None = None ,
50+ override_keys : Sequence [str ] | None = None ,
51+ verbose : bool = False ,
52+ ):
53+ """
54+ The previously applied transform may have been lazily applied to MetaTensor `data` and
55+ made `data.has_pending_operations` equals to True. Given the upcoming transform ``upcoming``,
56+ this function determines whether `data.pending_operations` should be evaluated. If so, it will
57+ evaluate the lazily applied transforms.
58+
59+ Currently, the conditions for evaluation are:
60+
61+ - ``lazy_evaluation`` is ``True``, AND
62+ - the data is a ``MetaTensor`` and has pending operations, AND
63+ - the upcoming transform is an instance of ``Identity`` or ``IdentityD`` or ``None``.
64+
65+ The returned `data` will then be ready for the ``upcoming`` transform.
66+
67+ Args:
68+ data: data to be evaluated.
69+ upcoming: the upcoming transform.
70+ lazy_evaluation: whether to evaluate the pending operations.
71+ override: keyword arguments to apply transforms.
72+ override_keys: to which the override arguments are used when apply transforms.
73+ verbose: whether to print debugging info when evaluate MetaTensor with pending operations.
74+
75+ """
76+ if not lazy_evaluation :
77+ return data # eager evaluation
78+ overrides = (overrides or {}).copy ()
79+ if isinstance (data , monai .data .MetaTensor ):
80+ if data .has_pending_operations and ((isinstance (upcoming , (mt .Identityd , mt .Identity ))) or upcoming is None ):
81+ data , _ = mt .apply_transforms (data , None , overrides = overrides )
82+ if verbose :
83+ next_name = "final output" if upcoming is None else f"'{ upcoming .__class__ .__name__ } '"
84+ logger .info (f"Evaluated - '{ override_keys } ' - up-to-date for - { next_name } " )
85+ elif verbose :
86+ logger .info (
87+ f"Lazy - '{ override_keys } ' - upcoming: '{ upcoming .__class__ .__name__ } '"
88+ f"- pending { len (data .pending_operations )} "
89+ )
90+ return data
91+ override_keys = ensure_tuple (override_keys )
92+ if isinstance (data , dict ):
93+ if isinstance (upcoming , MapTransform ):
94+ applied_keys = {k for k in data if k in upcoming .keys }
95+ if not applied_keys :
96+ return data
97+ else :
98+ applied_keys = set (data .keys ())
99+
100+ keys_to_override = {k for k in applied_keys if k in override_keys }
101+ # generate a list of dictionaries with the appropriate override value per key
102+ dict_overrides = to_tuple_of_dictionaries (overrides , override_keys )
103+ for k in data :
104+ if k in keys_to_override :
105+ dict_for_key = dict_overrides [override_keys .index (k )]
106+ data [k ] = evaluate_with_overrides (data [k ], upcoming , lazy_evaluation , dict_for_key , k , verbose )
107+ else :
108+ data [k ] = evaluate_with_overrides (data [k ], upcoming , lazy_evaluation , None , k , verbose )
109+
110+ if isinstance (data , (list , tuple )):
111+ return [evaluate_with_overrides (v , upcoming , lazy_evaluation , overrides , override_keys , verbose ) for v in data ]
112+ return data
38113
39114
40115class Compose (Randomizable , InvertibleTransform ):
@@ -114,7 +189,21 @@ class Compose(Randomizable, InvertibleTransform):
114189 log_stats: whether to log the detailed information of data and applied transform when error happened,
115190 for NumPy array and PyTorch Tensor, log the data shape and value range,
116191 for other metadata, log the values directly. default to `False`.
117-
192+ lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be
193+ carried out on a transform by transform basis. If True, all lazy transforms will
194+ be executed by accumulating changes and resampling as few times as possible.
195+ A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
196+ the pending operations and make the primary data up-to-date.
197+ overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
198+ when executing a pipeline. These each parameter that is compatible with a given transform is then applied
199+ to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
200+ is True. If lazy_evaluation is False they are ignored.
201+ currently supported args are:
202+ {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
203+ please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
204+ override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
205+ ``overrides`` is set, ``override_keys`` must also be set.
206+ verbose: whether to print debugging info when lazy_evaluation=True.
118207 """
119208
120209 def __init__ (
@@ -123,6 +212,10 @@ def __init__(
123212 map_items : bool = True ,
124213 unpack_items : bool = False ,
125214 log_stats : bool = False ,
215+ lazy_evaluation : bool | None = None ,
216+ overrides : dict | None = None ,
217+ override_keys : Sequence [str ] | None = None ,
218+ verbose : bool = False ,
126219 ) -> None :
127220 if transforms is None :
128221 transforms = []
@@ -132,6 +225,16 @@ def __init__(
132225 self .log_stats = log_stats
133226 self .set_random_state (seed = get_seed ())
134227
228+ self .lazy_evaluation = lazy_evaluation
229+ self .overrides = overrides
230+ self .override_keys = override_keys
231+ self .verbose = verbose
232+
233+ if self .lazy_evaluation is not None :
234+ for t in self .flatten ().transforms : # TODO: test Compose of Compose/OneOf
235+ if isinstance (t , LazyTransform ):
236+ t .lazy_evaluation = self .lazy_evaluation
237+
135238 def set_random_state (self , seed : int | None = None , state : np .random .RandomState | None = None ) -> Compose :
136239 super ().set_random_state (seed = seed , state = state )
137240 for _transform in self .transforms :
@@ -172,9 +275,26 @@ def __len__(self):
172275 """Return number of transformations."""
173276 return len (self .flatten ().transforms )
174277
278+ def evaluate_with_overrides (self , input_ , upcoming_xform ):
279+ """
280+ Args:
281+ input_: input data to be transformed.
282+ upcoming_xform: a transform used to determine whether to evaluate with override
283+ """
284+ return evaluate_with_overrides (
285+ input_ ,
286+ upcoming_xform ,
287+ lazy_evaluation = self .lazy_evaluation ,
288+ overrides = self .overrides ,
289+ override_keys = self .override_keys ,
290+ verbose = self .verbose ,
291+ )
292+
175293 def __call__ (self , input_ ):
176294 for _transform in self .transforms :
295+ input_ = self .evaluate_with_overrides (input_ , _transform )
177296 input_ = apply_transform (_transform , input_ , self .map_items , self .unpack_items , self .log_stats )
297+ input_ = self .evaluate_with_overrides (input_ , None )
178298 return input_
179299
180300 def inverse (self , data ):
@@ -204,7 +324,21 @@ class OneOf(Compose):
204324 log_stats: whether to log the detailed information of data and applied transform when error happened,
205325 for NumPy array and PyTorch Tensor, log the data shape and value range,
206326 for other metadata, log the values directly. default to `False`.
207-
327+ lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will
328+ be executed by accumulating changes and resampling as few times as possible. If False, transforms will be
329+ carried out on a transform by transform basis.
330+ A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
331+ the pending operations and make the primary data up-to-date.
332+ overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
333+ when executing a pipeline. These each parameter that is compatible with a given transform is then applied
334+ to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
335+ is True. If lazy_evaluation is False they are ignored.
336+ currently supported args are:
337+ {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
338+ please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
339+ override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
340+ ``overrides`` is set, ``override_keys`` must also be set.
341+ verbose: whether to print debugging info when lazy_evaluation=True.
208342 """
209343
210344 def __init__ (
@@ -214,8 +348,14 @@ def __init__(
214348 map_items : bool = True ,
215349 unpack_items : bool = False ,
216350 log_stats : bool = False ,
351+ lazy_evaluation : bool | None = None ,
352+ overrides : dict | None = None ,
353+ override_keys : Sequence [str ] | None = None ,
354+ verbose : bool = False ,
217355 ) -> None :
218- super ().__init__ (transforms , map_items , unpack_items , log_stats )
356+ super ().__init__ (
357+ transforms , map_items , unpack_items , log_stats , lazy_evaluation , overrides , override_keys , verbose
358+ )
219359 if len (self .transforms ) == 0 :
220360 weights = []
221361 elif weights is None or isinstance (weights , float ):
@@ -265,8 +405,8 @@ def __call__(self, data):
265405 self .push_transform (data , extra_info = {"index" : index })
266406 elif isinstance (data , Mapping ):
267407 for key in data : # dictionary not change size during iteration
268- if isinstance (data [key ], monai .data .MetaTensor ) or self . trace_key ( key ) in data :
269- self .push_transform (data , key , extra_info = {"index" : index })
408+ if isinstance (data [key ], monai .data .MetaTensor ):
409+ self .push_transform (data [ key ] , extra_info = {"index" : index })
270410 return data
271411
272412 def inverse (self , data ):
@@ -278,7 +418,7 @@ def inverse(self, data):
278418 index = self .pop_transform (data )[TraceKeys .EXTRA_INFO ]["index" ]
279419 elif isinstance (data , Mapping ):
280420 for key in data :
281- if isinstance (data [key ], monai .data .MetaTensor ) or self . trace_key ( key ) in data :
421+ if isinstance (data [key ], monai .data .MetaTensor ):
282422 index = self .pop_transform (data , key )[TraceKeys .EXTRA_INFO ]["index" ]
283423 else :
284424 raise RuntimeError (
@@ -306,7 +446,21 @@ class RandomOrder(Compose):
306446 log_stats: whether to log the detailed information of data and applied transform when error happened,
307447 for NumPy array and PyTorch Tensor, log the data shape and value range,
308448 for other metadata, log the values directly. default to `False`.
309-
449+ lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will
450+ be executed by accumulating changes and resampling as few times as possible. If False, transforms will be
451+ carried out on a transform by transform basis.
452+ A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
453+ the pending operations and make the primary data up-to-date.
454+ overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
455+ when executing a pipeline. These each parameter that is compatible with a given transform is then applied
456+ to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
457+ is True. If lazy_evaluation is False they are ignored.
458+ currently supported args are:
459+ {``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
460+ please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
461+ override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
462+ ``overrides`` is set, ``override_keys`` must also be set.
463+ verbose: whether to print debugging info when lazy_evaluation=True.
310464 """
311465
312466 def __init__ (
@@ -315,8 +469,14 @@ def __init__(
315469 map_items : bool = True ,
316470 unpack_items : bool = False ,
317471 log_stats : bool = False ,
472+ lazy_evaluation : bool | None = None ,
473+ overrides : dict | None = None ,
474+ override_keys : Sequence [str ] | None = None ,
475+ verbose : bool = False ,
318476 ) -> None :
319- super ().__init__ (transforms , map_items , unpack_items , log_stats )
477+ super ().__init__ (
478+ transforms , map_items , unpack_items , log_stats , lazy_evaluation , overrides , override_keys , verbose
479+ )
320480
321481 def __call__ (self , input_ ):
322482 if len (self .transforms ) == 0 :
@@ -331,8 +491,8 @@ def __call__(self, input_):
331491 self .push_transform (input_ , extra_info = {"applied_order" : applied_order })
332492 elif isinstance (input_ , Mapping ):
333493 for key in input_ : # dictionary not change size during iteration
334- if isinstance (input_ [key ], monai .data .MetaTensor ) or self . trace_key ( key ) in input_ :
335- self .push_transform (input_ , key , extra_info = {"applied_order" : applied_order })
494+ if isinstance (input_ [key ], monai .data .MetaTensor ):
495+ self .push_transform (input_ [ key ] , extra_info = {"applied_order" : applied_order })
336496 return input_
337497
338498 def inverse (self , data ):
@@ -344,7 +504,7 @@ def inverse(self, data):
344504 applied_order = self .pop_transform (data )[TraceKeys .EXTRA_INFO ]["applied_order" ]
345505 elif isinstance (data , Mapping ):
346506 for key in data :
347- if isinstance (data [key ], monai .data .MetaTensor ) or self . trace_key ( key ) in data :
507+ if isinstance (data [key ], monai .data .MetaTensor ):
348508 applied_order = self .pop_transform (data , key )[TraceKeys .EXTRA_INFO ]["applied_order" ]
349509 else :
350510 raise RuntimeError (
@@ -356,5 +516,8 @@ def inverse(self, data):
356516
357517 # loop backwards over transforms
358518 for o in reversed (applied_order ):
359- data = apply_transform (self .transforms [o ].inverse , data , self .map_items , self .unpack_items , self .log_stats )
519+ if isinstance (self .transforms [o ], InvertibleTransform ):
520+ data = apply_transform (
521+ self .transforms [o ].inverse , data , self .map_items , self .unpack_items , self .log_stats
522+ )
360523 return data
0 commit comments