From cc53588bc2e8e50f1659c1288d5906b458e87e47 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Feb 2021 23:54:52 +0800 Subject: [PATCH 1/7] [DLMED] fix distirbuted data parallel issue in ClassificationSaver Signed-off-by: Nic Ma --- monai/handlers/__init__.py | 8 ++- monai/handlers/classification_saver.py | 22 +++++-- monai/handlers/metrics_saver.py | 14 +---- monai/handlers/utils.py | 28 ++++++++- .../test_handler_classification_saver_dist.py | 59 +++++++++++++++++++ 5 files changed, 113 insertions(+), 18 deletions(-) create mode 100644 tests/test_handler_classification_saver_dist.py diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 6b190518fb..1224721465 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -25,5 +25,11 @@ from .stats_handler import StatsHandler from .surface_distance import SurfaceDistance from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler -from .utils import evenly_divisible_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, write_metrics_reports +from .utils import ( + evenly_divisible_all_gather, + string_list_all_gather, + stopping_fn_from_loss, + stopping_fn_from_metric, + write_metrics_reports, +) from .validation_handler import ValidationHandler diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 6753cafcb0..d7a9f2de78 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -13,8 +13,10 @@ from typing import TYPE_CHECKING, Callable, Optional from monai.data import CSVSaver +from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import exact_version, optional_import +idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") if TYPE_CHECKING: from ignite.engine import Engine @@ -25,6 +27,8 @@ class ClassificationSaver: """ Event handler triggered on completing every iteration to save the classification predictions as CSV file. + If running in distributed data parallel, only saves CSV file in the specified rank. + """ def __init__( @@ -35,6 +39,7 @@ def __init__( batch_transform: Callable = lambda x: x, output_transform: Callable = lambda x: x, name: Optional[str] = None, + save_rank: int = 0, ) -> None: """ Args: @@ -49,8 +54,11 @@ def __init__( The first dimension of this transform's output will be treated as the batch dimension. Each item in the batch will be saved individually. name: identifier of logging.logger to use, defaulting to `engine.logger`. + save_rank: only the handler on specified rank will save to CSV file in multi-gpus validation, + default to 0. """ + self._expected_rank: bool = idist.get_rank() == save_rank self.saver = CSVSaver(output_dir, filename, overwrite) self.batch_transform = batch_transform self.output_transform = output_transform @@ -67,7 +75,7 @@ def attach(self, engine: Engine) -> None: self.logger = engine.logger if not engine.has_event_handler(self, Events.ITERATION_COMPLETED): engine.add_event_handler(Events.ITERATION_COMPLETED, self) - if not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): + if self._expected_rank and not engine.has_event_handler(self.saver.finalize, Events.COMPLETED): engine.add_event_handler(Events.COMPLETED, lambda engine: self.saver.finalize()) def __call__(self, engine: Engine) -> None: @@ -77,6 +85,12 @@ def __call__(self, engine: Engine) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ - meta_data = self.batch_transform(engine.state.batch) - engine_output = self.output_transform(engine.state.output) - self.saver.save_batch(engine_output, meta_data) + _meta_data = self.batch_transform(engine.state.batch) + if "filename_or_obj" in _meta_data: + # all gather filenames across ranks + _meta_data["filename_or_obj"] = string_list_all_gather(_meta_data["filename_or_obj"]) + # all gather predictions across ranks + _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) + + if self._expected_rank: + self.saver.save_batch(_engine_output, _meta_data) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index f9deea35df..17e98a8397 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -11,9 +11,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union -from monai.handlers.utils import write_metrics_reports +from monai.handlers.utils import write_metrics_reports, string_list_all_gather from monai.utils import ensure_tuple, exact_version, optional_import -from monai.utils.module import get_torch_version_tuple Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") @@ -105,15 +104,8 @@ def __call__(self, engine: Engine) -> None: if self.save_rank >= ws: raise ValueError("target rank is greater than the distributed group size.") - _images = self._filenames - if ws > 1: - _filenames = self.deli.join(_images) - if get_torch_version_tuple() > (1, 6, 0): - # all gather across all processes - _filenames = self.deli.join(idist.all_gather(_filenames)) - else: - raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") - _images = _filenames.split(self.deli) + # all gather file names across ranks + _images = string_list_all_gather(strings=self._filenames) if ws > 1 else self._filenames # only save metrics to file in specified rank if idist.get_rank() == self.save_rank: diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index ef652efe0a..0a67df6346 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -11,12 +11,12 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union, List import numpy as np import torch -from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils import ensure_tuple, exact_version, optional_import, get_torch_version_tuple idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") if TYPE_CHECKING: @@ -28,6 +28,7 @@ "stopping_fn_from_metric", "stopping_fn_from_loss", "evenly_divisible_all_gather", + "string_list_all_gather", "write_metrics_reports", ] @@ -81,6 +82,29 @@ def evenly_divisible_all_gather(data: torch.Tensor) -> torch.Tensor: return torch.cat([data[i * max_len : i * max_len + l, ...] for i, l in enumerate(all_lens)], dim=0) +def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[str]: + """ + Utility function for distributed data parallel to all gather a list of strings. + + Args: + strings: a list of strings to all gather. + delimiter: use the delimiter to join the string list to be a long string, + then all gather across ranks and split to a list. default to "\t". + + """ + if idist.get_world_size() <= 1: + return strings + + _joined = delimiter.join(strings) + if get_torch_version_tuple() > (1, 6, 0): + # all gather across all ranks + _joined = delimiter.join(idist.all_gather(_joined)) + else: + raise RuntimeError("MetricsSaver can not save metric details in distributed mode with PyTorch < 1.7.0.") + + return _joined.split(delimiter) + + def write_metrics_reports( save_dir: str, images: Optional[Sequence[str]], diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py new file mode 100644 index 0000000000..f00f61d238 --- /dev/null +++ b/tests/test_handler_classification_saver_dist.py @@ -0,0 +1,59 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import csv +import os +import tempfile +import unittest + +import numpy as np +import torch +import torch.distributed as dist +from ignite.engine import Engine + +from monai.handlers import ClassificationSaver +from tests.utils import DistCall, DistTestCase + + +class DistributedHandlerClassificationSaver(DistTestCase): + @DistCall(nnodes=1, nproc_per_node=2) + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + rank = dist.get_rank() + + # set up engine + def _train_func(engine, batch): + return torch.zeros(8 + rank * 2) + + engine = Engine(_train_func) + + # set up testing handler + saver = ClassificationSaver(output_dir=tempdir, filename="predictions.csv", save_rank=1) + saver.attach(engine) + + # rank 0 has 8 images, rank 1 has 10 images + data = [{"filename_or_obj": ["testfile" + str(i) for i in range(8 * rank, (8 + rank) * (rank + 1))]}] + engine.run(data, max_epochs=1) + filepath = os.path.join(tempdir, "predictions.csv") + if rank == 1: + self.assertTrue(os.path.exists(filepath)) + with open(filepath, "r") as f: + reader = csv.reader(f) + i = 0 + for row in reader: + self.assertEqual(row[0], "testfile" + str(i)) + self.assertEqual(np.array(row[1:]).astype(np.float32), 0.0) + i += 1 + self.assertEqual(i, 18) + + +if __name__ == "__main__": + unittest.main() From e2ed4514ee12d6addf365e0c7c84ba7367d1f7e3 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Tue, 2 Feb 2021 11:08:49 +0000 Subject: [PATCH 2/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/handlers/__init__.py | 2 +- monai/handlers/metrics_saver.py | 2 +- monai/handlers/utils.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/handlers/__init__.py b/monai/handlers/__init__.py index 1224721465..81c65ed580 100644 --- a/monai/handlers/__init__.py +++ b/monai/handlers/__init__.py @@ -27,9 +27,9 @@ from .tensorboard_handlers import TensorBoardImageHandler, TensorBoardStatsHandler from .utils import ( evenly_divisible_all_gather, - string_list_all_gather, stopping_fn_from_loss, stopping_fn_from_metric, + string_list_all_gather, write_metrics_reports, ) from .validation_handler import ValidationHandler diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index f48674ddcc..9dbcbaa388 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -11,7 +11,7 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union -from monai.handlers.utils import write_metrics_reports, string_list_all_gather +from monai.handlers.utils import string_list_all_gather, write_metrics_reports from monai.utils import ensure_tuple, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") diff --git a/monai/handlers/utils.py b/monai/handlers/utils.py index 31352608e0..2165ad8860 100644 --- a/monai/handlers/utils.py +++ b/monai/handlers/utils.py @@ -11,12 +11,12 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Union, List +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Union import numpy as np import torch -from monai.utils import ensure_tuple, exact_version, optional_import, get_torch_version_tuple +from monai.utils import ensure_tuple, exact_version, get_torch_version_tuple, optional_import idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") if TYPE_CHECKING: @@ -94,7 +94,7 @@ def string_list_all_gather(strings: List[str], delimiter: str = "\t") -> List[st """ if idist.get_world_size() <= 1: return strings - + _joined = delimiter.join(strings) if get_torch_version_tuple() > (1, 6, 0): # all gather across all ranks From e87901ad05982ce1f0e50931ba81dad6e73401be Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Feb 2021 19:27:44 +0800 Subject: [PATCH 3/7] [DLMED] fix min test Signed-off-by: Nic Ma --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 665ead6cc6..0fd6985067 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -103,6 +103,7 @@ def run_testsuit(): "test_handler_metrics_saver", "test_handler_metrics_saver_dist", "test_evenly_divisible_all_gather_dist", + "test_handler_classification_saver_dist", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From 46b33287545beb3c83d0c030e9cd258e5494b913 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Feb 2021 22:32:42 +0800 Subject: [PATCH 4/7] [DLMED] add @SkipIfBeforePyTorchVersion((1, 7)) Signed-off-by: Nic Ma --- tests/test_handler_classification_saver_dist.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_handler_classification_saver_dist.py b/tests/test_handler_classification_saver_dist.py index f00f61d238..275d5b2231 100644 --- a/tests/test_handler_classification_saver_dist.py +++ b/tests/test_handler_classification_saver_dist.py @@ -20,9 +20,10 @@ from ignite.engine import Engine from monai.handlers import ClassificationSaver -from tests.utils import DistCall, DistTestCase +from tests.utils import DistCall, DistTestCase, SkipIfBeforePyTorchVersion +@SkipIfBeforePyTorchVersion((1, 7)) class DistributedHandlerClassificationSaver(DistTestCase): @DistCall(nnodes=1, nproc_per_node=2) def test_saved_content(self): From 3b005bd13e731a90a91ea71e2c9d105a9a85fd57 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 3 Feb 2021 10:25:50 +0800 Subject: [PATCH 5/7] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/csv_saver.py | 3 ++- monai/data/nifti_saver.py | 3 ++- monai/data/png_saver.py | 3 ++- monai/handlers/classification_saver.py | 5 +++-- monai/handlers/metrics_saver.py | 3 ++- monai/transforms/io/array.py | 3 ++- monai/utils/__init__.py | 1 + monai/utils/enums.py | 9 +++++++++ 8 files changed, 23 insertions(+), 7 deletions(-) diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index ec9ec562cd..ee9ab4ef83 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -16,6 +16,7 @@ import numpy as np import torch +from monai.utils import ImageMetaKey as Key class CSVSaver: @@ -73,7 +74,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] meta_data: the meta data information corresponding to the data. """ - save_key = meta_data["filename_or_obj"] if meta_data else str(self._data_index) + save_key = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 if isinstance(data, torch.Tensor): data = data.detach().cpu().numpy() diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index db559f97f4..e699a0ce9b 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -18,6 +18,7 @@ from monai.data.nifti_writer import write_nifti from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils import ImageMetaKey as Key class NiftiSaver: @@ -95,7 +96,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] See Also :py:meth:`monai.data.nifti_writer.write_nifti` """ - filename = meta_data["filename_or_obj"] if meta_data else str(self._data_index) + filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 original_affine = meta_data.get("original_affine", None) if meta_data else None affine = meta_data.get("affine", None) if meta_data else None diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 8ed8b234f4..00d5583cab 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -17,6 +17,7 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import InterpolateMode +from monai.utils import ImageMetaKey as Key class PNGSaver: @@ -82,7 +83,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] :py:meth:`monai.data.png_writer.write_png` """ - filename = meta_data["filename_or_obj"] if meta_data else str(self._data_index) + filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index d7a9f2de78..0e37799bbf 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -15,6 +15,7 @@ from monai.data import CSVSaver from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather from monai.utils import exact_version, optional_import +from monai.utils import ImageMetaKey as Key idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") @@ -86,9 +87,9 @@ def __call__(self, engine: Engine) -> None: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ _meta_data = self.batch_transform(engine.state.batch) - if "filename_or_obj" in _meta_data: + if Key.FILENAME_OR_OBJ in _meta_data: # all gather filenames across ranks - _meta_data["filename_or_obj"] = string_list_all_gather(_meta_data["filename_or_obj"]) + _meta_data[Key.FILENAME_OR_OBJ] = string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ]) # all gather predictions across ranks _engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output)) diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 9dbcbaa388..2ffc5657ce 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -13,6 +13,7 @@ from monai.handlers.utils import string_list_all_gather, write_metrics_reports from monai.utils import ensure_tuple, exact_version, optional_import +from monai.utils import ImageMetaKey as Key Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") @@ -92,7 +93,7 @@ def _started(self, engine: Engine) -> None: def _get_filenames(self, engine: Engine) -> None: if self.metric_details is not None: - _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)["filename_or_obj"])) + _filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ])) self._filenames += _filenames def __call__(self, engine: Engine) -> None: diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 772c7cf74f..c569f85877 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -21,6 +21,7 @@ from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms.compose import Transform from monai.utils import ensure_tuple, optional_import +from monai.utils import ImageMetaKey as Key nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") @@ -126,5 +127,5 @@ def __call__( if self.image_only: return img_array - meta_data["filename_or_obj"] = ensure_tuple(filename)[0] + meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0] return img_array, meta_data diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index e5567f9f16..f7fe76434d 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -19,6 +19,7 @@ ChannelMatching, GridSampleMode, GridSamplePadMode, + ImageMetaKey, InterpolateMode, LossReduction, Method, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index d1d2d3bcce..5a986d4bb5 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -28,6 +28,7 @@ "ChannelMatching", "SkipMode", "Method", + "ImageMetaKey", ] @@ -214,3 +215,11 @@ class Method(Enum): SYMMETRIC = "symmetric" END = "end" + + +class ImageMetaKey(Enum): + """ + Common key names in the meta data header of images + """ + + FILENAME_OR_OBJ = "filename_or_obj" From fc083081dd3cd71d49abd292498b788bd8524db2 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 3 Feb 2021 11:21:47 +0800 Subject: [PATCH 6/7] [DLMED] fix typo Signed-off-by: Nic Ma --- monai/utils/__init__.py | 2 +- monai/utils/enums.py | 9 --------- monai/utils/misc.py | 9 +++++++++ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index f7fe76434d..14a56a56cd 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -19,7 +19,6 @@ ChannelMatching, GridSampleMode, GridSamplePadMode, - ImageMetaKey, InterpolateMode, LossReduction, Method, @@ -42,6 +41,7 @@ fall_back_tuple, first, get_seed, + ImageMetaKey, is_scalar, is_scalar_tensor, issequenceiterable, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 5a986d4bb5..d1d2d3bcce 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -28,7 +28,6 @@ "ChannelMatching", "SkipMode", "Method", - "ImageMetaKey", ] @@ -215,11 +214,3 @@ class Method(Enum): SYMMETRIC = "symmetric" END = "end" - - -class ImageMetaKey(Enum): - """ - Common key names in the meta data header of images - """ - - FILENAME_OR_OBJ = "filename_or_obj" diff --git a/monai/utils/misc.py b/monai/utils/misc.py index c5e8318db3..f9346340cf 100644 --- a/monai/utils/misc.py +++ b/monai/utils/misc.py @@ -41,6 +41,7 @@ "dtype_numpy_to_torch", "MAX_SEED", "copy_to_device", + "ImageMetaKey", ] _seed = None @@ -349,3 +350,11 @@ def copy_to_device( warnings.warn(f"{fn_name} called with incompatible type: " + f"{type(obj)}. Data will be returned unchanged.") return obj + + +class ImageMetaKey: + """ + Common key names in the meta data header of images + """ + + FILENAME_OR_OBJ = "filename_or_obj" From d1f1638635d6e1d674d94039963505fda531de48 Mon Sep 17 00:00:00 2001 From: monai-bot Date: Wed, 3 Feb 2021 03:25:51 +0000 Subject: [PATCH 7/7] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/data/csv_saver.py | 1 + monai/data/png_saver.py | 2 +- monai/handlers/classification_saver.py | 2 +- monai/handlers/metrics_saver.py | 2 +- monai/transforms/io/array.py | 2 +- monai/utils/__init__.py | 2 +- 6 files changed, 6 insertions(+), 5 deletions(-) diff --git a/monai/data/csv_saver.py b/monai/data/csv_saver.py index ee9ab4ef83..830c6a4f0d 100644 --- a/monai/data/csv_saver.py +++ b/monai/data/csv_saver.py @@ -16,6 +16,7 @@ import numpy as np import torch + from monai.utils import ImageMetaKey as Key diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index 00d5583cab..4c4c847824 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -16,8 +16,8 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename -from monai.utils import InterpolateMode from monai.utils import ImageMetaKey as Key +from monai.utils import InterpolateMode class PNGSaver: diff --git a/monai/handlers/classification_saver.py b/monai/handlers/classification_saver.py index 0e37799bbf..a1c76dd338 100644 --- a/monai/handlers/classification_saver.py +++ b/monai/handlers/classification_saver.py @@ -14,8 +14,8 @@ from monai.data import CSVSaver from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather -from monai.utils import exact_version, optional_import from monai.utils import ImageMetaKey as Key +from monai.utils import exact_version, optional_import idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") diff --git a/monai/handlers/metrics_saver.py b/monai/handlers/metrics_saver.py index 2ffc5657ce..b9ea296821 100644 --- a/monai/handlers/metrics_saver.py +++ b/monai/handlers/metrics_saver.py @@ -12,8 +12,8 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Sequence, Union from monai.handlers.utils import string_list_all_gather, write_metrics_reports -from monai.utils import ensure_tuple, exact_version, optional_import from monai.utils import ImageMetaKey as Key +from monai.utils import ensure_tuple, exact_version, optional_import Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events") idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed") diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index c569f85877..f57b2dd27a 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -20,8 +20,8 @@ from monai.config import DtypeLike from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms.compose import Transform -from monai.utils import ensure_tuple, optional_import from monai.utils import ImageMetaKey as Key +from monai.utils import ensure_tuple, optional_import nib, _ = optional_import("nibabel") Image, _ = optional_import("PIL.Image") diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 14a56a56cd..1e17d44029 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -32,6 +32,7 @@ ) from .misc import ( MAX_SEED, + ImageMetaKey, copy_to_device, dtype_numpy_to_torch, dtype_torch_to_numpy, @@ -41,7 +42,6 @@ fall_back_tuple, first, get_seed, - ImageMetaKey, is_scalar, is_scalar_tensor, issequenceiterable,