Skip to content

Commit

Permalink
Fix Lazy=True ignored when using Dataset call (#6975)
Browse files Browse the repository at this point in the history
Fixes #6974.

### Description
Change default value of `lazy` in `apply_transform` to None.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored Sep 13, 2023
1 parent 392c5c1 commit d4dc055
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 2 deletions.
4 changes: 2 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def apply_transform(
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool | str = False,
lazy: bool | None = False,
lazy: bool | None = None,
overrides: dict | None = None,
) -> list[ReturnType] | ReturnType:
"""
Expand All @@ -124,7 +124,7 @@ def apply_transform(
disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the
default logger name. Setting it to a string specifies the logger to which errors should be logged.
lazy: whether to execute in lazy mode or not. See the :ref:`Lazy Resampling topic<lazy_resampling> for more
information about lazy resampling.
information about lazy resampling. Defaults to None.
overrides: optional overrides to apply to transform parameters. This parameter is ignored unless transforms
are being executed lazily. See the :ref:`Lazy Resampling topic<lazy_resampling> for more details and
examples of its usage.
Expand Down
6 changes: 6 additions & 0 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,12 @@ def test_compose_with_logger(self, keys, pipeline):
"INFO - Pending transforms applied: applied_operations: 1\n"
),
],
[
mt.OneOf,
(mt.Flip(0),),
False,
("INFO - Apply pending transforms - lazy: False, pending: 0, " "upcoming 'Flip', transform.lazy: False\n"),
],
]


Expand Down
37 changes: 37 additions & 0 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,20 @@

from __future__ import annotations

import logging
import os
import tempfile
import unittest
from copy import deepcopy
from io import StringIO

import nibabel as nib
import numpy as np
from parameterized import parameterized

from monai.data import Dataset
from monai.transforms import Compose, LoadImaged, SimulateDelayd
from tests.test_compose import TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES, data_from_keys

TEST_CASE_1 = [(128, 128, 128)]

Expand Down Expand Up @@ -89,6 +93,39 @@ def test_shape(self, expected_shape):
for d in data4_list:
self.assertTupleEqual(d["image"].shape, expected_shape)

def test_dataset_lazy_on_call(self):
data = np.zeros((1, 5, 5))
data[0, 0:2, 0:2] = 1


class TestDatsesetWithLazy(unittest.TestCase):
LOGGER_NAME = "a_logger_name"

def init_logger(self, name=LOGGER_NAME):
stream = StringIO()
handler = logging.StreamHandler(stream)
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
while len(logger.handlers) > 0:
logger.removeHandler(logger.handlers[-1])
logger.addHandler(handler)
return handler, stream

@parameterized.expand(TEST_COMPOSE_LAZY_ON_CALL_LOGGING_TEST_CASES)
def test_dataset_lazy_with_logging(self, compose_type, pipeline, lazy, expected):
handler, stream = self.init_logger(name=self.LOGGER_NAME)

data = data_from_keys(None, 12, 16)
c = compose_type(deepcopy(pipeline), log_stats=self.LOGGER_NAME, lazy=lazy)
ds = Dataset([data], transform=c)
ds[0]

handler.flush()
actual = stream.getvalue()
self.assertEqual(actual, expected)


if __name__ == "__main__":
unittest.main()

0 comments on commit d4dc055

Please sign in to comment.