Skip to content

Commit

Permalink
[DLMED] update according to comments
Browse files Browse the repository at this point in the history
Signed-off-by: Nic Ma <nma@nvidia.com>
  • Loading branch information
Nic-Ma committed Feb 3, 2021
1 parent b813c73 commit 3b005bd
Show file tree
Hide file tree
Showing 8 changed files with 23 additions and 7 deletions.
3 changes: 2 additions & 1 deletion monai/data/csv_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import numpy as np
import torch
from monai.utils import ImageMetaKey as Key


class CSVSaver:
Expand Down Expand Up @@ -73,7 +74,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
meta_data: the meta data information corresponding to the data.
"""
save_key = meta_data["filename_or_obj"] if meta_data else str(self._data_index)
save_key = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()
Expand Down
3 changes: 2 additions & 1 deletion monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monai.data.nifti_writer import write_nifti
from monai.data.utils import create_file_basename
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key


class NiftiSaver:
Expand Down Expand Up @@ -95,7 +96,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
See Also
:py:meth:`monai.data.nifti_writer.write_nifti`
"""
filename = meta_data["filename_or_obj"] if meta_data else str(self._data_index)
filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
original_affine = meta_data.get("original_affine", None) if meta_data else None
affine = meta_data.get("affine", None) if meta_data else None
Expand Down
3 changes: 2 additions & 1 deletion monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from monai.data.png_writer import write_png
from monai.data.utils import create_file_basename
from monai.utils import InterpolateMode
from monai.utils import ImageMetaKey as Key


class PNGSaver:
Expand Down Expand Up @@ -82,7 +83,7 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict]
:py:meth:`monai.data.png_writer.write_png`
"""
filename = meta_data["filename_or_obj"] if meta_data else str(self._data_index)
filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
spatial_shape = meta_data.get("spatial_shape", None) if meta_data and self.resample else None

Expand Down
5 changes: 3 additions & 2 deletions monai/handlers/classification_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from monai.data import CSVSaver
from monai.handlers.utils import evenly_divisible_all_gather, string_list_all_gather
from monai.utils import exact_version, optional_import
from monai.utils import ImageMetaKey as Key

idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
Expand Down Expand Up @@ -86,9 +87,9 @@ def __call__(self, engine: Engine) -> None:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
"""
_meta_data = self.batch_transform(engine.state.batch)
if "filename_or_obj" in _meta_data:
if Key.FILENAME_OR_OBJ in _meta_data:
# all gather filenames across ranks
_meta_data["filename_or_obj"] = string_list_all_gather(_meta_data["filename_or_obj"])
_meta_data[Key.FILENAME_OR_OBJ] = string_list_all_gather(_meta_data[Key.FILENAME_OR_OBJ])
# all gather predictions across ranks
_engine_output = evenly_divisible_all_gather(self.output_transform(engine.state.output))

Expand Down
3 changes: 2 additions & 1 deletion monai/handlers/metrics_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from monai.handlers.utils import string_list_all_gather, write_metrics_reports
from monai.utils import ensure_tuple, exact_version, optional_import
from monai.utils import ImageMetaKey as Key

Events, _ = optional_import("ignite.engine", "0.4.2", exact_version, "Events")
idist, _ = optional_import("ignite", "0.4.2", exact_version, "distributed")
Expand Down Expand Up @@ -92,7 +93,7 @@ def _started(self, engine: Engine) -> None:

def _get_filenames(self, engine: Engine) -> None:
if self.metric_details is not None:
_filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)["filename_or_obj"]))
_filenames = list(ensure_tuple(self.batch_transform(engine.state.batch)[Key.FILENAME_OR_OBJ]))
self._filenames += _filenames

def __call__(self, engine: Engine) -> None:
Expand Down
3 changes: 2 additions & 1 deletion monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from monai.transforms.compose import Transform
from monai.utils import ensure_tuple, optional_import
from monai.utils import ImageMetaKey as Key

nib, _ = optional_import("nibabel")
Image, _ = optional_import("PIL.Image")
Expand Down Expand Up @@ -126,5 +127,5 @@ def __call__(

if self.image_only:
return img_array
meta_data["filename_or_obj"] = ensure_tuple(filename)[0]
meta_data[Key.FILENAME_OR_OBJ] = ensure_tuple(filename)[0]
return img_array, meta_data
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ChannelMatching,
GridSampleMode,
GridSamplePadMode,
ImageMetaKey,
InterpolateMode,
LossReduction,
Method,
Expand Down
9 changes: 9 additions & 0 deletions monai/utils/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"ChannelMatching",
"SkipMode",
"Method",
"ImageMetaKey",
]


Expand Down Expand Up @@ -214,3 +215,11 @@ class Method(Enum):

SYMMETRIC = "symmetric"
END = "end"


class ImageMetaKey(Enum):
"""
Common key names in the meta data header of images
"""

FILENAME_OR_OBJ = "filename_or_obj"

0 comments on commit 3b005bd

Please sign in to comment.