Skip to content

Commit b87375f

Browse files
wyliatbenmurraybinliunlsyiheng-wang-nvKumoLiu
authored
4855 lazy resampling impl -- Compose (#5860)
part of #4855 upgrade #4911 to use the latest dev API ### Description Example usage: for a sequence of spatial transforms ```py xforms = [ mt.LoadImageD(keys, ensure_channel_first=True), mt.Orientationd(keys, "RAS"), mt.SpacingD(keys, (1.5, 1.5, 1.5)), mt.CenterScaleCropD(keys, roi_scale=0.9), # mt.CropForegroundD(keys, source_key="seg", k_divisible=5), mt.RandRotateD(keys, prob=1.0, range_y=np.pi / 2, range_x=np.pi / 3), mt.RandSpatialCropD(keys, roi_size=(76, 87, 73)), mt.RandScaleCropD(keys, roi_scale=0.9), mt.Resized(keys, (30, 40, 60)), # mt.NormalizeIntensityd(keys), mt.ZoomD(keys, 1.3, keep_size=False), mt.FlipD(keys), mt.Rotate90D(keys), mt.RandAffined(keys), mt.ResizeWithPadOrCropd(keys, spatial_size=(32, 43, 54)), mt.DivisiblePadD(keys, k=3), ] lazy_kwargs = dict(mode=("bilinear", 0), padding_mode=("border", "nearest"), dtype=(torch.float32, torch.uint8)) xform = mt.Compose(xforms, lazy_evaluation=True, overrides=lazy_kwargs, override_keys=keys) xform.set_random_state(0) ``` lazy_evaluation=True preserves more details ![Screenshot 2023-01-17 at 00 31 40](https://user-images.githubusercontent.com/831580/212784981-ea39833b-54ab-42fb-bc03-38b012281857.png) compared with the regular compose ![Screenshot 2023-01-17 at 00 31 43](https://user-images.githubusercontent.com/831580/212785016-ba3be8ff-f17f-47b4-8025-cd351a637a82.png) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com> Signed-off-by: Yiheng Wang <vennw@nvidia.com> Signed-off-by: KumoLiu <yunl@nvidia.com> Signed-off-by: Ben Murray <ben.murray@gmail.com> Co-authored-by: Ben Murray <ben.murray@gmail.com> Co-authored-by: binliu <binliu@nvidia.com> Co-authored-by: Yiheng Wang <68361391+yiheng-wang-nv@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Yiheng Wang <vennw@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: KumoLiu <yunl@nvidia.com>
1 parent 1cd0d7b commit b87375f

File tree

16 files changed

+578
-102
lines changed

16 files changed

+578
-102
lines changed

docs/source/transforms.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2206,3 +2206,9 @@ Utilities
22062206

22072207
.. automodule:: monai.transforms.utils_pytorch_numpy_unification
22082208
:members:
2209+
2210+
Lazy
2211+
----
2212+
.. automodule:: monai.transforms.lazy
2213+
:members:
2214+
:imported-members:

monai/data/dataset.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,9 @@ def _pre_transform(self, item_transformed):
322322
break
323323
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
324324
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
325+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
325326
item_transformed = apply_transform(_xform, item_transformed)
327+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
326328
if self.reset_ops_id:
327329
reset_ops_id(item_transformed)
328330
return item_transformed
@@ -348,7 +350,9 @@ def _post_transform(self, item_transformed):
348350
or not isinstance(_transform, Transform)
349351
):
350352
start_post_randomize_run = True
353+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _transform)
351354
item_transformed = apply_transform(_transform, item_transformed)
355+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
352356
return item_transformed
353357

354358
def _cachecheck(self, item_transformed):
@@ -496,7 +500,9 @@ def _pre_transform(self, item_transformed):
496500
if i == self.cache_n_trans:
497501
break
498502
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
503+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, _xform)
499504
item_transformed = apply_transform(_xform, item_transformed)
505+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
500506
reset_ops_id(item_transformed)
501507
return item_transformed
502508

@@ -514,7 +520,9 @@ def _post_transform(self, item_transformed):
514520
raise ValueError("transform must be an instance of monai.transforms.Compose.")
515521
for i, _transform in enumerate(self.transform.transforms):
516522
if i >= self.cache_n_trans:
523+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, item_transformed)
517524
item_transformed = apply_transform(_transform, item_transformed)
525+
item_transformed = self.transform.evaluate_with_overrides(item_transformed, None)
518526
return item_transformed
519527

520528

@@ -884,7 +892,9 @@ def _load_cache_item(self, idx: int):
884892
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
885893
break
886894
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
895+
item = self.transform.evaluate_with_overrides(item, _xform)
887896
item = apply_transform(_xform, item)
897+
item = self.transform.evaluate_with_overrides(item, None)
888898
if self.as_contiguous:
889899
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
890900
return item
@@ -921,7 +931,9 @@ def _transform(self, index: int):
921931
start_run = True
922932
if self.copy_cache:
923933
data = deepcopy(data)
934+
data = self.transform.evaluate_with_overrides(data, _transform)
924935
data = apply_transform(_transform, data)
936+
data = self.transform.evaluate_with_overrides(data, None)
925937
return data
926938

927939

monai/data/meta_obj.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,15 @@ def pending_operations(self) -> list[dict]:
214214
return self._pending_operations
215215
return MetaObj.get_default_applied_operations() # the same default as applied_ops
216216

217+
@property
218+
def has_pending_operations(self) -> bool:
219+
"""
220+
Determine whether there are pending operations.
221+
Returns:
222+
True if there are pending operations; False if not
223+
"""
224+
return self.pending_operations is not None and len(self.pending_operations) > 0
225+
217226
def push_pending_operation(self, t: Any) -> None:
218227
self._pending_operations.append(t)
219228

monai/data/meta_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,7 @@ def peek_pending_affine(self):
492492
continue
493493
res = convert_to_dst_type(res, next_matrix)[0]
494494
next_matrix = monai.data.utils.to_affine_nd(r, next_matrix)
495-
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)
495+
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix) # type: ignore
496496
return res
497497

498498
def peek_pending_rank(self):

monai/transforms/compose.py

Lines changed: 178 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,95 @@
2121
import numpy as np
2222

2323
import monai
24+
import monai.transforms as mt
25+
from monai.apps.utils import get_logger
2426
from monai.transforms.inverse import InvertibleTransform
2527

2628
# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
2729
from monai.transforms.transform import ( # noqa: F401
30+
LazyTransform,
2831
MapTransform,
2932
Randomizable,
3033
RandomizableTransform,
3134
Transform,
3235
apply_transform,
3336
)
34-
from monai.utils import MAX_SEED, ensure_tuple, get_seed
35-
from monai.utils.enums import TraceKeys
37+
from monai.utils import MAX_SEED, TraceKeys, ensure_tuple, get_seed
38+
from monai.utils.misc import to_tuple_of_dictionaries
3639

37-
__all__ = ["Compose", "OneOf", "RandomOrder"]
40+
logger = get_logger(__name__)
41+
42+
__all__ = ["Compose", "OneOf", "RandomOrder", "evaluate_with_overrides"]
43+
44+
45+
def evaluate_with_overrides(
46+
data,
47+
upcoming,
48+
lazy_evaluation: bool | None = False,
49+
overrides: dict | None = None,
50+
override_keys: Sequence[str] | None = None,
51+
verbose: bool = False,
52+
):
53+
"""
54+
The previously applied transform may have been lazily applied to MetaTensor `data` and
55+
made `data.has_pending_operations` equals to True. Given the upcoming transform ``upcoming``,
56+
this function determines whether `data.pending_operations` should be evaluated. If so, it will
57+
evaluate the lazily applied transforms.
58+
59+
Currently, the conditions for evaluation are:
60+
61+
- ``lazy_evaluation`` is ``True``, AND
62+
- the data is a ``MetaTensor`` and has pending operations, AND
63+
- the upcoming transform is an instance of ``Identity`` or ``IdentityD`` or ``None``.
64+
65+
The returned `data` will then be ready for the ``upcoming`` transform.
66+
67+
Args:
68+
data: data to be evaluated.
69+
upcoming: the upcoming transform.
70+
lazy_evaluation: whether to evaluate the pending operations.
71+
override: keyword arguments to apply transforms.
72+
override_keys: to which the override arguments are used when apply transforms.
73+
verbose: whether to print debugging info when evaluate MetaTensor with pending operations.
74+
75+
"""
76+
if not lazy_evaluation:
77+
return data # eager evaluation
78+
overrides = (overrides or {}).copy()
79+
if isinstance(data, monai.data.MetaTensor):
80+
if data.has_pending_operations and ((isinstance(upcoming, (mt.Identityd, mt.Identity))) or upcoming is None):
81+
data, _ = mt.apply_transforms(data, None, overrides=overrides)
82+
if verbose:
83+
next_name = "final output" if upcoming is None else f"'{upcoming.__class__.__name__}'"
84+
logger.info(f"Evaluated - '{override_keys}' - up-to-date for - {next_name}")
85+
elif verbose:
86+
logger.info(
87+
f"Lazy - '{override_keys}' - upcoming: '{upcoming.__class__.__name__}'"
88+
f"- pending {len(data.pending_operations)}"
89+
)
90+
return data
91+
override_keys = ensure_tuple(override_keys)
92+
if isinstance(data, dict):
93+
if isinstance(upcoming, MapTransform):
94+
applied_keys = {k for k in data if k in upcoming.keys}
95+
if not applied_keys:
96+
return data
97+
else:
98+
applied_keys = set(data.keys())
99+
100+
keys_to_override = {k for k in applied_keys if k in override_keys}
101+
# generate a list of dictionaries with the appropriate override value per key
102+
dict_overrides = to_tuple_of_dictionaries(overrides, override_keys)
103+
for k in data:
104+
if k in keys_to_override:
105+
dict_for_key = dict_overrides[override_keys.index(k)]
106+
data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, dict_for_key, k, verbose)
107+
else:
108+
data[k] = evaluate_with_overrides(data[k], upcoming, lazy_evaluation, None, k, verbose)
109+
110+
if isinstance(data, (list, tuple)):
111+
return [evaluate_with_overrides(v, upcoming, lazy_evaluation, overrides, override_keys, verbose) for v in data]
112+
return data
38113

39114

40115
class Compose(Randomizable, InvertibleTransform):
@@ -114,7 +189,21 @@ class Compose(Randomizable, InvertibleTransform):
114189
log_stats: whether to log the detailed information of data and applied transform when error happened,
115190
for NumPy array and PyTorch Tensor, log the data shape and value range,
116191
for other metadata, log the values directly. default to `False`.
117-
192+
lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If False, transforms will be
193+
carried out on a transform by transform basis. If True, all lazy transforms will
194+
be executed by accumulating changes and resampling as few times as possible.
195+
A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
196+
the pending operations and make the primary data up-to-date.
197+
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
198+
when executing a pipeline. These each parameter that is compatible with a given transform is then applied
199+
to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
200+
is True. If lazy_evaluation is False they are ignored.
201+
currently supported args are:
202+
{``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
203+
please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
204+
override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
205+
``overrides`` is set, ``override_keys`` must also be set.
206+
verbose: whether to print debugging info when lazy_evaluation=True.
118207
"""
119208

120209
def __init__(
@@ -123,6 +212,10 @@ def __init__(
123212
map_items: bool = True,
124213
unpack_items: bool = False,
125214
log_stats: bool = False,
215+
lazy_evaluation: bool | None = None,
216+
overrides: dict | None = None,
217+
override_keys: Sequence[str] | None = None,
218+
verbose: bool = False,
126219
) -> None:
127220
if transforms is None:
128221
transforms = []
@@ -132,6 +225,16 @@ def __init__(
132225
self.log_stats = log_stats
133226
self.set_random_state(seed=get_seed())
134227

228+
self.lazy_evaluation = lazy_evaluation
229+
self.overrides = overrides
230+
self.override_keys = override_keys
231+
self.verbose = verbose
232+
233+
if self.lazy_evaluation is not None:
234+
for t in self.flatten().transforms: # TODO: test Compose of Compose/OneOf
235+
if isinstance(t, LazyTransform):
236+
t.lazy_evaluation = self.lazy_evaluation
237+
135238
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose:
136239
super().set_random_state(seed=seed, state=state)
137240
for _transform in self.transforms:
@@ -172,9 +275,26 @@ def __len__(self):
172275
"""Return number of transformations."""
173276
return len(self.flatten().transforms)
174277

278+
def evaluate_with_overrides(self, input_, upcoming_xform):
279+
"""
280+
Args:
281+
input_: input data to be transformed.
282+
upcoming_xform: a transform used to determine whether to evaluate with override
283+
"""
284+
return evaluate_with_overrides(
285+
input_,
286+
upcoming_xform,
287+
lazy_evaluation=self.lazy_evaluation,
288+
overrides=self.overrides,
289+
override_keys=self.override_keys,
290+
verbose=self.verbose,
291+
)
292+
175293
def __call__(self, input_):
176294
for _transform in self.transforms:
295+
input_ = self.evaluate_with_overrides(input_, _transform)
177296
input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats)
297+
input_ = self.evaluate_with_overrides(input_, None)
178298
return input_
179299

180300
def inverse(self, data):
@@ -204,7 +324,21 @@ class OneOf(Compose):
204324
log_stats: whether to log the detailed information of data and applied transform when error happened,
205325
for NumPy array and PyTorch Tensor, log the data shape and value range,
206326
for other metadata, log the values directly. default to `False`.
207-
327+
lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will
328+
be executed by accumulating changes and resampling as few times as possible. If False, transforms will be
329+
carried out on a transform by transform basis.
330+
A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
331+
the pending operations and make the primary data up-to-date.
332+
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
333+
when executing a pipeline. These each parameter that is compatible with a given transform is then applied
334+
to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
335+
is True. If lazy_evaluation is False they are ignored.
336+
currently supported args are:
337+
{``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
338+
please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
339+
override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
340+
``overrides`` is set, ``override_keys`` must also be set.
341+
verbose: whether to print debugging info when lazy_evaluation=True.
208342
"""
209343

210344
def __init__(
@@ -214,8 +348,14 @@ def __init__(
214348
map_items: bool = True,
215349
unpack_items: bool = False,
216350
log_stats: bool = False,
351+
lazy_evaluation: bool | None = None,
352+
overrides: dict | None = None,
353+
override_keys: Sequence[str] | None = None,
354+
verbose: bool = False,
217355
) -> None:
218-
super().__init__(transforms, map_items, unpack_items, log_stats)
356+
super().__init__(
357+
transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose
358+
)
219359
if len(self.transforms) == 0:
220360
weights = []
221361
elif weights is None or isinstance(weights, float):
@@ -265,8 +405,8 @@ def __call__(self, data):
265405
self.push_transform(data, extra_info={"index": index})
266406
elif isinstance(data, Mapping):
267407
for key in data: # dictionary not change size during iteration
268-
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
269-
self.push_transform(data, key, extra_info={"index": index})
408+
if isinstance(data[key], monai.data.MetaTensor):
409+
self.push_transform(data[key], extra_info={"index": index})
270410
return data
271411

272412
def inverse(self, data):
@@ -278,7 +418,7 @@ def inverse(self, data):
278418
index = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["index"]
279419
elif isinstance(data, Mapping):
280420
for key in data:
281-
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
421+
if isinstance(data[key], monai.data.MetaTensor):
282422
index = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
283423
else:
284424
raise RuntimeError(
@@ -306,7 +446,21 @@ class RandomOrder(Compose):
306446
log_stats: whether to log the detailed information of data and applied transform when error happened,
307447
for NumPy array and PyTorch Tensor, log the data shape and value range,
308448
for other metadata, log the values directly. default to `False`.
309-
449+
lazy_evaluation: whether to enable lazy evaluation for lazy transforms. If True, all lazy transforms will
450+
be executed by accumulating changes and resampling as few times as possible. If False, transforms will be
451+
carried out on a transform by transform basis.
452+
A `monai.transforms.Identity[D]` transform in the pipeline will trigger the evaluation of
453+
the pending operations and make the primary data up-to-date.
454+
overrides: this optional parameter allows you to specify a dictionary of parameters that should be overridden
455+
when executing a pipeline. These each parameter that is compatible with a given transform is then applied
456+
to that transform before it is executed. Note that overrides are currently only applied when lazy_evaluation
457+
is True. If lazy_evaluation is False they are ignored.
458+
currently supported args are:
459+
{``"mode"``, ``"padding_mode"``, ``"dtype"``, ``"align_corners"``, ``"resample_mode"``, ``device``},
460+
please see also :py:func:`monai.transforms.lazy.apply_transforms` for more details.
461+
override_keys: this optional parameter specifies the keys to which ``overrides`` are to be applied. If
462+
``overrides`` is set, ``override_keys`` must also be set.
463+
verbose: whether to print debugging info when lazy_evaluation=True.
310464
"""
311465

312466
def __init__(
@@ -315,8 +469,14 @@ def __init__(
315469
map_items: bool = True,
316470
unpack_items: bool = False,
317471
log_stats: bool = False,
472+
lazy_evaluation: bool | None = None,
473+
overrides: dict | None = None,
474+
override_keys: Sequence[str] | None = None,
475+
verbose: bool = False,
318476
) -> None:
319-
super().__init__(transforms, map_items, unpack_items, log_stats)
477+
super().__init__(
478+
transforms, map_items, unpack_items, log_stats, lazy_evaluation, overrides, override_keys, verbose
479+
)
320480

321481
def __call__(self, input_):
322482
if len(self.transforms) == 0:
@@ -331,8 +491,8 @@ def __call__(self, input_):
331491
self.push_transform(input_, extra_info={"applied_order": applied_order})
332492
elif isinstance(input_, Mapping):
333493
for key in input_: # dictionary not change size during iteration
334-
if isinstance(input_[key], monai.data.MetaTensor) or self.trace_key(key) in input_:
335-
self.push_transform(input_, key, extra_info={"applied_order": applied_order})
494+
if isinstance(input_[key], monai.data.MetaTensor):
495+
self.push_transform(input_[key], extra_info={"applied_order": applied_order})
336496
return input_
337497

338498
def inverse(self, data):
@@ -344,7 +504,7 @@ def inverse(self, data):
344504
applied_order = self.pop_transform(data)[TraceKeys.EXTRA_INFO]["applied_order"]
345505
elif isinstance(data, Mapping):
346506
for key in data:
347-
if isinstance(data[key], monai.data.MetaTensor) or self.trace_key(key) in data:
507+
if isinstance(data[key], monai.data.MetaTensor):
348508
applied_order = self.pop_transform(data, key)[TraceKeys.EXTRA_INFO]["applied_order"]
349509
else:
350510
raise RuntimeError(
@@ -356,5 +516,8 @@ def inverse(self, data):
356516

357517
# loop backwards over transforms
358518
for o in reversed(applied_order):
359-
data = apply_transform(self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats)
519+
if isinstance(self.transforms[o], InvertibleTransform):
520+
data = apply_transform(
521+
self.transforms[o].inverse, data, self.map_items, self.unpack_items, self.log_stats
522+
)
360523
return data

0 commit comments

Comments
 (0)